Coverage for cli/capacity/advisor.py: 88%
206 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-15 15:07 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-15 15:07 +0000
1"""Bedrock-powered AI capacity advisor."""
3from __future__ import annotations
5import json
6import logging
7from dataclasses import dataclass, field
8from datetime import UTC, datetime, timedelta
9from typing import Any
11import boto3
12from botocore.exceptions import ClientError
14from cli.config import GCOConfig, get_config
16from .checker import CapacityChecker
17from .multi_region import MultiRegionCapacityChecker, compute_price_trend
19logger = logging.getLogger(__name__)
22@dataclass
23class BedrockCapacityRecommendation:
24 """AI-generated capacity recommendation from Bedrock."""
26 recommended_region: str
27 recommended_instance_type: str
28 recommended_capacity_type: str # "spot" or "on-demand"
29 reasoning: str
30 confidence: str # "high", "medium", "low"
31 cost_estimate: str | None = None
32 alternative_options: list[dict[str, Any]] = field(default_factory=list)
33 warnings: list[str] = field(default_factory=list)
34 raw_response: str = ""
37class BedrockCapacityAdvisor:
38 """
39 AI-powered capacity advisor using Amazon Bedrock.
41 Gathers comprehensive capacity data and uses an LLM to provide
42 intelligent recommendations for workload placement.
44 DISCLAIMER: Recommendations are AI-generated and should be validated
45 before making production decisions.
46 """
48 # Default model to use if none specified
49 DEFAULT_MODEL = "us.anthropic.claude-sonnet-4-5-20250929-v1:0"
51 def __init__(self, config: GCOConfig | None = None, model_id: str | None = None):
52 self.config = config or get_config()
53 self._session = boto3.Session()
54 self._capacity_checker = CapacityChecker(config)
55 self._multi_region_checker = MultiRegionCapacityChecker(config)
56 self.model_id = model_id or self.DEFAULT_MODEL
58 def _get_bedrock_client(self) -> Any:
59 """Get Bedrock runtime client."""
60 return self._session.client("bedrock-runtime", region_name="us-east-1")
62 def gather_capacity_data(
63 self,
64 instance_types: list[str] | None = None,
65 regions: list[str] | None = None,
66 ) -> dict[str, Any]:
67 """
68 Gather comprehensive capacity data for AI analysis.
70 Args:
71 instance_types: List of instance types to analyze (defaults to common GPU types)
72 regions: List of regions to check (defaults to deployed GCO regions)
74 Returns:
75 Dictionary containing all gathered capacity data
76 """
77 from cli.aws_client import get_aws_client
79 # Default to common GPU instance types if not specified
80 if not instance_types:
81 instance_types = [
82 "g4dn.xlarge",
83 "g4dn.2xlarge",
84 "g4dn.4xlarge",
85 "g5.xlarge",
86 "g5.2xlarge",
87 "g5.4xlarge",
88 "p3.2xlarge",
89 "p4d.24xlarge",
90 ]
92 # Get deployed regions if not specified
93 if not regions:
94 aws_client = get_aws_client(self.config)
95 stacks = aws_client.discover_regional_stacks()
96 regions = list(stacks.keys()) if stacks else [self.config.default_region]
98 data: dict[str, Any] = {
99 "timestamp": datetime.now(UTC).isoformat(),
100 "regions_analyzed": regions,
101 "instance_types_analyzed": instance_types,
102 "regional_capacity": {},
103 "spot_data": {},
104 "on_demand_data": {},
105 "cluster_metrics": [],
106 "queue_status": {},
107 }
109 # Gather regional cluster metrics
110 for region in regions:
111 try:
112 capacity = self._multi_region_checker.get_region_capacity(region)
113 data["cluster_metrics"].append(
114 {
115 "region": region,
116 "queue_depth": capacity.queue_depth,
117 "running_jobs": capacity.running_jobs,
118 "pending_jobs": capacity.pending_jobs,
119 "gpu_utilization": capacity.gpu_utilization,
120 "cpu_utilization": capacity.cpu_utilization,
121 "recommendation_score": capacity.recommendation_score,
122 }
123 )
124 except Exception as e:
125 logger.debug("Failed to get cluster metrics for %s: %s", region, e)
126 for instance_type in instance_types:
127 data["spot_data"][instance_type] = {}
128 data["on_demand_data"][instance_type] = {}
130 for region in regions:
131 try:
132 # Get spot placement scores and prices
133 spot_scores = self._capacity_checker.get_spot_placement_score(
134 instance_type, region
135 )
136 spot_prices = self._capacity_checker.get_spot_price_history(
137 instance_type, region, days=7
138 )
139 on_demand_price = self._capacity_checker.get_on_demand_price(
140 instance_type, region
141 )
143 data["spot_data"][instance_type][region] = {
144 "placement_scores": spot_scores,
145 "prices": [
146 {
147 "az": p.availability_zone,
148 "current": p.current_price,
149 "avg_7d": p.avg_price_7d,
150 "stability": p.price_stability,
151 }
152 for p in spot_prices
153 ],
154 }
156 # Spot price trend analysis per AZ (for AI interpretation)
157 try:
158 ec2 = self._session.client("ec2", region_name=region)
159 raw_resp = ec2.describe_spot_price_history(
160 InstanceTypes=[instance_type],
161 ProductDescriptions=["Linux/UNIX"],
162 StartTime=datetime.now(UTC) - timedelta(days=7),
163 EndTime=datetime.now(UTC),
164 )
165 az_raw: dict[str, list[float]] = {}
166 for item in raw_resp.get("SpotPriceHistory", []):
167 az = item["AvailabilityZone"]
168 if az not in az_raw:
169 az_raw[az] = []
170 az_raw[az].append(float(item["SpotPrice"]))
171 az_trends = {
172 az: compute_price_trend(prices)
173 for az, prices in az_raw.items()
174 if len(prices) >= 2
175 }
176 if az_trends:
177 data["spot_data"][instance_type][region]["price_trends"] = az_trends
178 except Exception as e:
179 logger.debug(
180 "Failed to get price trends for %s in %s: %s", instance_type, region, e
181 )
183 data["on_demand_data"][instance_type][region] = {
184 "price_per_hour": on_demand_price,
185 "available": self._capacity_checker.check_instance_available_in_region(
186 instance_type, region
187 ),
188 }
189 except Exception as e:
190 logger.debug(
191 "Failed to gather capacity data for %s in %s: %s", instance_type, region, e
192 )
194 # Gather capacity reservation and block data
195 data["reservations"] = {}
196 data["capacity_blocks"] = {}
197 for instance_type in instance_types:
198 data["reservations"][instance_type] = {}
199 data["capacity_blocks"][instance_type] = {}
200 for region in regions:
201 try:
202 odcrs = self._capacity_checker.list_capacity_reservations(
203 region, instance_type=instance_type
204 )
205 if odcrs:
206 data["reservations"][instance_type][region] = [
207 {
208 "az": r["availability_zone"],
209 "total": r["total_instances"],
210 "available": r["available_instances"],
211 "utilization_pct": r["utilization_pct"],
212 }
213 for r in odcrs
214 ]
215 except Exception as e:
216 logger.debug(
217 "Failed to list reservations for %s in %s: %s", instance_type, region, e
218 )
220 try:
221 blocks = self._capacity_checker.list_capacity_block_offerings(
222 region, instance_type=instance_type, instance_count=1, duration_hours=24
223 )
224 if blocks:
225 data["capacity_blocks"][instance_type][region] = [
226 {
227 "az": b["availability_zone"],
228 "duration_hours": b["duration_hours"],
229 "start_date": b["start_date"],
230 "upfront_fee": b["upfront_fee"],
231 }
232 for b in blocks
233 ]
234 except Exception as e:
235 logger.debug(
236 "Failed to list capacity blocks for %s in %s: %s", instance_type, region, e
237 )
239 # Capacity block availability trends (26-week regression per instance type per region)
240 data["capacity_block_trends"] = {}
241 for instance_type in instance_types:
242 data["capacity_block_trends"][instance_type] = {}
243 for region in regions:
244 try:
245 trend = self._capacity_checker.get_capacity_block_trend(instance_type, region)
246 if trend != 0.0:
247 data["capacity_block_trends"][instance_type][region] = {
248 "trend_score": trend,
249 "interpretation": (
250 "capacity growing"
251 if trend > 0.2
252 else "capacity shrinking"
253 if trend < -0.2
254 else "stable"
255 ),
256 }
257 except Exception as e:
258 logger.debug(
259 "Failed to get capacity block trend for %s in %s: %s",
260 instance_type,
261 region,
262 e,
263 )
265 # Weighted recommendation scores (algorithmic ranking for AI context)
266 try:
267 weighted_results = self._multi_region_checker.recommend_region_for_job(
268 instance_type=instance_types[0] if instance_types else None,
269 )
270 data["weighted_recommendation"] = {
271 "top_region": weighted_results.get("region"),
272 "scoring_method": weighted_results.get("scoring_method", "simple"),
273 "all_regions": weighted_results.get("all_regions", []),
274 }
275 except Exception as e:
276 logger.debug("Failed to compute weighted recommendation: %s", e)
278 return data
280 def _build_prompt(
281 self,
282 capacity_data: dict[str, Any],
283 workload_description: str | None = None,
284 requirements: dict[str, Any] | None = None,
285 ) -> str:
286 """Build the prompt for Bedrock."""
287 requirements = requirements or {}
289 prompt = """You are an expert AWS capacity planning advisor for GPU/ML workloads.
290Analyze the following capacity data and provide a recommendation for where to place a workload.
292IMPORTANT DISCLAIMERS:
293- This is AI-generated advice and should be validated before production use
294- Capacity availability can change rapidly
295- Spot instances may be interrupted at any time
296- Pricing data may not reflect real-time prices
298"""
300 if workload_description:
301 prompt += f"WORKLOAD DESCRIPTION:\n{workload_description}\n\n"
303 if requirements:
304 prompt += "REQUIREMENTS:\n"
305 if requirements.get("gpu_required"): 305 ↛ 307line 305 didn't jump to line 307 because the condition on line 305 was always true
306 prompt += "- GPU Required: Yes\n"
307 if requirements.get("min_gpus"): 307 ↛ 309line 307 didn't jump to line 309 because the condition on line 307 was always true
308 prompt += f"- Minimum GPUs: {requirements['min_gpus']}\n"
309 if requirements.get("min_memory_gb"): 309 ↛ 310line 309 didn't jump to line 310 because the condition on line 309 was never true
310 prompt += f"- Minimum Memory: {requirements['min_memory_gb']} GB\n"
311 if requirements.get("fault_tolerance"): 311 ↛ 312line 311 didn't jump to line 312 because the condition on line 311 was never true
312 prompt += f"- Fault Tolerance: {requirements['fault_tolerance']}\n"
313 if requirements.get("max_cost_per_hour"): 313 ↛ 314line 313 didn't jump to line 314 because the condition on line 313 was never true
314 prompt += f"- Max Cost/Hour: ${requirements['max_cost_per_hour']}\n"
315 prompt += "\n"
317 prompt += "CAPACITY DATA:\n"
318 prompt += f"Timestamp: {capacity_data.get('timestamp', 'N/A')}\n"
319 prompt += f"Regions Analyzed: {', '.join(capacity_data.get('regions_analyzed', []))}\n"
320 prompt += (
321 f"Instance Types: {', '.join(capacity_data.get('instance_types_analyzed', []))}\n\n"
322 )
324 # Cluster metrics
325 if capacity_data.get("cluster_metrics"):
326 prompt += "CLUSTER METRICS BY REGION:\n"
327 for m in capacity_data["cluster_metrics"]:
328 prompt += f" {m['region']}:\n"
329 prompt += f" - Queue Depth: {m['queue_depth']}\n"
330 prompt += f" - Running Jobs: {m['running_jobs']}\n"
331 prompt += f" - GPU Utilization: {m['gpu_utilization']:.1f}%\n"
332 prompt += f" - CPU Utilization: {m['cpu_utilization']:.1f}%\n"
333 prompt += "\n"
335 # Spot data summary
336 prompt += "SPOT CAPACITY SUMMARY:\n"
337 for instance_type, regions_data in capacity_data.get("spot_data", {}).items():
338 prompt += f" {instance_type}:\n"
339 for region, spot_info in regions_data.items():
340 scores = spot_info.get("placement_scores", {})
341 regional_score = scores.get("regional", "N/A")
342 prices = spot_info.get("prices", [])
343 avg_price = sum(p["current"] for p in prices) / len(prices) if prices else "N/A"
344 prompt += f" {region}: Score={regional_score}/10, "
345 prompt += f"Avg Price=${avg_price if isinstance(avg_price, str) else f'{avg_price:.4f}'}/hr\n"
346 prompt += "\n"
348 # On-demand data summary
349 prompt += "ON-DEMAND PRICING:\n"
350 for instance_type, regions_data in capacity_data.get("on_demand_data", {}).items():
351 prompt += f" {instance_type}:\n"
352 for region, od_info in regions_data.items():
353 price = od_info.get("price_per_hour")
354 available = od_info.get("available", False)
355 prompt += f" {region}: ${price:.4f}/hr" if price else f" {region}: N/A"
356 prompt += f" (Available: {available})\n"
357 prompt += "\n"
359 # Capacity reservations (ODCRs)
360 reservations = capacity_data.get("reservations", {})
361 has_reservations = any(bool(regions_data) for regions_data in reservations.values())
362 if has_reservations: 362 ↛ 363line 362 didn't jump to line 363 because the condition on line 362 was never true
363 prompt += "CAPACITY RESERVATIONS (ODCRs):\n"
364 for instance_type, regions_data in reservations.items():
365 for region, odcrs in regions_data.items():
366 for r in odcrs:
367 prompt += (
368 f" {instance_type} in {region} ({r['az']}): "
369 f"{r['available']}/{r['total']} available "
370 f"({r['utilization_pct']}% used)\n"
371 )
372 prompt += "\n"
374 # Capacity Blocks for ML
375 blocks = capacity_data.get("capacity_blocks", {})
376 has_blocks = any(bool(regions_data) for regions_data in blocks.values())
377 if has_blocks: 377 ↛ 378line 377 didn't jump to line 378 because the condition on line 377 was never true
378 prompt += "CAPACITY BLOCK OFFERINGS (guaranteed GPU blocks):\n"
379 for instance_type, regions_data in blocks.items():
380 for region, offerings in regions_data.items():
381 for b in offerings:
382 prompt += (
383 f" {instance_type} in {region} ({b['az']}): "
384 f"{b['duration_hours']}h starting {b['start_date']}, "
385 f"${b['upfront_fee']}\n"
386 )
387 prompt += "\n"
389 prompt += """Based on this data, provide your recommendation in the following JSON format:
390{
391 "recommended_region": "region-name",
392 "recommended_instance_type": "instance-type",
393 "recommended_capacity_type": "spot, on-demand, odcr, or capacity-block",
394 "reasoning": "Detailed explanation of why this is the best choice",
395 "confidence": "high, medium, or low",
396 "cost_estimate": "Estimated hourly cost",
397 "reservation_advice": "If ODCRs or Capacity Blocks are available, explain how to use them. If not, suggest whether the user should consider purchasing a Capacity Block.",
398 "alternative_options": [
399 {"region": "...", "instance_type": "...", "capacity_type": "...", "reason": "..."}
400 ],
401 "warnings": ["Any important warnings or caveats"]
402}
404Respond ONLY with the JSON object, no additional text."""
406 return prompt
408 def get_recommendation(
409 self,
410 workload_description: str | None = None,
411 instance_types: list[str] | None = None,
412 regions: list[str] | None = None,
413 requirements: dict[str, Any] | None = None,
414 ) -> BedrockCapacityRecommendation:
415 """
416 Get an AI-powered capacity recommendation.
418 Args:
419 workload_description: Description of the workload
420 instance_types: List of instance types to consider
421 regions: List of regions to consider
422 requirements: Dictionary of requirements (gpu_required, min_gpus, etc.)
424 Returns:
425 BedrockCapacityRecommendation with the AI's recommendation
426 """
427 # Gather capacity data
428 capacity_data = self.gather_capacity_data(instance_types, regions)
430 # Build prompt
431 prompt = self._build_prompt(capacity_data, workload_description, requirements)
433 # Call Bedrock
434 bedrock = self._get_bedrock_client()
436 try:
437 # Use the Converse API for better compatibility across models
438 response = bedrock.converse(
439 modelId=self.model_id,
440 messages=[{"role": "user", "content": [{"text": prompt}]}],
441 inferenceConfig={"maxTokens": 2048, "temperature": 0.1},
442 )
444 # Extract response text
445 response_text = response["output"]["message"]["content"][0]["text"]
447 # Parse JSON response
448 # Find JSON in response (in case model adds extra text)
449 json_start = response_text.find("{")
450 json_end = response_text.rfind("}") + 1
451 if json_start >= 0 and json_end > json_start:
452 json_str = response_text[json_start:json_end]
453 result = json.loads(json_str)
454 else:
455 raise ValueError("No valid JSON found in response")
457 return BedrockCapacityRecommendation(
458 recommended_region=result.get("recommended_region", "unknown"),
459 recommended_instance_type=result.get("recommended_instance_type", "unknown"),
460 recommended_capacity_type=result.get("recommended_capacity_type", "spot"),
461 reasoning=result.get("reasoning", ""),
462 confidence=result.get("confidence", "low"),
463 cost_estimate=result.get("cost_estimate"),
464 alternative_options=result.get("alternative_options", []),
465 warnings=result.get("warnings", []),
466 raw_response=response_text,
467 )
469 except ClientError as e:
470 error_code = e.response.get("Error", {}).get("Code", "")
471 if error_code == "AccessDeniedException":
472 raise RuntimeError(
473 "Access denied to Bedrock. Ensure your IAM role has "
474 "bedrock:InvokeModel permission and the model is enabled in your account."
475 ) from e
476 if error_code == "ValidationException":
477 raise RuntimeError(
478 f"Model {self.model_id} may not be available. "
479 "Try a different model with --model option."
480 ) from e
481 raise RuntimeError(f"Bedrock API error: {e}") from e
482 except json.JSONDecodeError as e:
483 raise RuntimeError(f"Failed to parse AI response as JSON: {e}") from e
484 except Exception as e:
485 raise RuntimeError(f"Failed to get AI recommendation: {e}") from e
488def get_bedrock_capacity_advisor(
489 config: GCOConfig | None = None, model_id: str | None = None
490) -> BedrockCapacityAdvisor:
491 """Get a configured Bedrock capacity advisor instance."""
492 return BedrockCapacityAdvisor(config, model_id)