Coverage for cli / dag.py: 94%
144 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 21:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 21:47 +0000
1"""Job DAG (Directed Acyclic Graph) runner for GCO.
3Allows defining multi-step ML pipelines where jobs run in dependency
4order. Each step can reference the output of a previous step via
5shared EFS storage.
7DAG definition format (YAML):
8 name: my-pipeline
9 region: us-east-1
10 namespace: gco-jobs
11 steps:
12 - name: preprocess
13 manifest: examples/preprocess-job.yaml
14 - name: train
15 manifest: examples/train-job.yaml
16 depends_on: [preprocess]
17 - name: evaluate
18 manifest: examples/evaluate-job.yaml
19 depends_on: [train]
20"""
22from __future__ import annotations
24import logging
25from collections.abc import Callable
26from dataclasses import dataclass, field
27from datetime import UTC, datetime
28from pathlib import Path
30import yaml
32from .config import GCOConfig, get_config
33from .jobs import JobManager, get_job_manager
35logger = logging.getLogger(__name__)
38@dataclass
39class DagStep:
40 """A single step in a DAG."""
42 name: str
43 manifest: str
44 depends_on: list[str] = field(default_factory=list)
45 status: str = "pending" # pending, running, succeeded, failed, skipped
46 job_name: str | None = None
47 started_at: str | None = None
48 completed_at: str | None = None
49 error: str | None = None
52@dataclass
53class DagDefinition:
54 """A DAG pipeline definition."""
56 name: str
57 steps: list[DagStep]
58 region: str | None = None
59 namespace: str = "gco-jobs"
61 def validate(self) -> list[str]:
62 """Validate the DAG structure. Returns list of errors."""
63 errors: list[str] = []
64 step_names = {s.name for s in self.steps}
66 # Check for duplicate step names
67 if len(step_names) != len(self.steps):
68 errors.append("Duplicate step names found")
70 # Check dependencies exist
71 for step in self.steps:
72 for dep in step.depends_on:
73 if dep not in step_names:
74 errors.append(f"Step '{step.name}' depends on unknown step '{dep}'")
76 # Check for cycles
77 if not errors:
78 visited: set[str] = set()
79 in_stack: set[str] = set()
80 dep_map = {s.name: s.depends_on for s in self.steps}
82 def has_cycle(node: str) -> bool:
83 visited.add(node)
84 in_stack.add(node)
85 for dep in dep_map.get(node, []):
86 if dep not in visited:
87 if has_cycle(dep): 87 ↛ 85line 87 didn't jump to line 85 because the condition on line 87 was always true
88 return True
89 elif dep in in_stack:
90 return True
91 in_stack.discard(node)
92 return False
94 for step in self.steps:
95 if step.name not in visited and has_cycle(step.name):
96 errors.append("Cycle detected in DAG dependencies")
97 break
99 # Check manifest files exist
100 for step in self.steps:
101 if not Path(step.manifest).exists():
102 errors.append(f"Manifest not found for step '{step.name}': {step.manifest}")
104 return errors
106 def get_ready_steps(self) -> list[DagStep]:
107 """Get steps whose dependencies are all satisfied."""
108 completed = {s.name for s in self.steps if s.status == "succeeded"}
109 ready = []
110 for step in self.steps:
111 if step.status != "pending":
112 continue
113 if all(dep in completed for dep in step.depends_on):
114 ready.append(step)
115 return ready
117 def is_complete(self) -> bool:
118 """Check if all steps are done (succeeded, failed, or skipped)."""
119 return all(s.status in ("succeeded", "failed", "skipped") for s in self.steps)
121 def has_failures(self) -> bool:
122 """Check if any step failed."""
123 return any(s.status == "failed" for s in self.steps)
126def load_dag(path: str) -> DagDefinition:
127 """Load a DAG definition from a YAML file."""
128 with open(path, encoding="utf-8") as f:
129 data = yaml.safe_load(f)
131 steps = []
132 for step_data in data.get("steps", []):
133 steps.append(
134 DagStep(
135 name=step_data["name"],
136 manifest=step_data["manifest"],
137 depends_on=step_data.get("depends_on", []),
138 )
139 )
141 return DagDefinition(
142 name=data.get("name", Path(path).stem),
143 steps=steps,
144 region=data.get("region"),
145 namespace=data.get("namespace", "gco-jobs"),
146 )
149class DagRunner:
150 """Executes a DAG by submitting jobs in dependency order."""
152 def __init__(
153 self,
154 config: GCOConfig | None = None,
155 job_manager: JobManager | None = None,
156 ):
157 self.config = config or get_config()
158 self.job_manager = job_manager or get_job_manager(config)
160 def run(
161 self,
162 dag: DagDefinition,
163 region: str | None = None,
164 timeout_per_step: int = 3600,
165 poll_interval: int = 10,
166 progress_callback: Callable[[str, str, str], None] | None = None,
167 ) -> DagDefinition:
168 """Execute a DAG, submitting steps as dependencies complete.
170 Args:
171 dag: The DAG definition to execute
172 region: Override region (default: dag.region or first deployed region)
173 timeout_per_step: Max seconds to wait per step
174 poll_interval: Seconds between status checks
175 progress_callback: Optional callable(step_name, status, message)
177 Returns:
178 The DAG with updated step statuses
179 """
180 target_region = region or dag.region
181 if not target_region:
182 stacks = self.job_manager._aws_client.discover_regional_stacks()
183 regions = list(stacks.keys())
184 if not regions: 184 ↛ 186line 184 didn't jump to line 186 because the condition on line 184 was always true
185 raise ValueError("No deployed regions found")
186 target_region = regions[0]
188 def _notify(step_name: str, status: str, msg: str) -> None:
189 if progress_callback:
190 progress_callback(step_name, status, msg)
192 _notify(dag.name, "started", f"Running DAG '{dag.name}' with {len(dag.steps)} steps")
194 while not dag.is_complete():
195 ready = dag.get_ready_steps()
197 if not ready:
198 # Check if we're stuck (all remaining steps have failed deps)
199 pending = [s for s in dag.steps if s.status == "pending"]
200 if pending: 200 ↛ 208line 200 didn't jump to line 208 because the condition on line 200 was always true
201 failed_names = {s.name for s in dag.steps if s.status == "failed"}
202 for step in pending:
203 if any(dep in failed_names for dep in step.depends_on): 203 ↛ 202line 203 didn't jump to line 202 because the condition on line 203 was always true
204 step.status = "skipped"
205 step.error = "Dependency failed"
206 _notify(step.name, "skipped", "Skipped (dependency failed)")
207 continue
208 break
210 # Submit all ready steps
211 for step in ready:
212 try:
213 step.status = "running"
214 step.started_at = datetime.now(UTC).isoformat()
215 _notify(step.name, "running", f"Submitting {step.manifest}")
217 # Submit the job via API Gateway
218 self.job_manager.submit_job(
219 manifests=step.manifest,
220 namespace=dag.namespace,
221 target_region=target_region,
222 )
224 # Extract job name from manifest
225 manifests = self.job_manager.load_manifests(step.manifest)
226 if manifests: 226 ↛ 229line 226 didn't jump to line 229 because the condition on line 226 was always true
227 step.job_name = manifests[0].get("metadata", {}).get("name", step.name)
229 _notify(step.name, "running", f"Job '{step.job_name}' submitted")
231 # Wait for completion
232 job_info = self.job_manager.wait_for_job(
233 job_name=step.job_name or step.name,
234 namespace=dag.namespace,
235 region=target_region,
236 timeout_seconds=timeout_per_step,
237 poll_interval=poll_interval,
238 )
240 if job_info.status in ("Complete", "Succeeded", "succeeded"): 240 ↛ 245line 240 didn't jump to line 245 because the condition on line 240 was always true
241 step.status = "succeeded"
242 step.completed_at = datetime.now(UTC).isoformat()
243 _notify(step.name, "succeeded", f"Step '{step.name}' completed")
244 else:
245 step.status = "failed"
246 step.completed_at = datetime.now(UTC).isoformat()
247 step.error = f"Job ended with status: {job_info.status}"
248 _notify(
249 step.name, "failed", f"Step '{step.name}' failed: {job_info.status}"
250 )
252 except TimeoutError as e:
253 step.status = "failed"
254 step.completed_at = datetime.now(UTC).isoformat()
255 step.error = str(e)
256 _notify(step.name, "failed", f"Step '{step.name}' timed out")
258 except Exception as e:
259 step.status = "failed"
260 step.completed_at = datetime.now(UTC).isoformat()
261 step.error = str(e)
262 _notify(step.name, "failed", f"Step '{step.name}' error: {e}")
264 status = "completed" if not dag.has_failures() else "completed with failures"
265 succeeded = sum(1 for s in dag.steps if s.status == "succeeded")
266 failed = sum(1 for s in dag.steps if s.status == "failed")
267 skipped = sum(1 for s in dag.steps if s.status == "skipped")
268 _notify(
269 dag.name,
270 status,
271 f"DAG '{dag.name}': {succeeded} succeeded, {failed} failed, {skipped} skipped",
272 )
274 return dag
277def get_dag_runner(config: GCOConfig | None = None) -> DagRunner:
278 """Factory function for DagRunner."""
279 return DagRunner(config)