Coverage for gco / services / inference_store.py: 94%
108 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 21:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 21:47 +0000
1"""
2DynamoDB-backed store for inference endpoint state.
4Provides CRUD operations for inference endpoints. The inference_monitor
5in each regional cluster polls this table to reconcile desired state
6with actual Kubernetes resources.
7"""
9from __future__ import annotations
11import logging
12import os
13from datetime import UTC, datetime
14from typing import Any
16import boto3
17from botocore.exceptions import ClientError
19logger = logging.getLogger(__name__)
21DEFAULT_TABLE_NAME = "gco-inference-endpoints"
24def _utc_now_iso() -> str:
25 return datetime.now(UTC).isoformat()
28class InferenceEndpointStore:
29 """DynamoDB store for inference endpoint desired state."""
31 def __init__(self, table_name: str | None = None, region: str | None = None):
32 self.table_name = table_name or os.getenv(
33 "INFERENCE_ENDPOINTS_TABLE_NAME", DEFAULT_TABLE_NAME
34 )
35 self._region = region or os.getenv("DYNAMODB_REGION") or os.getenv("REGION", "us-east-1")
36 self._dynamodb = boto3.resource("dynamodb", region_name=self._region)
37 self._table = self._dynamodb.Table(self.table_name)
39 def create_endpoint(
40 self,
41 endpoint_name: str,
42 spec: dict[str, Any],
43 target_regions: list[str],
44 namespace: str = "gco-inference",
45 labels: dict[str, str] | None = None,
46 created_by: str | None = None,
47 ) -> dict[str, Any]:
48 """Create a new inference endpoint entry."""
49 now = _utc_now_iso()
50 ingress_path = f"/inference/{endpoint_name}"
52 item: dict[str, Any] = {
53 "endpoint_name": endpoint_name,
54 "desired_state": "deploying",
55 "target_regions": target_regions,
56 "namespace": namespace,
57 "spec": _serialize_for_dynamo(spec),
58 "ingress_path": ingress_path,
59 "created_at": now,
60 "updated_at": now,
61 "region_status": {},
62 }
63 if labels:
64 item["labels"] = labels
65 if created_by:
66 item["created_by"] = created_by
68 try:
69 self._table.put_item(
70 Item=item,
71 ConditionExpression="attribute_not_exists(endpoint_name)",
72 )
73 except ClientError as e:
74 if e.response["Error"]["Code"] == "ConditionalCheckFailedException":
75 raise ValueError(f"Endpoint '{endpoint_name}' already exists") from e
76 raise
78 return item
80 def get_endpoint(self, endpoint_name: str) -> dict[str, Any] | None:
81 """Get an endpoint by name."""
82 response = self._table.get_item(Key={"endpoint_name": endpoint_name})
83 item = response.get("Item")
84 if item:
85 return _deserialize_from_dynamo(item)
86 return None
88 def list_endpoints(
89 self,
90 desired_state: str | None = None,
91 target_region: str | None = None,
92 ) -> list[dict[str, Any]]:
93 """List all endpoints, optionally filtered."""
94 response = self._table.scan()
95 items = [_deserialize_from_dynamo(i) for i in response.get("Items", [])]
97 if desired_state:
98 items = [i for i in items if i.get("desired_state") == desired_state]
99 if target_region:
100 items = [i for i in items if target_region in i.get("target_regions", [])]
102 return sorted(items, key=lambda x: x.get("created_at", ""), reverse=True)
104 def update_desired_state(self, endpoint_name: str, desired_state: str) -> dict[str, Any] | None:
105 """Update the desired state of an endpoint."""
106 try:
107 response = self._table.update_item(
108 Key={"endpoint_name": endpoint_name},
109 UpdateExpression="SET desired_state = :s, updated_at = :u",
110 ExpressionAttributeValues={
111 ":s": desired_state,
112 ":u": _utc_now_iso(),
113 },
114 ConditionExpression="attribute_exists(endpoint_name)",
115 ReturnValues="ALL_NEW",
116 )
117 return _deserialize_from_dynamo(response.get("Attributes", {}))
118 except ClientError as e:
119 if e.response["Error"]["Code"] == "ConditionalCheckFailedException": 119 ↛ 121line 119 didn't jump to line 121 because the condition on line 119 was always true
120 return None
121 raise
123 def update_spec(self, endpoint_name: str, spec: dict[str, Any]) -> dict[str, Any] | None:
124 """Update the spec of an endpoint (triggers re-reconciliation)."""
125 try:
126 response = self._table.update_item(
127 Key={"endpoint_name": endpoint_name},
128 UpdateExpression="SET spec = :s, updated_at = :u, desired_state = :ds",
129 ExpressionAttributeValues={
130 ":s": _serialize_for_dynamo(spec),
131 ":u": _utc_now_iso(),
132 ":ds": "deploying",
133 },
134 ConditionExpression="attribute_exists(endpoint_name)",
135 ReturnValues="ALL_NEW",
136 )
137 return _deserialize_from_dynamo(response.get("Attributes", {}))
138 except ClientError as e:
139 if e.response["Error"]["Code"] == "ConditionalCheckFailedException": 139 ↛ 141line 139 didn't jump to line 141 because the condition on line 139 was always true
140 return None
141 raise
143 def update_region_status(
144 self,
145 endpoint_name: str,
146 region: str,
147 state: str,
148 replicas_ready: int = 0,
149 replicas_desired: int = 0,
150 error: str | None = None,
151 ) -> None:
152 """Update the sync status for a specific region."""
153 status_value: dict[str, Any] = {
154 "state": state,
155 "replicas_ready": replicas_ready,
156 "replicas_desired": replicas_desired,
157 "last_sync": _utc_now_iso(),
158 }
159 if error:
160 status_value["error"] = error
162 try:
163 self._table.update_item(
164 Key={"endpoint_name": endpoint_name},
165 UpdateExpression="SET region_status.#r = :s, updated_at = :u",
166 ExpressionAttributeNames={"#r": region},
167 ExpressionAttributeValues={
168 ":s": status_value,
169 ":u": _utc_now_iso(),
170 },
171 )
172 except ClientError as e:
173 logger.error(
174 "Failed to update region status for %s/%s: %s",
175 endpoint_name,
176 region,
177 e,
178 )
180 def delete_endpoint(self, endpoint_name: str) -> bool:
181 """Delete an endpoint record entirely."""
182 try:
183 self._table.delete_item(
184 Key={"endpoint_name": endpoint_name},
185 ConditionExpression="attribute_exists(endpoint_name)",
186 )
187 return True
188 except ClientError as e:
189 if e.response["Error"]["Code"] == "ConditionalCheckFailedException": 189 ↛ 191line 189 didn't jump to line 191 because the condition on line 189 was always true
190 return False
191 raise
193 def scale_endpoint(self, endpoint_name: str, replicas: int) -> dict[str, Any] | None:
194 """Update the replica count in the spec."""
195 try:
196 response = self._table.update_item(
197 Key={"endpoint_name": endpoint_name},
198 UpdateExpression="SET spec.replicas = :r, updated_at = :u",
199 ExpressionAttributeValues={
200 ":r": replicas,
201 ":u": _utc_now_iso(),
202 },
203 ConditionExpression="attribute_exists(endpoint_name)",
204 ReturnValues="ALL_NEW",
205 )
206 return _deserialize_from_dynamo(response.get("Attributes", {}))
207 except ClientError as e:
208 if e.response["Error"]["Code"] == "ConditionalCheckFailedException": 208 ↛ 210line 208 didn't jump to line 210 because the condition on line 208 was always true
209 return None
210 raise
213def _serialize_for_dynamo(obj: Any) -> Any:
214 """Convert Python objects to DynamoDB-compatible types."""
215 if isinstance(obj, dict):
216 return {k: _serialize_for_dynamo(v) for k, v in obj.items()}
217 if isinstance(obj, list):
218 return [_serialize_for_dynamo(i) for i in obj]
219 if isinstance(obj, (int, float)):
220 return str(obj) if isinstance(obj, float) else obj
221 return obj
224def _deserialize_from_dynamo(item: dict[str, Any]) -> dict[str, Any]:
225 """Convert DynamoDB item back to Python types."""
226 from decimal import Decimal
228 def convert(v: Any) -> Any:
229 if isinstance(v, Decimal):
230 return int(v) if v == int(v) else float(v)
231 if isinstance(v, dict):
232 return {k: convert(val) for k, val in v.items()}
233 if isinstance(v, list):
234 return [convert(i) for i in v]
235 return v
237 result: dict[str, Any] = convert(item)
238 return result
241def get_inference_endpoint_store() -> InferenceEndpointStore:
242 """Factory function for InferenceEndpointStore."""
243 return InferenceEndpointStore()