Coverage for gco / services / inference_store.py: 94%

108 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 21:47 +0000

1""" 

2DynamoDB-backed store for inference endpoint state. 

3 

4Provides CRUD operations for inference endpoints. The inference_monitor 

5in each regional cluster polls this table to reconcile desired state 

6with actual Kubernetes resources. 

7""" 

8 

9from __future__ import annotations 

10 

11import logging 

12import os 

13from datetime import UTC, datetime 

14from typing import Any 

15 

16import boto3 

17from botocore.exceptions import ClientError 

18 

19logger = logging.getLogger(__name__) 

20 

21DEFAULT_TABLE_NAME = "gco-inference-endpoints" 

22 

23 

24def _utc_now_iso() -> str: 

25 return datetime.now(UTC).isoformat() 

26 

27 

28class InferenceEndpointStore: 

29 """DynamoDB store for inference endpoint desired state.""" 

30 

31 def __init__(self, table_name: str | None = None, region: str | None = None): 

32 self.table_name = table_name or os.getenv( 

33 "INFERENCE_ENDPOINTS_TABLE_NAME", DEFAULT_TABLE_NAME 

34 ) 

35 self._region = region or os.getenv("DYNAMODB_REGION") or os.getenv("REGION", "us-east-1") 

36 self._dynamodb = boto3.resource("dynamodb", region_name=self._region) 

37 self._table = self._dynamodb.Table(self.table_name) 

38 

39 def create_endpoint( 

40 self, 

41 endpoint_name: str, 

42 spec: dict[str, Any], 

43 target_regions: list[str], 

44 namespace: str = "gco-inference", 

45 labels: dict[str, str] | None = None, 

46 created_by: str | None = None, 

47 ) -> dict[str, Any]: 

48 """Create a new inference endpoint entry.""" 

49 now = _utc_now_iso() 

50 ingress_path = f"/inference/{endpoint_name}" 

51 

52 item: dict[str, Any] = { 

53 "endpoint_name": endpoint_name, 

54 "desired_state": "deploying", 

55 "target_regions": target_regions, 

56 "namespace": namespace, 

57 "spec": _serialize_for_dynamo(spec), 

58 "ingress_path": ingress_path, 

59 "created_at": now, 

60 "updated_at": now, 

61 "region_status": {}, 

62 } 

63 if labels: 

64 item["labels"] = labels 

65 if created_by: 

66 item["created_by"] = created_by 

67 

68 try: 

69 self._table.put_item( 

70 Item=item, 

71 ConditionExpression="attribute_not_exists(endpoint_name)", 

72 ) 

73 except ClientError as e: 

74 if e.response["Error"]["Code"] == "ConditionalCheckFailedException": 

75 raise ValueError(f"Endpoint '{endpoint_name}' already exists") from e 

76 raise 

77 

78 return item 

79 

80 def get_endpoint(self, endpoint_name: str) -> dict[str, Any] | None: 

81 """Get an endpoint by name.""" 

82 response = self._table.get_item(Key={"endpoint_name": endpoint_name}) 

83 item = response.get("Item") 

84 if item: 

85 return _deserialize_from_dynamo(item) 

86 return None 

87 

88 def list_endpoints( 

89 self, 

90 desired_state: str | None = None, 

91 target_region: str | None = None, 

92 ) -> list[dict[str, Any]]: 

93 """List all endpoints, optionally filtered.""" 

94 response = self._table.scan() 

95 items = [_deserialize_from_dynamo(i) for i in response.get("Items", [])] 

96 

97 if desired_state: 

98 items = [i for i in items if i.get("desired_state") == desired_state] 

99 if target_region: 

100 items = [i for i in items if target_region in i.get("target_regions", [])] 

101 

102 return sorted(items, key=lambda x: x.get("created_at", ""), reverse=True) 

103 

104 def update_desired_state(self, endpoint_name: str, desired_state: str) -> dict[str, Any] | None: 

105 """Update the desired state of an endpoint.""" 

106 try: 

107 response = self._table.update_item( 

108 Key={"endpoint_name": endpoint_name}, 

109 UpdateExpression="SET desired_state = :s, updated_at = :u", 

110 ExpressionAttributeValues={ 

111 ":s": desired_state, 

112 ":u": _utc_now_iso(), 

113 }, 

114 ConditionExpression="attribute_exists(endpoint_name)", 

115 ReturnValues="ALL_NEW", 

116 ) 

117 return _deserialize_from_dynamo(response.get("Attributes", {})) 

118 except ClientError as e: 

119 if e.response["Error"]["Code"] == "ConditionalCheckFailedException": 119 ↛ 121line 119 didn't jump to line 121 because the condition on line 119 was always true

120 return None 

121 raise 

122 

123 def update_spec(self, endpoint_name: str, spec: dict[str, Any]) -> dict[str, Any] | None: 

124 """Update the spec of an endpoint (triggers re-reconciliation).""" 

125 try: 

126 response = self._table.update_item( 

127 Key={"endpoint_name": endpoint_name}, 

128 UpdateExpression="SET spec = :s, updated_at = :u, desired_state = :ds", 

129 ExpressionAttributeValues={ 

130 ":s": _serialize_for_dynamo(spec), 

131 ":u": _utc_now_iso(), 

132 ":ds": "deploying", 

133 }, 

134 ConditionExpression="attribute_exists(endpoint_name)", 

135 ReturnValues="ALL_NEW", 

136 ) 

137 return _deserialize_from_dynamo(response.get("Attributes", {})) 

138 except ClientError as e: 

139 if e.response["Error"]["Code"] == "ConditionalCheckFailedException": 139 ↛ 141line 139 didn't jump to line 141 because the condition on line 139 was always true

140 return None 

141 raise 

142 

143 def update_region_status( 

144 self, 

145 endpoint_name: str, 

146 region: str, 

147 state: str, 

148 replicas_ready: int = 0, 

149 replicas_desired: int = 0, 

150 error: str | None = None, 

151 ) -> None: 

152 """Update the sync status for a specific region.""" 

153 status_value: dict[str, Any] = { 

154 "state": state, 

155 "replicas_ready": replicas_ready, 

156 "replicas_desired": replicas_desired, 

157 "last_sync": _utc_now_iso(), 

158 } 

159 if error: 

160 status_value["error"] = error 

161 

162 try: 

163 self._table.update_item( 

164 Key={"endpoint_name": endpoint_name}, 

165 UpdateExpression="SET region_status.#r = :s, updated_at = :u", 

166 ExpressionAttributeNames={"#r": region}, 

167 ExpressionAttributeValues={ 

168 ":s": status_value, 

169 ":u": _utc_now_iso(), 

170 }, 

171 ) 

172 except ClientError as e: 

173 logger.error( 

174 "Failed to update region status for %s/%s: %s", 

175 endpoint_name, 

176 region, 

177 e, 

178 ) 

179 

180 def delete_endpoint(self, endpoint_name: str) -> bool: 

181 """Delete an endpoint record entirely.""" 

182 try: 

183 self._table.delete_item( 

184 Key={"endpoint_name": endpoint_name}, 

185 ConditionExpression="attribute_exists(endpoint_name)", 

186 ) 

187 return True 

188 except ClientError as e: 

189 if e.response["Error"]["Code"] == "ConditionalCheckFailedException": 189 ↛ 191line 189 didn't jump to line 191 because the condition on line 189 was always true

190 return False 

191 raise 

192 

193 def scale_endpoint(self, endpoint_name: str, replicas: int) -> dict[str, Any] | None: 

194 """Update the replica count in the spec.""" 

195 try: 

196 response = self._table.update_item( 

197 Key={"endpoint_name": endpoint_name}, 

198 UpdateExpression="SET spec.replicas = :r, updated_at = :u", 

199 ExpressionAttributeValues={ 

200 ":r": replicas, 

201 ":u": _utc_now_iso(), 

202 }, 

203 ConditionExpression="attribute_exists(endpoint_name)", 

204 ReturnValues="ALL_NEW", 

205 ) 

206 return _deserialize_from_dynamo(response.get("Attributes", {})) 

207 except ClientError as e: 

208 if e.response["Error"]["Code"] == "ConditionalCheckFailedException": 208 ↛ 210line 208 didn't jump to line 210 because the condition on line 208 was always true

209 return None 

210 raise 

211 

212 

213def _serialize_for_dynamo(obj: Any) -> Any: 

214 """Convert Python objects to DynamoDB-compatible types.""" 

215 if isinstance(obj, dict): 

216 return {k: _serialize_for_dynamo(v) for k, v in obj.items()} 

217 if isinstance(obj, list): 

218 return [_serialize_for_dynamo(i) for i in obj] 

219 if isinstance(obj, (int, float)): 

220 return str(obj) if isinstance(obj, float) else obj 

221 return obj 

222 

223 

224def _deserialize_from_dynamo(item: dict[str, Any]) -> dict[str, Any]: 

225 """Convert DynamoDB item back to Python types.""" 

226 from decimal import Decimal 

227 

228 def convert(v: Any) -> Any: 

229 if isinstance(v, Decimal): 

230 return int(v) if v == int(v) else float(v) 

231 if isinstance(v, dict): 

232 return {k: convert(val) for k, val in v.items()} 

233 if isinstance(v, list): 

234 return [convert(i) for i in v] 

235 return v 

236 

237 result: dict[str, Any] = convert(item) 

238 return result 

239 

240 

241def get_inference_endpoint_store() -> InferenceEndpointStore: 

242 """Factory function for InferenceEndpointStore.""" 

243 return InferenceEndpointStore()