Coverage for gco / models / inference_models.py: 97%
110 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 21:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 21:47 +0000
1"""
2Data models for inference endpoint management.
4Defines the schema for inference endpoints stored in DynamoDB and
5used by the inference_monitor reconciliation loop.
6"""
8from __future__ import annotations
10from dataclasses import dataclass, field
11from enum import StrEnum
12from typing import Any
15class EndpointState(StrEnum):
16 """Desired state for an inference endpoint."""
18 DEPLOYING = "deploying"
19 RUNNING = "running"
20 STOPPED = "stopped"
21 DELETED = "deleted"
24class RegionSyncState(StrEnum):
25 """Sync state of an endpoint in a specific region."""
27 PENDING = "pending"
28 CREATING = "creating"
29 RUNNING = "running"
30 UPDATING = "updating"
31 STOPPING = "stopping"
32 STOPPED = "stopped"
33 DELETING = "deleting"
34 DELETED = "deleted"
35 ERROR = "error"
38@dataclass
39class InferenceEndpointSpec:
40 """Specification for an inference endpoint deployment."""
42 image: str
43 port: int = 8000
44 replicas: int = 1
45 gpu_count: int = 1
46 gpu_type: str | None = None # e.g. "g5.xlarge" — used for nodeSelector
47 model_path: str | None = None # EFS/FSx path for model weights
48 model_source: str | None = None # S3 URI (s3://bucket/path) or HuggingFace repo ID
49 health_check_path: str = "/health"
50 env: dict[str, str] = field(default_factory=dict)
51 resources: dict[str, Any] = field(default_factory=dict)
52 command: list[str] | None = None
53 args: list[str] | None = None
54 tolerations: list[dict[str, Any]] | None = None
55 node_selector: dict[str, str] | None = None
56 autoscaling: dict[str, Any] | None = (
57 None # {enabled, min_replicas, max_replicas, metrics: [{type, target}]}
58 )
59 # Canary deployment fields
60 canary: dict[str, Any] | None = None # {image, weight, replicas}
61 # Capacity type: "on-demand", "spot", or "mixed" (base on-demand, overflow spot)
62 capacity_type: str | None = None
64 def to_dict(self) -> dict[str, Any]:
65 result: dict[str, Any] = {
66 "image": self.image,
67 "port": self.port,
68 "replicas": self.replicas,
69 "gpu_count": self.gpu_count,
70 "health_check_path": self.health_check_path,
71 }
72 if self.gpu_type:
73 result["gpu_type"] = self.gpu_type
74 if self.model_path:
75 result["model_path"] = self.model_path
76 if self.model_source: 76 ↛ 77line 76 didn't jump to line 77 because the condition on line 76 was never true
77 result["model_source"] = self.model_source
78 if self.env:
79 result["env"] = self.env
80 if self.resources:
81 result["resources"] = self.resources
82 if self.command:
83 result["command"] = self.command
84 if self.args:
85 result["args"] = self.args
86 if self.tolerations:
87 result["tolerations"] = self.tolerations
88 if self.node_selector:
89 result["node_selector"] = self.node_selector
90 if self.autoscaling: 90 ↛ 91line 90 didn't jump to line 91 because the condition on line 90 was never true
91 result["autoscaling"] = self.autoscaling
92 if self.canary:
93 result["canary"] = self.canary
94 if self.capacity_type:
95 result["capacity_type"] = self.capacity_type
96 return result
98 @classmethod
99 def from_dict(cls, data: dict[str, Any]) -> InferenceEndpointSpec:
100 return cls(
101 image=data["image"],
102 port=data.get("port", 8000),
103 replicas=data.get("replicas", 1),
104 gpu_count=data.get("gpu_count", 1),
105 gpu_type=data.get("gpu_type"),
106 model_path=data.get("model_path"),
107 model_source=data.get("model_source"),
108 health_check_path=data.get("health_check_path", "/health"),
109 env=data.get("env", {}),
110 resources=data.get("resources", {}),
111 command=data.get("command"),
112 args=data.get("args"),
113 tolerations=data.get("tolerations"),
114 node_selector=data.get("node_selector"),
115 autoscaling=data.get("autoscaling"),
116 canary=data.get("canary"),
117 capacity_type=data.get("capacity_type"),
118 )
121@dataclass
122class RegionStatus:
123 """Status of an endpoint in a specific region."""
125 region: str
126 state: str = RegionSyncState.PENDING.value
127 replicas_ready: int = 0
128 replicas_desired: int = 0
129 last_sync: str | None = None
130 error: str | None = None
131 endpoint_url: str | None = None
133 def to_dict(self) -> dict[str, Any]:
134 result: dict[str, Any] = {
135 "region": self.region,
136 "state": self.state,
137 "replicas_ready": self.replicas_ready,
138 "replicas_desired": self.replicas_desired,
139 }
140 if self.last_sync:
141 result["last_sync"] = self.last_sync
142 if self.error:
143 result["error"] = self.error
144 if self.endpoint_url:
145 result["endpoint_url"] = self.endpoint_url
146 return result
149@dataclass
150class InferenceEndpoint:
151 """An inference endpoint managed by GCO."""
153 endpoint_name: str
154 desired_state: str = EndpointState.DEPLOYING.value
155 target_regions: list[str] = field(default_factory=list)
156 namespace: str = "gco-inference"
157 spec: InferenceEndpointSpec | dict[str, Any] = field(
158 default_factory=lambda: InferenceEndpointSpec(image="")
159 )
160 ingress_path: str = ""
161 created_at: str | None = None
162 updated_at: str | None = None
163 created_by: str | None = None
164 region_status: dict[str, Any] = field(default_factory=dict)
165 labels: dict[str, str] = field(default_factory=dict)
167 def __post_init__(self) -> None:
168 if isinstance(self.spec, dict):
169 self.spec = InferenceEndpointSpec.from_dict(self.spec)
170 if not self.ingress_path:
171 self.ingress_path = f"/inference/{self.endpoint_name}"
173 def to_dict(self) -> dict[str, Any]:
174 spec_dict = (
175 self.spec.to_dict() if isinstance(self.spec, InferenceEndpointSpec) else self.spec
176 )
177 return {
178 "endpoint_name": self.endpoint_name,
179 "desired_state": self.desired_state,
180 "target_regions": self.target_regions,
181 "namespace": self.namespace,
182 "spec": spec_dict,
183 "ingress_path": self.ingress_path,
184 "created_at": self.created_at,
185 "updated_at": self.updated_at,
186 "created_by": self.created_by,
187 "region_status": self.region_status,
188 "labels": self.labels,
189 }
191 @classmethod
192 def from_dict(cls, data: dict[str, Any]) -> InferenceEndpoint:
193 return cls(
194 endpoint_name=data["endpoint_name"],
195 desired_state=data.get("desired_state", EndpointState.DEPLOYING.value),
196 target_regions=data.get("target_regions", []),
197 namespace=data.get("namespace", "gco-inference"),
198 spec=data.get("spec", {}),
199 ingress_path=data.get("ingress_path", ""),
200 created_at=data.get("created_at"),
201 updated_at=data.get("updated_at"),
202 created_by=data.get("created_by"),
203 region_status=data.get("region_status", {}),
204 labels=data.get("labels", {}),
205 )