Coverage for cli / costs.py: 99%

152 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 21:47 +0000

1"""Cost visibility for GCO workloads. 

2 

3Uses AWS Cost Explorer for historical spend and the Pricing API 

4for real-time cost estimates on running workloads. 

5""" 

6 

7from __future__ import annotations 

8 

9import logging 

10from dataclasses import dataclass, field 

11from datetime import UTC, datetime, timedelta 

12from typing import Any 

13 

14import boto3 

15 

16from .config import GCOConfig 

17 

18logger = logging.getLogger(__name__) 

19 

20 

21@dataclass 

22class ResourceCost: 

23 """Cost for a single resource or service.""" 

24 

25 service: str 

26 amount: float 

27 currency: str = "USD" 

28 region: str | None = None 

29 detail: str | None = None 

30 

31 

32@dataclass 

33class CostSummary: 

34 """Aggregated cost summary.""" 

35 

36 total: float 

37 currency: str = "USD" 

38 period_start: str = "" 

39 period_end: str = "" 

40 by_service: list[ResourceCost] = field(default_factory=list) 

41 by_region: dict[str, float] = field(default_factory=dict) 

42 

43 

44@dataclass 

45class WorkloadCost: 

46 """Estimated cost for a running workload.""" 

47 

48 name: str 

49 workload_type: str # "job" or "inference" 

50 instance_type: str 

51 gpu_count: int 

52 hourly_rate: float 

53 runtime_hours: float 

54 estimated_cost: float 

55 region: str 

56 status: str 

57 

58 

59class CostTracker: 

60 """Track and estimate costs for GCO resources.""" 

61 

62 def __init__(self, config: GCOConfig | None = None): 

63 self._config = config 

64 self._session = boto3.Session() 

65 self._pricing_cache: dict[str, float | None] = {} 

66 

67 def get_cost_summary( 

68 self, 

69 days: int = 30, 

70 granularity: str = "MONTHLY", 

71 unfiltered: bool = False, 

72 ) -> CostSummary: 

73 """Get cost summary from Cost Explorer filtered by GCO tags.""" 

74 ce = self._session.client("ce", region_name="us-east-1") 

75 

76 end = datetime.now(UTC).date() 

77 start = end - timedelta(days=days) 

78 

79 kwargs: dict[str, Any] = { 

80 "TimePeriod": { 

81 "Start": start.isoformat(), 

82 "End": end.isoformat(), 

83 }, 

84 "Granularity": granularity, 

85 "Metrics": ["UnblendedCost"], 

86 "GroupBy": [ 

87 {"Type": "DIMENSION", "Key": "SERVICE"}, 

88 ], 

89 } 

90 

91 if not unfiltered: 

92 kwargs["Filter"] = { 

93 "Tags": { 

94 "Key": "Project", 

95 "Values": ["GCO"], 

96 } 

97 } 

98 

99 try: 

100 response = ce.get_cost_and_usage(**kwargs) 

101 except Exception as e: 

102 raise RuntimeError(f"Cost Explorer query failed: {e}") from e 

103 

104 summary = CostSummary( 

105 total=0.0, 

106 period_start=start.isoformat(), 

107 period_end=end.isoformat(), 

108 ) 

109 

110 for result in response.get("ResultsByTime", []): 

111 for group in result.get("Groups", []): 

112 service = group["Keys"][0] 

113 amount = float(group["Metrics"]["UnblendedCost"]["Amount"]) 

114 if amount > 0.001: 

115 summary.by_service.append(ResourceCost(service=service, amount=amount)) 

116 summary.total += amount 

117 

118 # Sort by cost descending 

119 summary.by_service.sort(key=lambda x: x.amount, reverse=True) 

120 

121 return summary 

122 

123 def get_cost_by_region(self, days: int = 30) -> dict[str, float]: 

124 """Get cost breakdown by region.""" 

125 ce = self._session.client("ce", region_name="us-east-1") 

126 

127 end = datetime.now(UTC).date() 

128 start = end - timedelta(days=days) 

129 

130 try: 

131 response = ce.get_cost_and_usage( 

132 TimePeriod={ 

133 "Start": start.isoformat(), 

134 "End": end.isoformat(), 

135 }, 

136 Granularity="MONTHLY", 

137 Metrics=["UnblendedCost"], 

138 Filter={ 

139 "Tags": { 

140 "Key": "Project", 

141 "Values": ["GCO"], 

142 } 

143 }, 

144 GroupBy=[ 

145 {"Type": "DIMENSION", "Key": "REGION"}, 

146 ], 

147 ) 

148 except Exception as e: 

149 raise RuntimeError(f"Cost Explorer query failed: {e}") from e 

150 

151 by_region: dict[str, float] = {} 

152 for result in response.get("ResultsByTime", []): 

153 for group in result.get("Groups", []): 

154 region = group["Keys"][0] 

155 amount = float(group["Metrics"]["UnblendedCost"]["Amount"]) 

156 if amount > 0.001: 

157 by_region[region] = by_region.get(region, 0) + amount 

158 

159 return dict(sorted(by_region.items(), key=lambda x: x[1], reverse=True)) 

160 

161 def get_daily_trend(self, days: int = 14, unfiltered: bool = False) -> list[dict[str, Any]]: 

162 """Get daily cost trend.""" 

163 ce = self._session.client("ce", region_name="us-east-1") 

164 

165 end = datetime.now(UTC).date() 

166 start = end - timedelta(days=days) 

167 

168 kwargs: dict[str, Any] = { 

169 "TimePeriod": { 

170 "Start": start.isoformat(), 

171 "End": end.isoformat(), 

172 }, 

173 "Granularity": "DAILY", 

174 "Metrics": ["UnblendedCost"], 

175 } 

176 

177 if not unfiltered: 

178 kwargs["Filter"] = { 

179 "Tags": { 

180 "Key": "Project", 

181 "Values": ["GCO"], 

182 } 

183 } 

184 

185 try: 

186 response = ce.get_cost_and_usage(**kwargs) 

187 except Exception as e: 

188 raise RuntimeError(f"Cost Explorer query failed: {e}") from e 

189 

190 trend = [] 

191 for result in response.get("ResultsByTime", []): 

192 date = result["TimePeriod"]["Start"] 

193 amount = float(result["Total"]["UnblendedCost"]["Amount"]) 

194 trend.append({"date": date, "amount": amount}) 

195 

196 return trend 

197 

198 def estimate_running_workloads(self, region: str) -> list[WorkloadCost]: 

199 """Estimate costs for currently running workloads in a region.""" 

200 try: 

201 from .capacity import get_capacity_checker 

202 except ImportError: 

203 return [] 

204 

205 checker = get_capacity_checker(self._config) 

206 estimates: list[WorkloadCost] = [] 

207 

208 # Get running pods from EKS 

209 try: 

210 cluster_name = f"gco-{region}" 

211 

212 from .kubectl_helpers import update_kubeconfig 

213 

214 update_kubeconfig(cluster_name, region) 

215 

216 from kubernetes import client as k8s_client 

217 from kubernetes import config as k8s_config 

218 

219 k8s_config.load_kube_config() 

220 v1 = k8s_client.CoreV1Api() 

221 

222 # Check inference namespace 

223 for ns in ["gco-inference", "gco-jobs"]: 

224 try: 

225 pods = v1.list_namespaced_pod(namespace=ns) 

226 except Exception as e: 

227 logger.debug("Failed to list pods in %s: %s", ns, e) 

228 continue 

229 

230 for pod in pods.items: 

231 if pod.status.phase not in ("Running", "Pending"): 

232 continue 

233 

234 name = pod.metadata.name 

235 gpu_count = 0 

236 instance_type = "unknown" 

237 

238 # Get GPU requests 

239 for container in pod.spec.containers or []: 

240 requests = container.resources.requests or {} 

241 gpu_req = requests.get( # nosec B113 - dict.get(), not HTTP requests 

242 "nvidia.com/gpu", "0" 

243 ) 

244 gpu_count += int(gpu_req) 

245 

246 # Get node instance type 

247 if pod.spec.node_name: 

248 try: 

249 node = v1.read_node(pod.spec.node_name) 

250 instance_type = node.metadata.labels.get( 

251 "node.kubernetes.io/instance-type", "unknown" 

252 ) 

253 except Exception as e: 

254 logger.debug( 

255 "Failed to get node info for %s: %s", pod.spec.node_name, e 

256 ) 

257 

258 # Calculate cost 

259 hourly_rate = checker.get_on_demand_price(instance_type, region) or 0.0 

260 

261 # Calculate runtime 

262 start_time = pod.status.start_time 

263 if start_time: 

264 runtime = datetime.now(UTC) - start_time 

265 runtime_hours = runtime.total_seconds() / 3600 

266 else: 

267 runtime_hours = 0.0 

268 

269 workload_type = "inference" if ns == "gco-inference" else "job" 

270 

271 estimates.append( 

272 WorkloadCost( 

273 name=name, 

274 workload_type=workload_type, 

275 instance_type=instance_type, 

276 gpu_count=gpu_count, 

277 hourly_rate=hourly_rate, 

278 runtime_hours=round(runtime_hours, 2), 

279 estimated_cost=round(hourly_rate * runtime_hours, 4), 

280 region=region, 

281 status=pod.status.phase, 

282 ) 

283 ) 

284 

285 except Exception as e: 

286 logger.debug("Failed to estimate workload costs: %s", e) 

287 

288 return estimates 

289 

290 def get_forecast(self, days_ahead: int = 30) -> dict[str, Any]: 

291 """Get cost forecast for the next N days.""" 

292 ce = self._session.client("ce", region_name="us-east-1") 

293 

294 start = datetime.now(UTC).date() 

295 end = start + timedelta(days=days_ahead) 

296 

297 try: 

298 response = ce.get_cost_forecast( 

299 TimePeriod={ 

300 "Start": start.isoformat(), 

301 "End": end.isoformat(), 

302 }, 

303 Metric="UNBLENDED_COST", 

304 Granularity="MONTHLY", 

305 Filter={ 

306 "Tags": { 

307 "Key": "Project", 

308 "Values": ["GCO"], 

309 } 

310 }, 

311 ) 

312 

313 return { 

314 "forecast_total": float(response.get("Total", {}).get("Amount", 0)), 

315 "period_start": start.isoformat(), 

316 "period_end": end.isoformat(), 

317 } 

318 except Exception as e: 

319 return {"error": str(e)} 

320 

321 

322def get_cost_tracker(config: GCOConfig | None = None) -> CostTracker: 

323 """Factory function for CostTracker.""" 

324 return CostTracker(config=config)