Coverage for gco/config/config_loader.py: 99%

268 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-15 15:07 +0000

1""" 

2Configuration loader for GCO (Global Capacity Orchestrator on AWS). 

3 

4This module loads and validates configuration from CDK context (cdk.json). 

5It provides type-safe access to all configuration values with sensible defaults 

6and comprehensive validation. 

7 

8Configuration Sections: 

9- project_name: Unique identifier for the deployment 

10- regions: List of AWS regions to deploy to 

11- kubernetes_version: EKS Kubernetes version 

12- resource_thresholds: CPU/memory/GPU utilization thresholds 

13- global_accelerator: Global Accelerator settings 

14- alb_config: Application Load Balancer health check settings 

15- manifest_processor: Manifest validation and resource limits 

16- api_gateway: Throttling and logging configuration 

17- tags: Common tags applied to all resources 

18 

19Usage: 

20 config = ConfigLoader(app) 

21 regions = config.get_regions() 

22 cluster_config = config.get_cluster_config("us-east-1") 

23""" 

24 

25from __future__ import annotations 

26 

27import logging 

28from typing import Any, cast 

29 

30import boto3 

31from aws_cdk import App 

32 

33from gco.models import ClusterConfig, ResourceThresholds 

34 

35logger = logging.getLogger(__name__) 

36 

37 

38class ConfigValidationError(Exception): 

39 """Raised when configuration validation fails.""" 

40 

41 pass 

42 

43 

44class ConfigLoader: 

45 """ 

46 Loads and validates configuration from CDK context (cdk.json) 

47 """ 

48 

49 # Valid AWS regions (subset of commonly used regions) 

50 VALID_REGIONS = { 

51 "us-east-1", 

52 "us-east-2", 

53 "us-west-1", 

54 "us-west-2", 

55 "eu-west-1", 

56 "eu-west-2", 

57 "eu-west-3", 

58 "eu-central-1", 

59 "ap-southeast-1", 

60 "ap-southeast-2", 

61 "ap-northeast-1", 

62 "ap-northeast-2", 

63 "ca-central-1", 

64 "sa-east-1", 

65 } 

66 

67 def __init__(self, app: App): 

68 self.app = app 

69 self._validate_configuration() 

70 

71 def _validate_configuration(self) -> None: 

72 """Validate the entire configuration""" 

73 # Check if we have any context at all (might be running outside CDK) 

74 project_name = self.app.node.try_get_context("project_name") 

75 if project_name is None: 

76 # Running outside CDK context, skip validation 

77 return 

78 

79 # Validate required fields exist 

80 required_fields = [ 

81 "project_name", 

82 "kubernetes_version", 

83 "resource_thresholds", 

84 ] 

85 for field in required_fields: 

86 if not self.app.node.try_get_context(field): 

87 raise ConfigValidationError(f"Required configuration field '{field}' is missing") 

88 

89 # Check for deployment_regions 

90 deployment_regions = self.app.node.try_get_context("deployment_regions") 

91 if not deployment_regions: 

92 raise ConfigValidationError( 

93 "Required configuration field 'deployment_regions' is missing" 

94 ) 

95 

96 # Validate regions 

97 self._validate_regions() 

98 

99 # Validate resource thresholds 

100 self._validate_resource_thresholds() 

101 

102 # Validate Global Accelerator config 

103 self._validate_global_accelerator_config() 

104 

105 # Validate ALB config 

106 self._validate_alb_config() 

107 

108 # Validate manifest processor config 

109 self._validate_manifest_processor_config() 

110 

111 # Validate API Gateway config 

112 self._validate_api_gateway_config() 

113 

114 # Validate EKS cluster config 

115 self._validate_eks_cluster_config() 

116 

117 # Validate analytics environment config (optional block) 

118 self._validate_analytics_environment_config() 

119 

120 def _validate_regions(self) -> None: 

121 """Validate region configuration""" 

122 regions = self.get_regions() 

123 

124 if not regions: 

125 raise ConfigValidationError("At least one region must be specified") 

126 

127 if len(regions) > 10: 

128 raise ConfigValidationError("Maximum of 10 regions supported") 

129 

130 for region in regions: 

131 if region not in self.VALID_REGIONS: 

132 raise ConfigValidationError( 

133 f"Invalid region '{region}'. Valid regions: {sorted(self.VALID_REGIONS)}" 

134 ) 

135 

136 # Check for duplicates 

137 if len(regions) != len(set(regions)): 

138 raise ConfigValidationError("Duplicate regions found in configuration") 

139 

140 def _validate_resource_thresholds(self) -> None: 

141 """Validate resource threshold configuration""" 

142 thresholds_config = self.app.node.try_get_context("resource_thresholds") 

143 

144 required_thresholds = ["cpu_threshold", "memory_threshold", "gpu_threshold"] 

145 for threshold in required_thresholds: 

146 if threshold not in thresholds_config: 

147 raise ConfigValidationError(f"Missing threshold configuration: {threshold}") 

148 

149 value = thresholds_config[threshold] 

150 if not isinstance(value, int) or (value != -1 and not 0 <= value <= 100): 

151 raise ConfigValidationError( 

152 f"{threshold} must be an integer between 0 and 100 (or -1 to disable), got {value}" 

153 ) 

154 

155 # Validate optional thresholds if present 

156 for opt_threshold in [ 

157 "pending_pods_threshold", 

158 "pending_requested_cpu_vcpus", 

159 "pending_requested_memory_gb", 

160 "pending_requested_gpus", 

161 ]: 

162 if opt_threshold in thresholds_config: 

163 value = thresholds_config[opt_threshold] 

164 if not isinstance(value, int) or (value != -1 and value < 0): 

165 raise ConfigValidationError( 

166 f"{opt_threshold} must be a non-negative integer (or -1 to disable), got {value}" 

167 ) 

168 

169 def _validate_global_accelerator_config(self) -> None: 

170 """Validate Global Accelerator configuration""" 

171 ga_config = self.app.node.try_get_context("global_accelerator") 

172 if not ga_config: 

173 raise ConfigValidationError("global_accelerator configuration is required") 

174 

175 required_fields = [ 

176 "name", 

177 "health_check_grace_period", 

178 "health_check_interval", 

179 "health_check_timeout", 

180 "health_check_path", 

181 ] 

182 for field in required_fields: 

183 if field not in ga_config: 

184 raise ConfigValidationError(f"Missing global_accelerator configuration: {field}") 

185 

186 # Validate timing values 

187 for field in ["health_check_grace_period", "health_check_interval", "health_check_timeout"]: 

188 value = ga_config[field] 

189 if not isinstance(value, int) or value <= 0: 

190 raise ConfigValidationError(f"{field} must be a positive integer, got {value}") 

191 

192 # Validate health check path 

193 if not ga_config["health_check_path"].startswith("/"): 

194 raise ConfigValidationError("health_check_path must start with '/'") 

195 

196 # Validate optional client affinity. Omitting the key is allowed and 

197 # defaults to "NONE" in get_global_accelerator_config(). 

198 if "client_affinity" in ga_config: 

199 allowed_affinity = {"NONE", "SOURCE_IP"} 

200 value = ga_config["client_affinity"] 

201 if not isinstance(value, str) or value.upper() not in allowed_affinity: 

202 raise ConfigValidationError( 

203 f"client_affinity must be one of {sorted(allowed_affinity)}, got {value!r}" 

204 ) 

205 

206 def _validate_alb_config(self) -> None: 

207 """Validate ALB configuration""" 

208 alb_config = self.app.node.try_get_context("alb_config") 

209 if not alb_config: 

210 raise ConfigValidationError("alb_config configuration is required") 

211 

212 required_fields = [ 

213 "health_check_interval", 

214 "health_check_timeout", 

215 "healthy_threshold", 

216 "unhealthy_threshold", 

217 ] 

218 for field in required_fields: 

219 if field not in alb_config: 

220 raise ConfigValidationError(f"Missing alb_config configuration: {field}") 

221 

222 value = alb_config[field] 

223 if not isinstance(value, int) or value <= 0: 

224 raise ConfigValidationError(f"{field} must be a positive integer, got {value}") 

225 

226 def _validate_manifest_processor_config(self) -> None: 

227 """Validate manifest processor configuration. 

228 

229 The manifest processor section in cdk.json holds service-specific 

230 settings only. The shared validation policy (allowed_namespaces, 

231 resource_quotas, trusted_registries, trusted_dockerhub_orgs, 

232 manifest_security_policy, allowed_kinds) lives under 

233 ``job_validation_policy`` because the queue_processor reads the 

234 same values. 

235 """ 

236 mp_config = self.app.node.try_get_context("manifest_processor") 

237 if not mp_config: 

238 raise ConfigValidationError("manifest_processor configuration is required") 

239 

240 required_fields = [ 

241 "image", 

242 "replicas", 

243 "resource_limits", 

244 ] 

245 for field in required_fields: 

246 if field not in mp_config: 

247 raise ConfigValidationError(f"Missing manifest_processor configuration: {field}") 

248 

249 # Validate replicas 

250 if not isinstance(mp_config["replicas"], int) or mp_config["replicas"] <= 0: 

251 raise ConfigValidationError("manifest_processor replicas must be a positive integer") 

252 

253 # Validate the shared policy section separately so a misconfigured 

254 # policy block surfaces a clear error pointing at the right key. 

255 policy = self.app.node.try_get_context("job_validation_policy") 

256 if policy is None: 

257 raise ConfigValidationError( 

258 "job_validation_policy configuration is required (shared between " 

259 "manifest_processor and queue_processor)" 

260 ) 

261 for policy_field in ("allowed_namespaces", "resource_quotas"): 

262 if policy_field not in policy: 

263 raise ConfigValidationError( 

264 f"Missing job_validation_policy configuration: {policy_field}" 

265 ) 

266 

267 # Validate resource limits 

268 resource_limits = mp_config["resource_limits"] 

269 if "cpu" not in resource_limits or "memory" not in resource_limits: 

270 raise ConfigValidationError( 

271 "manifest_processor resource_limits must contain 'cpu' and 'memory'" 

272 ) 

273 

274 # Validate allowed namespaces (lives under job_validation_policy). 

275 if not isinstance(policy["allowed_namespaces"], list): 

276 raise ConfigValidationError("job_validation_policy.allowed_namespaces must be a list") 

277 

278 def _validate_api_gateway_config(self) -> None: 

279 """Validate API Gateway configuration""" 

280 api_gw_config = self.app.node.try_get_context("api_gateway") 

281 if not api_gw_config: 

282 raise ConfigValidationError("api_gateway configuration is required") 

283 

284 required_fields = [ 

285 "throttle_rate_limit", 

286 "throttle_burst_limit", 

287 "log_level", 

288 "metrics_enabled", 

289 "tracing_enabled", 

290 ] 

291 for field in required_fields: 

292 if field not in api_gw_config: 

293 raise ConfigValidationError(f"Missing api_gateway configuration: {field}") 

294 

295 # Validate throttle limits 

296 throttle_rate = api_gw_config["throttle_rate_limit"] 

297 throttle_burst = api_gw_config["throttle_burst_limit"] 

298 

299 if not isinstance(throttle_rate, int) or throttle_rate <= 0: 

300 raise ConfigValidationError( 

301 f"throttle_rate_limit must be a positive integer, got {throttle_rate}" 

302 ) 

303 

304 if not isinstance(throttle_burst, int) or throttle_burst <= 0: 

305 raise ConfigValidationError( 

306 f"throttle_burst_limit must be a positive integer, got {throttle_burst}" 

307 ) 

308 

309 if throttle_burst < throttle_rate: 

310 raise ConfigValidationError( 

311 "throttle_burst_limit should be greater than or equal to throttle_rate_limit" 

312 ) 

313 

314 # Validate log level 

315 valid_log_levels = ["OFF", "ERROR", "INFO"] 

316 log_level = api_gw_config["log_level"] 

317 if log_level not in valid_log_levels: 

318 raise ConfigValidationError( 

319 f"log_level must be one of {valid_log_levels}, got {log_level}" 

320 ) 

321 

322 # Validate boolean flags 

323 if not isinstance(api_gw_config["metrics_enabled"], bool): 

324 raise ConfigValidationError("metrics_enabled must be a boolean") 

325 

326 if not isinstance(api_gw_config["tracing_enabled"], bool): 

327 raise ConfigValidationError("tracing_enabled must be a boolean") 

328 

329 def _validate_eks_cluster_config(self) -> None: 

330 """Validate EKS cluster configuration""" 

331 eks_config = self.app.node.try_get_context("eks_cluster") or {} 

332 

333 # Validate endpoint_access if present 

334 if "endpoint_access" in eks_config: 

335 valid_access_modes = ["PRIVATE", "PUBLIC_AND_PRIVATE"] 

336 if eks_config["endpoint_access"] not in valid_access_modes: 

337 raise ConfigValidationError( 

338 f"endpoint_access must be one of {valid_access_modes}, " 

339 f"got {eks_config['endpoint_access']}" 

340 ) 

341 

342 def _validate_analytics_environment_config(self) -> None: 

343 """Validate the optional analytics_environment block in cdk.json. 

344 

345 The block is entirely optional; absence means the feature is disabled 

346 and no validation is needed. When present, we validate: 

347 

348 - ``enabled``: must be a bool if present (defaults to False via merge). 

349 - ``hyperpod.enabled``: must be a bool if present (defaults to False). 

350 - ``cognito.removal_policy`` and ``efs.removal_policy``: must be the 

351 literal strings ``"destroy"`` or ``"retain"`` (case sensitive — they 

352 are passed verbatim to CDK's ``RemovalPolicy`` lookup by the 

353 consumer). 

354 """ 

355 analytics_ctx = self.app.node.try_get_context("analytics_environment") 

356 if not isinstance(analytics_ctx, dict): 

357 # Block is absent or malformed — defaults apply, nothing to validate. 

358 return 

359 

360 # Top-level `enabled` must be a bool if provided. 

361 if "enabled" in analytics_ctx and not isinstance(analytics_ctx["enabled"], bool): 

362 raise ConfigValidationError( 

363 f"analytics_environment.enabled must be a bool, got " 

364 f"{type(analytics_ctx['enabled']).__name__}: {analytics_ctx['enabled']!r}" 

365 ) 

366 

367 # `hyperpod.enabled` must be a bool if the sub-block is a dict and 

368 # carries the key. 

369 hyperpod_ctx = analytics_ctx.get("hyperpod") 

370 if ( 

371 isinstance(hyperpod_ctx, dict) 

372 and "enabled" in hyperpod_ctx 

373 and not isinstance(hyperpod_ctx["enabled"], bool) 

374 ): 

375 raise ConfigValidationError( 

376 f"analytics_environment.hyperpod.enabled must be a bool, got " 

377 f"{type(hyperpod_ctx['enabled']).__name__}: {hyperpod_ctx['enabled']!r}" 

378 ) 

379 

380 # `canvas.enabled` must be a bool if the sub-block is a dict and 

381 # carries the key. Mirrors the hyperpod validation above. 

382 canvas_ctx = analytics_ctx.get("canvas") 

383 if ( 

384 isinstance(canvas_ctx, dict) 

385 and "enabled" in canvas_ctx 

386 and not isinstance(canvas_ctx["enabled"], bool) 

387 ): 

388 raise ConfigValidationError( 

389 f"analytics_environment.canvas.enabled must be a bool, got " 

390 f"{type(canvas_ctx['enabled']).__name__}: {canvas_ctx['enabled']!r}" 

391 ) 

392 

393 valid_removal_policies = {"destroy", "retain"} 

394 

395 for sub_block in ("cognito", "efs"): 

396 sub_ctx = analytics_ctx.get(sub_block) 

397 if not isinstance(sub_ctx, dict): 

398 continue 

399 if "removal_policy" not in sub_ctx: 

400 continue 

401 removal_policy = sub_ctx["removal_policy"] 

402 if removal_policy not in valid_removal_policies: 

403 raise ConfigValidationError( 

404 f"analytics_environment.{sub_block}.removal_policy must be one of " 

405 f"{sorted(valid_removal_policies)}, got {removal_policy!r}" 

406 ) 

407 

408 def get_project_name(self) -> str: 

409 """Get project name from configuration""" 

410 return self.app.node.try_get_context("project_name") or "gco" 

411 

412 def get_deployment_regions(self) -> dict[str, Any]: 

413 """Get deployment regions configuration. 

414 

415 Returns a dict with: 

416 - global: Region for Global Accelerator and SSM parameters (default: us-east-2) 

417 - api_gateway: Region for API Gateway stack (default: us-east-2) 

418 - monitoring: Region for Monitoring stack (default: us-east-2) 

419 - regional: List of regions for EKS clusters (default: ["us-east-1"]) 

420 

421 Note: Global Accelerator is a global service but requires a "home" region 

422 for CloudFormation deployment. us-east-2 is used by default to keep 

423 global infrastructure separate from workload regions. 

424 """ 

425 deployment_regions = self.app.node.try_get_context("deployment_regions") or {} 

426 

427 return { 

428 "global": deployment_regions.get("global", "us-east-2"), 

429 "api_gateway": deployment_regions.get("api_gateway", "us-east-2"), 

430 "monitoring": deployment_regions.get("monitoring", "us-east-2"), 

431 "regional": deployment_regions.get("regional", ["us-east-1"]), 

432 } 

433 

434 def get_global_region(self) -> str: 

435 """Get the region for global resources (Global Accelerator, SSM params).""" 

436 region = self.get_deployment_regions()["global"] 

437 return str(region) 

438 

439 def get_api_gateway_region(self) -> str: 

440 """Get the region for API Gateway stack.""" 

441 region = self.get_deployment_regions()["api_gateway"] 

442 return str(region) 

443 

444 def get_monitoring_region(self) -> str: 

445 """Get the region for Monitoring stack.""" 

446 region = self.get_deployment_regions()["monitoring"] 

447 return str(region) 

448 

449 def get_regions(self) -> list[str]: 

450 """Get list of regions for EKS cluster deployment.""" 

451 deployment_regions = self.get_deployment_regions() 

452 regional = deployment_regions["regional"] 

453 return list(regional) if isinstance(regional, list) else [str(regional)] 

454 

455 def get_kubernetes_version(self) -> str: 

456 """Get Kubernetes version from configuration""" 

457 return self.app.node.try_get_context("kubernetes_version") or "1.36" 

458 

459 def get_resource_thresholds(self) -> ResourceThresholds: 

460 """Get resource thresholds configuration""" 

461 thresholds_config = self.app.node.try_get_context("resource_thresholds") or { 

462 "cpu_threshold": 60, 

463 "memory_threshold": 60, 

464 "gpu_threshold": -1, 

465 "pending_pods_threshold": 10, 

466 "pending_requested_cpu_vcpus": 100, 

467 "pending_requested_memory_gb": 200, 

468 "pending_requested_gpus": -1, 

469 } 

470 return ResourceThresholds( 

471 cpu_threshold=thresholds_config["cpu_threshold"], 

472 memory_threshold=thresholds_config["memory_threshold"], 

473 gpu_threshold=thresholds_config["gpu_threshold"], 

474 pending_pods_threshold=thresholds_config.get("pending_pods_threshold", 10), 

475 pending_requested_cpu_vcpus=thresholds_config.get("pending_requested_cpu_vcpus", 100), 

476 pending_requested_memory_gb=thresholds_config.get("pending_requested_memory_gb", 200), 

477 pending_requested_gpus=thresholds_config.get("pending_requested_gpus", 8), 

478 ) 

479 

480 def get_cluster_config(self, region: str) -> ClusterConfig: 

481 """Get complete cluster configuration for a region""" 

482 return ClusterConfig( 

483 region=region, 

484 cluster_name=f"{self.get_project_name()}-{region}", 

485 kubernetes_version=self.get_kubernetes_version(), 

486 addons=["metrics-server"], 

487 resource_thresholds=self.get_resource_thresholds(), 

488 ) 

489 

490 def get_global_accelerator_config(self) -> dict[str, Any]: 

491 """Get Global Accelerator configuration""" 

492 return self.app.node.try_get_context("global_accelerator") or { 

493 "name": "gco-accelerator", 

494 "health_check_grace_period": 30, 

495 "health_check_interval": 30, 

496 "health_check_timeout": 5, 

497 "health_check_path": "/api/v1/health", 

498 "client_affinity": "NONE", 

499 } 

500 

501 def get_alb_config(self) -> dict[str, Any]: 

502 """Get ALB configuration""" 

503 return self.app.node.try_get_context("alb_config") or { 

504 "health_check_interval": 30, 

505 "health_check_timeout": 5, 

506 "healthy_threshold": 2, 

507 "unhealthy_threshold": 2, 

508 } 

509 

510 def get_manifest_processor_config(self) -> dict[str, Any]: 

511 """Get manifest processor configuration. 

512 

513 Merges three cdk.json sections into a single runtime config: 

514 

515 - ``manifest_processor``: service-specific settings (replicas, image, 

516 resource_limits, allowed_namespaces, validation_enabled, 

517 max_request_body_bytes, yaml_max_depth) 

518 - ``job_validation_policy``: shared validation policy (resource_quotas, 

519 trusted_registries, trusted_dockerhub_orgs, manifest_security_policy, 

520 allowed_kinds). Pulled in verbatim so the REST path reads the same 

521 policy the SQS queue processor enforces. 

522 

523 Note: The 'image' field is a placeholder default. In practice, the actual 

524 image is built from dockerfiles/manifest-processor-dockerfile and pushed 

525 to ECR during CDK deployment. The {{MANIFEST_PROCESSOR_IMAGE}} placeholder 

526 in manifests is replaced with the ECR image URI. 

527 """ 

528 default_config = { 

529 "image": "gco/manifest-processor:latest", # Placeholder, replaced by ECR image 

530 "replicas": 3, 

531 "resource_limits": {"cpu": "1000m", "memory": "2Gi"}, 

532 "validation_enabled": True, 

533 # allowed_namespaces, resource_quotas, trusted_registries, 

534 # trusted_dockerhub_orgs, manifest_security_policy, and 

535 # allowed_kinds are merged in below from job_validation_policy. 

536 "allowed_namespaces": ["default", "gco-jobs"], 

537 "resource_quotas": { 

538 "max_cpu_per_manifest": "10", 

539 "max_memory_per_manifest": "32Gi", 

540 "max_gpu_per_manifest": 4, 

541 }, 

542 "trusted_registries": [ 

543 "docker.io", 

544 "gcr.io", 

545 "quay.io", 

546 "registry.k8s.io", 

547 "k8s.gcr.io", 

548 "public.ecr.aws", 

549 "nvcr.io", 

550 "gco", 

551 ], 

552 "trusted_dockerhub_orgs": [ 

553 "nvidia", 

554 "pytorch", 

555 "rayproject", 

556 "tensorflow", 

557 "huggingface", 

558 "amazon", 

559 "bitnami", 

560 ], 

561 } 

562 context_config = self.app.node.try_get_context("manifest_processor") or {} 

563 

564 # Merge in the shared job_validation_policy section. These keys apply 

565 # to BOTH the manifest processor and the queue processor; they live 

566 # in their own top-level cdk.json section so neither service "owns" 

567 # them. We flatten them into the manifest processor's runtime config 

568 # so service code keeps its existing attribute layout. 

569 shared_policy = self.app.node.try_get_context("job_validation_policy") or {} 

570 return {**default_config, **context_config, **shared_policy} 

571 

572 def get_api_gateway_config(self) -> dict[str, Any]: 

573 """Get API Gateway configuration. 

574 

575 Returns: 

576 API Gateway configuration dictionary with the following keys: 

577 - throttle_rate_limit: Requests per second limit 

578 - throttle_burst_limit: Burst capacity 

579 - log_level: CloudWatch logging level (OFF, ERROR, INFO) 

580 - metrics_enabled: Enable CloudWatch metrics 

581 - tracing_enabled: Enable X-Ray tracing 

582 - regional_api_enabled: Enable regional API Gateways for private access 

583 When true, deploys a regional API Gateway with VPC Lambda in each 

584 region, allowing API access when the ALB is internal-only. 

585 """ 

586 default_config = { 

587 "throttle_rate_limit": 1000, 

588 "throttle_burst_limit": 2000, 

589 "log_level": "INFO", 

590 "metrics_enabled": True, 

591 "tracing_enabled": True, 

592 "regional_api_enabled": False, 

593 } 

594 return {**default_config, **(self.app.node.try_get_context("api_gateway") or {})} 

595 

596 def get_eks_cluster_config(self) -> dict[str, Any]: 

597 """Get EKS cluster configuration. 

598 

599 Returns: 

600 EKS cluster configuration dictionary with the following keys: 

601 - endpoint_access: EKS API endpoint access mode 

602 - "PRIVATE": API server only accessible from within VPC (default, most secure) 

603 - "PUBLIC_AND_PRIVATE": API server accessible from internet and VPC 

604 

605 Note: 

606 PRIVATE endpoint is recommended for production. Job submission still works 

607 via API Gateway → Lambda (in VPC) or SQS queues. For kubectl access with 

608 PRIVATE endpoint, use a bastion host, VPN, or AWS SSM Session Manager. 

609 """ 

610 default_config = { 

611 "endpoint_access": "PRIVATE", 

612 } 

613 return {**default_config, **(self.app.node.try_get_context("eks_cluster") or {})} 

614 

615 def get_fsx_lustre_config(self, region: str | None = None) -> dict[str, Any]: 

616 """Get FSx for Lustre configuration. 

617 

618 Args: 

619 region: Optional region to get config for. If provided, checks for 

620 region-specific overrides first. 

621 

622 Returns: 

623 FSx configuration dictionary with the following keys: 

624 - enabled: Whether FSx is enabled 

625 - storage_capacity_gib: Storage capacity in GiB (min 1200) 

626 - deployment_type: SCRATCH_1, SCRATCH_2, PERSISTENT_1, PERSISTENT_2 

627 - file_system_type_version: Lustre version (2.12 or 2.15, default: 2.15) 

628 IMPORTANT: Use 2.15 for kernel 6.x compatibility (AL2023, Bottlerocket) 

629 - per_unit_storage_throughput: Throughput for PERSISTENT types 

630 - data_compression_type: LZ4 or NONE 

631 - import_path: S3 path for data import 

632 - export_path: S3 path for data export 

633 - auto_import_policy: NEW, NEW_CHANGED, NEW_CHANGED_DELETED 

634 - node_group: Node group configuration for FSx workloads 

635 - instance_types: List of instance types 

636 - min_size: Minimum nodes (default: 0) 

637 - max_size: Maximum nodes (default: 10) 

638 - desired_size: Desired nodes (default: 0, scales from zero) 

639 - ami_type: AMI type - one of: 

640 AL2023_X86_64_STANDARD (default), AL2023_ARM_64_STANDARD, 

641 AL2023_X86_64_NVIDIA, AL2023_ARM_64_NVIDIA, AL2023_X86_64_NEURON 

642 - capacity_type: ON_DEMAND (default) or SPOT 

643 - disk_size: Root disk size in GB (default: 100) 

644 - labels: Additional node labels (dict) 

645 """ 

646 default_config = { 

647 "enabled": False, 

648 "storage_capacity_gib": 1200, 

649 "deployment_type": "SCRATCH_2", 

650 "file_system_type_version": "2.15", # Use 2.15 for kernel 6.x compatibility 

651 "per_unit_storage_throughput": 200, 

652 "data_compression_type": "LZ4", 

653 "import_path": None, 

654 "export_path": None, 

655 "auto_import_policy": "NEW_CHANGED_DELETED", 

656 "node_group": { 

657 "instance_types": ["m5.large", "m5.xlarge", "m6i.large", "m6i.xlarge"], 

658 "min_size": 0, 

659 "max_size": 10, 

660 "desired_size": 1, 

661 "ami_type": "AL2023_X86_64_STANDARD", 

662 "capacity_type": "ON_DEMAND", 

663 "disk_size": 100, 

664 "labels": {}, 

665 }, 

666 } 

667 

668 # Get global FSx config 

669 global_ctx = self.app.node.try_get_context("fsx_lustre") 

670 global_config: dict[str, Any] = global_ctx if isinstance(global_ctx, dict) else {} 

671 merged_config: dict[str, Any] = {**default_config, **global_config} 

672 

673 # Ensure node_group has all required fields with defaults 

674 if "node_group" in global_config: 

675 global_node_group = global_config["node_group"] 

676 if isinstance(global_node_group, dict): 676 ↛ 684line 676 didn't jump to line 684 because the condition on line 676 was always true

677 default_node_group = cast(dict[str, Any], default_config["node_group"]) 

678 merged_config["node_group"] = { 

679 **default_node_group, 

680 **global_node_group, 

681 } 

682 

683 # Check for region-specific override 

684 if region: 

685 region_overrides_ctx = self.app.node.try_get_context("fsx_lustre_regions") 

686 region_overrides: dict[str, Any] = ( 

687 region_overrides_ctx if isinstance(region_overrides_ctx, dict) else {} 

688 ) 

689 if region in region_overrides: 

690 region_config = region_overrides[region] 

691 if isinstance(region_config, dict): 691 ↛ 707line 691 didn't jump to line 707 because the condition on line 691 was always true

692 merged_config = {**merged_config, **region_config} 

693 # Handle nested node_group override 

694 if "node_group" in region_config: 

695 region_node_group = region_config["node_group"] 

696 if isinstance(region_node_group, dict): 696 ↛ 707line 696 didn't jump to line 707 because the condition on line 696 was always true

697 existing_node_group = merged_config.get("node_group") 

698 if isinstance(existing_node_group, dict): 698 ↛ 701line 698 didn't jump to line 701 because the condition on line 698 was always true

699 base_node_group = existing_node_group 

700 else: 

701 base_node_group = cast(dict[str, Any], default_config["node_group"]) 

702 merged_config["node_group"] = { 

703 **base_node_group, 

704 **region_node_group, 

705 } 

706 

707 return merged_config 

708 

709 def get_valkey_config(self) -> dict[str, Any]: 

710 """Get Valkey Serverless cache configuration. 

711 

712 Returns: 

713 Valkey configuration dictionary with the following keys: 

714 - enabled: Whether Valkey cache is enabled (default: True) 

715 - max_data_storage_gb: Maximum data storage in GB (default: 5) 

716 - max_ecpu_per_second: Maximum ECPUs per second (default: 5000) 

717 - snapshot_retention_limit: Daily snapshots to retain (default: 1) 

718 """ 

719 default_config: dict[str, Any] = { 

720 "enabled": True, 

721 "max_data_storage_gb": 5, 

722 "max_ecpu_per_second": 5000, 

723 "snapshot_retention_limit": 1, 

724 } 

725 valkey_ctx = self.app.node.try_get_context("valkey") 

726 valkey_config: dict[str, Any] = valkey_ctx if isinstance(valkey_ctx, dict) else {} 

727 return {**default_config, **valkey_config} 

728 

729 def get_aurora_pgvector_config(self) -> dict[str, Any]: 

730 """Get Aurora Serverless v2 + pgvector vector database configuration. 

731 

732 Returns: 

733 Aurora pgvector configuration dictionary with the following keys: 

734 - enabled: Whether Aurora pgvector is enabled (default: False) 

735 - min_acu: Minimum Aurora Capacity Units (default: 0, scales to zero) 

736 - max_acu: Maximum Aurora Capacity Units (default: 16) 

737 - backup_retention_days: Number of days to retain automated backups (default: 7) 

738 - deletion_protection: Whether deletion protection is enabled (default: False) 

739 """ 

740 default_config: dict[str, Any] = { 

741 "enabled": False, 

742 "min_acu": 0, 

743 "max_acu": 16, 

744 "backup_retention_days": 7, 

745 "deletion_protection": False, 

746 } 

747 aurora_ctx = self.app.node.try_get_context("aurora_pgvector") 

748 aurora_config: dict[str, Any] = aurora_ctx if isinstance(aurora_ctx, dict) else {} 

749 return {**default_config, **aurora_config} 

750 

751 def get_analytics_config(self) -> dict[str, Any]: 

752 """Get optional analytics environment configuration. 

753 

754 Returns the fully-merged analytics_environment block from cdk.json 

755 layered on top of the defaults below. Sub-blocks (``hyperpod``, 

756 ``cognito``, ``efs``, ``studio``) are deep-merged so a user who 

757 overrides a single nested key (e.g. ``cognito.domain_prefix``) does 

758 not inadvertently wipe the sub-block's other defaults — mirroring the 

759 nested-merge pattern used by ``get_fsx_lustre_config`` for its 

760 ``node_group`` sub-block. 

761 

762 Returns: 

763 Analytics configuration dictionary with the following keys: 

764 - enabled: Whether the analytics environment stack is deployed 

765 (default: False — the feature is off unless explicitly opted in) 

766 - hyperpod: SageMaker HyperPod integration sub-block 

767 - enabled: Whether to add the HyperPod IAM grants to 

768 SageMaker_Execution_Role (default: False) 

769 - canvas: SageMaker Canvas integration sub-block 

770 - enabled: Whether to enable the SageMaker Canvas app on 

771 the Studio domain and attach ``AmazonSageMakerCanvasFullAccess`` 

772 to the SageMaker_Execution_Role (default: False) 

773 - cognito: Cognito user-pool sub-block 

774 - domain_prefix: UserPoolDomain prefix, or None to let the 

775 analytics stack derive one (default: None) 

776 - removal_policy: "destroy" (default) or "retain" — controls 

777 the Cognito pool's CloudFormation DeletionPolicy 

778 - efs: Studio_EFS sub-block 

779 - removal_policy: "destroy" (default) or "retain" — controls 

780 the Studio EFS file system's CloudFormation DeletionPolicy 

781 - studio: SageMaker Studio sub-block 

782 - user_profile_name_prefix: Optional prefix for per-user 

783 profile names, or None to use the Cognito username verbatim 

784 (default: None) 

785 """ 

786 default_config: dict[str, Any] = { 

787 "enabled": False, 

788 "hyperpod": {"enabled": False}, 

789 "canvas": {"enabled": False}, 

790 "cognito": {"domain_prefix": None, "removal_policy": "destroy"}, 

791 "efs": {"removal_policy": "destroy"}, 

792 "studio": {"user_profile_name_prefix": None}, 

793 } 

794 analytics_ctx = self.app.node.try_get_context("analytics_environment") 

795 analytics_config: dict[str, Any] = analytics_ctx if isinstance(analytics_ctx, dict) else {} 

796 merged_config: dict[str, Any] = {**default_config, **analytics_config} 

797 

798 # Deep-merge each nested sub-block so a partial override does not 

799 # drop the other defaults in the same sub-block. 

800 for sub_block in ("hyperpod", "canvas", "cognito", "efs", "studio"): 

801 override = analytics_config.get(sub_block) 

802 if isinstance(override, dict): 

803 default_sub = cast(dict[str, Any], default_config[sub_block]) 

804 merged_config[sub_block] = {**default_sub, **override} 

805 

806 return merged_config 

807 

808 def get_analytics_enabled(self) -> bool: 

809 """Return whether the analytics environment stack is enabled. 

810 

811 Thin wrapper around ``get_analytics_config()["enabled"]`` to mirror 

812 the existing ``get_valkey_config`` / ``get_aurora_pgvector_config`` 

813 access pattern without forcing every call site to index into the 

814 merged dict. 

815 """ 

816 return bool(self.get_analytics_config()["enabled"]) 

817 

818 def get_tags(self) -> dict[str, str]: 

819 """Get common tags from configuration""" 

820 return self.app.node.try_get_context("tags") or {} 

821 

822 def validate_region_availability(self, region: str) -> bool: 

823 """Validate that a region is available in the current AWS account""" 

824 try: 

825 ec2 = boto3.client("ec2", region_name=region) 

826 ec2.describe_regions(RegionNames=[region]) 

827 return True 

828 except Exception as e: 

829 logger.debug("Region %s not available: %s", region, e) 

830 return False 

831 

832 def get_available_regions(self) -> list[str]: 

833 """Get list of available AWS regions for the current account""" 

834 try: 

835 ec2 = boto3.client("ec2") 

836 response = ec2.describe_regions() 

837 return [region["RegionName"] for region in response["Regions"]] 

838 except Exception as e: 

839 logger.debug("Failed to list regions, using defaults: %s", e) 

840 return list(self.VALID_REGIONS)