Coverage for gco/config/config_loader.py: 99%
268 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-15 15:07 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-15 15:07 +0000
1"""
2Configuration loader for GCO (Global Capacity Orchestrator on AWS).
4This module loads and validates configuration from CDK context (cdk.json).
5It provides type-safe access to all configuration values with sensible defaults
6and comprehensive validation.
8Configuration Sections:
9- project_name: Unique identifier for the deployment
10- regions: List of AWS regions to deploy to
11- kubernetes_version: EKS Kubernetes version
12- resource_thresholds: CPU/memory/GPU utilization thresholds
13- global_accelerator: Global Accelerator settings
14- alb_config: Application Load Balancer health check settings
15- manifest_processor: Manifest validation and resource limits
16- api_gateway: Throttling and logging configuration
17- tags: Common tags applied to all resources
19Usage:
20 config = ConfigLoader(app)
21 regions = config.get_regions()
22 cluster_config = config.get_cluster_config("us-east-1")
23"""
25from __future__ import annotations
27import logging
28from typing import Any, cast
30import boto3
31from aws_cdk import App
33from gco.models import ClusterConfig, ResourceThresholds
35logger = logging.getLogger(__name__)
38class ConfigValidationError(Exception):
39 """Raised when configuration validation fails."""
41 pass
44class ConfigLoader:
45 """
46 Loads and validates configuration from CDK context (cdk.json)
47 """
49 # Valid AWS regions (subset of commonly used regions)
50 VALID_REGIONS = {
51 "us-east-1",
52 "us-east-2",
53 "us-west-1",
54 "us-west-2",
55 "eu-west-1",
56 "eu-west-2",
57 "eu-west-3",
58 "eu-central-1",
59 "ap-southeast-1",
60 "ap-southeast-2",
61 "ap-northeast-1",
62 "ap-northeast-2",
63 "ca-central-1",
64 "sa-east-1",
65 }
67 def __init__(self, app: App):
68 self.app = app
69 self._validate_configuration()
71 def _validate_configuration(self) -> None:
72 """Validate the entire configuration"""
73 # Check if we have any context at all (might be running outside CDK)
74 project_name = self.app.node.try_get_context("project_name")
75 if project_name is None:
76 # Running outside CDK context, skip validation
77 return
79 # Validate required fields exist
80 required_fields = [
81 "project_name",
82 "kubernetes_version",
83 "resource_thresholds",
84 ]
85 for field in required_fields:
86 if not self.app.node.try_get_context(field):
87 raise ConfigValidationError(f"Required configuration field '{field}' is missing")
89 # Check for deployment_regions
90 deployment_regions = self.app.node.try_get_context("deployment_regions")
91 if not deployment_regions:
92 raise ConfigValidationError(
93 "Required configuration field 'deployment_regions' is missing"
94 )
96 # Validate regions
97 self._validate_regions()
99 # Validate resource thresholds
100 self._validate_resource_thresholds()
102 # Validate Global Accelerator config
103 self._validate_global_accelerator_config()
105 # Validate ALB config
106 self._validate_alb_config()
108 # Validate manifest processor config
109 self._validate_manifest_processor_config()
111 # Validate API Gateway config
112 self._validate_api_gateway_config()
114 # Validate EKS cluster config
115 self._validate_eks_cluster_config()
117 # Validate analytics environment config (optional block)
118 self._validate_analytics_environment_config()
120 def _validate_regions(self) -> None:
121 """Validate region configuration"""
122 regions = self.get_regions()
124 if not regions:
125 raise ConfigValidationError("At least one region must be specified")
127 if len(regions) > 10:
128 raise ConfigValidationError("Maximum of 10 regions supported")
130 for region in regions:
131 if region not in self.VALID_REGIONS:
132 raise ConfigValidationError(
133 f"Invalid region '{region}'. Valid regions: {sorted(self.VALID_REGIONS)}"
134 )
136 # Check for duplicates
137 if len(regions) != len(set(regions)):
138 raise ConfigValidationError("Duplicate regions found in configuration")
140 def _validate_resource_thresholds(self) -> None:
141 """Validate resource threshold configuration"""
142 thresholds_config = self.app.node.try_get_context("resource_thresholds")
144 required_thresholds = ["cpu_threshold", "memory_threshold", "gpu_threshold"]
145 for threshold in required_thresholds:
146 if threshold not in thresholds_config:
147 raise ConfigValidationError(f"Missing threshold configuration: {threshold}")
149 value = thresholds_config[threshold]
150 if not isinstance(value, int) or (value != -1 and not 0 <= value <= 100):
151 raise ConfigValidationError(
152 f"{threshold} must be an integer between 0 and 100 (or -1 to disable), got {value}"
153 )
155 # Validate optional thresholds if present
156 for opt_threshold in [
157 "pending_pods_threshold",
158 "pending_requested_cpu_vcpus",
159 "pending_requested_memory_gb",
160 "pending_requested_gpus",
161 ]:
162 if opt_threshold in thresholds_config:
163 value = thresholds_config[opt_threshold]
164 if not isinstance(value, int) or (value != -1 and value < 0):
165 raise ConfigValidationError(
166 f"{opt_threshold} must be a non-negative integer (or -1 to disable), got {value}"
167 )
169 def _validate_global_accelerator_config(self) -> None:
170 """Validate Global Accelerator configuration"""
171 ga_config = self.app.node.try_get_context("global_accelerator")
172 if not ga_config:
173 raise ConfigValidationError("global_accelerator configuration is required")
175 required_fields = [
176 "name",
177 "health_check_grace_period",
178 "health_check_interval",
179 "health_check_timeout",
180 "health_check_path",
181 ]
182 for field in required_fields:
183 if field not in ga_config:
184 raise ConfigValidationError(f"Missing global_accelerator configuration: {field}")
186 # Validate timing values
187 for field in ["health_check_grace_period", "health_check_interval", "health_check_timeout"]:
188 value = ga_config[field]
189 if not isinstance(value, int) or value <= 0:
190 raise ConfigValidationError(f"{field} must be a positive integer, got {value}")
192 # Validate health check path
193 if not ga_config["health_check_path"].startswith("/"):
194 raise ConfigValidationError("health_check_path must start with '/'")
196 # Validate optional client affinity. Omitting the key is allowed and
197 # defaults to "NONE" in get_global_accelerator_config().
198 if "client_affinity" in ga_config:
199 allowed_affinity = {"NONE", "SOURCE_IP"}
200 value = ga_config["client_affinity"]
201 if not isinstance(value, str) or value.upper() not in allowed_affinity:
202 raise ConfigValidationError(
203 f"client_affinity must be one of {sorted(allowed_affinity)}, got {value!r}"
204 )
206 def _validate_alb_config(self) -> None:
207 """Validate ALB configuration"""
208 alb_config = self.app.node.try_get_context("alb_config")
209 if not alb_config:
210 raise ConfigValidationError("alb_config configuration is required")
212 required_fields = [
213 "health_check_interval",
214 "health_check_timeout",
215 "healthy_threshold",
216 "unhealthy_threshold",
217 ]
218 for field in required_fields:
219 if field not in alb_config:
220 raise ConfigValidationError(f"Missing alb_config configuration: {field}")
222 value = alb_config[field]
223 if not isinstance(value, int) or value <= 0:
224 raise ConfigValidationError(f"{field} must be a positive integer, got {value}")
226 def _validate_manifest_processor_config(self) -> None:
227 """Validate manifest processor configuration.
229 The manifest processor section in cdk.json holds service-specific
230 settings only. The shared validation policy (allowed_namespaces,
231 resource_quotas, trusted_registries, trusted_dockerhub_orgs,
232 manifest_security_policy, allowed_kinds) lives under
233 ``job_validation_policy`` because the queue_processor reads the
234 same values.
235 """
236 mp_config = self.app.node.try_get_context("manifest_processor")
237 if not mp_config:
238 raise ConfigValidationError("manifest_processor configuration is required")
240 required_fields = [
241 "image",
242 "replicas",
243 "resource_limits",
244 ]
245 for field in required_fields:
246 if field not in mp_config:
247 raise ConfigValidationError(f"Missing manifest_processor configuration: {field}")
249 # Validate replicas
250 if not isinstance(mp_config["replicas"], int) or mp_config["replicas"] <= 0:
251 raise ConfigValidationError("manifest_processor replicas must be a positive integer")
253 # Validate the shared policy section separately so a misconfigured
254 # policy block surfaces a clear error pointing at the right key.
255 policy = self.app.node.try_get_context("job_validation_policy")
256 if policy is None:
257 raise ConfigValidationError(
258 "job_validation_policy configuration is required (shared between "
259 "manifest_processor and queue_processor)"
260 )
261 for policy_field in ("allowed_namespaces", "resource_quotas"):
262 if policy_field not in policy:
263 raise ConfigValidationError(
264 f"Missing job_validation_policy configuration: {policy_field}"
265 )
267 # Validate resource limits
268 resource_limits = mp_config["resource_limits"]
269 if "cpu" not in resource_limits or "memory" not in resource_limits:
270 raise ConfigValidationError(
271 "manifest_processor resource_limits must contain 'cpu' and 'memory'"
272 )
274 # Validate allowed namespaces (lives under job_validation_policy).
275 if not isinstance(policy["allowed_namespaces"], list):
276 raise ConfigValidationError("job_validation_policy.allowed_namespaces must be a list")
278 def _validate_api_gateway_config(self) -> None:
279 """Validate API Gateway configuration"""
280 api_gw_config = self.app.node.try_get_context("api_gateway")
281 if not api_gw_config:
282 raise ConfigValidationError("api_gateway configuration is required")
284 required_fields = [
285 "throttle_rate_limit",
286 "throttle_burst_limit",
287 "log_level",
288 "metrics_enabled",
289 "tracing_enabled",
290 ]
291 for field in required_fields:
292 if field not in api_gw_config:
293 raise ConfigValidationError(f"Missing api_gateway configuration: {field}")
295 # Validate throttle limits
296 throttle_rate = api_gw_config["throttle_rate_limit"]
297 throttle_burst = api_gw_config["throttle_burst_limit"]
299 if not isinstance(throttle_rate, int) or throttle_rate <= 0:
300 raise ConfigValidationError(
301 f"throttle_rate_limit must be a positive integer, got {throttle_rate}"
302 )
304 if not isinstance(throttle_burst, int) or throttle_burst <= 0:
305 raise ConfigValidationError(
306 f"throttle_burst_limit must be a positive integer, got {throttle_burst}"
307 )
309 if throttle_burst < throttle_rate:
310 raise ConfigValidationError(
311 "throttle_burst_limit should be greater than or equal to throttle_rate_limit"
312 )
314 # Validate log level
315 valid_log_levels = ["OFF", "ERROR", "INFO"]
316 log_level = api_gw_config["log_level"]
317 if log_level not in valid_log_levels:
318 raise ConfigValidationError(
319 f"log_level must be one of {valid_log_levels}, got {log_level}"
320 )
322 # Validate boolean flags
323 if not isinstance(api_gw_config["metrics_enabled"], bool):
324 raise ConfigValidationError("metrics_enabled must be a boolean")
326 if not isinstance(api_gw_config["tracing_enabled"], bool):
327 raise ConfigValidationError("tracing_enabled must be a boolean")
329 def _validate_eks_cluster_config(self) -> None:
330 """Validate EKS cluster configuration"""
331 eks_config = self.app.node.try_get_context("eks_cluster") or {}
333 # Validate endpoint_access if present
334 if "endpoint_access" in eks_config:
335 valid_access_modes = ["PRIVATE", "PUBLIC_AND_PRIVATE"]
336 if eks_config["endpoint_access"] not in valid_access_modes:
337 raise ConfigValidationError(
338 f"endpoint_access must be one of {valid_access_modes}, "
339 f"got {eks_config['endpoint_access']}"
340 )
342 def _validate_analytics_environment_config(self) -> None:
343 """Validate the optional analytics_environment block in cdk.json.
345 The block is entirely optional; absence means the feature is disabled
346 and no validation is needed. When present, we validate:
348 - ``enabled``: must be a bool if present (defaults to False via merge).
349 - ``hyperpod.enabled``: must be a bool if present (defaults to False).
350 - ``cognito.removal_policy`` and ``efs.removal_policy``: must be the
351 literal strings ``"destroy"`` or ``"retain"`` (case sensitive — they
352 are passed verbatim to CDK's ``RemovalPolicy`` lookup by the
353 consumer).
354 """
355 analytics_ctx = self.app.node.try_get_context("analytics_environment")
356 if not isinstance(analytics_ctx, dict):
357 # Block is absent or malformed — defaults apply, nothing to validate.
358 return
360 # Top-level `enabled` must be a bool if provided.
361 if "enabled" in analytics_ctx and not isinstance(analytics_ctx["enabled"], bool):
362 raise ConfigValidationError(
363 f"analytics_environment.enabled must be a bool, got "
364 f"{type(analytics_ctx['enabled']).__name__}: {analytics_ctx['enabled']!r}"
365 )
367 # `hyperpod.enabled` must be a bool if the sub-block is a dict and
368 # carries the key.
369 hyperpod_ctx = analytics_ctx.get("hyperpod")
370 if (
371 isinstance(hyperpod_ctx, dict)
372 and "enabled" in hyperpod_ctx
373 and not isinstance(hyperpod_ctx["enabled"], bool)
374 ):
375 raise ConfigValidationError(
376 f"analytics_environment.hyperpod.enabled must be a bool, got "
377 f"{type(hyperpod_ctx['enabled']).__name__}: {hyperpod_ctx['enabled']!r}"
378 )
380 # `canvas.enabled` must be a bool if the sub-block is a dict and
381 # carries the key. Mirrors the hyperpod validation above.
382 canvas_ctx = analytics_ctx.get("canvas")
383 if (
384 isinstance(canvas_ctx, dict)
385 and "enabled" in canvas_ctx
386 and not isinstance(canvas_ctx["enabled"], bool)
387 ):
388 raise ConfigValidationError(
389 f"analytics_environment.canvas.enabled must be a bool, got "
390 f"{type(canvas_ctx['enabled']).__name__}: {canvas_ctx['enabled']!r}"
391 )
393 valid_removal_policies = {"destroy", "retain"}
395 for sub_block in ("cognito", "efs"):
396 sub_ctx = analytics_ctx.get(sub_block)
397 if not isinstance(sub_ctx, dict):
398 continue
399 if "removal_policy" not in sub_ctx:
400 continue
401 removal_policy = sub_ctx["removal_policy"]
402 if removal_policy not in valid_removal_policies:
403 raise ConfigValidationError(
404 f"analytics_environment.{sub_block}.removal_policy must be one of "
405 f"{sorted(valid_removal_policies)}, got {removal_policy!r}"
406 )
408 def get_project_name(self) -> str:
409 """Get project name from configuration"""
410 return self.app.node.try_get_context("project_name") or "gco"
412 def get_deployment_regions(self) -> dict[str, Any]:
413 """Get deployment regions configuration.
415 Returns a dict with:
416 - global: Region for Global Accelerator and SSM parameters (default: us-east-2)
417 - api_gateway: Region for API Gateway stack (default: us-east-2)
418 - monitoring: Region for Monitoring stack (default: us-east-2)
419 - regional: List of regions for EKS clusters (default: ["us-east-1"])
421 Note: Global Accelerator is a global service but requires a "home" region
422 for CloudFormation deployment. us-east-2 is used by default to keep
423 global infrastructure separate from workload regions.
424 """
425 deployment_regions = self.app.node.try_get_context("deployment_regions") or {}
427 return {
428 "global": deployment_regions.get("global", "us-east-2"),
429 "api_gateway": deployment_regions.get("api_gateway", "us-east-2"),
430 "monitoring": deployment_regions.get("monitoring", "us-east-2"),
431 "regional": deployment_regions.get("regional", ["us-east-1"]),
432 }
434 def get_global_region(self) -> str:
435 """Get the region for global resources (Global Accelerator, SSM params)."""
436 region = self.get_deployment_regions()["global"]
437 return str(region)
439 def get_api_gateway_region(self) -> str:
440 """Get the region for API Gateway stack."""
441 region = self.get_deployment_regions()["api_gateway"]
442 return str(region)
444 def get_monitoring_region(self) -> str:
445 """Get the region for Monitoring stack."""
446 region = self.get_deployment_regions()["monitoring"]
447 return str(region)
449 def get_regions(self) -> list[str]:
450 """Get list of regions for EKS cluster deployment."""
451 deployment_regions = self.get_deployment_regions()
452 regional = deployment_regions["regional"]
453 return list(regional) if isinstance(regional, list) else [str(regional)]
455 def get_kubernetes_version(self) -> str:
456 """Get Kubernetes version from configuration"""
457 return self.app.node.try_get_context("kubernetes_version") or "1.36"
459 def get_resource_thresholds(self) -> ResourceThresholds:
460 """Get resource thresholds configuration"""
461 thresholds_config = self.app.node.try_get_context("resource_thresholds") or {
462 "cpu_threshold": 60,
463 "memory_threshold": 60,
464 "gpu_threshold": -1,
465 "pending_pods_threshold": 10,
466 "pending_requested_cpu_vcpus": 100,
467 "pending_requested_memory_gb": 200,
468 "pending_requested_gpus": -1,
469 }
470 return ResourceThresholds(
471 cpu_threshold=thresholds_config["cpu_threshold"],
472 memory_threshold=thresholds_config["memory_threshold"],
473 gpu_threshold=thresholds_config["gpu_threshold"],
474 pending_pods_threshold=thresholds_config.get("pending_pods_threshold", 10),
475 pending_requested_cpu_vcpus=thresholds_config.get("pending_requested_cpu_vcpus", 100),
476 pending_requested_memory_gb=thresholds_config.get("pending_requested_memory_gb", 200),
477 pending_requested_gpus=thresholds_config.get("pending_requested_gpus", 8),
478 )
480 def get_cluster_config(self, region: str) -> ClusterConfig:
481 """Get complete cluster configuration for a region"""
482 return ClusterConfig(
483 region=region,
484 cluster_name=f"{self.get_project_name()}-{region}",
485 kubernetes_version=self.get_kubernetes_version(),
486 addons=["metrics-server"],
487 resource_thresholds=self.get_resource_thresholds(),
488 )
490 def get_global_accelerator_config(self) -> dict[str, Any]:
491 """Get Global Accelerator configuration"""
492 return self.app.node.try_get_context("global_accelerator") or {
493 "name": "gco-accelerator",
494 "health_check_grace_period": 30,
495 "health_check_interval": 30,
496 "health_check_timeout": 5,
497 "health_check_path": "/api/v1/health",
498 "client_affinity": "NONE",
499 }
501 def get_alb_config(self) -> dict[str, Any]:
502 """Get ALB configuration"""
503 return self.app.node.try_get_context("alb_config") or {
504 "health_check_interval": 30,
505 "health_check_timeout": 5,
506 "healthy_threshold": 2,
507 "unhealthy_threshold": 2,
508 }
510 def get_manifest_processor_config(self) -> dict[str, Any]:
511 """Get manifest processor configuration.
513 Merges three cdk.json sections into a single runtime config:
515 - ``manifest_processor``: service-specific settings (replicas, image,
516 resource_limits, allowed_namespaces, validation_enabled,
517 max_request_body_bytes, yaml_max_depth)
518 - ``job_validation_policy``: shared validation policy (resource_quotas,
519 trusted_registries, trusted_dockerhub_orgs, manifest_security_policy,
520 allowed_kinds). Pulled in verbatim so the REST path reads the same
521 policy the SQS queue processor enforces.
523 Note: The 'image' field is a placeholder default. In practice, the actual
524 image is built from dockerfiles/manifest-processor-dockerfile and pushed
525 to ECR during CDK deployment. The {{MANIFEST_PROCESSOR_IMAGE}} placeholder
526 in manifests is replaced with the ECR image URI.
527 """
528 default_config = {
529 "image": "gco/manifest-processor:latest", # Placeholder, replaced by ECR image
530 "replicas": 3,
531 "resource_limits": {"cpu": "1000m", "memory": "2Gi"},
532 "validation_enabled": True,
533 # allowed_namespaces, resource_quotas, trusted_registries,
534 # trusted_dockerhub_orgs, manifest_security_policy, and
535 # allowed_kinds are merged in below from job_validation_policy.
536 "allowed_namespaces": ["default", "gco-jobs"],
537 "resource_quotas": {
538 "max_cpu_per_manifest": "10",
539 "max_memory_per_manifest": "32Gi",
540 "max_gpu_per_manifest": 4,
541 },
542 "trusted_registries": [
543 "docker.io",
544 "gcr.io",
545 "quay.io",
546 "registry.k8s.io",
547 "k8s.gcr.io",
548 "public.ecr.aws",
549 "nvcr.io",
550 "gco",
551 ],
552 "trusted_dockerhub_orgs": [
553 "nvidia",
554 "pytorch",
555 "rayproject",
556 "tensorflow",
557 "huggingface",
558 "amazon",
559 "bitnami",
560 ],
561 }
562 context_config = self.app.node.try_get_context("manifest_processor") or {}
564 # Merge in the shared job_validation_policy section. These keys apply
565 # to BOTH the manifest processor and the queue processor; they live
566 # in their own top-level cdk.json section so neither service "owns"
567 # them. We flatten them into the manifest processor's runtime config
568 # so service code keeps its existing attribute layout.
569 shared_policy = self.app.node.try_get_context("job_validation_policy") or {}
570 return {**default_config, **context_config, **shared_policy}
572 def get_api_gateway_config(self) -> dict[str, Any]:
573 """Get API Gateway configuration.
575 Returns:
576 API Gateway configuration dictionary with the following keys:
577 - throttle_rate_limit: Requests per second limit
578 - throttle_burst_limit: Burst capacity
579 - log_level: CloudWatch logging level (OFF, ERROR, INFO)
580 - metrics_enabled: Enable CloudWatch metrics
581 - tracing_enabled: Enable X-Ray tracing
582 - regional_api_enabled: Enable regional API Gateways for private access
583 When true, deploys a regional API Gateway with VPC Lambda in each
584 region, allowing API access when the ALB is internal-only.
585 """
586 default_config = {
587 "throttle_rate_limit": 1000,
588 "throttle_burst_limit": 2000,
589 "log_level": "INFO",
590 "metrics_enabled": True,
591 "tracing_enabled": True,
592 "regional_api_enabled": False,
593 }
594 return {**default_config, **(self.app.node.try_get_context("api_gateway") or {})}
596 def get_eks_cluster_config(self) -> dict[str, Any]:
597 """Get EKS cluster configuration.
599 Returns:
600 EKS cluster configuration dictionary with the following keys:
601 - endpoint_access: EKS API endpoint access mode
602 - "PRIVATE": API server only accessible from within VPC (default, most secure)
603 - "PUBLIC_AND_PRIVATE": API server accessible from internet and VPC
605 Note:
606 PRIVATE endpoint is recommended for production. Job submission still works
607 via API Gateway → Lambda (in VPC) or SQS queues. For kubectl access with
608 PRIVATE endpoint, use a bastion host, VPN, or AWS SSM Session Manager.
609 """
610 default_config = {
611 "endpoint_access": "PRIVATE",
612 }
613 return {**default_config, **(self.app.node.try_get_context("eks_cluster") or {})}
615 def get_fsx_lustre_config(self, region: str | None = None) -> dict[str, Any]:
616 """Get FSx for Lustre configuration.
618 Args:
619 region: Optional region to get config for. If provided, checks for
620 region-specific overrides first.
622 Returns:
623 FSx configuration dictionary with the following keys:
624 - enabled: Whether FSx is enabled
625 - storage_capacity_gib: Storage capacity in GiB (min 1200)
626 - deployment_type: SCRATCH_1, SCRATCH_2, PERSISTENT_1, PERSISTENT_2
627 - file_system_type_version: Lustre version (2.12 or 2.15, default: 2.15)
628 IMPORTANT: Use 2.15 for kernel 6.x compatibility (AL2023, Bottlerocket)
629 - per_unit_storage_throughput: Throughput for PERSISTENT types
630 - data_compression_type: LZ4 or NONE
631 - import_path: S3 path for data import
632 - export_path: S3 path for data export
633 - auto_import_policy: NEW, NEW_CHANGED, NEW_CHANGED_DELETED
634 - node_group: Node group configuration for FSx workloads
635 - instance_types: List of instance types
636 - min_size: Minimum nodes (default: 0)
637 - max_size: Maximum nodes (default: 10)
638 - desired_size: Desired nodes (default: 0, scales from zero)
639 - ami_type: AMI type - one of:
640 AL2023_X86_64_STANDARD (default), AL2023_ARM_64_STANDARD,
641 AL2023_X86_64_NVIDIA, AL2023_ARM_64_NVIDIA, AL2023_X86_64_NEURON
642 - capacity_type: ON_DEMAND (default) or SPOT
643 - disk_size: Root disk size in GB (default: 100)
644 - labels: Additional node labels (dict)
645 """
646 default_config = {
647 "enabled": False,
648 "storage_capacity_gib": 1200,
649 "deployment_type": "SCRATCH_2",
650 "file_system_type_version": "2.15", # Use 2.15 for kernel 6.x compatibility
651 "per_unit_storage_throughput": 200,
652 "data_compression_type": "LZ4",
653 "import_path": None,
654 "export_path": None,
655 "auto_import_policy": "NEW_CHANGED_DELETED",
656 "node_group": {
657 "instance_types": ["m5.large", "m5.xlarge", "m6i.large", "m6i.xlarge"],
658 "min_size": 0,
659 "max_size": 10,
660 "desired_size": 1,
661 "ami_type": "AL2023_X86_64_STANDARD",
662 "capacity_type": "ON_DEMAND",
663 "disk_size": 100,
664 "labels": {},
665 },
666 }
668 # Get global FSx config
669 global_ctx = self.app.node.try_get_context("fsx_lustre")
670 global_config: dict[str, Any] = global_ctx if isinstance(global_ctx, dict) else {}
671 merged_config: dict[str, Any] = {**default_config, **global_config}
673 # Ensure node_group has all required fields with defaults
674 if "node_group" in global_config:
675 global_node_group = global_config["node_group"]
676 if isinstance(global_node_group, dict): 676 ↛ 684line 676 didn't jump to line 684 because the condition on line 676 was always true
677 default_node_group = cast(dict[str, Any], default_config["node_group"])
678 merged_config["node_group"] = {
679 **default_node_group,
680 **global_node_group,
681 }
683 # Check for region-specific override
684 if region:
685 region_overrides_ctx = self.app.node.try_get_context("fsx_lustre_regions")
686 region_overrides: dict[str, Any] = (
687 region_overrides_ctx if isinstance(region_overrides_ctx, dict) else {}
688 )
689 if region in region_overrides:
690 region_config = region_overrides[region]
691 if isinstance(region_config, dict): 691 ↛ 707line 691 didn't jump to line 707 because the condition on line 691 was always true
692 merged_config = {**merged_config, **region_config}
693 # Handle nested node_group override
694 if "node_group" in region_config:
695 region_node_group = region_config["node_group"]
696 if isinstance(region_node_group, dict): 696 ↛ 707line 696 didn't jump to line 707 because the condition on line 696 was always true
697 existing_node_group = merged_config.get("node_group")
698 if isinstance(existing_node_group, dict): 698 ↛ 701line 698 didn't jump to line 701 because the condition on line 698 was always true
699 base_node_group = existing_node_group
700 else:
701 base_node_group = cast(dict[str, Any], default_config["node_group"])
702 merged_config["node_group"] = {
703 **base_node_group,
704 **region_node_group,
705 }
707 return merged_config
709 def get_valkey_config(self) -> dict[str, Any]:
710 """Get Valkey Serverless cache configuration.
712 Returns:
713 Valkey configuration dictionary with the following keys:
714 - enabled: Whether Valkey cache is enabled (default: True)
715 - max_data_storage_gb: Maximum data storage in GB (default: 5)
716 - max_ecpu_per_second: Maximum ECPUs per second (default: 5000)
717 - snapshot_retention_limit: Daily snapshots to retain (default: 1)
718 """
719 default_config: dict[str, Any] = {
720 "enabled": True,
721 "max_data_storage_gb": 5,
722 "max_ecpu_per_second": 5000,
723 "snapshot_retention_limit": 1,
724 }
725 valkey_ctx = self.app.node.try_get_context("valkey")
726 valkey_config: dict[str, Any] = valkey_ctx if isinstance(valkey_ctx, dict) else {}
727 return {**default_config, **valkey_config}
729 def get_aurora_pgvector_config(self) -> dict[str, Any]:
730 """Get Aurora Serverless v2 + pgvector vector database configuration.
732 Returns:
733 Aurora pgvector configuration dictionary with the following keys:
734 - enabled: Whether Aurora pgvector is enabled (default: False)
735 - min_acu: Minimum Aurora Capacity Units (default: 0, scales to zero)
736 - max_acu: Maximum Aurora Capacity Units (default: 16)
737 - backup_retention_days: Number of days to retain automated backups (default: 7)
738 - deletion_protection: Whether deletion protection is enabled (default: False)
739 """
740 default_config: dict[str, Any] = {
741 "enabled": False,
742 "min_acu": 0,
743 "max_acu": 16,
744 "backup_retention_days": 7,
745 "deletion_protection": False,
746 }
747 aurora_ctx = self.app.node.try_get_context("aurora_pgvector")
748 aurora_config: dict[str, Any] = aurora_ctx if isinstance(aurora_ctx, dict) else {}
749 return {**default_config, **aurora_config}
751 def get_analytics_config(self) -> dict[str, Any]:
752 """Get optional analytics environment configuration.
754 Returns the fully-merged analytics_environment block from cdk.json
755 layered on top of the defaults below. Sub-blocks (``hyperpod``,
756 ``cognito``, ``efs``, ``studio``) are deep-merged so a user who
757 overrides a single nested key (e.g. ``cognito.domain_prefix``) does
758 not inadvertently wipe the sub-block's other defaults — mirroring the
759 nested-merge pattern used by ``get_fsx_lustre_config`` for its
760 ``node_group`` sub-block.
762 Returns:
763 Analytics configuration dictionary with the following keys:
764 - enabled: Whether the analytics environment stack is deployed
765 (default: False — the feature is off unless explicitly opted in)
766 - hyperpod: SageMaker HyperPod integration sub-block
767 - enabled: Whether to add the HyperPod IAM grants to
768 SageMaker_Execution_Role (default: False)
769 - canvas: SageMaker Canvas integration sub-block
770 - enabled: Whether to enable the SageMaker Canvas app on
771 the Studio domain and attach ``AmazonSageMakerCanvasFullAccess``
772 to the SageMaker_Execution_Role (default: False)
773 - cognito: Cognito user-pool sub-block
774 - domain_prefix: UserPoolDomain prefix, or None to let the
775 analytics stack derive one (default: None)
776 - removal_policy: "destroy" (default) or "retain" — controls
777 the Cognito pool's CloudFormation DeletionPolicy
778 - efs: Studio_EFS sub-block
779 - removal_policy: "destroy" (default) or "retain" — controls
780 the Studio EFS file system's CloudFormation DeletionPolicy
781 - studio: SageMaker Studio sub-block
782 - user_profile_name_prefix: Optional prefix for per-user
783 profile names, or None to use the Cognito username verbatim
784 (default: None)
785 """
786 default_config: dict[str, Any] = {
787 "enabled": False,
788 "hyperpod": {"enabled": False},
789 "canvas": {"enabled": False},
790 "cognito": {"domain_prefix": None, "removal_policy": "destroy"},
791 "efs": {"removal_policy": "destroy"},
792 "studio": {"user_profile_name_prefix": None},
793 }
794 analytics_ctx = self.app.node.try_get_context("analytics_environment")
795 analytics_config: dict[str, Any] = analytics_ctx if isinstance(analytics_ctx, dict) else {}
796 merged_config: dict[str, Any] = {**default_config, **analytics_config}
798 # Deep-merge each nested sub-block so a partial override does not
799 # drop the other defaults in the same sub-block.
800 for sub_block in ("hyperpod", "canvas", "cognito", "efs", "studio"):
801 override = analytics_config.get(sub_block)
802 if isinstance(override, dict):
803 default_sub = cast(dict[str, Any], default_config[sub_block])
804 merged_config[sub_block] = {**default_sub, **override}
806 return merged_config
808 def get_analytics_enabled(self) -> bool:
809 """Return whether the analytics environment stack is enabled.
811 Thin wrapper around ``get_analytics_config()["enabled"]`` to mirror
812 the existing ``get_valkey_config`` / ``get_aurora_pgvector_config``
813 access pattern without forcing every call site to index into the
814 merged dict.
815 """
816 return bool(self.get_analytics_config()["enabled"])
818 def get_tags(self) -> dict[str, str]:
819 """Get common tags from configuration"""
820 return self.app.node.try_get_context("tags") or {}
822 def validate_region_availability(self, region: str) -> bool:
823 """Validate that a region is available in the current AWS account"""
824 try:
825 ec2 = boto3.client("ec2", region_name=region)
826 ec2.describe_regions(RegionNames=[region])
827 return True
828 except Exception as e:
829 logger.debug("Region %s not available: %s", region, e)
830 return False
832 def get_available_regions(self) -> list[str]:
833 """Get list of available AWS regions for the current account"""
834 try:
835 ec2 = boto3.client("ec2")
836 response = ec2.describe_regions()
837 return [region["RegionName"] for region in response["Regions"]]
838 except Exception as e:
839 logger.debug("Failed to list regions, using defaults: %s", e)
840 return list(self.VALID_REGIONS)