Coverage for cli / inference.py: 98%

148 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 21:47 +0000

1""" 

2Inference endpoint management for GCO CLI. 

3 

4Provides functionality to deploy, manage, and monitor inference endpoints 

5across multi-region EKS clusters via the DynamoDB-backed reconciliation 

6pattern (inference_monitor). 

7""" 

8 

9from __future__ import annotations 

10 

11import logging 

12from typing import TYPE_CHECKING, Any 

13 

14from .aws_client import get_aws_client 

15from .config import GCOConfig, get_config 

16 

17if TYPE_CHECKING: 

18 from gco.services.inference_store import InferenceEndpointStore 

19 

20logger = logging.getLogger(__name__) 

21 

22 

23class InferenceManager: 

24 """Manages inference endpoints via the DynamoDB store.""" 

25 

26 def __init__(self, config: GCOConfig | None = None): 

27 self.config = config or get_config() 

28 self._aws_client = get_aws_client(config) 

29 

30 def _get_store(self, region: str | None = None) -> InferenceEndpointStore: 

31 """Get an InferenceEndpointStore for the global region.""" 

32 from gco.services.inference_store import InferenceEndpointStore 

33 

34 # Use the global region for DynamoDB (same as job store) 

35 store_region = region or self.config.global_region 

36 return InferenceEndpointStore(region=store_region) 

37 

38 def deploy( 

39 self, 

40 endpoint_name: str, 

41 image: str, 

42 target_regions: list[str] | None = None, 

43 replicas: int = 1, 

44 gpu_count: int = 1, 

45 gpu_type: str | None = None, 

46 port: int = 8000, 

47 model_path: str | None = None, 

48 model_source: str | None = None, 

49 health_check_path: str = "/health", 

50 env: dict[str, str] | None = None, 

51 namespace: str = "gco-inference", 

52 labels: dict[str, str] | None = None, 

53 autoscaling: dict[str, Any] | None = None, 

54 capacity_type: str | None = None, 

55 extra_args: list[str] | None = None, 

56 accelerator: str = "nvidia", 

57 node_selector: dict[str, str] | None = None, 

58 ) -> dict[str, Any]: 

59 """ 

60 Deploy an inference endpoint to one or more regions. 

61 

62 The endpoint spec is written to DynamoDB. The inference_monitor 

63 in each target region picks it up and creates the K8s resources. 

64 

65 Args: 

66 endpoint_name: Unique name for the endpoint 

67 image: Container image (e.g. vllm/vllm-openai:v0.8.0) 

68 target_regions: Regions to deploy to (default: all deployed regions) 

69 replicas: Number of replicas per region 

70 gpu_count: GPUs per replica 

71 gpu_type: GPU instance type hint for node selector 

72 port: Container port 

73 model_path: EFS path for model weights 

74 health_check_path: Health check endpoint path 

75 env: Environment variables 

76 namespace: Kubernetes namespace 

77 labels: Labels for the endpoint 

78 

79 Returns: 

80 Created endpoint record 

81 """ 

82 if not target_regions: 

83 stacks = self._aws_client.discover_regional_stacks() 

84 target_regions = list(stacks.keys()) 

85 if not target_regions: 

86 raise ValueError("No deployed regions found. Deploy infrastructure first.") 

87 

88 spec = { 

89 "image": image, 

90 "port": port, 

91 "replicas": replicas, 

92 "gpu_count": gpu_count, 

93 "health_check_path": health_check_path, 

94 } 

95 if gpu_type: 

96 spec["gpu_type"] = gpu_type 

97 if model_path: 

98 spec["model_path"] = model_path 

99 if model_source: 

100 spec["model_source"] = model_source 

101 if env: 

102 spec["env"] = env 

103 if autoscaling: 

104 spec["autoscaling"] = autoscaling 

105 if capacity_type: 

106 spec["capacity_type"] = capacity_type 

107 if extra_args: 

108 spec["args"] = extra_args 

109 if accelerator != "nvidia": 

110 spec["accelerator"] = accelerator 

111 if node_selector: 

112 spec["node_selector"] = node_selector 

113 

114 store = self._get_store() 

115 result: dict[str, Any] = store.create_endpoint( 

116 endpoint_name=endpoint_name, 

117 spec=spec, 

118 target_regions=target_regions, 

119 namespace=namespace, 

120 labels=labels, 

121 ) 

122 return result 

123 

124 def list_endpoints( 

125 self, 

126 desired_state: str | None = None, 

127 region: str | None = None, 

128 ) -> list[dict[str, Any]]: 

129 """List all inference endpoints.""" 

130 store = self._get_store() 

131 result: list[dict[str, Any]] = store.list_endpoints( 

132 desired_state=desired_state, 

133 target_region=region, 

134 ) 

135 return result 

136 

137 def get_endpoint(self, endpoint_name: str) -> dict[str, Any] | None: 

138 """Get details of a specific endpoint.""" 

139 store = self._get_store() 

140 result: dict[str, Any] | None = store.get_endpoint(endpoint_name) 

141 return result 

142 

143 def scale(self, endpoint_name: str, replicas: int) -> dict[str, Any] | None: 

144 """Scale an endpoint to a new replica count.""" 

145 store = self._get_store() 

146 result: dict[str, Any] | None = store.scale_endpoint(endpoint_name, replicas) 

147 return result 

148 

149 def stop(self, endpoint_name: str) -> dict[str, Any] | None: 

150 """Stop an endpoint (scale to zero, keep resources).""" 

151 store = self._get_store() 

152 result: dict[str, Any] | None = store.update_desired_state(endpoint_name, "stopped") 

153 return result 

154 

155 def start(self, endpoint_name: str) -> dict[str, Any] | None: 

156 """Start a stopped endpoint.""" 

157 store = self._get_store() 

158 result: dict[str, Any] | None = store.update_desired_state(endpoint_name, "running") 

159 return result 

160 

161 def delete(self, endpoint_name: str) -> dict[str, Any] | None: 

162 """Mark an endpoint for deletion (inference_monitor cleans up).""" 

163 store = self._get_store() 

164 result: dict[str, Any] | None = store.update_desired_state(endpoint_name, "deleted") 

165 return result 

166 

167 def update_image(self, endpoint_name: str, image: str) -> dict[str, Any] | None: 

168 """Update the container image for an endpoint.""" 

169 store = self._get_store() 

170 endpoint = store.get_endpoint(endpoint_name) 

171 if not endpoint: 

172 return None 

173 spec = endpoint.get("spec", {}) 

174 spec["image"] = image 

175 result: dict[str, Any] | None = store.update_spec(endpoint_name, spec) 

176 return result 

177 

178 def add_region(self, endpoint_name: str, region: str) -> dict[str, Any] | None: 

179 """Add a region to an existing endpoint.""" 

180 from datetime import UTC, datetime 

181 

182 store = self._get_store() 

183 endpoint = store.get_endpoint(endpoint_name) 

184 if not endpoint: 

185 return None 

186 regions = endpoint.get("target_regions", []) 

187 if region not in regions: 

188 regions.append(region) 

189 # Update via raw DynamoDB update 

190 try: 

191 response = store._table.update_item( 

192 Key={"endpoint_name": endpoint_name}, 

193 UpdateExpression="SET target_regions = :r, updated_at = :u", 

194 ExpressionAttributeValues={ 

195 ":r": regions, 

196 ":u": datetime.now(UTC).isoformat(), 

197 }, 

198 ReturnValues="ALL_NEW", 

199 ) 

200 result: dict[str, Any] | None = response.get("Attributes") 

201 return result 

202 except Exception as e: 

203 logger.error("Failed to add region: %s", e) 

204 return None 

205 

206 def remove_region(self, endpoint_name: str, region: str) -> dict[str, Any] | None: 

207 """Remove a region from an existing endpoint.""" 

208 store = self._get_store() 

209 endpoint = store.get_endpoint(endpoint_name) 

210 if not endpoint: 

211 return None 

212 regions = endpoint.get("target_regions", []) 

213 if region in regions: 

214 regions.remove(region) 

215 try: 

216 from datetime import UTC, datetime 

217 

218 response = store._table.update_item( 

219 Key={"endpoint_name": endpoint_name}, 

220 UpdateExpression="SET target_regions = :r, updated_at = :u", 

221 ExpressionAttributeValues={ 

222 ":r": regions, 

223 ":u": datetime.now(UTC).isoformat(), 

224 }, 

225 ReturnValues="ALL_NEW", 

226 ) 

227 result: dict[str, Any] | None = response.get("Attributes") 

228 return result 

229 except Exception as e: 

230 logger.error("Failed to remove region: %s", e) 

231 return None 

232 

233 def canary_deploy( 

234 self, 

235 endpoint_name: str, 

236 image: str, 

237 weight: int = 10, 

238 replicas: int = 1, 

239 ) -> dict[str, Any] | None: 

240 """Start a canary deployment for an existing endpoint. 

241 

242 Creates a canary variant with the new image receiving `weight`% 

243 of traffic. The primary deployment continues serving the rest. 

244 

245 Args: 

246 endpoint_name: Existing endpoint to canary 

247 image: New container image for the canary 

248 weight: Percentage of traffic to route to canary (1-99) 

249 replicas: Number of canary replicas 

250 

251 Returns: 

252 Updated endpoint record, or None if endpoint not found 

253 """ 

254 if not 1 <= weight <= 99: 

255 raise ValueError("Canary weight must be between 1 and 99") 

256 

257 store = self._get_store() 

258 endpoint = store.get_endpoint(endpoint_name) 

259 if not endpoint: 

260 return None 

261 

262 if endpoint.get("desired_state") not in ("running", "deploying"): 

263 raise ValueError( 

264 f"Cannot canary an endpoint in '{endpoint.get('desired_state')}' state. " 

265 "Endpoint must be running." 

266 ) 

267 

268 spec = endpoint.get("spec", {}) 

269 spec["canary"] = { 

270 "image": image, 

271 "weight": weight, 

272 "replicas": replicas, 

273 } 

274 

275 result: dict[str, Any] | None = store.update_spec(endpoint_name, spec) 

276 return result 

277 

278 def promote_canary(self, endpoint_name: str) -> dict[str, Any] | None: 

279 """Promote the canary to primary, removing the canary deployment. 

280 

281 The primary image is replaced with the canary image, and the 

282 canary config is removed. All traffic goes to the new image. 

283 

284 Returns: 

285 Updated endpoint record, or None if endpoint not found 

286 """ 

287 store = self._get_store() 

288 endpoint = store.get_endpoint(endpoint_name) 

289 if not endpoint: 

290 return None 

291 

292 spec = endpoint.get("spec", {}) 

293 canary = spec.get("canary") 

294 if not canary: 

295 raise ValueError(f"Endpoint '{endpoint_name}' has no active canary deployment") 

296 

297 if "image" not in canary: 

298 raise ValueError( 

299 f"Canary deployment for '{endpoint_name}' is missing the 'image' field" 

300 ) 

301 

302 # Swap primary image to canary image 

303 spec["image"] = canary["image"] 

304 # Remove canary config 

305 del spec["canary"] 

306 

307 result: dict[str, Any] | None = store.update_spec(endpoint_name, spec) 

308 return result 

309 

310 def rollback_canary(self, endpoint_name: str) -> dict[str, Any] | None: 

311 """Remove the canary deployment, keeping the primary unchanged. 

312 

313 All traffic returns to the primary deployment. 

314 

315 Returns: 

316 Updated endpoint record, or None if endpoint not found 

317 """ 

318 store = self._get_store() 

319 endpoint = store.get_endpoint(endpoint_name) 

320 if not endpoint: 

321 return None 

322 

323 spec = endpoint.get("spec", {}) 

324 if "canary" not in spec: 

325 raise ValueError(f"Endpoint '{endpoint_name}' has no active canary deployment") 

326 

327 del spec["canary"] 

328 result: dict[str, Any] | None = store.update_spec(endpoint_name, spec) 

329 return result 

330 

331 

332def get_inference_manager(config: GCOConfig | None = None) -> InferenceManager: 

333 """Factory function for InferenceManager.""" 

334 return InferenceManager(config)