Coverage for cli / dag.py: 94%

144 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 21:47 +0000

1"""Job DAG (Directed Acyclic Graph) runner for GCO. 

2 

3Allows defining multi-step ML pipelines where jobs run in dependency 

4order. Each step can reference the output of a previous step via 

5shared EFS storage. 

6 

7DAG definition format (YAML): 

8 name: my-pipeline 

9 region: us-east-1 

10 namespace: gco-jobs 

11 steps: 

12 - name: preprocess 

13 manifest: examples/preprocess-job.yaml 

14 - name: train 

15 manifest: examples/train-job.yaml 

16 depends_on: [preprocess] 

17 - name: evaluate 

18 manifest: examples/evaluate-job.yaml 

19 depends_on: [train] 

20""" 

21 

22from __future__ import annotations 

23 

24import logging 

25from collections.abc import Callable 

26from dataclasses import dataclass, field 

27from datetime import UTC, datetime 

28from pathlib import Path 

29 

30import yaml 

31 

32from .config import GCOConfig, get_config 

33from .jobs import JobManager, get_job_manager 

34 

35logger = logging.getLogger(__name__) 

36 

37 

38@dataclass 

39class DagStep: 

40 """A single step in a DAG.""" 

41 

42 name: str 

43 manifest: str 

44 depends_on: list[str] = field(default_factory=list) 

45 status: str = "pending" # pending, running, succeeded, failed, skipped 

46 job_name: str | None = None 

47 started_at: str | None = None 

48 completed_at: str | None = None 

49 error: str | None = None 

50 

51 

52@dataclass 

53class DagDefinition: 

54 """A DAG pipeline definition.""" 

55 

56 name: str 

57 steps: list[DagStep] 

58 region: str | None = None 

59 namespace: str = "gco-jobs" 

60 

61 def validate(self) -> list[str]: 

62 """Validate the DAG structure. Returns list of errors.""" 

63 errors: list[str] = [] 

64 step_names = {s.name for s in self.steps} 

65 

66 # Check for duplicate step names 

67 if len(step_names) != len(self.steps): 

68 errors.append("Duplicate step names found") 

69 

70 # Check dependencies exist 

71 for step in self.steps: 

72 for dep in step.depends_on: 

73 if dep not in step_names: 

74 errors.append(f"Step '{step.name}' depends on unknown step '{dep}'") 

75 

76 # Check for cycles 

77 if not errors: 

78 visited: set[str] = set() 

79 in_stack: set[str] = set() 

80 dep_map = {s.name: s.depends_on for s in self.steps} 

81 

82 def has_cycle(node: str) -> bool: 

83 visited.add(node) 

84 in_stack.add(node) 

85 for dep in dep_map.get(node, []): 

86 if dep not in visited: 

87 if has_cycle(dep): 87 ↛ 85line 87 didn't jump to line 85 because the condition on line 87 was always true

88 return True 

89 elif dep in in_stack: 

90 return True 

91 in_stack.discard(node) 

92 return False 

93 

94 for step in self.steps: 

95 if step.name not in visited and has_cycle(step.name): 

96 errors.append("Cycle detected in DAG dependencies") 

97 break 

98 

99 # Check manifest files exist 

100 for step in self.steps: 

101 if not Path(step.manifest).exists(): 

102 errors.append(f"Manifest not found for step '{step.name}': {step.manifest}") 

103 

104 return errors 

105 

106 def get_ready_steps(self) -> list[DagStep]: 

107 """Get steps whose dependencies are all satisfied.""" 

108 completed = {s.name for s in self.steps if s.status == "succeeded"} 

109 ready = [] 

110 for step in self.steps: 

111 if step.status != "pending": 

112 continue 

113 if all(dep in completed for dep in step.depends_on): 

114 ready.append(step) 

115 return ready 

116 

117 def is_complete(self) -> bool: 

118 """Check if all steps are done (succeeded, failed, or skipped).""" 

119 return all(s.status in ("succeeded", "failed", "skipped") for s in self.steps) 

120 

121 def has_failures(self) -> bool: 

122 """Check if any step failed.""" 

123 return any(s.status == "failed" for s in self.steps) 

124 

125 

126def load_dag(path: str) -> DagDefinition: 

127 """Load a DAG definition from a YAML file.""" 

128 with open(path, encoding="utf-8") as f: 

129 data = yaml.safe_load(f) 

130 

131 steps = [] 

132 for step_data in data.get("steps", []): 

133 steps.append( 

134 DagStep( 

135 name=step_data["name"], 

136 manifest=step_data["manifest"], 

137 depends_on=step_data.get("depends_on", []), 

138 ) 

139 ) 

140 

141 return DagDefinition( 

142 name=data.get("name", Path(path).stem), 

143 steps=steps, 

144 region=data.get("region"), 

145 namespace=data.get("namespace", "gco-jobs"), 

146 ) 

147 

148 

149class DagRunner: 

150 """Executes a DAG by submitting jobs in dependency order.""" 

151 

152 def __init__( 

153 self, 

154 config: GCOConfig | None = None, 

155 job_manager: JobManager | None = None, 

156 ): 

157 self.config = config or get_config() 

158 self.job_manager = job_manager or get_job_manager(config) 

159 

160 def run( 

161 self, 

162 dag: DagDefinition, 

163 region: str | None = None, 

164 timeout_per_step: int = 3600, 

165 poll_interval: int = 10, 

166 progress_callback: Callable[[str, str, str], None] | None = None, 

167 ) -> DagDefinition: 

168 """Execute a DAG, submitting steps as dependencies complete. 

169 

170 Args: 

171 dag: The DAG definition to execute 

172 region: Override region (default: dag.region or first deployed region) 

173 timeout_per_step: Max seconds to wait per step 

174 poll_interval: Seconds between status checks 

175 progress_callback: Optional callable(step_name, status, message) 

176 

177 Returns: 

178 The DAG with updated step statuses 

179 """ 

180 target_region = region or dag.region 

181 if not target_region: 

182 stacks = self.job_manager._aws_client.discover_regional_stacks() 

183 regions = list(stacks.keys()) 

184 if not regions: 184 ↛ 186line 184 didn't jump to line 186 because the condition on line 184 was always true

185 raise ValueError("No deployed regions found") 

186 target_region = regions[0] 

187 

188 def _notify(step_name: str, status: str, msg: str) -> None: 

189 if progress_callback: 

190 progress_callback(step_name, status, msg) 

191 

192 _notify(dag.name, "started", f"Running DAG '{dag.name}' with {len(dag.steps)} steps") 

193 

194 while not dag.is_complete(): 

195 ready = dag.get_ready_steps() 

196 

197 if not ready: 

198 # Check if we're stuck (all remaining steps have failed deps) 

199 pending = [s for s in dag.steps if s.status == "pending"] 

200 if pending: 200 ↛ 208line 200 didn't jump to line 208 because the condition on line 200 was always true

201 failed_names = {s.name for s in dag.steps if s.status == "failed"} 

202 for step in pending: 

203 if any(dep in failed_names for dep in step.depends_on): 203 ↛ 202line 203 didn't jump to line 202 because the condition on line 203 was always true

204 step.status = "skipped" 

205 step.error = "Dependency failed" 

206 _notify(step.name, "skipped", "Skipped (dependency failed)") 

207 continue 

208 break 

209 

210 # Submit all ready steps 

211 for step in ready: 

212 try: 

213 step.status = "running" 

214 step.started_at = datetime.now(UTC).isoformat() 

215 _notify(step.name, "running", f"Submitting {step.manifest}") 

216 

217 # Submit the job via API Gateway 

218 self.job_manager.submit_job( 

219 manifests=step.manifest, 

220 namespace=dag.namespace, 

221 target_region=target_region, 

222 ) 

223 

224 # Extract job name from manifest 

225 manifests = self.job_manager.load_manifests(step.manifest) 

226 if manifests: 226 ↛ 229line 226 didn't jump to line 229 because the condition on line 226 was always true

227 step.job_name = manifests[0].get("metadata", {}).get("name", step.name) 

228 

229 _notify(step.name, "running", f"Job '{step.job_name}' submitted") 

230 

231 # Wait for completion 

232 job_info = self.job_manager.wait_for_job( 

233 job_name=step.job_name or step.name, 

234 namespace=dag.namespace, 

235 region=target_region, 

236 timeout_seconds=timeout_per_step, 

237 poll_interval=poll_interval, 

238 ) 

239 

240 if job_info.status in ("Complete", "Succeeded", "succeeded"): 240 ↛ 245line 240 didn't jump to line 245 because the condition on line 240 was always true

241 step.status = "succeeded" 

242 step.completed_at = datetime.now(UTC).isoformat() 

243 _notify(step.name, "succeeded", f"Step '{step.name}' completed") 

244 else: 

245 step.status = "failed" 

246 step.completed_at = datetime.now(UTC).isoformat() 

247 step.error = f"Job ended with status: {job_info.status}" 

248 _notify( 

249 step.name, "failed", f"Step '{step.name}' failed: {job_info.status}" 

250 ) 

251 

252 except TimeoutError as e: 

253 step.status = "failed" 

254 step.completed_at = datetime.now(UTC).isoformat() 

255 step.error = str(e) 

256 _notify(step.name, "failed", f"Step '{step.name}' timed out") 

257 

258 except Exception as e: 

259 step.status = "failed" 

260 step.completed_at = datetime.now(UTC).isoformat() 

261 step.error = str(e) 

262 _notify(step.name, "failed", f"Step '{step.name}' error: {e}") 

263 

264 status = "completed" if not dag.has_failures() else "completed with failures" 

265 succeeded = sum(1 for s in dag.steps if s.status == "succeeded") 

266 failed = sum(1 for s in dag.steps if s.status == "failed") 

267 skipped = sum(1 for s in dag.steps if s.status == "skipped") 

268 _notify( 

269 dag.name, 

270 status, 

271 f"DAG '{dag.name}': {succeeded} succeeded, {failed} failed, {skipped} skipped", 

272 ) 

273 

274 return dag 

275 

276 

277def get_dag_runner(config: GCOConfig | None = None) -> DagRunner: 

278 """Factory function for DagRunner.""" 

279 return DagRunner(config)