Coverage for cli / models.py: 100%

79 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 21:47 +0000

1""" 

2Model weight management for GCO CLI. 

3 

4Provides functionality to upload, list, and manage model weights 

5in the central S3 model bucket. Models uploaded here are automatically 

6available to inference endpoints across all regions via init container sync. 

7""" 

8 

9from __future__ import annotations 

10 

11import logging 

12import os 

13from pathlib import Path 

14from typing import Any 

15 

16import boto3 

17 

18from .config import GCOConfig, get_config 

19 

20logger = logging.getLogger(__name__) 

21 

22 

23class ModelManager: 

24 """Manages model weights in the central S3 bucket.""" 

25 

26 def __init__(self, config: GCOConfig | None = None): 

27 self.config = config or get_config() 

28 self._bucket_name: str | None = None 

29 

30 def _get_bucket_name(self) -> str: 

31 """Discover the model bucket name from SSM.""" 

32 if self._bucket_name: 

33 return self._bucket_name 

34 

35 ssm = boto3.client("ssm", region_name=self.config.global_region) 

36 try: 

37 response = ssm.get_parameter(Name=f"/{self.config.project_name}/model-bucket-name") 

38 self._bucket_name = response["Parameter"]["Value"] 

39 return self._bucket_name 

40 except Exception as e: 

41 raise RuntimeError( 

42 "Model bucket not found. Deploy the global stack first " 

43 "with 'gco stacks deploy gco-global'." 

44 ) from e 

45 

46 def _get_s3_client(self) -> Any: 

47 """Get S3 client for the global region.""" 

48 return boto3.client("s3", region_name=self.config.global_region) 

49 

50 def upload( 

51 self, 

52 local_path: str, 

53 model_name: str, 

54 prefix: str = "models", 

55 ) -> dict[str, Any]: 

56 """ 

57 Upload model weights to S3. 

58 

59 Args: 

60 local_path: Local file or directory path 

61 model_name: Name for the model in the bucket 

62 prefix: S3 prefix (default: "models") 

63 

64 Returns: 

65 Upload result with S3 URI and file count 

66 """ 

67 bucket = self._get_bucket_name() 

68 s3 = self._get_s3_client() 

69 s3_prefix = f"{prefix}/{model_name}" 

70 

71 local = Path(local_path) 

72 uploaded = 0 

73 

74 if local.is_file(): 

75 key = f"{s3_prefix}/{local.name}" 

76 s3.upload_file(str(local), bucket, key) 

77 uploaded = 1 

78 elif local.is_dir(): 

79 for root, _dirs, files in os.walk(local): 

80 for fname in files: 

81 file_path = Path(root) / fname 

82 relative = file_path.relative_to(local) 

83 key = f"{s3_prefix}/{relative}" 

84 s3.upload_file(str(file_path), bucket, key) 

85 uploaded += 1 

86 else: 

87 raise FileNotFoundError(f"Path not found: {local_path}") 

88 

89 s3_uri = f"s3://{bucket}/{s3_prefix}" 

90 return { 

91 "model_name": model_name, 

92 "s3_uri": s3_uri, 

93 "bucket": bucket, 

94 "prefix": s3_prefix, 

95 "files_uploaded": uploaded, 

96 } 

97 

98 def list_models(self, prefix: str = "models") -> list[dict[str, Any]]: 

99 """List all models in the bucket.""" 

100 bucket = self._get_bucket_name() 

101 s3 = self._get_s3_client() 

102 

103 # List top-level "directories" under the prefix 

104 response = s3.list_objects_v2( 

105 Bucket=bucket, 

106 Prefix=f"{prefix}/", 

107 Delimiter="/", 

108 ) 

109 

110 models = [] 

111 for cp in response.get("CommonPrefixes", []): 

112 model_prefix = cp["Prefix"] 

113 model_name = model_prefix.rstrip("/").split("/")[-1] 

114 

115 # Get total size and file count 

116 total_size = 0 

117 file_count = 0 

118 paginator = s3.get_paginator("list_objects_v2") 

119 for page in paginator.paginate(Bucket=bucket, Prefix=model_prefix): 

120 for obj in page.get("Contents", []): 

121 total_size += obj.get("Size", 0) 

122 file_count += 1 

123 

124 models.append( 

125 { 

126 "model_name": model_name, 

127 "s3_uri": f"s3://{bucket}/{model_prefix.rstrip('/')}", 

128 "files": file_count, 

129 "total_size_gb": round(total_size / (1024**3), 2), 

130 } 

131 ) 

132 

133 return models 

134 

135 def get_model_uri(self, model_name: str, prefix: str = "models") -> str: 

136 """Get the S3 URI for a model.""" 

137 bucket = self._get_bucket_name() 

138 return f"s3://{bucket}/{prefix}/{model_name}" 

139 

140 def delete_model(self, model_name: str, prefix: str = "models") -> int: 

141 """Delete a model and all its files from S3.""" 

142 bucket = self._get_bucket_name() 

143 s3 = self._get_s3_client() 

144 s3_prefix = f"{prefix}/{model_name}/" 

145 

146 # List and delete all objects 

147 deleted = 0 

148 paginator = s3.get_paginator("list_objects_v2") 

149 for page in paginator.paginate(Bucket=bucket, Prefix=s3_prefix): 

150 objects = [{"Key": obj["Key"]} for obj in page.get("Contents", [])] 

151 if objects: 

152 s3.delete_objects(Bucket=bucket, Delete={"Objects": objects}) 

153 deleted += len(objects) 

154 

155 return deleted 

156 

157 

158def get_model_manager(config: GCOConfig | None = None) -> ModelManager: 

159 """Factory function for ModelManager.""" 

160 return ModelManager(config)