Coverage for cli / models.py: 100%
79 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 21:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 21:47 +0000
1"""
2Model weight management for GCO CLI.
4Provides functionality to upload, list, and manage model weights
5in the central S3 model bucket. Models uploaded here are automatically
6available to inference endpoints across all regions via init container sync.
7"""
9from __future__ import annotations
11import logging
12import os
13from pathlib import Path
14from typing import Any
16import boto3
18from .config import GCOConfig, get_config
20logger = logging.getLogger(__name__)
23class ModelManager:
24 """Manages model weights in the central S3 bucket."""
26 def __init__(self, config: GCOConfig | None = None):
27 self.config = config or get_config()
28 self._bucket_name: str | None = None
30 def _get_bucket_name(self) -> str:
31 """Discover the model bucket name from SSM."""
32 if self._bucket_name:
33 return self._bucket_name
35 ssm = boto3.client("ssm", region_name=self.config.global_region)
36 try:
37 response = ssm.get_parameter(Name=f"/{self.config.project_name}/model-bucket-name")
38 self._bucket_name = response["Parameter"]["Value"]
39 return self._bucket_name
40 except Exception as e:
41 raise RuntimeError(
42 "Model bucket not found. Deploy the global stack first "
43 "with 'gco stacks deploy gco-global'."
44 ) from e
46 def _get_s3_client(self) -> Any:
47 """Get S3 client for the global region."""
48 return boto3.client("s3", region_name=self.config.global_region)
50 def upload(
51 self,
52 local_path: str,
53 model_name: str,
54 prefix: str = "models",
55 ) -> dict[str, Any]:
56 """
57 Upload model weights to S3.
59 Args:
60 local_path: Local file or directory path
61 model_name: Name for the model in the bucket
62 prefix: S3 prefix (default: "models")
64 Returns:
65 Upload result with S3 URI and file count
66 """
67 bucket = self._get_bucket_name()
68 s3 = self._get_s3_client()
69 s3_prefix = f"{prefix}/{model_name}"
71 local = Path(local_path)
72 uploaded = 0
74 if local.is_file():
75 key = f"{s3_prefix}/{local.name}"
76 s3.upload_file(str(local), bucket, key)
77 uploaded = 1
78 elif local.is_dir():
79 for root, _dirs, files in os.walk(local):
80 for fname in files:
81 file_path = Path(root) / fname
82 relative = file_path.relative_to(local)
83 key = f"{s3_prefix}/{relative}"
84 s3.upload_file(str(file_path), bucket, key)
85 uploaded += 1
86 else:
87 raise FileNotFoundError(f"Path not found: {local_path}")
89 s3_uri = f"s3://{bucket}/{s3_prefix}"
90 return {
91 "model_name": model_name,
92 "s3_uri": s3_uri,
93 "bucket": bucket,
94 "prefix": s3_prefix,
95 "files_uploaded": uploaded,
96 }
98 def list_models(self, prefix: str = "models") -> list[dict[str, Any]]:
99 """List all models in the bucket."""
100 bucket = self._get_bucket_name()
101 s3 = self._get_s3_client()
103 # List top-level "directories" under the prefix
104 response = s3.list_objects_v2(
105 Bucket=bucket,
106 Prefix=f"{prefix}/",
107 Delimiter="/",
108 )
110 models = []
111 for cp in response.get("CommonPrefixes", []):
112 model_prefix = cp["Prefix"]
113 model_name = model_prefix.rstrip("/").split("/")[-1]
115 # Get total size and file count
116 total_size = 0
117 file_count = 0
118 paginator = s3.get_paginator("list_objects_v2")
119 for page in paginator.paginate(Bucket=bucket, Prefix=model_prefix):
120 for obj in page.get("Contents", []):
121 total_size += obj.get("Size", 0)
122 file_count += 1
124 models.append(
125 {
126 "model_name": model_name,
127 "s3_uri": f"s3://{bucket}/{model_prefix.rstrip('/')}",
128 "files": file_count,
129 "total_size_gb": round(total_size / (1024**3), 2),
130 }
131 )
133 return models
135 def get_model_uri(self, model_name: str, prefix: str = "models") -> str:
136 """Get the S3 URI for a model."""
137 bucket = self._get_bucket_name()
138 return f"s3://{bucket}/{prefix}/{model_name}"
140 def delete_model(self, model_name: str, prefix: str = "models") -> int:
141 """Delete a model and all its files from S3."""
142 bucket = self._get_bucket_name()
143 s3 = self._get_s3_client()
144 s3_prefix = f"{prefix}/{model_name}/"
146 # List and delete all objects
147 deleted = 0
148 paginator = s3.get_paginator("list_objects_v2")
149 for page in paginator.paginate(Bucket=bucket, Prefix=s3_prefix):
150 objects = [{"Key": obj["Key"]} for obj in page.get("Contents", [])]
151 if objects:
152 s3.delete_objects(Bucket=bucket, Delete={"Objects": objects})
153 deleted += len(objects)
155 return deleted
158def get_model_manager(config: GCOConfig | None = None) -> ModelManager:
159 """Factory function for ModelManager."""
160 return ModelManager(config)