Coverage for cli/inference.py: 99%

155 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-15 15:07 +0000

1""" 

2Inference endpoint management for GCO CLI. 

3 

4Provides functionality to deploy, manage, and monitor inference endpoints 

5across multi-region EKS clusters via the DynamoDB-backed reconciliation 

6pattern (inference_monitor). 

7""" 

8 

9from __future__ import annotations 

10 

11import logging 

12from typing import TYPE_CHECKING, Any 

13 

14from .aws_client import get_aws_client 

15from .config import GCOConfig, get_config 

16 

17# <pyflowchart-code-diagram> BEGIN - auto-inserted, do not edit 

18# Flowchart(s) generated from this file: 

19# * ``InferenceManager.deploy`` -> ``diagrams/code_diagrams/cli/inference.InferenceManager_deploy.html`` 

20# (PNG: ``diagrams/code_diagrams/cli/inference.InferenceManager_deploy.png``) 

21# * ``InferenceManager.canary_deploy`` -> ``diagrams/code_diagrams/cli/inference.InferenceManager_canary_deploy.html`` 

22# (PNG: ``diagrams/code_diagrams/cli/inference.InferenceManager_canary_deploy.png``) 

23# Regenerate with ``python diagrams/code_diagrams/generate.py``. 

24# <pyflowchart-code-diagram> END 

25 

26 

27if TYPE_CHECKING: 

28 from gco.services.inference_store import InferenceEndpointStore 

29 

30logger = logging.getLogger(__name__) 

31 

32 

33class InferenceManager: 

34 """Manages inference endpoints via the DynamoDB store.""" 

35 

36 def __init__(self, config: GCOConfig | None = None): 

37 self.config = config or get_config() 

38 self._aws_client = get_aws_client(config) 

39 

40 def _get_store(self, region: str | None = None) -> InferenceEndpointStore: 

41 """Get an InferenceEndpointStore for the global region.""" 

42 from gco.services.inference_store import InferenceEndpointStore 

43 

44 # Use the global region for DynamoDB (same as job store) 

45 store_region = region or self.config.global_region 

46 return InferenceEndpointStore(region=store_region) 

47 

48 def deploy( 

49 self, 

50 endpoint_name: str, 

51 image: str, 

52 target_regions: list[str] | None = None, 

53 replicas: int = 1, 

54 gpu_count: int = 1, 

55 gpu_type: str | None = None, 

56 port: int = 8000, 

57 model_path: str | None = None, 

58 model_source: str | None = None, 

59 health_check_path: str = "/health", 

60 env: dict[str, str] | None = None, 

61 namespace: str = "gco-inference", 

62 labels: dict[str, str] | None = None, 

63 autoscaling: dict[str, Any] | None = None, 

64 capacity_type: str | None = None, 

65 extra_args: list[str] | None = None, 

66 accelerator: str = "nvidia", 

67 node_selector: dict[str, str] | None = None, 

68 rewrite_image: bool = True, 

69 ) -> dict[str, Any]: 

70 """ 

71 Deploy an inference endpoint to one or more regions. 

72 

73 The endpoint spec is written to DynamoDB. The inference_monitor 

74 in each target region picks it up and creates the K8s resources. 

75 

76 Args: 

77 endpoint_name: Unique name for the endpoint 

78 image: Container image (e.g. vllm/vllm-openai:v0.8.0) 

79 target_regions: Regions to deploy to (default: all deployed regions) 

80 replicas: Number of replicas per region 

81 gpu_count: GPUs per replica 

82 gpu_type: GPU instance type hint for node selector 

83 port: Container port 

84 model_path: EFS path for model weights 

85 health_check_path: Health check endpoint path 

86 env: Environment variables 

87 namespace: Kubernetes namespace 

88 labels: Labels for the endpoint 

89 rewrite_image: When True (the default), rewrite ECR URIs in 

90 ``image`` to target each region's local replica. Non-ECR 

91 refs (Docker Hub, GHCR, etc.) are left unchanged. When 

92 False, the URI is written verbatim to every region's 

93 spec — the operator is responsible for cross-region 

94 pulls. Per-region rewrites are stored under a 

95 ``region_overrides`` map on the spec keyed by region. 

96 

97 Returns: 

98 Created endpoint record 

99 """ 

100 if not target_regions: 

101 stacks = self._aws_client.discover_regional_stacks() 

102 target_regions = list(stacks.keys()) 

103 if not target_regions: 

104 raise ValueError("No deployed regions found. Deploy infrastructure first.") 

105 

106 # Per-region image-URI rewrites for ECR refs. Each target region 

107 # gets the local replica's URI on its own spec, so the 

108 # inference_monitor's pod-spec materialiser pulls in-region 

109 # rather than across the WAN. Non-ECR URIs come back unchanged 

110 # from the helper, so this is a no-op for Docker Hub / GHCR refs. 

111 # 

112 # The helper lives in ``cli._image_uri`` rather than ``cli.images`` 

113 # so this import doesn't create a module-level cycle: 

114 # ``cli.images`` itself imports the same helper. ``cli._image_uri`` 

115 # is a leaf module with no project-side dependencies. 

116 region_image_map: dict[str, str] = {} 

117 if rewrite_image: 

118 from ._image_uri import rewrite_image_uri_for_region 

119 

120 for region in target_regions: 

121 region_image_map[region] = rewrite_image_uri_for_region(image, region) 

122 

123 spec = { 

124 "image": image, 

125 "port": port, 

126 "replicas": replicas, 

127 "gpu_count": gpu_count, 

128 "health_check_path": health_check_path, 

129 } 

130 # Preserve the rewrite map on the spec so the inference_monitor 

131 # service can pick the right URI per region when materialising 

132 # pods. When ``rewrite_image=False`` no map is set and the flat 

133 # ``image`` field is the only source. 

134 if region_image_map and any(uri != image for uri in region_image_map.values()): 

135 spec["region_image_uris"] = region_image_map 

136 if gpu_type: 

137 spec["gpu_type"] = gpu_type 

138 if model_path: 

139 spec["model_path"] = model_path 

140 if model_source: 

141 spec["model_source"] = model_source 

142 if env: 

143 spec["env"] = env 

144 if autoscaling: 

145 spec["autoscaling"] = autoscaling 

146 if capacity_type: 

147 spec["capacity_type"] = capacity_type 

148 if extra_args: 

149 spec["args"] = extra_args 

150 if accelerator != "nvidia": 

151 spec["accelerator"] = accelerator 

152 if node_selector: 

153 spec["node_selector"] = node_selector 

154 

155 store = self._get_store() 

156 result: dict[str, Any] = store.create_endpoint( 

157 endpoint_name=endpoint_name, 

158 spec=spec, 

159 target_regions=target_regions, 

160 namespace=namespace, 

161 labels=labels, 

162 ) 

163 return result 

164 

165 def list_endpoints( 

166 self, 

167 desired_state: str | None = None, 

168 region: str | None = None, 

169 ) -> list[dict[str, Any]]: 

170 """List all inference endpoints.""" 

171 store = self._get_store() 

172 result: list[dict[str, Any]] = store.list_endpoints( 

173 desired_state=desired_state, 

174 target_region=region, 

175 ) 

176 return result 

177 

178 def get_endpoint(self, endpoint_name: str) -> dict[str, Any] | None: 

179 """Get details of a specific endpoint.""" 

180 store = self._get_store() 

181 result: dict[str, Any] | None = store.get_endpoint(endpoint_name) 

182 return result 

183 

184 def scale(self, endpoint_name: str, replicas: int) -> dict[str, Any] | None: 

185 """Scale an endpoint to a new replica count.""" 

186 store = self._get_store() 

187 result: dict[str, Any] | None = store.scale_endpoint(endpoint_name, replicas) 

188 return result 

189 

190 def stop(self, endpoint_name: str) -> dict[str, Any] | None: 

191 """Stop an endpoint (scale to zero, keep resources).""" 

192 store = self._get_store() 

193 result: dict[str, Any] | None = store.update_desired_state(endpoint_name, "stopped") 

194 return result 

195 

196 def start(self, endpoint_name: str) -> dict[str, Any] | None: 

197 """Start a stopped endpoint.""" 

198 store = self._get_store() 

199 result: dict[str, Any] | None = store.update_desired_state(endpoint_name, "running") 

200 return result 

201 

202 def delete(self, endpoint_name: str) -> dict[str, Any] | None: 

203 """Mark an endpoint for deletion (inference_monitor cleans up).""" 

204 store = self._get_store() 

205 result: dict[str, Any] | None = store.update_desired_state(endpoint_name, "deleted") 

206 return result 

207 

208 def update_image(self, endpoint_name: str, image: str) -> dict[str, Any] | None: 

209 """Update the container image for an endpoint.""" 

210 store = self._get_store() 

211 endpoint = store.get_endpoint(endpoint_name) 

212 if not endpoint: 

213 return None 

214 spec = endpoint.get("spec", {}) 

215 spec["image"] = image 

216 result: dict[str, Any] | None = store.update_spec(endpoint_name, spec) 

217 return result 

218 

219 def add_region(self, endpoint_name: str, region: str) -> dict[str, Any] | None: 

220 """Add a region to an existing endpoint.""" 

221 from datetime import UTC, datetime 

222 

223 store = self._get_store() 

224 endpoint = store.get_endpoint(endpoint_name) 

225 if not endpoint: 

226 return None 

227 regions = endpoint.get("target_regions", []) 

228 if region not in regions: 

229 regions.append(region) 

230 # Update via raw DynamoDB update 

231 try: 

232 response = store._table.update_item( 

233 Key={"endpoint_name": endpoint_name}, 

234 UpdateExpression="SET target_regions = :r, updated_at = :u", 

235 ExpressionAttributeValues={ 

236 ":r": regions, 

237 ":u": datetime.now(UTC).isoformat(), 

238 }, 

239 ReturnValues="ALL_NEW", 

240 ) 

241 result: dict[str, Any] | None = response.get("Attributes") 

242 return result 

243 except Exception as e: 

244 logger.error("Failed to add region: %s", e) 

245 return None 

246 

247 def remove_region(self, endpoint_name: str, region: str) -> dict[str, Any] | None: 

248 """Remove a region from an existing endpoint.""" 

249 store = self._get_store() 

250 endpoint = store.get_endpoint(endpoint_name) 

251 if not endpoint: 

252 return None 

253 regions = endpoint.get("target_regions", []) 

254 if region in regions: 

255 regions.remove(region) 

256 try: 

257 from datetime import UTC, datetime 

258 

259 response = store._table.update_item( 

260 Key={"endpoint_name": endpoint_name}, 

261 UpdateExpression="SET target_regions = :r, updated_at = :u", 

262 ExpressionAttributeValues={ 

263 ":r": regions, 

264 ":u": datetime.now(UTC).isoformat(), 

265 }, 

266 ReturnValues="ALL_NEW", 

267 ) 

268 result: dict[str, Any] | None = response.get("Attributes") 

269 return result 

270 except Exception as e: 

271 logger.error("Failed to remove region: %s", e) 

272 return None 

273 

274 def canary_deploy( 

275 self, 

276 endpoint_name: str, 

277 image: str, 

278 weight: int = 10, 

279 replicas: int = 1, 

280 ) -> dict[str, Any] | None: 

281 """Start a canary deployment for an existing endpoint. 

282 

283 Creates a canary variant with the new image receiving `weight`% 

284 of traffic. The primary deployment continues serving the rest. 

285 

286 Args: 

287 endpoint_name: Existing endpoint to canary 

288 image: New container image for the canary 

289 weight: Percentage of traffic to route to canary (1-99) 

290 replicas: Number of canary replicas 

291 

292 Returns: 

293 Updated endpoint record, or None if endpoint not found 

294 """ 

295 if not 1 <= weight <= 99: 

296 raise ValueError("Canary weight must be between 1 and 99") 

297 

298 store = self._get_store() 

299 endpoint = store.get_endpoint(endpoint_name) 

300 if not endpoint: 

301 return None 

302 

303 if endpoint.get("desired_state") not in ("running", "deploying"): 

304 raise ValueError( 

305 f"Cannot canary an endpoint in '{endpoint.get('desired_state')}' state. " 

306 "Endpoint must be running." 

307 ) 

308 

309 spec = endpoint.get("spec", {}) 

310 spec["canary"] = { 

311 "image": image, 

312 "weight": weight, 

313 "replicas": replicas, 

314 } 

315 

316 result: dict[str, Any] | None = store.update_spec(endpoint_name, spec) 

317 return result 

318 

319 def promote_canary(self, endpoint_name: str) -> dict[str, Any] | None: 

320 """Promote the canary to primary, removing the canary deployment. 

321 

322 The primary image is replaced with the canary image, and the 

323 canary config is removed. All traffic goes to the new image. 

324 

325 Returns: 

326 Updated endpoint record, or None if endpoint not found 

327 """ 

328 store = self._get_store() 

329 endpoint = store.get_endpoint(endpoint_name) 

330 if not endpoint: 

331 return None 

332 

333 spec = endpoint.get("spec", {}) 

334 canary = spec.get("canary") 

335 if not canary: 

336 raise ValueError(f"Endpoint '{endpoint_name}' has no active canary deployment") 

337 

338 if "image" not in canary: 

339 raise ValueError( 

340 f"Canary deployment for '{endpoint_name}' is missing the 'image' field" 

341 ) 

342 

343 # Swap primary image to canary image 

344 spec["image"] = canary["image"] 

345 # Remove canary config 

346 del spec["canary"] 

347 

348 result: dict[str, Any] | None = store.update_spec(endpoint_name, spec) 

349 return result 

350 

351 def rollback_canary(self, endpoint_name: str) -> dict[str, Any] | None: 

352 """Remove the canary deployment, keeping the primary unchanged. 

353 

354 All traffic returns to the primary deployment. 

355 

356 Returns: 

357 Updated endpoint record, or None if endpoint not found 

358 """ 

359 store = self._get_store() 

360 endpoint = store.get_endpoint(endpoint_name) 

361 if not endpoint: 

362 return None 

363 

364 spec = endpoint.get("spec", {}) 

365 if "canary" not in spec: 

366 raise ValueError(f"Endpoint '{endpoint_name}' has no active canary deployment") 

367 

368 del spec["canary"] 

369 result: dict[str, Any] | None = store.update_spec(endpoint_name, spec) 

370 return result 

371 

372 

373def get_inference_manager(config: GCOConfig | None = None) -> InferenceManager: 

374 """Factory function for InferenceManager.""" 

375 return InferenceManager(config)