Coverage for cli/models.py: 100%

78 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-15 15:07 +0000

1""" 

2Model weight management for GCO CLI. 

3 

4Provides functionality to upload, list, and manage model weights 

5in the central S3 model bucket. Models uploaded here are automatically 

6available to inference endpoints across all regions via init container sync. 

7""" 

8 

9from __future__ import annotations 

10 

11import logging 

12import os 

13from pathlib import Path 

14from typing import Any 

15 

16import boto3 

17 

18from .config import GCOConfig, get_config 

19 

20logger = logging.getLogger(__name__) 

21 

22 

23class ModelManager: 

24 """Manages model weights in the central S3 bucket.""" 

25 

26 def __init__(self, config: GCOConfig | None = None): 

27 self.config = config or get_config() 

28 self._bucket_name: str | None = None 

29 

30 def _get_bucket_name(self) -> str: 

31 """Discover the model bucket name from SSM.""" 

32 if self._bucket_name: 

33 return self._bucket_name 

34 

35 from gco.services.aws_ssm import get_ssm_parameter 

36 

37 try: 

38 self._bucket_name = get_ssm_parameter( 

39 f"/{self.config.project_name}/model-bucket-name", 

40 region=self.config.global_region, 

41 ) 

42 return self._bucket_name 

43 except Exception as e: 

44 raise RuntimeError( 

45 "Model bucket not found. Deploy the global stack first " 

46 "with 'gco stacks deploy gco-global'." 

47 ) from e 

48 

49 def _get_s3_client(self) -> Any: 

50 """Get S3 client for the global region.""" 

51 return boto3.client("s3", region_name=self.config.global_region) 

52 

53 def upload( 

54 self, 

55 local_path: str, 

56 model_name: str, 

57 prefix: str = "models", 

58 ) -> dict[str, Any]: 

59 """ 

60 Upload model weights to S3. 

61 

62 Args: 

63 local_path: Local file or directory path 

64 model_name: Name for the model in the bucket 

65 prefix: S3 prefix (default: "models") 

66 

67 Returns: 

68 Upload result with S3 URI and file count 

69 """ 

70 bucket = self._get_bucket_name() 

71 s3 = self._get_s3_client() 

72 s3_prefix = f"{prefix}/{model_name}" 

73 

74 local = Path(local_path) 

75 uploaded = 0 

76 

77 if local.is_file(): 

78 key = f"{s3_prefix}/{local.name}" 

79 s3.upload_file(str(local), bucket, key) 

80 uploaded = 1 

81 elif local.is_dir(): 

82 for root, _dirs, files in os.walk(local): 

83 for fname in files: 

84 file_path = Path(root) / fname 

85 relative = file_path.relative_to(local) 

86 key = f"{s3_prefix}/{relative}" 

87 s3.upload_file(str(file_path), bucket, key) 

88 uploaded += 1 

89 else: 

90 raise FileNotFoundError(f"Path not found: {local_path}") 

91 

92 s3_uri = f"s3://{bucket}/{s3_prefix}" 

93 return { 

94 "model_name": model_name, 

95 "s3_uri": s3_uri, 

96 "bucket": bucket, 

97 "prefix": s3_prefix, 

98 "files_uploaded": uploaded, 

99 } 

100 

101 def list_models(self, prefix: str = "models") -> list[dict[str, Any]]: 

102 """List all models in the bucket.""" 

103 bucket = self._get_bucket_name() 

104 s3 = self._get_s3_client() 

105 

106 # List top-level "directories" under the prefix 

107 response = s3.list_objects_v2( 

108 Bucket=bucket, 

109 Prefix=f"{prefix}/", 

110 Delimiter="/", 

111 ) 

112 

113 models = [] 

114 for cp in response.get("CommonPrefixes", []): 

115 model_prefix = cp["Prefix"] 

116 model_name = model_prefix.rstrip("/").split("/")[-1] 

117 

118 # Get total size and file count 

119 total_size = 0 

120 file_count = 0 

121 paginator = s3.get_paginator("list_objects_v2") 

122 for page in paginator.paginate(Bucket=bucket, Prefix=model_prefix): 

123 for obj in page.get("Contents", []): 

124 total_size += obj.get("Size", 0) 

125 file_count += 1 

126 

127 models.append( 

128 { 

129 "model_name": model_name, 

130 "s3_uri": f"s3://{bucket}/{model_prefix.rstrip('/')}", 

131 "files": file_count, 

132 "total_size_gb": round(total_size / (1024**3), 2), 

133 } 

134 ) 

135 

136 return models 

137 

138 def get_model_uri(self, model_name: str, prefix: str = "models") -> str: 

139 """Get the S3 URI for a model.""" 

140 bucket = self._get_bucket_name() 

141 return f"s3://{bucket}/{prefix}/{model_name}" 

142 

143 def delete_model(self, model_name: str, prefix: str = "models") -> int: 

144 """Delete a model and all its files from S3.""" 

145 bucket = self._get_bucket_name() 

146 s3 = self._get_s3_client() 

147 s3_prefix = f"{prefix}/{model_name}/" 

148 

149 # List and delete all objects 

150 deleted = 0 

151 paginator = s3.get_paginator("list_objects_v2") 

152 for page in paginator.paginate(Bucket=bucket, Prefix=s3_prefix): 

153 objects = [{"Key": obj["Key"]} for obj in page.get("Contents", [])] 

154 if objects: 

155 s3.delete_objects(Bucket=bucket, Delete={"Objects": objects}) 

156 deleted += len(objects) 

157 

158 return deleted 

159 

160 

161def get_model_manager(config: GCOConfig | None = None) -> ModelManager: 

162 """Factory function for ModelManager.""" 

163 return ModelManager(config)