Skip to main content

consortium_ray/
tasks.rs

1//! DagTask implementations for Ray job orchestration.
2
3use std::process::Command;
4use std::time::{Duration, Instant};
5
6use consortium::dag::{DagContext, DagTask, TaskId, TaskOutcome};
7use consortium_nix::build;
8
9/// Build a ray job environment via nix.
10pub struct NixBuildRayEnvTask {
11    pub env_name: String,
12    pub flake_attr: String,
13}
14
15impl NixBuildRayEnvTask {
16    pub fn new(env_name: &str, flake_uri: &str) -> Self {
17        Self {
18            env_name: env_name.to_string(),
19            flake_attr: format!("{}#rayEnvs.{}", flake_uri, env_name),
20        }
21    }
22}
23
24impl DagTask for NixBuildRayEnvTask {
25    fn execute(&self, ctx: &DagContext) -> TaskOutcome {
26        match build::build_flake_attr(&self.flake_attr, None) {
27            Ok(path) => {
28                ctx.set_output(TaskId(format!("build-ray-env:{}", self.env_name)), path);
29                TaskOutcome::Success
30            }
31            Err(e) => TaskOutcome::Failed(format!("build ray env: {}", e)),
32        }
33    }
34
35    fn describe(&self) -> String {
36        format!("build ray environment '{}'", self.env_name)
37    }
38}
39
40/// Submit a ray job via the Ray Jobs API.
41pub struct RaySubmitTask {
42    pub job_name: String,
43    pub entrypoint: String,
44    pub head_address: String,
45    pub dashboard_port: u16,
46    pub working_dir: Option<String>,
47}
48
49impl DagTask for RaySubmitTask {
50    fn execute(&self, ctx: &DagContext) -> TaskOutcome {
51        let address = format!("http://{}:{}", self.head_address, self.dashboard_port);
52
53        let mut cmd = Command::new("ray");
54        cmd.args(["job", "submit", "--address", &address]);
55
56        // Use nix-built working dir if available
57        if let Some(ref dir) = self.working_dir {
58            cmd.args(["--working-dir", dir]);
59        } else if let Some(env_path) =
60            ctx.get_output::<String>(&TaskId(format!("build-ray-env:{}", self.job_name)))
61        {
62            cmd.args(["--working-dir", &env_path]);
63        }
64
65        cmd.arg("--").arg(&self.entrypoint);
66
67        let output = match cmd.output() {
68            Ok(o) => o,
69            Err(e) => return TaskOutcome::Failed(format!("ray submit failed: {}", e)),
70        };
71
72        if output.status.success() {
73            let stdout = String::from_utf8_lossy(&output.stdout);
74            // Parse job ID from output
75            let job_id = stdout
76                .lines()
77                .find(|l| l.contains("raysubmit_"))
78                .unwrap_or("unknown")
79                .trim()
80                .to_string();
81
82            ctx.set_output(TaskId(format!("ray-submit:{}", self.job_name)), job_id);
83            TaskOutcome::Success
84        } else {
85            let stderr = String::from_utf8_lossy(&output.stderr);
86            TaskOutcome::Failed(format!("ray submit failed: {}", stderr.trim()))
87        }
88    }
89
90    fn describe(&self) -> String {
91        format!("submit ray job '{}'", self.job_name)
92    }
93}
94
95/// Wait for a ray job to complete by polling the Jobs API.
96pub struct RayWaitTask {
97    pub job_name: String,
98    pub head_address: String,
99    pub dashboard_port: u16,
100    pub poll_interval: Duration,
101    pub timeout: Option<Duration>,
102}
103
104impl DagTask for RayWaitTask {
105    fn execute(&self, ctx: &DagContext) -> TaskOutcome {
106        let job_id: String = match ctx.get_output(&TaskId(format!("ray-submit:{}", self.job_name)))
107        {
108            Some(id) => id,
109            None => return TaskOutcome::Failed("no ray job ID from submit".into()),
110        };
111
112        let address = format!("http://{}:{}", self.head_address, self.dashboard_port);
113        let start = Instant::now();
114
115        loop {
116            if let Some(timeout) = self.timeout {
117                if start.elapsed() > timeout {
118                    return TaskOutcome::Failed(format!("ray job {} timed out", job_id));
119                }
120            }
121
122            let output = Command::new("ray")
123                .args(["job", "status", "--address", &address, &job_id])
124                .output();
125
126            match output {
127                Ok(o) if o.status.success() => {
128                    let stdout = String::from_utf8_lossy(&o.stdout).to_string();
129                    if stdout.contains("SUCCEEDED") {
130                        ctx.set_output(
131                            TaskId(format!("ray-wait:{}", self.job_name)),
132                            job_id.clone(),
133                        );
134                        return TaskOutcome::Success;
135                    } else if stdout.contains("FAILED") || stdout.contains("STOPPED") {
136                        return TaskOutcome::Failed(format!(
137                            "ray job {} ended: {}",
138                            job_id,
139                            stdout.trim()
140                        ));
141                    }
142                }
143                _ => {}
144            }
145
146            std::thread::sleep(self.poll_interval);
147        }
148    }
149
150    fn describe(&self) -> String {
151        format!("wait for ray job '{}'", self.job_name)
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    #[test]
160    fn test_ray_submit_job_id_parsing() {
161        // Ray job submit outputs a line like "raysubmit_abc123"
162        let stdout = "Job submitted successfully\nraysubmit_abc123def\nDone.";
163        let job_id = stdout
164            .lines()
165            .find(|l| l.contains("raysubmit_"))
166            .unwrap_or("unknown")
167            .trim()
168            .to_string();
169        assert_eq!(job_id, "raysubmit_abc123def");
170    }
171
172    #[test]
173    fn test_ray_submit_no_job_id() {
174        let stdout = "Some error output";
175        let job_id = stdout
176            .lines()
177            .find(|l| l.contains("raysubmit_"))
178            .unwrap_or("unknown")
179            .trim()
180            .to_string();
181        assert_eq!(job_id, "unknown");
182    }
183
184    #[test]
185    fn test_ray_job_status_detection() {
186        assert!("Status: SUCCEEDED".contains("SUCCEEDED"));
187        assert!("Status: FAILED".contains("FAILED"));
188        assert!("Status: STOPPED".contains("STOPPED"));
189        assert!(!"Status: RUNNING".contains("SUCCEEDED"));
190        assert!(!"Status: RUNNING".contains("FAILED"));
191    }
192
193    #[test]
194    fn test_describe_methods() {
195        let build = NixBuildRayEnvTask::new("train", ".");
196        assert!(build.describe().contains("train"));
197
198        let submit = RaySubmitTask {
199            job_name: "train".to_string(),
200            entrypoint: "python train.py".to_string(),
201            head_address: "localhost".to_string(),
202            dashboard_port: 8265,
203            working_dir: None,
204        };
205        assert!(submit.describe().contains("train"));
206    }
207}