Coverage for cli/inference.py: 99%
155 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-15 15:07 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-15 15:07 +0000
1"""
2Inference endpoint management for GCO CLI.
4Provides functionality to deploy, manage, and monitor inference endpoints
5across multi-region EKS clusters via the DynamoDB-backed reconciliation
6pattern (inference_monitor).
7"""
9from __future__ import annotations
11import logging
12from typing import TYPE_CHECKING, Any
14from .aws_client import get_aws_client
15from .config import GCOConfig, get_config
17# <pyflowchart-code-diagram> BEGIN - auto-inserted, do not edit
18# Flowchart(s) generated from this file:
19# * ``InferenceManager.deploy`` -> ``diagrams/code_diagrams/cli/inference.InferenceManager_deploy.html``
20# (PNG: ``diagrams/code_diagrams/cli/inference.InferenceManager_deploy.png``)
21# * ``InferenceManager.canary_deploy`` -> ``diagrams/code_diagrams/cli/inference.InferenceManager_canary_deploy.html``
22# (PNG: ``diagrams/code_diagrams/cli/inference.InferenceManager_canary_deploy.png``)
23# Regenerate with ``python diagrams/code_diagrams/generate.py``.
24# <pyflowchart-code-diagram> END
27if TYPE_CHECKING:
28 from gco.services.inference_store import InferenceEndpointStore
30logger = logging.getLogger(__name__)
33class InferenceManager:
34 """Manages inference endpoints via the DynamoDB store."""
36 def __init__(self, config: GCOConfig | None = None):
37 self.config = config or get_config()
38 self._aws_client = get_aws_client(config)
40 def _get_store(self, region: str | None = None) -> InferenceEndpointStore:
41 """Get an InferenceEndpointStore for the global region."""
42 from gco.services.inference_store import InferenceEndpointStore
44 # Use the global region for DynamoDB (same as job store)
45 store_region = region or self.config.global_region
46 return InferenceEndpointStore(region=store_region)
48 def deploy(
49 self,
50 endpoint_name: str,
51 image: str,
52 target_regions: list[str] | None = None,
53 replicas: int = 1,
54 gpu_count: int = 1,
55 gpu_type: str | None = None,
56 port: int = 8000,
57 model_path: str | None = None,
58 model_source: str | None = None,
59 health_check_path: str = "/health",
60 env: dict[str, str] | None = None,
61 namespace: str = "gco-inference",
62 labels: dict[str, str] | None = None,
63 autoscaling: dict[str, Any] | None = None,
64 capacity_type: str | None = None,
65 extra_args: list[str] | None = None,
66 accelerator: str = "nvidia",
67 node_selector: dict[str, str] | None = None,
68 rewrite_image: bool = True,
69 ) -> dict[str, Any]:
70 """
71 Deploy an inference endpoint to one or more regions.
73 The endpoint spec is written to DynamoDB. The inference_monitor
74 in each target region picks it up and creates the K8s resources.
76 Args:
77 endpoint_name: Unique name for the endpoint
78 image: Container image (e.g. vllm/vllm-openai:v0.8.0)
79 target_regions: Regions to deploy to (default: all deployed regions)
80 replicas: Number of replicas per region
81 gpu_count: GPUs per replica
82 gpu_type: GPU instance type hint for node selector
83 port: Container port
84 model_path: EFS path for model weights
85 health_check_path: Health check endpoint path
86 env: Environment variables
87 namespace: Kubernetes namespace
88 labels: Labels for the endpoint
89 rewrite_image: When True (the default), rewrite ECR URIs in
90 ``image`` to target each region's local replica. Non-ECR
91 refs (Docker Hub, GHCR, etc.) are left unchanged. When
92 False, the URI is written verbatim to every region's
93 spec — the operator is responsible for cross-region
94 pulls. Per-region rewrites are stored under a
95 ``region_overrides`` map on the spec keyed by region.
97 Returns:
98 Created endpoint record
99 """
100 if not target_regions:
101 stacks = self._aws_client.discover_regional_stacks()
102 target_regions = list(stacks.keys())
103 if not target_regions:
104 raise ValueError("No deployed regions found. Deploy infrastructure first.")
106 # Per-region image-URI rewrites for ECR refs. Each target region
107 # gets the local replica's URI on its own spec, so the
108 # inference_monitor's pod-spec materialiser pulls in-region
109 # rather than across the WAN. Non-ECR URIs come back unchanged
110 # from the helper, so this is a no-op for Docker Hub / GHCR refs.
111 #
112 # The helper lives in ``cli._image_uri`` rather than ``cli.images``
113 # so this import doesn't create a module-level cycle:
114 # ``cli.images`` itself imports the same helper. ``cli._image_uri``
115 # is a leaf module with no project-side dependencies.
116 region_image_map: dict[str, str] = {}
117 if rewrite_image:
118 from ._image_uri import rewrite_image_uri_for_region
120 for region in target_regions:
121 region_image_map[region] = rewrite_image_uri_for_region(image, region)
123 spec = {
124 "image": image,
125 "port": port,
126 "replicas": replicas,
127 "gpu_count": gpu_count,
128 "health_check_path": health_check_path,
129 }
130 # Preserve the rewrite map on the spec so the inference_monitor
131 # service can pick the right URI per region when materialising
132 # pods. When ``rewrite_image=False`` no map is set and the flat
133 # ``image`` field is the only source.
134 if region_image_map and any(uri != image for uri in region_image_map.values()):
135 spec["region_image_uris"] = region_image_map
136 if gpu_type:
137 spec["gpu_type"] = gpu_type
138 if model_path:
139 spec["model_path"] = model_path
140 if model_source:
141 spec["model_source"] = model_source
142 if env:
143 spec["env"] = env
144 if autoscaling:
145 spec["autoscaling"] = autoscaling
146 if capacity_type:
147 spec["capacity_type"] = capacity_type
148 if extra_args:
149 spec["args"] = extra_args
150 if accelerator != "nvidia":
151 spec["accelerator"] = accelerator
152 if node_selector:
153 spec["node_selector"] = node_selector
155 store = self._get_store()
156 result: dict[str, Any] = store.create_endpoint(
157 endpoint_name=endpoint_name,
158 spec=spec,
159 target_regions=target_regions,
160 namespace=namespace,
161 labels=labels,
162 )
163 return result
165 def list_endpoints(
166 self,
167 desired_state: str | None = None,
168 region: str | None = None,
169 ) -> list[dict[str, Any]]:
170 """List all inference endpoints."""
171 store = self._get_store()
172 result: list[dict[str, Any]] = store.list_endpoints(
173 desired_state=desired_state,
174 target_region=region,
175 )
176 return result
178 def get_endpoint(self, endpoint_name: str) -> dict[str, Any] | None:
179 """Get details of a specific endpoint."""
180 store = self._get_store()
181 result: dict[str, Any] | None = store.get_endpoint(endpoint_name)
182 return result
184 def scale(self, endpoint_name: str, replicas: int) -> dict[str, Any] | None:
185 """Scale an endpoint to a new replica count."""
186 store = self._get_store()
187 result: dict[str, Any] | None = store.scale_endpoint(endpoint_name, replicas)
188 return result
190 def stop(self, endpoint_name: str) -> dict[str, Any] | None:
191 """Stop an endpoint (scale to zero, keep resources)."""
192 store = self._get_store()
193 result: dict[str, Any] | None = store.update_desired_state(endpoint_name, "stopped")
194 return result
196 def start(self, endpoint_name: str) -> dict[str, Any] | None:
197 """Start a stopped endpoint."""
198 store = self._get_store()
199 result: dict[str, Any] | None = store.update_desired_state(endpoint_name, "running")
200 return result
202 def delete(self, endpoint_name: str) -> dict[str, Any] | None:
203 """Mark an endpoint for deletion (inference_monitor cleans up)."""
204 store = self._get_store()
205 result: dict[str, Any] | None = store.update_desired_state(endpoint_name, "deleted")
206 return result
208 def update_image(self, endpoint_name: str, image: str) -> dict[str, Any] | None:
209 """Update the container image for an endpoint."""
210 store = self._get_store()
211 endpoint = store.get_endpoint(endpoint_name)
212 if not endpoint:
213 return None
214 spec = endpoint.get("spec", {})
215 spec["image"] = image
216 result: dict[str, Any] | None = store.update_spec(endpoint_name, spec)
217 return result
219 def add_region(self, endpoint_name: str, region: str) -> dict[str, Any] | None:
220 """Add a region to an existing endpoint."""
221 from datetime import UTC, datetime
223 store = self._get_store()
224 endpoint = store.get_endpoint(endpoint_name)
225 if not endpoint:
226 return None
227 regions = endpoint.get("target_regions", [])
228 if region not in regions:
229 regions.append(region)
230 # Update via raw DynamoDB update
231 try:
232 response = store._table.update_item(
233 Key={"endpoint_name": endpoint_name},
234 UpdateExpression="SET target_regions = :r, updated_at = :u",
235 ExpressionAttributeValues={
236 ":r": regions,
237 ":u": datetime.now(UTC).isoformat(),
238 },
239 ReturnValues="ALL_NEW",
240 )
241 result: dict[str, Any] | None = response.get("Attributes")
242 return result
243 except Exception as e:
244 logger.error("Failed to add region: %s", e)
245 return None
247 def remove_region(self, endpoint_name: str, region: str) -> dict[str, Any] | None:
248 """Remove a region from an existing endpoint."""
249 store = self._get_store()
250 endpoint = store.get_endpoint(endpoint_name)
251 if not endpoint:
252 return None
253 regions = endpoint.get("target_regions", [])
254 if region in regions:
255 regions.remove(region)
256 try:
257 from datetime import UTC, datetime
259 response = store._table.update_item(
260 Key={"endpoint_name": endpoint_name},
261 UpdateExpression="SET target_regions = :r, updated_at = :u",
262 ExpressionAttributeValues={
263 ":r": regions,
264 ":u": datetime.now(UTC).isoformat(),
265 },
266 ReturnValues="ALL_NEW",
267 )
268 result: dict[str, Any] | None = response.get("Attributes")
269 return result
270 except Exception as e:
271 logger.error("Failed to remove region: %s", e)
272 return None
274 def canary_deploy(
275 self,
276 endpoint_name: str,
277 image: str,
278 weight: int = 10,
279 replicas: int = 1,
280 ) -> dict[str, Any] | None:
281 """Start a canary deployment for an existing endpoint.
283 Creates a canary variant with the new image receiving `weight`%
284 of traffic. The primary deployment continues serving the rest.
286 Args:
287 endpoint_name: Existing endpoint to canary
288 image: New container image for the canary
289 weight: Percentage of traffic to route to canary (1-99)
290 replicas: Number of canary replicas
292 Returns:
293 Updated endpoint record, or None if endpoint not found
294 """
295 if not 1 <= weight <= 99:
296 raise ValueError("Canary weight must be between 1 and 99")
298 store = self._get_store()
299 endpoint = store.get_endpoint(endpoint_name)
300 if not endpoint:
301 return None
303 if endpoint.get("desired_state") not in ("running", "deploying"):
304 raise ValueError(
305 f"Cannot canary an endpoint in '{endpoint.get('desired_state')}' state. "
306 "Endpoint must be running."
307 )
309 spec = endpoint.get("spec", {})
310 spec["canary"] = {
311 "image": image,
312 "weight": weight,
313 "replicas": replicas,
314 }
316 result: dict[str, Any] | None = store.update_spec(endpoint_name, spec)
317 return result
319 def promote_canary(self, endpoint_name: str) -> dict[str, Any] | None:
320 """Promote the canary to primary, removing the canary deployment.
322 The primary image is replaced with the canary image, and the
323 canary config is removed. All traffic goes to the new image.
325 Returns:
326 Updated endpoint record, or None if endpoint not found
327 """
328 store = self._get_store()
329 endpoint = store.get_endpoint(endpoint_name)
330 if not endpoint:
331 return None
333 spec = endpoint.get("spec", {})
334 canary = spec.get("canary")
335 if not canary:
336 raise ValueError(f"Endpoint '{endpoint_name}' has no active canary deployment")
338 if "image" not in canary:
339 raise ValueError(
340 f"Canary deployment for '{endpoint_name}' is missing the 'image' field"
341 )
343 # Swap primary image to canary image
344 spec["image"] = canary["image"]
345 # Remove canary config
346 del spec["canary"]
348 result: dict[str, Any] | None = store.update_spec(endpoint_name, spec)
349 return result
351 def rollback_canary(self, endpoint_name: str) -> dict[str, Any] | None:
352 """Remove the canary deployment, keeping the primary unchanged.
354 All traffic returns to the primary deployment.
356 Returns:
357 Updated endpoint record, or None if endpoint not found
358 """
359 store = self._get_store()
360 endpoint = store.get_endpoint(endpoint_name)
361 if not endpoint:
362 return None
364 spec = endpoint.get("spec", {})
365 if "canary" not in spec:
366 raise ValueError(f"Endpoint '{endpoint_name}' has no active canary deployment")
368 del spec["canary"]
369 result: dict[str, Any] | None = store.update_spec(endpoint_name, spec)
370 return result
373def get_inference_manager(config: GCOConfig | None = None) -> InferenceManager:
374 """Factory function for InferenceManager."""
375 return InferenceManager(config)