Coverage for cli/capacity/advisor.py: 88%

206 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-15 15:07 +0000

1"""Bedrock-powered AI capacity advisor.""" 

2 

3from __future__ import annotations 

4 

5import json 

6import logging 

7from dataclasses import dataclass, field 

8from datetime import UTC, datetime, timedelta 

9from typing import Any 

10 

11import boto3 

12from botocore.exceptions import ClientError 

13 

14from cli.config import GCOConfig, get_config 

15 

16from .checker import CapacityChecker 

17from .multi_region import MultiRegionCapacityChecker, compute_price_trend 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22@dataclass 

23class BedrockCapacityRecommendation: 

24 """AI-generated capacity recommendation from Bedrock.""" 

25 

26 recommended_region: str 

27 recommended_instance_type: str 

28 recommended_capacity_type: str # "spot" or "on-demand" 

29 reasoning: str 

30 confidence: str # "high", "medium", "low" 

31 cost_estimate: str | None = None 

32 alternative_options: list[dict[str, Any]] = field(default_factory=list) 

33 warnings: list[str] = field(default_factory=list) 

34 raw_response: str = "" 

35 

36 

37class BedrockCapacityAdvisor: 

38 """ 

39 AI-powered capacity advisor using Amazon Bedrock. 

40 

41 Gathers comprehensive capacity data and uses an LLM to provide 

42 intelligent recommendations for workload placement. 

43 

44 DISCLAIMER: Recommendations are AI-generated and should be validated 

45 before making production decisions. 

46 """ 

47 

48 # Default model to use if none specified 

49 DEFAULT_MODEL = "us.anthropic.claude-sonnet-4-5-20250929-v1:0" 

50 

51 def __init__(self, config: GCOConfig | None = None, model_id: str | None = None): 

52 self.config = config or get_config() 

53 self._session = boto3.Session() 

54 self._capacity_checker = CapacityChecker(config) 

55 self._multi_region_checker = MultiRegionCapacityChecker(config) 

56 self.model_id = model_id or self.DEFAULT_MODEL 

57 

58 def _get_bedrock_client(self) -> Any: 

59 """Get Bedrock runtime client.""" 

60 return self._session.client("bedrock-runtime", region_name="us-east-1") 

61 

62 def gather_capacity_data( 

63 self, 

64 instance_types: list[str] | None = None, 

65 regions: list[str] | None = None, 

66 ) -> dict[str, Any]: 

67 """ 

68 Gather comprehensive capacity data for AI analysis. 

69 

70 Args: 

71 instance_types: List of instance types to analyze (defaults to common GPU types) 

72 regions: List of regions to check (defaults to deployed GCO regions) 

73 

74 Returns: 

75 Dictionary containing all gathered capacity data 

76 """ 

77 from cli.aws_client import get_aws_client 

78 

79 # Default to common GPU instance types if not specified 

80 if not instance_types: 

81 instance_types = [ 

82 "g4dn.xlarge", 

83 "g4dn.2xlarge", 

84 "g4dn.4xlarge", 

85 "g5.xlarge", 

86 "g5.2xlarge", 

87 "g5.4xlarge", 

88 "p3.2xlarge", 

89 "p4d.24xlarge", 

90 ] 

91 

92 # Get deployed regions if not specified 

93 if not regions: 

94 aws_client = get_aws_client(self.config) 

95 stacks = aws_client.discover_regional_stacks() 

96 regions = list(stacks.keys()) if stacks else [self.config.default_region] 

97 

98 data: dict[str, Any] = { 

99 "timestamp": datetime.now(UTC).isoformat(), 

100 "regions_analyzed": regions, 

101 "instance_types_analyzed": instance_types, 

102 "regional_capacity": {}, 

103 "spot_data": {}, 

104 "on_demand_data": {}, 

105 "cluster_metrics": [], 

106 "queue_status": {}, 

107 } 

108 

109 # Gather regional cluster metrics 

110 for region in regions: 

111 try: 

112 capacity = self._multi_region_checker.get_region_capacity(region) 

113 data["cluster_metrics"].append( 

114 { 

115 "region": region, 

116 "queue_depth": capacity.queue_depth, 

117 "running_jobs": capacity.running_jobs, 

118 "pending_jobs": capacity.pending_jobs, 

119 "gpu_utilization": capacity.gpu_utilization, 

120 "cpu_utilization": capacity.cpu_utilization, 

121 "recommendation_score": capacity.recommendation_score, 

122 } 

123 ) 

124 except Exception as e: 

125 logger.debug("Failed to get cluster metrics for %s: %s", region, e) 

126 for instance_type in instance_types: 

127 data["spot_data"][instance_type] = {} 

128 data["on_demand_data"][instance_type] = {} 

129 

130 for region in regions: 

131 try: 

132 # Get spot placement scores and prices 

133 spot_scores = self._capacity_checker.get_spot_placement_score( 

134 instance_type, region 

135 ) 

136 spot_prices = self._capacity_checker.get_spot_price_history( 

137 instance_type, region, days=7 

138 ) 

139 on_demand_price = self._capacity_checker.get_on_demand_price( 

140 instance_type, region 

141 ) 

142 

143 data["spot_data"][instance_type][region] = { 

144 "placement_scores": spot_scores, 

145 "prices": [ 

146 { 

147 "az": p.availability_zone, 

148 "current": p.current_price, 

149 "avg_7d": p.avg_price_7d, 

150 "stability": p.price_stability, 

151 } 

152 for p in spot_prices 

153 ], 

154 } 

155 

156 # Spot price trend analysis per AZ (for AI interpretation) 

157 try: 

158 ec2 = self._session.client("ec2", region_name=region) 

159 raw_resp = ec2.describe_spot_price_history( 

160 InstanceTypes=[instance_type], 

161 ProductDescriptions=["Linux/UNIX"], 

162 StartTime=datetime.now(UTC) - timedelta(days=7), 

163 EndTime=datetime.now(UTC), 

164 ) 

165 az_raw: dict[str, list[float]] = {} 

166 for item in raw_resp.get("SpotPriceHistory", []): 

167 az = item["AvailabilityZone"] 

168 if az not in az_raw: 

169 az_raw[az] = [] 

170 az_raw[az].append(float(item["SpotPrice"])) 

171 az_trends = { 

172 az: compute_price_trend(prices) 

173 for az, prices in az_raw.items() 

174 if len(prices) >= 2 

175 } 

176 if az_trends: 

177 data["spot_data"][instance_type][region]["price_trends"] = az_trends 

178 except Exception as e: 

179 logger.debug( 

180 "Failed to get price trends for %s in %s: %s", instance_type, region, e 

181 ) 

182 

183 data["on_demand_data"][instance_type][region] = { 

184 "price_per_hour": on_demand_price, 

185 "available": self._capacity_checker.check_instance_available_in_region( 

186 instance_type, region 

187 ), 

188 } 

189 except Exception as e: 

190 logger.debug( 

191 "Failed to gather capacity data for %s in %s: %s", instance_type, region, e 

192 ) 

193 

194 # Gather capacity reservation and block data 

195 data["reservations"] = {} 

196 data["capacity_blocks"] = {} 

197 for instance_type in instance_types: 

198 data["reservations"][instance_type] = {} 

199 data["capacity_blocks"][instance_type] = {} 

200 for region in regions: 

201 try: 

202 odcrs = self._capacity_checker.list_capacity_reservations( 

203 region, instance_type=instance_type 

204 ) 

205 if odcrs: 

206 data["reservations"][instance_type][region] = [ 

207 { 

208 "az": r["availability_zone"], 

209 "total": r["total_instances"], 

210 "available": r["available_instances"], 

211 "utilization_pct": r["utilization_pct"], 

212 } 

213 for r in odcrs 

214 ] 

215 except Exception as e: 

216 logger.debug( 

217 "Failed to list reservations for %s in %s: %s", instance_type, region, e 

218 ) 

219 

220 try: 

221 blocks = self._capacity_checker.list_capacity_block_offerings( 

222 region, instance_type=instance_type, instance_count=1, duration_hours=24 

223 ) 

224 if blocks: 

225 data["capacity_blocks"][instance_type][region] = [ 

226 { 

227 "az": b["availability_zone"], 

228 "duration_hours": b["duration_hours"], 

229 "start_date": b["start_date"], 

230 "upfront_fee": b["upfront_fee"], 

231 } 

232 for b in blocks 

233 ] 

234 except Exception as e: 

235 logger.debug( 

236 "Failed to list capacity blocks for %s in %s: %s", instance_type, region, e 

237 ) 

238 

239 # Capacity block availability trends (26-week regression per instance type per region) 

240 data["capacity_block_trends"] = {} 

241 for instance_type in instance_types: 

242 data["capacity_block_trends"][instance_type] = {} 

243 for region in regions: 

244 try: 

245 trend = self._capacity_checker.get_capacity_block_trend(instance_type, region) 

246 if trend != 0.0: 

247 data["capacity_block_trends"][instance_type][region] = { 

248 "trend_score": trend, 

249 "interpretation": ( 

250 "capacity growing" 

251 if trend > 0.2 

252 else "capacity shrinking" 

253 if trend < -0.2 

254 else "stable" 

255 ), 

256 } 

257 except Exception as e: 

258 logger.debug( 

259 "Failed to get capacity block trend for %s in %s: %s", 

260 instance_type, 

261 region, 

262 e, 

263 ) 

264 

265 # Weighted recommendation scores (algorithmic ranking for AI context) 

266 try: 

267 weighted_results = self._multi_region_checker.recommend_region_for_job( 

268 instance_type=instance_types[0] if instance_types else None, 

269 ) 

270 data["weighted_recommendation"] = { 

271 "top_region": weighted_results.get("region"), 

272 "scoring_method": weighted_results.get("scoring_method", "simple"), 

273 "all_regions": weighted_results.get("all_regions", []), 

274 } 

275 except Exception as e: 

276 logger.debug("Failed to compute weighted recommendation: %s", e) 

277 

278 return data 

279 

280 def _build_prompt( 

281 self, 

282 capacity_data: dict[str, Any], 

283 workload_description: str | None = None, 

284 requirements: dict[str, Any] | None = None, 

285 ) -> str: 

286 """Build the prompt for Bedrock.""" 

287 requirements = requirements or {} 

288 

289 prompt = """You are an expert AWS capacity planning advisor for GPU/ML workloads. 

290Analyze the following capacity data and provide a recommendation for where to place a workload. 

291 

292IMPORTANT DISCLAIMERS: 

293- This is AI-generated advice and should be validated before production use 

294- Capacity availability can change rapidly 

295- Spot instances may be interrupted at any time 

296- Pricing data may not reflect real-time prices 

297 

298""" 

299 

300 if workload_description: 

301 prompt += f"WORKLOAD DESCRIPTION:\n{workload_description}\n\n" 

302 

303 if requirements: 

304 prompt += "REQUIREMENTS:\n" 

305 if requirements.get("gpu_required"): 305 ↛ 307line 305 didn't jump to line 307 because the condition on line 305 was always true

306 prompt += "- GPU Required: Yes\n" 

307 if requirements.get("min_gpus"): 307 ↛ 309line 307 didn't jump to line 309 because the condition on line 307 was always true

308 prompt += f"- Minimum GPUs: {requirements['min_gpus']}\n" 

309 if requirements.get("min_memory_gb"): 309 ↛ 310line 309 didn't jump to line 310 because the condition on line 309 was never true

310 prompt += f"- Minimum Memory: {requirements['min_memory_gb']} GB\n" 

311 if requirements.get("fault_tolerance"): 311 ↛ 312line 311 didn't jump to line 312 because the condition on line 311 was never true

312 prompt += f"- Fault Tolerance: {requirements['fault_tolerance']}\n" 

313 if requirements.get("max_cost_per_hour"): 313 ↛ 314line 313 didn't jump to line 314 because the condition on line 313 was never true

314 prompt += f"- Max Cost/Hour: ${requirements['max_cost_per_hour']}\n" 

315 prompt += "\n" 

316 

317 prompt += "CAPACITY DATA:\n" 

318 prompt += f"Timestamp: {capacity_data.get('timestamp', 'N/A')}\n" 

319 prompt += f"Regions Analyzed: {', '.join(capacity_data.get('regions_analyzed', []))}\n" 

320 prompt += ( 

321 f"Instance Types: {', '.join(capacity_data.get('instance_types_analyzed', []))}\n\n" 

322 ) 

323 

324 # Cluster metrics 

325 if capacity_data.get("cluster_metrics"): 

326 prompt += "CLUSTER METRICS BY REGION:\n" 

327 for m in capacity_data["cluster_metrics"]: 

328 prompt += f" {m['region']}:\n" 

329 prompt += f" - Queue Depth: {m['queue_depth']}\n" 

330 prompt += f" - Running Jobs: {m['running_jobs']}\n" 

331 prompt += f" - GPU Utilization: {m['gpu_utilization']:.1f}%\n" 

332 prompt += f" - CPU Utilization: {m['cpu_utilization']:.1f}%\n" 

333 prompt += "\n" 

334 

335 # Spot data summary 

336 prompt += "SPOT CAPACITY SUMMARY:\n" 

337 for instance_type, regions_data in capacity_data.get("spot_data", {}).items(): 

338 prompt += f" {instance_type}:\n" 

339 for region, spot_info in regions_data.items(): 

340 scores = spot_info.get("placement_scores", {}) 

341 regional_score = scores.get("regional", "N/A") 

342 prices = spot_info.get("prices", []) 

343 avg_price = sum(p["current"] for p in prices) / len(prices) if prices else "N/A" 

344 prompt += f" {region}: Score={regional_score}/10, " 

345 prompt += f"Avg Price=${avg_price if isinstance(avg_price, str) else f'{avg_price:.4f}'}/hr\n" 

346 prompt += "\n" 

347 

348 # On-demand data summary 

349 prompt += "ON-DEMAND PRICING:\n" 

350 for instance_type, regions_data in capacity_data.get("on_demand_data", {}).items(): 

351 prompt += f" {instance_type}:\n" 

352 for region, od_info in regions_data.items(): 

353 price = od_info.get("price_per_hour") 

354 available = od_info.get("available", False) 

355 prompt += f" {region}: ${price:.4f}/hr" if price else f" {region}: N/A" 

356 prompt += f" (Available: {available})\n" 

357 prompt += "\n" 

358 

359 # Capacity reservations (ODCRs) 

360 reservations = capacity_data.get("reservations", {}) 

361 has_reservations = any(bool(regions_data) for regions_data in reservations.values()) 

362 if has_reservations: 362 ↛ 363line 362 didn't jump to line 363 because the condition on line 362 was never true

363 prompt += "CAPACITY RESERVATIONS (ODCRs):\n" 

364 for instance_type, regions_data in reservations.items(): 

365 for region, odcrs in regions_data.items(): 

366 for r in odcrs: 

367 prompt += ( 

368 f" {instance_type} in {region} ({r['az']}): " 

369 f"{r['available']}/{r['total']} available " 

370 f"({r['utilization_pct']}% used)\n" 

371 ) 

372 prompt += "\n" 

373 

374 # Capacity Blocks for ML 

375 blocks = capacity_data.get("capacity_blocks", {}) 

376 has_blocks = any(bool(regions_data) for regions_data in blocks.values()) 

377 if has_blocks: 377 ↛ 378line 377 didn't jump to line 378 because the condition on line 377 was never true

378 prompt += "CAPACITY BLOCK OFFERINGS (guaranteed GPU blocks):\n" 

379 for instance_type, regions_data in blocks.items(): 

380 for region, offerings in regions_data.items(): 

381 for b in offerings: 

382 prompt += ( 

383 f" {instance_type} in {region} ({b['az']}): " 

384 f"{b['duration_hours']}h starting {b['start_date']}, " 

385 f"${b['upfront_fee']}\n" 

386 ) 

387 prompt += "\n" 

388 

389 prompt += """Based on this data, provide your recommendation in the following JSON format: 

390{ 

391 "recommended_region": "region-name", 

392 "recommended_instance_type": "instance-type", 

393 "recommended_capacity_type": "spot, on-demand, odcr, or capacity-block", 

394 "reasoning": "Detailed explanation of why this is the best choice", 

395 "confidence": "high, medium, or low", 

396 "cost_estimate": "Estimated hourly cost", 

397 "reservation_advice": "If ODCRs or Capacity Blocks are available, explain how to use them. If not, suggest whether the user should consider purchasing a Capacity Block.", 

398 "alternative_options": [ 

399 {"region": "...", "instance_type": "...", "capacity_type": "...", "reason": "..."} 

400 ], 

401 "warnings": ["Any important warnings or caveats"] 

402} 

403 

404Respond ONLY with the JSON object, no additional text.""" 

405 

406 return prompt 

407 

408 def get_recommendation( 

409 self, 

410 workload_description: str | None = None, 

411 instance_types: list[str] | None = None, 

412 regions: list[str] | None = None, 

413 requirements: dict[str, Any] | None = None, 

414 ) -> BedrockCapacityRecommendation: 

415 """ 

416 Get an AI-powered capacity recommendation. 

417 

418 Args: 

419 workload_description: Description of the workload 

420 instance_types: List of instance types to consider 

421 regions: List of regions to consider 

422 requirements: Dictionary of requirements (gpu_required, min_gpus, etc.) 

423 

424 Returns: 

425 BedrockCapacityRecommendation with the AI's recommendation 

426 """ 

427 # Gather capacity data 

428 capacity_data = self.gather_capacity_data(instance_types, regions) 

429 

430 # Build prompt 

431 prompt = self._build_prompt(capacity_data, workload_description, requirements) 

432 

433 # Call Bedrock 

434 bedrock = self._get_bedrock_client() 

435 

436 try: 

437 # Use the Converse API for better compatibility across models 

438 response = bedrock.converse( 

439 modelId=self.model_id, 

440 messages=[{"role": "user", "content": [{"text": prompt}]}], 

441 inferenceConfig={"maxTokens": 2048, "temperature": 0.1}, 

442 ) 

443 

444 # Extract response text 

445 response_text = response["output"]["message"]["content"][0]["text"] 

446 

447 # Parse JSON response 

448 # Find JSON in response (in case model adds extra text) 

449 json_start = response_text.find("{") 

450 json_end = response_text.rfind("}") + 1 

451 if json_start >= 0 and json_end > json_start: 

452 json_str = response_text[json_start:json_end] 

453 result = json.loads(json_str) 

454 else: 

455 raise ValueError("No valid JSON found in response") 

456 

457 return BedrockCapacityRecommendation( 

458 recommended_region=result.get("recommended_region", "unknown"), 

459 recommended_instance_type=result.get("recommended_instance_type", "unknown"), 

460 recommended_capacity_type=result.get("recommended_capacity_type", "spot"), 

461 reasoning=result.get("reasoning", ""), 

462 confidence=result.get("confidence", "low"), 

463 cost_estimate=result.get("cost_estimate"), 

464 alternative_options=result.get("alternative_options", []), 

465 warnings=result.get("warnings", []), 

466 raw_response=response_text, 

467 ) 

468 

469 except ClientError as e: 

470 error_code = e.response.get("Error", {}).get("Code", "") 

471 if error_code == "AccessDeniedException": 

472 raise RuntimeError( 

473 "Access denied to Bedrock. Ensure your IAM role has " 

474 "bedrock:InvokeModel permission and the model is enabled in your account." 

475 ) from e 

476 if error_code == "ValidationException": 

477 raise RuntimeError( 

478 f"Model {self.model_id} may not be available. " 

479 "Try a different model with --model option." 

480 ) from e 

481 raise RuntimeError(f"Bedrock API error: {e}") from e 

482 except json.JSONDecodeError as e: 

483 raise RuntimeError(f"Failed to parse AI response as JSON: {e}") from e 

484 except Exception as e: 

485 raise RuntimeError(f"Failed to get AI recommendation: {e}") from e 

486 

487 

488def get_bedrock_capacity_advisor( 

489 config: GCOConfig | None = None, model_id: str | None = None 

490) -> BedrockCapacityAdvisor: 

491 """Get a configured Bedrock capacity advisor instance.""" 

492 return BedrockCapacityAdvisor(config, model_id)