Coverage for gco / models / inference_models.py: 97%

110 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 21:47 +0000

1""" 

2Data models for inference endpoint management. 

3 

4Defines the schema for inference endpoints stored in DynamoDB and 

5used by the inference_monitor reconciliation loop. 

6""" 

7 

8from __future__ import annotations 

9 

10from dataclasses import dataclass, field 

11from enum import StrEnum 

12from typing import Any 

13 

14 

15class EndpointState(StrEnum): 

16 """Desired state for an inference endpoint.""" 

17 

18 DEPLOYING = "deploying" 

19 RUNNING = "running" 

20 STOPPED = "stopped" 

21 DELETED = "deleted" 

22 

23 

24class RegionSyncState(StrEnum): 

25 """Sync state of an endpoint in a specific region.""" 

26 

27 PENDING = "pending" 

28 CREATING = "creating" 

29 RUNNING = "running" 

30 UPDATING = "updating" 

31 STOPPING = "stopping" 

32 STOPPED = "stopped" 

33 DELETING = "deleting" 

34 DELETED = "deleted" 

35 ERROR = "error" 

36 

37 

38@dataclass 

39class InferenceEndpointSpec: 

40 """Specification for an inference endpoint deployment.""" 

41 

42 image: str 

43 port: int = 8000 

44 replicas: int = 1 

45 gpu_count: int = 1 

46 gpu_type: str | None = None # e.g. "g5.xlarge" — used for nodeSelector 

47 model_path: str | None = None # EFS/FSx path for model weights 

48 model_source: str | None = None # S3 URI (s3://bucket/path) or HuggingFace repo ID 

49 health_check_path: str = "/health" 

50 env: dict[str, str] = field(default_factory=dict) 

51 resources: dict[str, Any] = field(default_factory=dict) 

52 command: list[str] | None = None 

53 args: list[str] | None = None 

54 tolerations: list[dict[str, Any]] | None = None 

55 node_selector: dict[str, str] | None = None 

56 autoscaling: dict[str, Any] | None = ( 

57 None # {enabled, min_replicas, max_replicas, metrics: [{type, target}]} 

58 ) 

59 # Canary deployment fields 

60 canary: dict[str, Any] | None = None # {image, weight, replicas} 

61 # Capacity type: "on-demand", "spot", or "mixed" (base on-demand, overflow spot) 

62 capacity_type: str | None = None 

63 

64 def to_dict(self) -> dict[str, Any]: 

65 result: dict[str, Any] = { 

66 "image": self.image, 

67 "port": self.port, 

68 "replicas": self.replicas, 

69 "gpu_count": self.gpu_count, 

70 "health_check_path": self.health_check_path, 

71 } 

72 if self.gpu_type: 

73 result["gpu_type"] = self.gpu_type 

74 if self.model_path: 

75 result["model_path"] = self.model_path 

76 if self.model_source: 76 ↛ 77line 76 didn't jump to line 77 because the condition on line 76 was never true

77 result["model_source"] = self.model_source 

78 if self.env: 

79 result["env"] = self.env 

80 if self.resources: 

81 result["resources"] = self.resources 

82 if self.command: 

83 result["command"] = self.command 

84 if self.args: 

85 result["args"] = self.args 

86 if self.tolerations: 

87 result["tolerations"] = self.tolerations 

88 if self.node_selector: 

89 result["node_selector"] = self.node_selector 

90 if self.autoscaling: 90 ↛ 91line 90 didn't jump to line 91 because the condition on line 90 was never true

91 result["autoscaling"] = self.autoscaling 

92 if self.canary: 

93 result["canary"] = self.canary 

94 if self.capacity_type: 

95 result["capacity_type"] = self.capacity_type 

96 return result 

97 

98 @classmethod 

99 def from_dict(cls, data: dict[str, Any]) -> InferenceEndpointSpec: 

100 return cls( 

101 image=data["image"], 

102 port=data.get("port", 8000), 

103 replicas=data.get("replicas", 1), 

104 gpu_count=data.get("gpu_count", 1), 

105 gpu_type=data.get("gpu_type"), 

106 model_path=data.get("model_path"), 

107 model_source=data.get("model_source"), 

108 health_check_path=data.get("health_check_path", "/health"), 

109 env=data.get("env", {}), 

110 resources=data.get("resources", {}), 

111 command=data.get("command"), 

112 args=data.get("args"), 

113 tolerations=data.get("tolerations"), 

114 node_selector=data.get("node_selector"), 

115 autoscaling=data.get("autoscaling"), 

116 canary=data.get("canary"), 

117 capacity_type=data.get("capacity_type"), 

118 ) 

119 

120 

121@dataclass 

122class RegionStatus: 

123 """Status of an endpoint in a specific region.""" 

124 

125 region: str 

126 state: str = RegionSyncState.PENDING.value 

127 replicas_ready: int = 0 

128 replicas_desired: int = 0 

129 last_sync: str | None = None 

130 error: str | None = None 

131 endpoint_url: str | None = None 

132 

133 def to_dict(self) -> dict[str, Any]: 

134 result: dict[str, Any] = { 

135 "region": self.region, 

136 "state": self.state, 

137 "replicas_ready": self.replicas_ready, 

138 "replicas_desired": self.replicas_desired, 

139 } 

140 if self.last_sync: 

141 result["last_sync"] = self.last_sync 

142 if self.error: 

143 result["error"] = self.error 

144 if self.endpoint_url: 

145 result["endpoint_url"] = self.endpoint_url 

146 return result 

147 

148 

149@dataclass 

150class InferenceEndpoint: 

151 """An inference endpoint managed by GCO.""" 

152 

153 endpoint_name: str 

154 desired_state: str = EndpointState.DEPLOYING.value 

155 target_regions: list[str] = field(default_factory=list) 

156 namespace: str = "gco-inference" 

157 spec: InferenceEndpointSpec | dict[str, Any] = field( 

158 default_factory=lambda: InferenceEndpointSpec(image="") 

159 ) 

160 ingress_path: str = "" 

161 created_at: str | None = None 

162 updated_at: str | None = None 

163 created_by: str | None = None 

164 region_status: dict[str, Any] = field(default_factory=dict) 

165 labels: dict[str, str] = field(default_factory=dict) 

166 

167 def __post_init__(self) -> None: 

168 if isinstance(self.spec, dict): 

169 self.spec = InferenceEndpointSpec.from_dict(self.spec) 

170 if not self.ingress_path: 

171 self.ingress_path = f"/inference/{self.endpoint_name}" 

172 

173 def to_dict(self) -> dict[str, Any]: 

174 spec_dict = ( 

175 self.spec.to_dict() if isinstance(self.spec, InferenceEndpointSpec) else self.spec 

176 ) 

177 return { 

178 "endpoint_name": self.endpoint_name, 

179 "desired_state": self.desired_state, 

180 "target_regions": self.target_regions, 

181 "namespace": self.namespace, 

182 "spec": spec_dict, 

183 "ingress_path": self.ingress_path, 

184 "created_at": self.created_at, 

185 "updated_at": self.updated_at, 

186 "created_by": self.created_by, 

187 "region_status": self.region_status, 

188 "labels": self.labels, 

189 } 

190 

191 @classmethod 

192 def from_dict(cls, data: dict[str, Any]) -> InferenceEndpoint: 

193 return cls( 

194 endpoint_name=data["endpoint_name"], 

195 desired_state=data.get("desired_state", EndpointState.DEPLOYING.value), 

196 target_regions=data.get("target_regions", []), 

197 namespace=data.get("namespace", "gco-inference"), 

198 spec=data.get("spec", {}), 

199 ingress_path=data.get("ingress_path", ""), 

200 created_at=data.get("created_at"), 

201 updated_at=data.get("updated_at"), 

202 created_by=data.get("created_by"), 

203 region_status=data.get("region_status", {}), 

204 labels=data.get("labels", {}), 

205 )