Coverage for cli / inference.py: 98%
148 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 21:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 21:47 +0000
1"""
2Inference endpoint management for GCO CLI.
4Provides functionality to deploy, manage, and monitor inference endpoints
5across multi-region EKS clusters via the DynamoDB-backed reconciliation
6pattern (inference_monitor).
7"""
9from __future__ import annotations
11import logging
12from typing import TYPE_CHECKING, Any
14from .aws_client import get_aws_client
15from .config import GCOConfig, get_config
17if TYPE_CHECKING:
18 from gco.services.inference_store import InferenceEndpointStore
20logger = logging.getLogger(__name__)
23class InferenceManager:
24 """Manages inference endpoints via the DynamoDB store."""
26 def __init__(self, config: GCOConfig | None = None):
27 self.config = config or get_config()
28 self._aws_client = get_aws_client(config)
30 def _get_store(self, region: str | None = None) -> InferenceEndpointStore:
31 """Get an InferenceEndpointStore for the global region."""
32 from gco.services.inference_store import InferenceEndpointStore
34 # Use the global region for DynamoDB (same as job store)
35 store_region = region or self.config.global_region
36 return InferenceEndpointStore(region=store_region)
38 def deploy(
39 self,
40 endpoint_name: str,
41 image: str,
42 target_regions: list[str] | None = None,
43 replicas: int = 1,
44 gpu_count: int = 1,
45 gpu_type: str | None = None,
46 port: int = 8000,
47 model_path: str | None = None,
48 model_source: str | None = None,
49 health_check_path: str = "/health",
50 env: dict[str, str] | None = None,
51 namespace: str = "gco-inference",
52 labels: dict[str, str] | None = None,
53 autoscaling: dict[str, Any] | None = None,
54 capacity_type: str | None = None,
55 extra_args: list[str] | None = None,
56 accelerator: str = "nvidia",
57 node_selector: dict[str, str] | None = None,
58 ) -> dict[str, Any]:
59 """
60 Deploy an inference endpoint to one or more regions.
62 The endpoint spec is written to DynamoDB. The inference_monitor
63 in each target region picks it up and creates the K8s resources.
65 Args:
66 endpoint_name: Unique name for the endpoint
67 image: Container image (e.g. vllm/vllm-openai:v0.8.0)
68 target_regions: Regions to deploy to (default: all deployed regions)
69 replicas: Number of replicas per region
70 gpu_count: GPUs per replica
71 gpu_type: GPU instance type hint for node selector
72 port: Container port
73 model_path: EFS path for model weights
74 health_check_path: Health check endpoint path
75 env: Environment variables
76 namespace: Kubernetes namespace
77 labels: Labels for the endpoint
79 Returns:
80 Created endpoint record
81 """
82 if not target_regions:
83 stacks = self._aws_client.discover_regional_stacks()
84 target_regions = list(stacks.keys())
85 if not target_regions:
86 raise ValueError("No deployed regions found. Deploy infrastructure first.")
88 spec = {
89 "image": image,
90 "port": port,
91 "replicas": replicas,
92 "gpu_count": gpu_count,
93 "health_check_path": health_check_path,
94 }
95 if gpu_type:
96 spec["gpu_type"] = gpu_type
97 if model_path:
98 spec["model_path"] = model_path
99 if model_source:
100 spec["model_source"] = model_source
101 if env:
102 spec["env"] = env
103 if autoscaling:
104 spec["autoscaling"] = autoscaling
105 if capacity_type:
106 spec["capacity_type"] = capacity_type
107 if extra_args:
108 spec["args"] = extra_args
109 if accelerator != "nvidia":
110 spec["accelerator"] = accelerator
111 if node_selector:
112 spec["node_selector"] = node_selector
114 store = self._get_store()
115 result: dict[str, Any] = store.create_endpoint(
116 endpoint_name=endpoint_name,
117 spec=spec,
118 target_regions=target_regions,
119 namespace=namespace,
120 labels=labels,
121 )
122 return result
124 def list_endpoints(
125 self,
126 desired_state: str | None = None,
127 region: str | None = None,
128 ) -> list[dict[str, Any]]:
129 """List all inference endpoints."""
130 store = self._get_store()
131 result: list[dict[str, Any]] = store.list_endpoints(
132 desired_state=desired_state,
133 target_region=region,
134 )
135 return result
137 def get_endpoint(self, endpoint_name: str) -> dict[str, Any] | None:
138 """Get details of a specific endpoint."""
139 store = self._get_store()
140 result: dict[str, Any] | None = store.get_endpoint(endpoint_name)
141 return result
143 def scale(self, endpoint_name: str, replicas: int) -> dict[str, Any] | None:
144 """Scale an endpoint to a new replica count."""
145 store = self._get_store()
146 result: dict[str, Any] | None = store.scale_endpoint(endpoint_name, replicas)
147 return result
149 def stop(self, endpoint_name: str) -> dict[str, Any] | None:
150 """Stop an endpoint (scale to zero, keep resources)."""
151 store = self._get_store()
152 result: dict[str, Any] | None = store.update_desired_state(endpoint_name, "stopped")
153 return result
155 def start(self, endpoint_name: str) -> dict[str, Any] | None:
156 """Start a stopped endpoint."""
157 store = self._get_store()
158 result: dict[str, Any] | None = store.update_desired_state(endpoint_name, "running")
159 return result
161 def delete(self, endpoint_name: str) -> dict[str, Any] | None:
162 """Mark an endpoint for deletion (inference_monitor cleans up)."""
163 store = self._get_store()
164 result: dict[str, Any] | None = store.update_desired_state(endpoint_name, "deleted")
165 return result
167 def update_image(self, endpoint_name: str, image: str) -> dict[str, Any] | None:
168 """Update the container image for an endpoint."""
169 store = self._get_store()
170 endpoint = store.get_endpoint(endpoint_name)
171 if not endpoint:
172 return None
173 spec = endpoint.get("spec", {})
174 spec["image"] = image
175 result: dict[str, Any] | None = store.update_spec(endpoint_name, spec)
176 return result
178 def add_region(self, endpoint_name: str, region: str) -> dict[str, Any] | None:
179 """Add a region to an existing endpoint."""
180 from datetime import UTC, datetime
182 store = self._get_store()
183 endpoint = store.get_endpoint(endpoint_name)
184 if not endpoint:
185 return None
186 regions = endpoint.get("target_regions", [])
187 if region not in regions:
188 regions.append(region)
189 # Update via raw DynamoDB update
190 try:
191 response = store._table.update_item(
192 Key={"endpoint_name": endpoint_name},
193 UpdateExpression="SET target_regions = :r, updated_at = :u",
194 ExpressionAttributeValues={
195 ":r": regions,
196 ":u": datetime.now(UTC).isoformat(),
197 },
198 ReturnValues="ALL_NEW",
199 )
200 result: dict[str, Any] | None = response.get("Attributes")
201 return result
202 except Exception as e:
203 logger.error("Failed to add region: %s", e)
204 return None
206 def remove_region(self, endpoint_name: str, region: str) -> dict[str, Any] | None:
207 """Remove a region from an existing endpoint."""
208 store = self._get_store()
209 endpoint = store.get_endpoint(endpoint_name)
210 if not endpoint:
211 return None
212 regions = endpoint.get("target_regions", [])
213 if region in regions:
214 regions.remove(region)
215 try:
216 from datetime import UTC, datetime
218 response = store._table.update_item(
219 Key={"endpoint_name": endpoint_name},
220 UpdateExpression="SET target_regions = :r, updated_at = :u",
221 ExpressionAttributeValues={
222 ":r": regions,
223 ":u": datetime.now(UTC).isoformat(),
224 },
225 ReturnValues="ALL_NEW",
226 )
227 result: dict[str, Any] | None = response.get("Attributes")
228 return result
229 except Exception as e:
230 logger.error("Failed to remove region: %s", e)
231 return None
233 def canary_deploy(
234 self,
235 endpoint_name: str,
236 image: str,
237 weight: int = 10,
238 replicas: int = 1,
239 ) -> dict[str, Any] | None:
240 """Start a canary deployment for an existing endpoint.
242 Creates a canary variant with the new image receiving `weight`%
243 of traffic. The primary deployment continues serving the rest.
245 Args:
246 endpoint_name: Existing endpoint to canary
247 image: New container image for the canary
248 weight: Percentage of traffic to route to canary (1-99)
249 replicas: Number of canary replicas
251 Returns:
252 Updated endpoint record, or None if endpoint not found
253 """
254 if not 1 <= weight <= 99:
255 raise ValueError("Canary weight must be between 1 and 99")
257 store = self._get_store()
258 endpoint = store.get_endpoint(endpoint_name)
259 if not endpoint:
260 return None
262 if endpoint.get("desired_state") not in ("running", "deploying"):
263 raise ValueError(
264 f"Cannot canary an endpoint in '{endpoint.get('desired_state')}' state. "
265 "Endpoint must be running."
266 )
268 spec = endpoint.get("spec", {})
269 spec["canary"] = {
270 "image": image,
271 "weight": weight,
272 "replicas": replicas,
273 }
275 result: dict[str, Any] | None = store.update_spec(endpoint_name, spec)
276 return result
278 def promote_canary(self, endpoint_name: str) -> dict[str, Any] | None:
279 """Promote the canary to primary, removing the canary deployment.
281 The primary image is replaced with the canary image, and the
282 canary config is removed. All traffic goes to the new image.
284 Returns:
285 Updated endpoint record, or None if endpoint not found
286 """
287 store = self._get_store()
288 endpoint = store.get_endpoint(endpoint_name)
289 if not endpoint:
290 return None
292 spec = endpoint.get("spec", {})
293 canary = spec.get("canary")
294 if not canary:
295 raise ValueError(f"Endpoint '{endpoint_name}' has no active canary deployment")
297 if "image" not in canary:
298 raise ValueError(
299 f"Canary deployment for '{endpoint_name}' is missing the 'image' field"
300 )
302 # Swap primary image to canary image
303 spec["image"] = canary["image"]
304 # Remove canary config
305 del spec["canary"]
307 result: dict[str, Any] | None = store.update_spec(endpoint_name, spec)
308 return result
310 def rollback_canary(self, endpoint_name: str) -> dict[str, Any] | None:
311 """Remove the canary deployment, keeping the primary unchanged.
313 All traffic returns to the primary deployment.
315 Returns:
316 Updated endpoint record, or None if endpoint not found
317 """
318 store = self._get_store()
319 endpoint = store.get_endpoint(endpoint_name)
320 if not endpoint:
321 return None
323 spec = endpoint.get("spec", {})
324 if "canary" not in spec:
325 raise ValueError(f"Endpoint '{endpoint_name}' has no active canary deployment")
327 del spec["canary"]
328 result: dict[str, Any] | None = store.update_spec(endpoint_name, spec)
329 return result
332def get_inference_manager(config: GCOConfig | None = None) -> InferenceManager:
333 """Factory function for InferenceManager."""
334 return InferenceManager(config)