Coverage for cli / capacity / advisor.py: 88%
206 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 21:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 21:47 +0000
1"""Bedrock-powered AI capacity advisor."""
3from __future__ import annotations
5import json
6import logging
7from dataclasses import dataclass, field
8from datetime import UTC, datetime, timedelta
9from typing import Any
11import boto3
12from botocore.exceptions import ClientError
14from cli.config import GCOConfig, get_config
16from .checker import CapacityChecker
17from .multi_region import MultiRegionCapacityChecker, compute_price_trend
19logger = logging.getLogger(__name__)
22@dataclass
23class BedrockCapacityRecommendation:
24 """AI-generated capacity recommendation from Bedrock."""
26 recommended_region: str
27 recommended_instance_type: str
28 recommended_capacity_type: str # "spot" or "on-demand"
29 reasoning: str
30 confidence: str # "high", "medium", "low"
31 cost_estimate: str | None = None
32 alternative_options: list[dict[str, Any]] = field(default_factory=list)
33 warnings: list[str] = field(default_factory=list)
34 raw_response: str = ""
37class BedrockCapacityAdvisor:
38 """
39 AI-powered capacity advisor using Amazon Bedrock.
41 Gathers comprehensive capacity data and uses an LLM to provide
42 intelligent recommendations for workload placement.
44 DISCLAIMER: Recommendations are AI-generated and should be validated
45 before making production decisions.
46 """
48 # Default model to use if none specified
49 DEFAULT_MODEL = "us.anthropic.claude-sonnet-4-5-20250929-v1:0"
51 def __init__(self, config: GCOConfig | None = None, model_id: str | None = None):
52 self.config = config or get_config()
53 self._session = boto3.Session()
54 self._capacity_checker = CapacityChecker(config)
55 self._multi_region_checker = MultiRegionCapacityChecker(config)
56 self.model_id = model_id or self.DEFAULT_MODEL
58 def _get_bedrock_client(self) -> Any:
59 """Get Bedrock runtime client."""
60 return self._session.client("bedrock-runtime", region_name="us-east-1")
62 def gather_capacity_data(
63 self,
64 instance_types: list[str] | None = None,
65 regions: list[str] | None = None,
66 ) -> dict[str, Any]:
67 """
68 Gather comprehensive capacity data for AI analysis.
70 Args:
71 instance_types: List of instance types to analyze (defaults to common GPU types)
72 regions: List of regions to check (defaults to deployed GCO regions)
74 Returns:
75 Dictionary containing all gathered capacity data
76 """
77 from cli.aws_client import get_aws_client
79 # Default to common GPU instance types if not specified
80 if not instance_types:
81 instance_types = [
82 "g4dn.xlarge",
83 "g4dn.2xlarge",
84 "g4dn.4xlarge",
85 "g5.xlarge",
86 "g5.2xlarge",
87 "g5.4xlarge",
88 "p3.2xlarge",
89 "p4d.24xlarge",
90 ]
92 # Get deployed regions if not specified
93 if not regions:
94 aws_client = get_aws_client(self.config)
95 stacks = aws_client.discover_regional_stacks()
96 regions = list(stacks.keys()) if stacks else [self.config.default_region]
98 data: dict[str, Any] = {
99 "timestamp": datetime.now(UTC).isoformat(),
100 "regions_analyzed": regions,
101 "instance_types_analyzed": instance_types,
102 "regional_capacity": {},
103 "spot_data": {},
104 "on_demand_data": {},
105 "cluster_metrics": [],
106 "queue_status": {},
107 }
109 # Gather regional cluster metrics
110 for region in regions:
111 try:
112 capacity = self._multi_region_checker.get_region_capacity(region)
113 data["cluster_metrics"].append(
114 {
115 "region": region,
116 "queue_depth": capacity.queue_depth,
117 "running_jobs": capacity.running_jobs,
118 "pending_jobs": capacity.pending_jobs,
119 "gpu_utilization": capacity.gpu_utilization,
120 "cpu_utilization": capacity.cpu_utilization,
121 "recommendation_score": capacity.recommendation_score,
122 }
123 )
124 except Exception as e:
125 logger.debug("Failed to get cluster metrics for %s: %s", region, e)
126 for instance_type in instance_types:
127 data["spot_data"][instance_type] = {}
128 data["on_demand_data"][instance_type] = {}
130 for region in regions:
131 try:
132 # Get spot placement scores and prices
133 spot_scores = self._capacity_checker.get_spot_placement_score(
134 instance_type, region
135 )
136 spot_prices = self._capacity_checker.get_spot_price_history(
137 instance_type, region, days=7
138 )
139 on_demand_price = self._capacity_checker.get_on_demand_price(
140 instance_type, region
141 )
143 data["spot_data"][instance_type][region] = {
144 "placement_scores": spot_scores,
145 "prices": [
146 {
147 "az": p.availability_zone,
148 "current": p.current_price,
149 "avg_7d": p.avg_price_7d,
150 "stability": p.price_stability,
151 }
152 for p in spot_prices
153 ],
154 }
156 # Spot price trend analysis per AZ (for AI interpretation)
157 try:
158 ec2 = self._session.client("ec2", region_name=region)
159 raw_resp = ec2.describe_spot_price_history(
160 InstanceTypes=[instance_type],
161 ProductDescriptions=["Linux/UNIX"],
162 StartTime=datetime.now(UTC) - timedelta(days=7),
163 EndTime=datetime.now(UTC),
164 )
165 az_raw: dict[str, list[float]] = {}
166 for item in raw_resp.get("SpotPriceHistory", []):
167 az = item["AvailabilityZone"]
168 if az not in az_raw:
169 az_raw[az] = []
170 az_raw[az].append(float(item["SpotPrice"]))
171 az_trends = {
172 az: compute_price_trend(prices)
173 for az, prices in az_raw.items()
174 if len(prices) >= 2
175 }
176 if az_trends:
177 data["spot_data"][instance_type][region]["price_trends"] = az_trends
178 except Exception as e:
179 logger.debug(
180 "Failed to get price trends for %s in %s: %s", instance_type, region, e
181 )
183 data["on_demand_data"][instance_type][region] = {
184 "price_per_hour": on_demand_price,
185 "available": self._capacity_checker.check_instance_available_in_region(
186 instance_type, region
187 ),
188 }
189 except Exception as e:
190 logger.debug(
191 "Failed to gather capacity data for %s in %s: %s", instance_type, region, e
192 )
194 # Gather capacity reservation and block data
195 data["reservations"] = {}
196 data["capacity_blocks"] = {}
197 for instance_type in instance_types:
198 data["reservations"][instance_type] = {}
199 data["capacity_blocks"][instance_type] = {}
200 for region in regions:
201 try:
202 odcrs = self._capacity_checker.list_capacity_reservations(
203 region, instance_type=instance_type
204 )
205 if odcrs:
206 data["reservations"][instance_type][region] = [
207 {
208 "az": r["availability_zone"],
209 "total": r["total_instances"],
210 "available": r["available_instances"],
211 "utilization_pct": r["utilization_pct"],
212 }
213 for r in odcrs
214 ]
215 except Exception as e:
216 logger.debug(
217 "Failed to list reservations for %s in %s: %s", instance_type, region, e
218 )
220 try:
221 blocks = self._capacity_checker.list_capacity_block_offerings(
222 region, instance_type=instance_type, instance_count=1, duration_hours=24
223 )
224 if blocks:
225 data["capacity_blocks"][instance_type][region] = [
226 {
227 "az": b["availability_zone"],
228 "duration_hours": b["duration_hours"],
229 "start_date": b["start_date"],
230 "upfront_fee": b["upfront_fee"],
231 }
232 for b in blocks
233 ]
234 except Exception as e:
235 logger.debug(
236 "Failed to list capacity blocks for %s in %s: %s", instance_type, region, e
237 )
239 # Capacity block availability trends (26-week regression per instance type per region)
240 data["capacity_block_trends"] = {}
241 for instance_type in instance_types:
242 data["capacity_block_trends"][instance_type] = {}
243 for region in regions:
244 try:
245 trend = self._capacity_checker.get_capacity_block_trend(instance_type, region)
246 if trend != 0.0:
247 data["capacity_block_trends"][instance_type][region] = {
248 "trend_score": trend,
249 "interpretation": (
250 "capacity growing"
251 if trend > 0.2
252 else "capacity shrinking" if trend < -0.2 else "stable"
253 ),
254 }
255 except Exception as e:
256 logger.debug(
257 "Failed to get capacity block trend for %s in %s: %s",
258 instance_type,
259 region,
260 e,
261 )
263 # Weighted recommendation scores (algorithmic ranking for AI context)
264 try:
265 weighted_results = self._multi_region_checker.recommend_region_for_job(
266 instance_type=instance_types[0] if instance_types else None,
267 )
268 data["weighted_recommendation"] = {
269 "top_region": weighted_results.get("region"),
270 "scoring_method": weighted_results.get("scoring_method", "simple"),
271 "all_regions": weighted_results.get("all_regions", []),
272 }
273 except Exception as e:
274 logger.debug("Failed to compute weighted recommendation: %s", e)
276 return data
278 def _build_prompt(
279 self,
280 capacity_data: dict[str, Any],
281 workload_description: str | None = None,
282 requirements: dict[str, Any] | None = None,
283 ) -> str:
284 """Build the prompt for Bedrock."""
285 requirements = requirements or {}
287 prompt = """You are an expert AWS capacity planning advisor for GPU/ML workloads.
288Analyze the following capacity data and provide a recommendation for where to place a workload.
290IMPORTANT DISCLAIMERS:
291- This is AI-generated advice and should be validated before production use
292- Capacity availability can change rapidly
293- Spot instances may be interrupted at any time
294- Pricing data may not reflect real-time prices
296"""
298 if workload_description:
299 prompt += f"WORKLOAD DESCRIPTION:\n{workload_description}\n\n"
301 if requirements:
302 prompt += "REQUIREMENTS:\n"
303 if requirements.get("gpu_required"): 303 ↛ 305line 303 didn't jump to line 305 because the condition on line 303 was always true
304 prompt += "- GPU Required: Yes\n"
305 if requirements.get("min_gpus"): 305 ↛ 307line 305 didn't jump to line 307 because the condition on line 305 was always true
306 prompt += f"- Minimum GPUs: {requirements['min_gpus']}\n"
307 if requirements.get("min_memory_gb"): 307 ↛ 308line 307 didn't jump to line 308 because the condition on line 307 was never true
308 prompt += f"- Minimum Memory: {requirements['min_memory_gb']} GB\n"
309 if requirements.get("fault_tolerance"): 309 ↛ 310line 309 didn't jump to line 310 because the condition on line 309 was never true
310 prompt += f"- Fault Tolerance: {requirements['fault_tolerance']}\n"
311 if requirements.get("max_cost_per_hour"): 311 ↛ 312line 311 didn't jump to line 312 because the condition on line 311 was never true
312 prompt += f"- Max Cost/Hour: ${requirements['max_cost_per_hour']}\n"
313 prompt += "\n"
315 prompt += "CAPACITY DATA:\n"
316 prompt += f"Timestamp: {capacity_data.get('timestamp', 'N/A')}\n"
317 prompt += f"Regions Analyzed: {', '.join(capacity_data.get('regions_analyzed', []))}\n"
318 prompt += (
319 f"Instance Types: {', '.join(capacity_data.get('instance_types_analyzed', []))}\n\n"
320 )
322 # Cluster metrics
323 if capacity_data.get("cluster_metrics"):
324 prompt += "CLUSTER METRICS BY REGION:\n"
325 for m in capacity_data["cluster_metrics"]:
326 prompt += f" {m['region']}:\n"
327 prompt += f" - Queue Depth: {m['queue_depth']}\n"
328 prompt += f" - Running Jobs: {m['running_jobs']}\n"
329 prompt += f" - GPU Utilization: {m['gpu_utilization']:.1f}%\n"
330 prompt += f" - CPU Utilization: {m['cpu_utilization']:.1f}%\n"
331 prompt += "\n"
333 # Spot data summary
334 prompt += "SPOT CAPACITY SUMMARY:\n"
335 for instance_type, regions_data in capacity_data.get("spot_data", {}).items():
336 prompt += f" {instance_type}:\n"
337 for region, spot_info in regions_data.items():
338 scores = spot_info.get("placement_scores", {})
339 regional_score = scores.get("regional", "N/A")
340 prices = spot_info.get("prices", [])
341 avg_price = sum(p["current"] for p in prices) / len(prices) if prices else "N/A"
342 prompt += f" {region}: Score={regional_score}/10, "
343 prompt += f"Avg Price=${avg_price if isinstance(avg_price, str) else f'{avg_price:.4f}'}/hr\n"
344 prompt += "\n"
346 # On-demand data summary
347 prompt += "ON-DEMAND PRICING:\n"
348 for instance_type, regions_data in capacity_data.get("on_demand_data", {}).items():
349 prompt += f" {instance_type}:\n"
350 for region, od_info in regions_data.items():
351 price = od_info.get("price_per_hour")
352 available = od_info.get("available", False)
353 prompt += f" {region}: ${price:.4f}/hr" if price else f" {region}: N/A"
354 prompt += f" (Available: {available})\n"
355 prompt += "\n"
357 # Capacity reservations (ODCRs)
358 reservations = capacity_data.get("reservations", {})
359 has_reservations = any(bool(regions_data) for regions_data in reservations.values())
360 if has_reservations: 360 ↛ 361line 360 didn't jump to line 361 because the condition on line 360 was never true
361 prompt += "CAPACITY RESERVATIONS (ODCRs):\n"
362 for instance_type, regions_data in reservations.items():
363 for region, odcrs in regions_data.items():
364 for r in odcrs:
365 prompt += (
366 f" {instance_type} in {region} ({r['az']}): "
367 f"{r['available']}/{r['total']} available "
368 f"({r['utilization_pct']}% used)\n"
369 )
370 prompt += "\n"
372 # Capacity Blocks for ML
373 blocks = capacity_data.get("capacity_blocks", {})
374 has_blocks = any(bool(regions_data) for regions_data in blocks.values())
375 if has_blocks: 375 ↛ 376line 375 didn't jump to line 376 because the condition on line 375 was never true
376 prompt += "CAPACITY BLOCK OFFERINGS (guaranteed GPU blocks):\n"
377 for instance_type, regions_data in blocks.items():
378 for region, offerings in regions_data.items():
379 for b in offerings:
380 prompt += (
381 f" {instance_type} in {region} ({b['az']}): "
382 f"{b['duration_hours']}h starting {b['start_date']}, "
383 f"${b['upfront_fee']}\n"
384 )
385 prompt += "\n"
387 prompt += """Based on this data, provide your recommendation in the following JSON format:
388{
389 "recommended_region": "region-name",
390 "recommended_instance_type": "instance-type",
391 "recommended_capacity_type": "spot, on-demand, odcr, or capacity-block",
392 "reasoning": "Detailed explanation of why this is the best choice",
393 "confidence": "high, medium, or low",
394 "cost_estimate": "Estimated hourly cost",
395 "reservation_advice": "If ODCRs or Capacity Blocks are available, explain how to use them. If not, suggest whether the user should consider purchasing a Capacity Block.",
396 "alternative_options": [
397 {"region": "...", "instance_type": "...", "capacity_type": "...", "reason": "..."}
398 ],
399 "warnings": ["Any important warnings or caveats"]
400}
402Respond ONLY with the JSON object, no additional text."""
404 return prompt
406 def get_recommendation(
407 self,
408 workload_description: str | None = None,
409 instance_types: list[str] | None = None,
410 regions: list[str] | None = None,
411 requirements: dict[str, Any] | None = None,
412 ) -> BedrockCapacityRecommendation:
413 """
414 Get an AI-powered capacity recommendation.
416 Args:
417 workload_description: Description of the workload
418 instance_types: List of instance types to consider
419 regions: List of regions to consider
420 requirements: Dictionary of requirements (gpu_required, min_gpus, etc.)
422 Returns:
423 BedrockCapacityRecommendation with the AI's recommendation
424 """
425 # Gather capacity data
426 capacity_data = self.gather_capacity_data(instance_types, regions)
428 # Build prompt
429 prompt = self._build_prompt(capacity_data, workload_description, requirements)
431 # Call Bedrock
432 bedrock = self._get_bedrock_client()
434 try:
435 # Use the Converse API for better compatibility across models
436 response = bedrock.converse(
437 modelId=self.model_id,
438 messages=[{"role": "user", "content": [{"text": prompt}]}],
439 inferenceConfig={"maxTokens": 2048, "temperature": 0.1},
440 )
442 # Extract response text
443 response_text = response["output"]["message"]["content"][0]["text"]
445 # Parse JSON response
446 # Find JSON in response (in case model adds extra text)
447 json_start = response_text.find("{")
448 json_end = response_text.rfind("}") + 1
449 if json_start >= 0 and json_end > json_start:
450 json_str = response_text[json_start:json_end]
451 result = json.loads(json_str)
452 else:
453 raise ValueError("No valid JSON found in response")
455 return BedrockCapacityRecommendation(
456 recommended_region=result.get("recommended_region", "unknown"),
457 recommended_instance_type=result.get("recommended_instance_type", "unknown"),
458 recommended_capacity_type=result.get("recommended_capacity_type", "spot"),
459 reasoning=result.get("reasoning", ""),
460 confidence=result.get("confidence", "low"),
461 cost_estimate=result.get("cost_estimate"),
462 alternative_options=result.get("alternative_options", []),
463 warnings=result.get("warnings", []),
464 raw_response=response_text,
465 )
467 except ClientError as e:
468 error_code = e.response.get("Error", {}).get("Code", "")
469 if error_code == "AccessDeniedException":
470 raise RuntimeError(
471 "Access denied to Bedrock. Ensure your IAM role has "
472 "bedrock:InvokeModel permission and the model is enabled in your account."
473 ) from e
474 if error_code == "ValidationException":
475 raise RuntimeError(
476 f"Model {self.model_id} may not be available. "
477 "Try a different model with --model option."
478 ) from e
479 raise RuntimeError(f"Bedrock API error: {e}") from e
480 except json.JSONDecodeError as e:
481 raise RuntimeError(f"Failed to parse AI response as JSON: {e}") from e
482 except Exception as e:
483 raise RuntimeError(f"Failed to get AI recommendation: {e}") from e
486def get_bedrock_capacity_advisor(
487 config: GCOConfig | None = None, model_id: str | None = None
488) -> BedrockCapacityAdvisor:
489 """Get a configured Bedrock capacity advisor instance."""
490 return BedrockCapacityAdvisor(config, model_id)