Coverage for cli/models.py: 100%
78 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-15 15:07 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-15 15:07 +0000
1"""
2Model weight management for GCO CLI.
4Provides functionality to upload, list, and manage model weights
5in the central S3 model bucket. Models uploaded here are automatically
6available to inference endpoints across all regions via init container sync.
7"""
9from __future__ import annotations
11import logging
12import os
13from pathlib import Path
14from typing import Any
16import boto3
18from .config import GCOConfig, get_config
20logger = logging.getLogger(__name__)
23class ModelManager:
24 """Manages model weights in the central S3 bucket."""
26 def __init__(self, config: GCOConfig | None = None):
27 self.config = config or get_config()
28 self._bucket_name: str | None = None
30 def _get_bucket_name(self) -> str:
31 """Discover the model bucket name from SSM."""
32 if self._bucket_name:
33 return self._bucket_name
35 from gco.services.aws_ssm import get_ssm_parameter
37 try:
38 self._bucket_name = get_ssm_parameter(
39 f"/{self.config.project_name}/model-bucket-name",
40 region=self.config.global_region,
41 )
42 return self._bucket_name
43 except Exception as e:
44 raise RuntimeError(
45 "Model bucket not found. Deploy the global stack first "
46 "with 'gco stacks deploy gco-global'."
47 ) from e
49 def _get_s3_client(self) -> Any:
50 """Get S3 client for the global region."""
51 return boto3.client("s3", region_name=self.config.global_region)
53 def upload(
54 self,
55 local_path: str,
56 model_name: str,
57 prefix: str = "models",
58 ) -> dict[str, Any]:
59 """
60 Upload model weights to S3.
62 Args:
63 local_path: Local file or directory path
64 model_name: Name for the model in the bucket
65 prefix: S3 prefix (default: "models")
67 Returns:
68 Upload result with S3 URI and file count
69 """
70 bucket = self._get_bucket_name()
71 s3 = self._get_s3_client()
72 s3_prefix = f"{prefix}/{model_name}"
74 local = Path(local_path)
75 uploaded = 0
77 if local.is_file():
78 key = f"{s3_prefix}/{local.name}"
79 s3.upload_file(str(local), bucket, key)
80 uploaded = 1
81 elif local.is_dir():
82 for root, _dirs, files in os.walk(local):
83 for fname in files:
84 file_path = Path(root) / fname
85 relative = file_path.relative_to(local)
86 key = f"{s3_prefix}/{relative}"
87 s3.upload_file(str(file_path), bucket, key)
88 uploaded += 1
89 else:
90 raise FileNotFoundError(f"Path not found: {local_path}")
92 s3_uri = f"s3://{bucket}/{s3_prefix}"
93 return {
94 "model_name": model_name,
95 "s3_uri": s3_uri,
96 "bucket": bucket,
97 "prefix": s3_prefix,
98 "files_uploaded": uploaded,
99 }
101 def list_models(self, prefix: str = "models") -> list[dict[str, Any]]:
102 """List all models in the bucket."""
103 bucket = self._get_bucket_name()
104 s3 = self._get_s3_client()
106 # List top-level "directories" under the prefix
107 response = s3.list_objects_v2(
108 Bucket=bucket,
109 Prefix=f"{prefix}/",
110 Delimiter="/",
111 )
113 models = []
114 for cp in response.get("CommonPrefixes", []):
115 model_prefix = cp["Prefix"]
116 model_name = model_prefix.rstrip("/").split("/")[-1]
118 # Get total size and file count
119 total_size = 0
120 file_count = 0
121 paginator = s3.get_paginator("list_objects_v2")
122 for page in paginator.paginate(Bucket=bucket, Prefix=model_prefix):
123 for obj in page.get("Contents", []):
124 total_size += obj.get("Size", 0)
125 file_count += 1
127 models.append(
128 {
129 "model_name": model_name,
130 "s3_uri": f"s3://{bucket}/{model_prefix.rstrip('/')}",
131 "files": file_count,
132 "total_size_gb": round(total_size / (1024**3), 2),
133 }
134 )
136 return models
138 def get_model_uri(self, model_name: str, prefix: str = "models") -> str:
139 """Get the S3 URI for a model."""
140 bucket = self._get_bucket_name()
141 return f"s3://{bucket}/{prefix}/{model_name}"
143 def delete_model(self, model_name: str, prefix: str = "models") -> int:
144 """Delete a model and all its files from S3."""
145 bucket = self._get_bucket_name()
146 s3 = self._get_s3_client()
147 s3_prefix = f"{prefix}/{model_name}/"
149 # List and delete all objects
150 deleted = 0
151 paginator = s3.get_paginator("list_objects_v2")
152 for page in paginator.paginate(Bucket=bucket, Prefix=s3_prefix):
153 objects = [{"Key": obj["Key"]} for obj in page.get("Contents", [])]
154 if objects:
155 s3.delete_objects(Bucket=bucket, Delete={"Objects": objects})
156 deleted += len(objects)
158 return deleted
161def get_model_manager(config: GCOConfig | None = None) -> ModelManager:
162 """Factory function for ModelManager."""
163 return ModelManager(config)