Coverage for gco / services / inference_monitor.py: 87%

518 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 21:47 +0000

1""" 

2Inference Monitor — reconciliation controller for inference endpoints. 

3 

4Runs in each regional EKS cluster and polls the global DynamoDB table 

5(gco-inference-endpoints) to reconcile desired state with actual 

6Kubernetes resources. Follows a GitOps-style reconciliation pattern: 

7 

8 DynamoDB (desired state) → inference_monitor → Kubernetes (actual state) 

9 

10The monitor: 

11- Creates Deployments, Services, and Ingress rules for new endpoints 

12- Updates existing deployments when spec changes 

13- Scales deployments up/down 

14- Tears down resources when endpoints are deleted 

15- Reports per-region status back to DynamoDB 

16 

17Environment Variables: 

18 CLUSTER_NAME: Name of the EKS cluster 

19 REGION: AWS region this monitor runs in 

20 INFERENCE_ENDPOINTS_TABLE_NAME: DynamoDB table name 

21 RECONCILE_INTERVAL_SECONDS: Seconds between reconciliation loops (default: 15) 

22 INFERENCE_NAMESPACE: Namespace for inference workloads (default: gco-inference) 

23""" 

24 

25import asyncio 

26import logging 

27import os 

28from datetime import UTC, datetime 

29from typing import Any 

30 

31from kubernetes import client, config 

32from kubernetes.client.models import V1Deployment 

33from kubernetes.client.rest import ApiException 

34 

35from gco.services.inference_store import InferenceEndpointStore 

36from gco.services.structured_logging import configure_structured_logging 

37 

38logging.basicConfig( 

39 level=logging.INFO, 

40 format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 

41) 

42logger = logging.getLogger(__name__) 

43 

44 

45class InferenceMonitor: 

46 """ 

47 Reconciliation controller for inference endpoints. 

48 

49 Polls DynamoDB for desired endpoint state and reconciles with 

50 the actual Kubernetes resources in the local cluster. 

51 """ 

52 

53 def __init__( 

54 self, 

55 cluster_id: str, 

56 region: str, 

57 store: InferenceEndpointStore, 

58 namespace: str = "gco-inference", 

59 reconcile_interval: int = 15, 

60 ): 

61 self.cluster_id = cluster_id 

62 self.region = region 

63 self.store = store 

64 self.namespace = namespace 

65 self.reconcile_interval = reconcile_interval 

66 self._running = False 

67 

68 # Initialize Kubernetes clients 

69 try: 

70 config.load_incluster_config() 

71 logger.info("Loaded in-cluster Kubernetes configuration") 

72 except config.ConfigException: 

73 try: 

74 config.load_kube_config() 

75 logger.info("Loaded local Kubernetes configuration") 

76 except config.ConfigException as e: 

77 logger.error("Failed to load Kubernetes configuration: %s", e) 

78 raise 

79 

80 self.apps_v1 = client.AppsV1Api() 

81 self.core_v1 = client.CoreV1Api() 

82 self.networking_v1 = client.NetworkingV1Api() 

83 

84 # Timeout for Kubernetes API calls (seconds) 

85 self._k8s_timeout = int(os.environ.get("K8S_API_TIMEOUT", "30")) 

86 

87 # Health watchdog: tracks when each endpoint first became unready. 

88 # If an endpoint stays unready for longer than _ingress_removal_threshold, 

89 # the watchdog removes its Ingress to protect the shared ALB from 

90 # having an unhealthy target group (which would make GA mark the 

91 # entire ALB as unhealthy, blocking all inference in the region). 

92 self._unready_since: dict[str, datetime] = {} 

93 self._ingress_removal_threshold = int( 

94 os.environ.get("INFERENCE_UNHEALTHY_THRESHOLD_SECONDS", "300") 

95 ) # 5 minutes default 

96 

97 # Metrics 

98 self._reconcile_count = 0 

99 self._errors_count = 0 

100 

101 # ------------------------------------------------------------------ 

102 # Reconciliation loop 

103 # ------------------------------------------------------------------ 

104 

105 async def start(self) -> None: 

106 """Start the reconciliation loop with leader election. 

107 

108 Uses a Kubernetes Lease object for leader election so that only 

109 one replica reconciles at a time. Other replicas stay on standby 

110 and take over if the leader dies. 

111 """ 

112 if self._running: 

113 logger.warning("Inference monitor already running") 

114 return 

115 self._running = True 

116 logger.info( 

117 "Starting inference monitor for %s in %s (interval=%ds)", 

118 self.cluster_id, 

119 self.region, 

120 self.reconcile_interval, 

121 ) 

122 

123 # Namespace and ServiceAccount are pre-created by the kubectl-applier 

124 # at deploy time (00-namespaces.yaml, 01-serviceaccounts.yaml). The 

125 # inference-monitor SA has namespace-scoped RBAC only — it cannot 

126 # read_namespace/create_namespace, so we don't try. If the namespace 

127 # is ever missing, deployments below will fail with a clear 404. 

128 

129 # Get pod identity for leader election 

130 pod_name = os.environ.get("HOSTNAME", f"monitor-{id(self)}") 

131 lease_name = "inference-monitor-leader" 

132 

133 while self._running: 

134 try: 

135 if self._try_acquire_lease(lease_name, pod_name): 135 ↛ 138line 135 didn't jump to line 138 because the condition on line 135 was always true

136 await self.reconcile() 

137 else: 

138 logger.debug("Not the leader, waiting...") 

139 except Exception as e: 

140 logger.error("Reconciliation error: %s", e, exc_info=True) 

141 self._errors_count += 1 

142 try: 

143 await asyncio.sleep(self.reconcile_interval) 

144 except Exception as e: 

145 logger.error("Sleep interrupted: %s", e) 

146 break 

147 

148 def _try_acquire_lease(self, lease_name: str, holder: str) -> bool: 

149 """Try to acquire or renew a Kubernetes Lease for leader election. 

150 

151 Uses optimistic concurrency via resourceVersion — if two monitors 

152 race to update the same lease, K8s returns 409 Conflict for the 

153 loser, preventing split-brain. 

154 

155 Returns True if this instance is the leader. 

156 """ 

157 

158 coordination_v1 = client.CoordinationV1Api() 

159 now = datetime.now(UTC) 

160 

161 try: 

162 lease = coordination_v1.read_namespaced_lease(lease_name, self.namespace) 

163 current_holder = lease.spec.holder_identity 

164 renew_time = lease.spec.renew_time 

165 

166 # Check if lease is expired (holder hasn't renewed in 3x interval) 

167 if renew_time: 

168 elapsed = (now - renew_time.replace(tzinfo=UTC)).total_seconds() 

169 if elapsed > self.reconcile_interval * 3: 

170 # Lease expired — take over 

171 logger.info("Lease expired (held by %s), taking over", current_holder) 

172 current_holder = None 

173 

174 if current_holder == holder: 

175 # We're the leader — renew 

176 lease.spec.renew_time = now 

177 try: 

178 coordination_v1.replace_namespaced_lease(lease_name, self.namespace, lease) 

179 except ApiException as conflict: 

180 if conflict.status == 409: 

181 logger.debug("Lease renew conflict (another writer), retrying next cycle") 

182 return False 

183 raise 

184 return True 

185 if current_holder is None or current_holder == "": 

186 # No leader — claim it 

187 lease.spec.holder_identity = holder 

188 lease.spec.renew_time = now 

189 try: 

190 coordination_v1.replace_namespaced_lease(lease_name, self.namespace, lease) 

191 except ApiException as conflict: 

192 if conflict.status == 409: 

193 logger.info("Lost lease race to another monitor") 

194 return False 

195 raise 

196 logger.info("Acquired leader lease as %s", holder) 

197 return True 

198 # Someone else is the leader 

199 return False 

200 

201 except ApiException as e: 

202 if e.status == 404: 

203 # Lease doesn't exist — create it 

204 lease = client.V1Lease( 

205 metadata=client.V1ObjectMeta( 

206 name=lease_name, 

207 namespace=self.namespace, 

208 ), 

209 spec=client.V1LeaseSpec( 

210 holder_identity=holder, 

211 lease_duration_seconds=self.reconcile_interval * 3, 

212 renew_time=now, 

213 ), 

214 ) 

215 try: 

216 coordination_v1.create_namespaced_lease(self.namespace, lease) 

217 logger.info("Created leader lease as %s", holder) 

218 return True 

219 except ApiException: 

220 return False 

221 logger.warning("Lease check failed: %s", e.reason) 

222 return False 

223 

224 def stop(self) -> None: 

225 """Stop the reconciliation loop.""" 

226 self._running = False 

227 logger.info("Inference monitor stopped") 

228 

229 async def reconcile(self) -> list[dict[str, Any]]: 

230 """ 

231 Run one reconciliation cycle. 

232 

233 Returns a list of actions taken (for logging/testing). 

234 """ 

235 self._reconcile_count += 1 

236 actions: list[dict[str, Any]] = [] 

237 

238 # Get all endpoints from DynamoDB 

239 try: 

240 endpoints = self.store.list_endpoints() 

241 except Exception as e: 

242 logger.error("Failed to list endpoints from DynamoDB: %s", e) 

243 return actions 

244 

245 for endpoint in endpoints: 

246 try: 

247 action = await self._reconcile_endpoint(endpoint) 

248 if action: 248 ↛ 245line 248 didn't jump to line 245 because the condition on line 248 was always true

249 actions.append(action) 

250 except Exception as e: 

251 name = endpoint.get("endpoint_name", "unknown") 

252 logger.error("Failed to reconcile endpoint %s: %s", name, e) 

253 self._errors_count += 1 

254 self.store.update_region_status( 

255 name, 

256 self.region, 

257 "error", 

258 error=str(e), 

259 ) 

260 

261 # Purge fully-deleted endpoints from DynamoDB to prevent unbounded growth. 

262 # An endpoint is fully deleted when desired_state is "deleted" and all 

263 # target regions report "deleted" status. 

264 for endpoint in endpoints: 

265 if endpoint.get("desired_state") != "deleted": 265 ↛ 267line 265 didn't jump to line 267 because the condition on line 265 was always true

266 continue 

267 region_status = endpoint.get("region_status", {}) 

268 target_regions = endpoint.get("target_regions", []) 

269 if not target_regions: 

270 continue 

271 all_deleted = all( 

272 isinstance(region_status.get(r), dict) 

273 and region_status.get(r, {}).get("state") == "deleted" 

274 for r in target_regions 

275 ) 

276 if all_deleted: 

277 ep_name = endpoint["endpoint_name"] 

278 try: 

279 self.store.delete_endpoint(ep_name) 

280 logger.info("Purged fully-deleted endpoint %s from DynamoDB", ep_name) 

281 actions.append({"action": "purge", "endpoint": ep_name}) 

282 except Exception as e: 

283 logger.warning("Failed to purge endpoint %s: %s", ep_name, e) 

284 

285 return actions 

286 

287 async def _reconcile_endpoint(self, endpoint: dict[str, Any]) -> dict[str, Any] | None: 

288 """Reconcile a single endpoint.""" 

289 name = endpoint["endpoint_name"] 

290 desired_state = endpoint.get("desired_state", "deploying") 

291 target_regions = endpoint.get("target_regions", []) 

292 spec = endpoint.get("spec", {}) 

293 ns = endpoint.get("namespace", self.namespace) 

294 

295 # Am I a target region? 

296 if self.region not in target_regions: 

297 # If I have resources for this endpoint, clean them up 

298 if self._deployment_exists(name, ns): 

299 logger.info( 

300 "Endpoint %s no longer targets %s, cleaning up", 

301 name, 

302 self.region, 

303 ) 

304 self._delete_resources(name, ns) 

305 self.store.update_region_status( 

306 name, 

307 self.region, 

308 "deleted", 

309 ) 

310 return {"action": "cleanup", "endpoint": name, "reason": "region_removed"} 

311 return None 

312 

313 # Reconcile based on desired state 

314 if desired_state in ("deploying", "running"): 

315 return await self._reconcile_running(name, ns, spec, endpoint) 

316 if desired_state == "stopped": 

317 return self._reconcile_stopped(name, ns) 

318 if desired_state == "deleted": 

319 return self._reconcile_deleted(name, ns) 

320 

321 return None 

322 

323 async def _reconcile_running( 

324 self, 

325 name: str, 

326 namespace: str, 

327 spec: dict[str, Any], 

328 endpoint: dict[str, Any], 

329 ) -> dict[str, Any] | None: 

330 """Ensure the endpoint is running with the correct spec.""" 

331 deployment = self._get_deployment(name, namespace) 

332 

333 if deployment is None: 

334 # Create everything 

335 logger.info("Creating endpoint %s in %s", name, self.region) 

336 self._create_deployment(name, namespace, spec) 

337 self._create_service(name, namespace, spec) 

338 self._update_ingress_rule(name, namespace, spec, endpoint) 

339 if spec.get("autoscaling", {}).get("enabled"): 

340 self._create_or_update_hpa(name, namespace, spec) 

341 self.store.update_region_status( 

342 name, 

343 self.region, 

344 "creating", 

345 replicas_desired=spec.get("replicas", 1), 

346 ) 

347 return {"action": "create", "endpoint": name} 

348 

349 # Deployment exists — ensure Service and Ingress also exist 

350 # (they may have been manually deleted or lost during a rollout) 

351 self._ensure_service(name, namespace, spec) 

352 

353 # Check readiness before ensuring Ingress — the health watchdog may 

354 # remove the Ingress if the endpoint has been unready too long 

355 desired_replicas = spec.get("replicas", 1) 

356 current_replicas = deployment.spec.replicas or 1 

357 ready_replicas = deployment.status.ready_replicas or 0 

358 

359 ingress_removed = self._check_health_watchdog( 

360 name, namespace, ready_replicas, desired_replicas, spec, endpoint 

361 ) 

362 if not ingress_removed: 362 ↛ 365line 362 didn't jump to line 365 because the condition on line 362 was always true

363 self._ensure_ingress(name, namespace, spec, endpoint) 

364 

365 if current_replicas != desired_replicas: 

366 logger.info( 

367 "Scaling endpoint %s: %d → %d replicas", 

368 name, 

369 current_replicas, 

370 desired_replicas, 

371 ) 

372 self._scale_deployment(name, namespace, desired_replicas) 

373 self.store.update_region_status( 

374 name, 

375 self.region, 

376 "updating", 

377 replicas_ready=ready_replicas, 

378 replicas_desired=desired_replicas, 

379 ) 

380 return {"action": "scale", "endpoint": name, "replicas": desired_replicas} 

381 

382 # Check if image changed 

383 current_image = self._get_deployment_image(deployment) 

384 desired_image = spec.get("image", "") 

385 if current_image and desired_image and current_image != desired_image: 

386 logger.info("Updating endpoint %s image: %s → %s", name, current_image, desired_image) 

387 self._update_deployment_image(name, namespace, desired_image) 

388 self.store.update_region_status( 

389 name, 

390 self.region, 

391 "updating", 

392 replicas_ready=ready_replicas, 

393 replicas_desired=desired_replicas, 

394 ) 

395 return {"action": "update_image", "endpoint": name, "image": desired_image} 

396 

397 # Everything is in sync — report status 

398 state = "running" if ready_replicas >= desired_replicas else "creating" 

399 self.store.update_region_status( 

400 name, 

401 self.region, 

402 state, 

403 replicas_ready=ready_replicas, 

404 replicas_desired=desired_replicas, 

405 ) 

406 

407 # Reconcile canary deployment if present 

408 canary = spec.get("canary") 

409 if canary: 409 ↛ 410line 409 didn't jump to line 410 because the condition on line 409 was never true

410 self._reconcile_canary(name, namespace, spec, canary, endpoint) 

411 else: 

412 # No canary — clean up canary resources if they exist 

413 self._cleanup_canary(name, namespace) 

414 

415 # If all replicas are ready and desired_state is "deploying", promote to "running" 

416 if state == "running" and endpoint.get("desired_state") == "deploying": 

417 # Check if all target regions are running 

418 all_running = True 

419 for r_status in endpoint.get("region_status", {}).values(): 

420 if isinstance(r_status, dict) and r_status.get("state") != "running": 420 ↛ 421line 420 didn't jump to line 421 because the condition on line 420 was never true

421 all_running = False 

422 break 

423 if all_running: 423 ↛ 426line 423 didn't jump to line 426 because the condition on line 423 was always true

424 self.store.update_desired_state(name, "running") 

425 

426 return None 

427 

428 def _reconcile_stopped(self, name: str, namespace: str) -> dict[str, Any] | None: 

429 """Scale deployment to zero.""" 

430 deployment = self._get_deployment(name, namespace) 

431 if deployment is None: 

432 return None 

433 

434 current_replicas = deployment.spec.replicas or 0 

435 if current_replicas > 0: 

436 logger.info("Stopping endpoint %s (scaling to 0)", name) 

437 self._scale_deployment(name, namespace, 0) 

438 self.store.update_region_status( 

439 name, 

440 self.region, 

441 "stopped", 

442 replicas_ready=0, 

443 replicas_desired=0, 

444 ) 

445 return {"action": "stop", "endpoint": name} 

446 

447 self.store.update_region_status( 

448 name, 

449 self.region, 

450 "stopped", 

451 replicas_ready=0, 

452 replicas_desired=0, 

453 ) 

454 return None 

455 

456 def _reconcile_deleted(self, name: str, namespace: str) -> dict[str, Any] | None: 

457 """Delete all resources for the endpoint.""" 

458 # Clean up health watchdog tracker 

459 self._unready_since.pop(name, None) 

460 

461 if self._deployment_exists(name, namespace): 

462 logger.info("Deleting endpoint %s from %s", name, self.region) 

463 self._delete_resources(name, namespace) 

464 self.store.update_region_status(name, self.region, "deleted") 

465 return {"action": "delete", "endpoint": name} 

466 

467 self.store.update_region_status(name, self.region, "deleted") 

468 return None 

469 

470 # ------------------------------------------------------------------ 

471 # Kubernetes resource management 

472 # ------------------------------------------------------------------ 

473 

474 def _deployment_exists(self, name: str, namespace: str) -> bool: 

475 try: 

476 self.apps_v1.read_namespaced_deployment( 

477 name, namespace, _request_timeout=self._k8s_timeout 

478 ) 

479 return True 

480 except ApiException as e: 

481 if e.status == 404: 

482 return False 

483 raise 

484 

485 def _get_deployment(self, name: str, namespace: str) -> V1Deployment | None: 

486 try: 

487 return self.apps_v1.read_namespaced_deployment( 

488 name, namespace, _request_timeout=self._k8s_timeout 

489 ) 

490 except ApiException as e: 

491 if e.status == 404: 

492 return None 

493 raise 

494 

495 def _get_deployment_image(self, deployment: V1Deployment) -> str | None: 

496 """Get the image of the first container in a deployment.""" 

497 containers = deployment.spec.template.spec.containers 

498 if containers: 

499 image: str = containers[0].image 

500 return image 

501 return None 

502 

503 def _create_deployment(self, name: str, namespace: str, spec: dict[str, Any]) -> None: 

504 """Create a Kubernetes Deployment for an inference endpoint.""" 

505 replicas = spec.get("replicas", 1) 

506 image = spec["image"] 

507 port = spec.get("port", 8000) 

508 gpu_count = spec.get("gpu_count", 1) 

509 health_path = spec.get("health_check_path", "/health") 

510 env_vars = spec.get("env", {}) 

511 resources = spec.get("resources", {}) 

512 model_path = spec.get("model_path") 

513 command = spec.get("command") 

514 args = spec.get("args") 

515 

516 # Build container 

517 container_env = [client.V1EnvVar(name=k, value=str(v)) for k, v in env_vars.items()] 

518 

519 # Inject --root-path for servers that support it (vLLM, TGI). 

520 # This tells the server to mount its API at /inference/{name}. 

521 # We append to existing args (from --extra-args) rather than replacing them. 

522 ingress_prefix = f"/inference/{name}" 

523 root_path_images = ("vllm", "text-generation-inference", "tgi") 

524 image_lower = image.lower() 

525 if not command and any(tag in image_lower for tag in root_path_images): 

526 if args: 526 ↛ 528line 526 didn't jump to line 528 because the condition on line 526 was never true

527 # Append --root-path to user-provided args if not already present 

528 if "--root-path" not in args: 

529 args = list(args) + ["--root-path", ingress_prefix] 

530 else: 

531 args = ["--root-path", ingress_prefix] 

532 

533 resource_reqs = client.V1ResourceRequirements( 

534 requests=resources.get("requests", {"cpu": "1", "memory": "4Gi"}), 

535 limits=resources.get("limits", {"cpu": "4", "memory": "16Gi"}), 

536 ) 

537 # Add accelerator resources (GPU or Neuron) 

538 accelerator = spec.get("accelerator", "nvidia") 

539 if gpu_count > 0: 

540 if accelerator == "neuron": 

541 # AWS Trainium/Inferentia — request Neuron devices 

542 if resource_reqs.limits is None: 542 ↛ 543line 542 didn't jump to line 543 because the condition on line 542 was never true

543 resource_reqs.limits = {} 

544 resource_reqs.limits["aws.amazon.com/neuron"] = str(gpu_count) 

545 if resource_reqs.requests is None: 545 ↛ 546line 545 didn't jump to line 546 because the condition on line 545 was never true

546 resource_reqs.requests = {} 

547 resource_reqs.requests["aws.amazon.com/neuron"] = str(gpu_count) 

548 else: 

549 # NVIDIA GPU (default) 

550 if resource_reqs.limits is None: 550 ↛ 551line 550 didn't jump to line 551 because the condition on line 550 was never true

551 resource_reqs.limits = {} 

552 resource_reqs.limits["nvidia.com/gpu"] = str(gpu_count) 

553 if resource_reqs.requests is None: 553 ↛ 554line 553 didn't jump to line 554 because the condition on line 553 was never true

554 resource_reqs.requests = {} 

555 resource_reqs.requests["nvidia.com/gpu"] = str(gpu_count) 

556 

557 volume_mounts = [] 

558 volumes = [] 

559 init_containers = [] 

560 model_source = spec.get("model_source") 

561 

562 if model_path or model_source: 

563 volume_mounts.append( 

564 client.V1VolumeMount( 

565 name="model-storage", 

566 mount_path="/models", 

567 ) 

568 ) 

569 volumes.append( 

570 client.V1Volume( 

571 name="model-storage", 

572 persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource( 

573 claim_name="efs-claim", 

574 ), 

575 ) 

576 ) 

577 

578 # Add init container to sync model from S3 if model_source is set 

579 if model_source and model_source.startswith("s3://"): 

580 model_dest = f"/models/{name}" 

581 init_containers.append( 

582 client.V1Container( 

583 name="model-sync", 

584 image="amazon/aws-cli:latest", 

585 command=["sh", "-c"], 

586 args=[ 

587 f"if [ -d '{model_dest}' ] && [ \"$(ls -A '{model_dest}')\" ]; then " 

588 f"echo 'Model already cached at {model_dest}, skipping sync'; " 

589 f"else echo 'Syncing model from {model_source}...'; " 

590 f"aws s3 sync {model_source} {model_dest} --quiet; " 

591 f"echo 'Model sync complete'; fi" 

592 ], 

593 volume_mounts=[ 

594 client.V1VolumeMount( 

595 name="model-storage", 

596 mount_path="/models", 

597 ) 

598 ], 

599 resources=client.V1ResourceRequirements( 

600 requests={"cpu": "1", "memory": "2Gi"}, 

601 limits={"cpu": "4", "memory": "8Gi"}, 

602 ), 

603 ) 

604 ) 

605 

606 # Probe path depends on whether the server handles the prefix 

607 uses_root_path = args is not None and "--root-path" in args 

608 probe_health = f"{ingress_prefix}{health_path}" if uses_root_path else health_path 

609 

610 container = client.V1Container( 

611 name="inference", 

612 image=image, 

613 ports=[client.V1ContainerPort(container_port=port)], 

614 env=container_env if container_env else None, 

615 resources=resource_reqs, 

616 volume_mounts=volume_mounts if volume_mounts else None, 

617 command=command, 

618 args=args, 

619 liveness_probe=client.V1Probe( 

620 http_get=client.V1HTTPGetAction(path=probe_health, port=port), 

621 initial_delay_seconds=120, 

622 period_seconds=15, 

623 failure_threshold=5, 

624 ), 

625 readiness_probe=client.V1Probe( 

626 http_get=client.V1HTTPGetAction(path=probe_health, port=port), 

627 initial_delay_seconds=30, 

628 period_seconds=10, 

629 ), 

630 ) 

631 

632 # Build tolerations based on accelerator type 

633 if accelerator == "neuron": 

634 tolerations = [ 

635 client.V1Toleration( 

636 key="aws.amazon.com/neuron", 

637 operator="Equal", 

638 value="true", 

639 effect="NoSchedule", 

640 ) 

641 ] 

642 else: 

643 tolerations = [ 

644 client.V1Toleration( 

645 key="nvidia.com/gpu", 

646 operator="Equal", 

647 value="true", 

648 effect="NoSchedule", 

649 ) 

650 ] 

651 

652 # Node selector based on accelerator type 

653 node_selector = spec.get("node_selector", {}) 

654 if gpu_count > 0 and not node_selector: 

655 if accelerator == "neuron": 

656 node_selector = {"accelerator": "neuron"} 

657 else: 

658 node_selector = {"eks.amazonaws.com/instance-gpu-manufacturer": "nvidia"} 

659 

660 # Apply capacity type preference (spot/on-demand) 

661 capacity_type = spec.get("capacity_type") 

662 if capacity_type in ("spot", "on-demand"): 

663 node_selector["karpenter.sh/capacity-type"] = capacity_type 

664 

665 deployment = client.V1Deployment( 

666 metadata=client.V1ObjectMeta( 

667 name=name, 

668 namespace=namespace, 

669 labels={ 

670 "app": name, 

671 "project": "gco", 

672 "gco.io/type": "inference", 

673 }, 

674 ), 

675 spec=client.V1DeploymentSpec( 

676 replicas=replicas, 

677 selector=client.V1LabelSelector( 

678 match_labels={"app": name}, 

679 ), 

680 template=client.V1PodTemplateSpec( 

681 metadata=client.V1ObjectMeta( 

682 labels={ 

683 "app": name, 

684 "project": "gco", 

685 "gco.io/type": "inference", 

686 }, 

687 ), 

688 spec=client.V1PodSpec( 

689 service_account_name="gco-service-account", 

690 containers=[container], 

691 init_containers=init_containers if init_containers else None, 

692 tolerations=tolerations, 

693 node_selector=node_selector if node_selector else None, 

694 volumes=volumes if volumes else None, 

695 ), 

696 ), 

697 ), 

698 ) 

699 

700 self.apps_v1.create_namespaced_deployment( 

701 namespace, deployment, _request_timeout=self._k8s_timeout 

702 ) 

703 logger.info("Created deployment %s/%s", namespace, name) 

704 

705 def _create_service(self, name: str, namespace: str, spec: dict[str, Any]) -> None: 

706 """Create a Kubernetes Service for an inference endpoint.""" 

707 port = spec.get("port", 8000) 

708 

709 service = client.V1Service( 

710 metadata=client.V1ObjectMeta( 

711 name=name, 

712 namespace=namespace, 

713 labels={ 

714 "app": name, 

715 "project": "gco", 

716 "gco.io/type": "inference", 

717 }, 

718 ), 

719 spec=client.V1ServiceSpec( 

720 selector={"app": name}, 

721 ports=[ 

722 client.V1ServicePort( 

723 port=80, 

724 target_port=port, 

725 protocol="TCP", 

726 ) 

727 ], 

728 type="ClusterIP", 

729 ), 

730 ) 

731 

732 try: 

733 self.core_v1.create_namespaced_service( 

734 namespace, service, _request_timeout=self._k8s_timeout 

735 ) 

736 logger.info("Created service %s/%s", namespace, name) 

737 except ApiException as e: 

738 if e.status == 409: 738 ↛ 741line 738 didn't jump to line 741 because the condition on line 738 was always true

739 logger.info("Service %s/%s already exists", namespace, name) 

740 else: 

741 raise 

742 

743 def _ensure_service(self, name: str, namespace: str, spec: dict[str, Any]) -> None: 

744 """Ensure the Service exists, recreating it if missing.""" 

745 try: 

746 self.core_v1.read_namespaced_service( 

747 name, namespace, _request_timeout=self._k8s_timeout 

748 ) 

749 except ApiException as e: 

750 if e.status == 404: 

751 logger.warning("Service %s/%s missing, recreating", namespace, name) 

752 self._create_service(name, namespace, spec) 

753 else: 

754 raise 

755 

756 def _ensure_ingress( 

757 self, 

758 name: str, 

759 namespace: str, 

760 spec: dict[str, Any], 

761 endpoint: dict[str, Any], 

762 ) -> None: 

763 """Ensure the Ingress exists, recreating it if missing.""" 

764 try: 

765 self.networking_v1.read_namespaced_ingress( 

766 f"inference-{name}", namespace, _request_timeout=self._k8s_timeout 

767 ) 

768 except ApiException as e: 

769 if e.status == 404: 

770 logger.warning("Ingress for %s missing, recreating", name) 

771 self._update_ingress_rule(name, namespace, spec, endpoint) 

772 else: 

773 raise 

774 

775 def _update_ingress_rule( 

776 self, 

777 name: str, 

778 namespace: str, 

779 spec: dict[str, Any], 

780 endpoint: dict[str, Any], 

781 ) -> None: 

782 """Create or update an Ingress for the inference endpoint. 

783 

784 The Ingress is created in the same namespace as the Service and pods. 

785 IngressClassParams with group.name merges all Ingresses onto a single 

786 shared ALB regardless of namespace. 

787 """ 

788 ingress_path = endpoint.get("ingress_path", f"/inference/{name}") 

789 image = spec.get("image", "") 

790 image_lower = image.lower() 

791 root_path_images = ("vllm", "text-generation-inference", "tgi") 

792 uses_root_path = any(tag in image_lower for tag in root_path_images) 

793 base_health = spec.get("health_check_path", "/health") 

794 health_path = f"/inference/{name}{base_health}" if uses_root_path else base_health 

795 

796 ingress = client.V1Ingress( 

797 metadata=client.V1ObjectMeta( 

798 name=f"inference-{name}", 

799 namespace=namespace, 

800 labels={ 

801 "app": name, 

802 "project": "gco", 

803 "gco.io/type": "inference", 

804 }, 

805 annotations={ 

806 "alb.ingress.kubernetes.io/healthcheck-path": health_path, 

807 "alb.ingress.kubernetes.io/healthcheck-interval-seconds": "15", 

808 }, 

809 ), 

810 spec=client.V1IngressSpec( 

811 ingress_class_name="alb", 

812 rules=[ 

813 client.V1IngressRule( 

814 http=client.V1HTTPIngressRuleValue( 

815 paths=[ 

816 client.V1HTTPIngressPath( 

817 path=ingress_path, 

818 path_type="Prefix", 

819 backend=client.V1IngressBackend( 

820 service=client.V1IngressServiceBackend( 

821 name=name, 

822 port=client.V1ServiceBackendPort( 

823 number=80, 

824 ), 

825 ), 

826 ), 

827 ) 

828 ] 

829 ) 

830 ) 

831 ], 

832 ), 

833 ) 

834 

835 try: 

836 self.networking_v1.create_namespaced_ingress( 

837 namespace, ingress, _request_timeout=self._k8s_timeout 

838 ) 

839 logger.info("Created ingress for %s at %s", name, ingress_path) 

840 except ApiException as e: 

841 if e.status == 409: 841 ↛ 847line 841 didn't jump to line 847 because the condition on line 841 was always true

842 self.networking_v1.patch_namespaced_ingress( 

843 f"inference-{name}", namespace, ingress, _request_timeout=self._k8s_timeout 

844 ) 

845 logger.info("Updated ingress for %s", name) 

846 else: 

847 raise 

848 

849 def _check_health_watchdog( 

850 self, 

851 name: str, 

852 namespace: str, 

853 ready_replicas: int, 

854 desired_replicas: int, 

855 spec: dict[str, Any], 

856 endpoint: dict[str, Any], 

857 ) -> bool: 

858 """Health watchdog: remove Ingress for persistently unhealthy endpoints. 

859 

860 If an endpoint has zero ready replicas for longer than the configured 

861 threshold, the watchdog removes its Ingress to protect the shared ALB. 

862 Global Accelerator considers an ALB unhealthy if ANY target group has 

863 zero healthy targets, so one bad endpoint can block all inference 

864 traffic to the region. 

865 

866 When the endpoint recovers (ready_replicas > 0), the Ingress is 

867 automatically re-created by _ensure_ingress on the next cycle. 

868 

869 Returns: 

870 True if the Ingress was removed (caller should skip _ensure_ingress). 

871 False if the endpoint is healthy or still within the grace period. 

872 """ 

873 if ready_replicas > 0: 

874 # Endpoint is healthy — clear the tracker 

875 if name in self._unready_since: 

876 logger.info( 

877 "Endpoint %s recovered, re-enabling Ingress", 

878 name, 

879 ) 

880 del self._unready_since[name] 

881 return False 

882 

883 # Endpoint has zero ready replicas 

884 now = datetime.now(UTC) 

885 

886 if name not in self._unready_since: 

887 # First time seeing this endpoint as unready — start the clock 

888 self._unready_since[name] = now 

889 logger.warning( 

890 "Endpoint %s has 0/%d ready replicas, starting health watchdog timer", 

891 name, 

892 desired_replicas, 

893 ) 

894 return False 

895 

896 # Check how long it's been unready 

897 unready_duration = (now - self._unready_since[name]).total_seconds() 

898 

899 if unready_duration < self._ingress_removal_threshold: 

900 remaining = self._ingress_removal_threshold - unready_duration 

901 logger.warning( 

902 "Endpoint %s unready for %ds (removing Ingress in %ds)", 

903 name, 

904 int(unready_duration), 

905 int(remaining), 

906 ) 

907 return False 

908 

909 # Threshold exceeded — remove the Ingress to protect the ALB 

910 ingress_name = f"inference-{name}" 

911 try: 

912 self.networking_v1.delete_namespaced_ingress( 

913 ingress_name, namespace, _request_timeout=self._k8s_timeout 

914 ) 

915 logger.warning( 

916 "WATCHDOG: Removed Ingress for unhealthy endpoint %s " 

917 "(unready for %ds > %ds threshold). " 

918 "Ingress will be re-created when the endpoint recovers.", 

919 name, 

920 int(unready_duration), 

921 self._ingress_removal_threshold, 

922 ) 

923 except ApiException as e: 

924 if e.status == 404: 

925 logger.debug("Ingress for %s already removed", name) 

926 else: 

927 logger.error("Failed to remove Ingress for %s: %s", name, e) 

928 

929 return True 

930 

931 def _scale_deployment(self, name: str, namespace: str, replicas: int) -> None: 

932 """Scale a deployment to the desired replica count.""" 

933 self.apps_v1.patch_namespaced_deployment( 

934 name, 

935 namespace, 

936 body={"spec": {"replicas": replicas}}, 

937 _request_timeout=self._k8s_timeout, 

938 ) 

939 

940 def _update_deployment_image(self, name: str, namespace: str, image: str) -> None: 

941 """Update the container image of a deployment.""" 

942 self.apps_v1.patch_namespaced_deployment( 

943 name, 

944 namespace, 

945 body={ 

946 "spec": { 

947 "template": {"spec": {"containers": [{"name": "inference", "image": image}]}} 

948 } 

949 }, 

950 _request_timeout=self._k8s_timeout, 

951 ) 

952 

953 def _reconcile_canary( 

954 self, 

955 name: str, 

956 namespace: str, 

957 spec: dict[str, Any], 

958 canary: dict[str, Any], 

959 endpoint: dict[str, Any], 

960 ) -> None: 

961 """Reconcile canary deployment and weighted ingress routing. 

962 

963 Creates a canary deployment and service alongside the primary, 

964 then updates the ingress to use ALB action-based weighted routing. 

965 """ 

966 canary_name = f"{name}-canary" 

967 canary_image = canary.get("image", "") 

968 canary_replicas = canary.get("replicas", 1) 

969 canary_weight = canary.get("weight", 10) 

970 primary_weight = 100 - canary_weight 

971 

972 # Build canary spec (same as primary but with canary image/replicas) 

973 canary_spec = dict(spec) 

974 canary_spec["image"] = canary_image 

975 canary_spec["replicas"] = canary_replicas 

976 # Remove canary field from the canary spec to avoid recursion 

977 canary_spec.pop("canary", None) 

978 

979 # Create or update canary deployment 

980 canary_deployment = self._get_deployment(canary_name, namespace) 

981 if canary_deployment is None: 

982 logger.info("Creating canary deployment %s with image %s", canary_name, canary_image) 

983 self._create_deployment(canary_name, namespace, canary_spec) 

984 self._create_service(canary_name, namespace, canary_spec) 

985 else: 

986 # Update image if changed 

987 current_image = self._get_deployment_image(canary_deployment) 

988 if current_image != canary_image: 

989 self._update_deployment_image(canary_name, namespace, canary_image) 

990 # Update replicas if changed 

991 if (canary_deployment.spec.replicas or 1) != canary_replicas: 

992 self._scale_deployment(canary_name, namespace, canary_replicas) 

993 

994 # Update ingress with weighted routing via ALB actions annotation 

995 self._update_canary_ingress(name, namespace, spec, endpoint, primary_weight, canary_weight) 

996 

997 def _update_canary_ingress( 

998 self, 

999 name: str, 

1000 namespace: str, 

1001 spec: dict[str, Any], 

1002 endpoint: dict[str, Any], 

1003 primary_weight: int, 

1004 canary_weight: int, 

1005 ) -> None: 

1006 """Update ingress with ALB weighted target group routing.""" 

1007 import json as _json 

1008 

1009 ingress_path = endpoint.get("ingress_path", f"/inference/{name}") 

1010 image = spec.get("image", "") 

1011 image_lower = image.lower() 

1012 root_path_images = ("vllm", "text-generation-inference", "tgi") 

1013 uses_root_path = any(tag in image_lower for tag in root_path_images) 

1014 base_health = spec.get("health_check_path", "/health") 

1015 health_path = f"/inference/{name}{base_health}" if uses_root_path else base_health 

1016 

1017 # ALB weighted routing via forward action annotation 

1018 forward_config = _json.dumps( 

1019 { 

1020 "type": "forward", 

1021 "forwardConfig": { 

1022 "targetGroups": [ 

1023 { 

1024 "serviceName": name, 

1025 "servicePort": 80, 

1026 "weight": primary_weight, 

1027 }, 

1028 { 

1029 "serviceName": f"{name}-canary", 

1030 "servicePort": 80, 

1031 "weight": canary_weight, 

1032 }, 

1033 ] 

1034 }, 

1035 } 

1036 ) 

1037 

1038 ingress = client.V1Ingress( 

1039 metadata=client.V1ObjectMeta( 

1040 name=f"inference-{name}", 

1041 namespace=namespace, 

1042 labels={ 

1043 "app": name, 

1044 "project": "gco", 

1045 "gco.io/type": "inference", 

1046 "gco.io/canary": "true", 

1047 }, 

1048 annotations={ 

1049 "alb.ingress.kubernetes.io/healthcheck-path": health_path, 

1050 "alb.ingress.kubernetes.io/healthcheck-interval-seconds": "15", 

1051 "alb.ingress.kubernetes.io/actions.weighted-routing": forward_config, 

1052 }, 

1053 ), 

1054 spec=client.V1IngressSpec( 

1055 ingress_class_name="alb", 

1056 rules=[ 

1057 client.V1IngressRule( 

1058 http=client.V1HTTPIngressRuleValue( 

1059 paths=[ 

1060 client.V1HTTPIngressPath( 

1061 path=ingress_path, 

1062 path_type="Prefix", 

1063 backend=client.V1IngressBackend( 

1064 service=client.V1IngressServiceBackend( 

1065 name="weighted-routing", 

1066 port=client.V1ServiceBackendPort( 

1067 name="use-annotation", 

1068 ), 

1069 ), 

1070 ), 

1071 ) 

1072 ] 

1073 ) 

1074 ) 

1075 ], 

1076 ), 

1077 ) 

1078 

1079 try: 

1080 self.networking_v1.patch_namespaced_ingress( 

1081 f"inference-{name}", namespace, ingress, _request_timeout=self._k8s_timeout 

1082 ) 

1083 logger.info( 

1084 "Updated ingress for %s: primary=%d%% canary=%d%%", 

1085 name, 

1086 primary_weight, 

1087 canary_weight, 

1088 ) 

1089 except ApiException as e: 

1090 if e.status == 404: 

1091 self.networking_v1.create_namespaced_ingress( 

1092 namespace, ingress, _request_timeout=self._k8s_timeout 

1093 ) 

1094 logger.info("Created canary ingress for %s", name) 

1095 else: 

1096 raise 

1097 

1098 def _cleanup_canary(self, name: str, namespace: str) -> None: 

1099 """Remove canary deployment, service, and restore primary-only ingress.""" 

1100 canary_name = f"{name}-canary" 

1101 

1102 # Delete canary deployment 

1103 try: 

1104 self.apps_v1.delete_namespaced_deployment( 

1105 canary_name, namespace, _request_timeout=self._k8s_timeout 

1106 ) 

1107 logger.info("Deleted canary deployment %s", canary_name) 

1108 except ApiException as e: 

1109 if e.status != 404: 

1110 logger.error("Failed to delete canary deployment %s: %s", canary_name, e) 

1111 

1112 # Delete canary service 

1113 try: 

1114 self.core_v1.delete_namespaced_service( 

1115 canary_name, namespace, _request_timeout=self._k8s_timeout 

1116 ) 

1117 logger.info("Deleted canary service %s", canary_name) 

1118 except ApiException as e: 

1119 if e.status != 404: 

1120 logger.error("Failed to delete canary service %s: %s", canary_name, e) 

1121 

1122 def _delete_resources(self, name: str, namespace: str) -> None: 

1123 """Delete all Kubernetes resources for an endpoint.""" 

1124 # Delete canary resources first 

1125 self._cleanup_canary(name, namespace) 

1126 

1127 # Delete deployment 

1128 try: 

1129 self.apps_v1.delete_namespaced_deployment( 

1130 name, namespace, _request_timeout=self._k8s_timeout 

1131 ) 

1132 logger.info("Deleted deployment %s/%s", namespace, name) 

1133 except ApiException as e: 

1134 if e.status != 404: 

1135 logger.error("Failed to delete deployment %s: %s", name, e) 

1136 

1137 # Delete service 

1138 try: 

1139 self.core_v1.delete_namespaced_service( 

1140 name, namespace, _request_timeout=self._k8s_timeout 

1141 ) 

1142 logger.info("Deleted service %s/%s", namespace, name) 

1143 except ApiException as e: 

1144 if e.status != 404: 

1145 logger.error("Failed to delete service %s: %s", name, e) 

1146 

1147 # Delete ingress 

1148 try: 

1149 self.networking_v1.delete_namespaced_ingress( 

1150 f"inference-{name}", namespace, _request_timeout=self._k8s_timeout 

1151 ) 

1152 logger.info("Deleted ingress for %s", name) 

1153 except ApiException as e: 

1154 if e.status != 404: 

1155 logger.error("Failed to delete ingress for %s: %s", name, e) 

1156 

1157 # Delete HPA 

1158 try: 

1159 autoscaling_v2 = client.AutoscalingV2Api() 

1160 autoscaling_v2.delete_namespaced_horizontal_pod_autoscaler(name, namespace) 

1161 logger.info("Deleted HPA for %s", name) 

1162 except ApiException as e: 

1163 if e.status != 404: 

1164 logger.error("Failed to delete HPA for %s: %s", name, e) 

1165 

1166 def _create_or_update_hpa(self, name: str, namespace: str, spec: dict[str, Any]) -> None: 

1167 """Create or update a Horizontal Pod Autoscaler for an inference endpoint.""" 

1168 autoscaling_config = spec.get("autoscaling", {}) 

1169 if not autoscaling_config.get("enabled"): 

1170 return 

1171 

1172 min_replicas = autoscaling_config.get("min_replicas", 1) 

1173 max_replicas = autoscaling_config.get("max_replicas", 10) 

1174 metrics_config = autoscaling_config.get("metrics", [{"type": "cpu", "target": 70}]) 

1175 

1176 # Build HPA metrics 

1177 hpa_metrics = [] 

1178 for m in metrics_config: 

1179 metric_type = m.get("type", "cpu") 

1180 target_value = m.get("target", 70) 

1181 

1182 if metric_type == "cpu": 

1183 hpa_metrics.append( 

1184 client.V2MetricSpec( 

1185 type="Resource", 

1186 resource=client.V2ResourceMetricSource( 

1187 name="cpu", 

1188 target=client.V2MetricTarget( 

1189 type="Utilization", 

1190 average_utilization=target_value, 

1191 ), 

1192 ), 

1193 ) 

1194 ) 

1195 elif metric_type == "memory": 

1196 hpa_metrics.append( 

1197 client.V2MetricSpec( 

1198 type="Resource", 

1199 resource=client.V2ResourceMetricSource( 

1200 name="memory", 

1201 target=client.V2MetricTarget( 

1202 type="Utilization", 

1203 average_utilization=target_value, 

1204 ), 

1205 ), 

1206 ) 

1207 ) 

1208 

1209 if not hpa_metrics: 

1210 # Default to CPU if no recognized metrics 

1211 hpa_metrics.append( 

1212 client.V2MetricSpec( 

1213 type="Resource", 

1214 resource=client.V2ResourceMetricSource( 

1215 name="cpu", 

1216 target=client.V2MetricTarget( 

1217 type="Utilization", 

1218 average_utilization=70, 

1219 ), 

1220 ), 

1221 ) 

1222 ) 

1223 

1224 hpa = client.V2HorizontalPodAutoscaler( 

1225 metadata=client.V1ObjectMeta( 

1226 name=name, 

1227 namespace=namespace, 

1228 labels={ 

1229 "app": name, 

1230 "project": "gco", 

1231 "gco.io/type": "inference", 

1232 }, 

1233 ), 

1234 spec=client.V2HorizontalPodAutoscalerSpec( 

1235 scale_target_ref=client.V2CrossVersionObjectReference( 

1236 api_version="apps/v1", 

1237 kind="Deployment", 

1238 name=name, 

1239 ), 

1240 min_replicas=min_replicas, 

1241 max_replicas=max_replicas, 

1242 metrics=hpa_metrics, 

1243 ), 

1244 ) 

1245 

1246 autoscaling_v2 = client.AutoscalingV2Api() 

1247 try: 

1248 autoscaling_v2.create_namespaced_horizontal_pod_autoscaler(namespace, hpa) 

1249 logger.info("Created HPA for %s (min=%d, max=%d)", name, min_replicas, max_replicas) 

1250 except ApiException as e: 

1251 if e.status == 409: 1251 ↛ 1255line 1251 didn't jump to line 1255 because the condition on line 1251 was always true

1252 autoscaling_v2.patch_namespaced_horizontal_pod_autoscaler(name, namespace, hpa) 

1253 logger.info("Updated HPA for %s", name) 

1254 else: 

1255 raise 

1256 

1257 # ------------------------------------------------------------------ 

1258 # Metrics 

1259 # ------------------------------------------------------------------ 

1260 

1261 def get_metrics(self) -> dict[str, Any]: 

1262 return { 

1263 "cluster_id": self.cluster_id, 

1264 "region": self.region, 

1265 "running": self._running, 

1266 "reconcile_count": self._reconcile_count, 

1267 "errors_count": self._errors_count, 

1268 } 

1269 

1270 

1271def create_inference_monitor_from_env() -> InferenceMonitor: 

1272 """Create an InferenceMonitor from environment variables.""" 

1273 cluster_id = os.getenv("CLUSTER_NAME", "unknown-cluster") 

1274 region = os.getenv("REGION", "unknown-region") 

1275 namespace = os.getenv("INFERENCE_NAMESPACE", "gco-inference") 

1276 interval = int(os.getenv("RECONCILE_INTERVAL_SECONDS", "15")) 

1277 

1278 # Enable structured JSON logging for CloudWatch Insights 

1279 configure_structured_logging( 

1280 service_name="inference-monitor", 

1281 cluster_id=cluster_id, 

1282 region=region, 

1283 ) 

1284 

1285 store = InferenceEndpointStore() # Uses DYNAMODB_REGION env var, falls back to REGION 

1286 

1287 return InferenceMonitor( 

1288 cluster_id=cluster_id, 

1289 region=region, 

1290 store=store, 

1291 namespace=namespace, 

1292 reconcile_interval=interval, 

1293 ) 

1294 

1295 

1296async def main() -> None: 

1297 """Entry point for the inference monitor.""" 

1298 monitor = create_inference_monitor_from_env() 

1299 logger.info("Inference monitor initialized: %s", monitor.get_metrics()) 

1300 

1301 while True: 

1302 try: 

1303 await monitor.start() 

1304 except KeyboardInterrupt: 

1305 logger.info("Shutting down inference monitor") 

1306 monitor.stop() 

1307 break 

1308 except Exception as e: 

1309 logger.error("Monitor crashed, restarting in 10s: %s", e, exc_info=True) 

1310 monitor.stop() 

1311 monitor._running = False 

1312 await asyncio.sleep(10) 

1313 

1314 

1315if __name__ == "__main__": 

1316 asyncio.run(main())