Skip to main content

consortium_slurm/
tasks.rs

1//! DagTask implementations for Slurm job orchestration.
2//!
3//! Pipeline: build-job-env → copy-to-submit → submit → wait → collect
4
5use std::process::Command;
6use std::time::{Duration, Instant};
7
8use consortium::dag::{DagContext, DagTask, TaskId, TaskOutcome};
9use consortium_nix::{build, copy};
10
11/// Build a hermetic job environment via nix.
12pub struct NixBuildJobEnvTask {
13    pub job_name: String,
14    pub flake_attr: String,
15}
16
17impl NixBuildJobEnvTask {
18    pub fn new(job_name: &str, flake_uri: &str) -> Self {
19        Self {
20            job_name: job_name.to_string(),
21            flake_attr: format!("{}#slurmEnvs.{}", flake_uri, job_name),
22        }
23    }
24}
25
26impl DagTask for NixBuildJobEnvTask {
27    fn execute(&self, ctx: &DagContext) -> TaskOutcome {
28        match build::build_flake_attr(&self.flake_attr, None) {
29            Ok(path) => {
30                ctx.set_output(TaskId(format!("build-job-env:{}", self.job_name)), path);
31                TaskOutcome::Success
32            }
33            Err(e) => TaskOutcome::Failed(format!("build job env: {}", e)),
34        }
35    }
36
37    fn describe(&self) -> String {
38        format!("build slurm job env '{}'", self.job_name)
39    }
40}
41
42/// Copy the job environment to the submit node.
43pub struct NixCopyToSubmitTask {
44    pub job_name: String,
45    pub submit_host: String,
46    pub submit_user: String,
47}
48
49impl DagTask for NixCopyToSubmitTask {
50    fn execute(&self, ctx: &DagContext) -> TaskOutcome {
51        let store_path: String =
52            match ctx.get_output(&TaskId(format!("build-job-env:{}", self.job_name))) {
53                Some(p) => p,
54                None => return TaskOutcome::Failed("no job env build output".into()),
55            };
56
57        let store_uri = format!("ssh-ng://{}@{}", self.submit_user, self.submit_host);
58        match copy::copy_closure(&store_path, &store_uri) {
59            Ok(()) => {
60                ctx.set_output(
61                    TaskId(format!("copy-job-env:{}", self.job_name)),
62                    store_path,
63                );
64                TaskOutcome::Success
65            }
66            Err(e) => TaskOutcome::Failed(format!("copy job env: {}", e)),
67        }
68    }
69
70    fn describe(&self) -> String {
71        format!("copy job env '{}' to {}", self.job_name, self.submit_host)
72    }
73}
74
75/// Submit a slurm job via sbatch.
76pub struct SlurmSubmitTask {
77    pub job_name: String,
78    pub script: String,
79    pub partition: Option<String>,
80    pub submit_host: String,
81    pub submit_user: String,
82}
83
84impl DagTask for SlurmSubmitTask {
85    fn execute(&self, ctx: &DagContext) -> TaskOutcome {
86        // Build sbatch command
87        let mut sbatch_args = vec!["sbatch".to_string()];
88        sbatch_args.push(format!("--job-name={}", self.job_name));
89
90        if let Some(ref partition) = self.partition {
91            sbatch_args.push(format!("--partition={}", partition));
92        }
93
94        // If we have a nix env, set PATH in the job
95        if let Some(env_path) =
96            ctx.get_output::<String>(&TaskId(format!("copy-job-env:{}", self.job_name)))
97        {
98            sbatch_args.push(format!("--export=ALL,PATH={}/bin:$PATH", env_path));
99        }
100
101        sbatch_args.push(self.script.clone());
102
103        let ssh_cmd = sbatch_args.join(" ");
104        let output = Command::new("ssh")
105            .args([
106                "-oStrictHostKeyChecking=no",
107                "-oPasswordAuthentication=no",
108                "-l",
109                &self.submit_user,
110                &self.submit_host,
111                &ssh_cmd,
112            ])
113            .output();
114
115        match output {
116            Ok(o) if o.status.success() => {
117                let stdout = String::from_utf8_lossy(&o.stdout);
118                // sbatch output: "Submitted batch job 12345"
119                let job_id: u64 = stdout
120                    .trim()
121                    .rsplit_once(' ')
122                    .and_then(|(_, id)| id.parse().ok())
123                    .unwrap_or(0);
124
125                ctx.set_output(TaskId(format!("slurm-submit:{}", self.job_name)), job_id);
126                TaskOutcome::Success
127            }
128            Ok(o) => {
129                let stderr = String::from_utf8_lossy(&o.stderr);
130                TaskOutcome::Failed(format!("sbatch failed: {}", stderr.trim()))
131            }
132            Err(e) => TaskOutcome::Failed(format!("sbatch exec failed: {}", e)),
133        }
134    }
135
136    fn describe(&self) -> String {
137        format!("submit slurm job '{}'", self.job_name)
138    }
139}
140
141/// Wait for a slurm job to complete by polling sacct.
142pub struct SlurmWaitTask {
143    pub job_name: String,
144    pub submit_host: String,
145    pub submit_user: String,
146    pub poll_interval: Duration,
147    pub timeout: Option<Duration>,
148}
149
150impl SlurmWaitTask {
151    pub fn new(job_name: &str, submit_host: &str, submit_user: &str) -> Self {
152        Self {
153            job_name: job_name.to_string(),
154            submit_host: submit_host.to_string(),
155            submit_user: submit_user.to_string(),
156            poll_interval: Duration::from_secs(10),
157            timeout: None,
158        }
159    }
160}
161
162impl DagTask for SlurmWaitTask {
163    fn execute(&self, ctx: &DagContext) -> TaskOutcome {
164        let job_id: u64 = match ctx.get_output(&TaskId(format!("slurm-submit:{}", self.job_name))) {
165            Some(id) => id,
166            None => return TaskOutcome::Failed("no job ID from submit".into()),
167        };
168
169        let start = Instant::now();
170
171        loop {
172            // Check timeout
173            if let Some(timeout) = self.timeout {
174                if start.elapsed() > timeout {
175                    return TaskOutcome::Failed(format!(
176                        "job {} timed out after {}s",
177                        job_id,
178                        start.elapsed().as_secs()
179                    ));
180                }
181            }
182
183            // Poll sacct for job state
184            let sacct_cmd = format!(
185                "sacct -j {} --format=State --noheader --parsable2 | head -1",
186                job_id
187            );
188            let output = Command::new("ssh")
189                .args([
190                    "-oStrictHostKeyChecking=no",
191                    "-oPasswordAuthentication=no",
192                    "-l",
193                    &self.submit_user,
194                    &self.submit_host,
195                    &sacct_cmd,
196                ])
197                .output();
198
199            match output {
200                Ok(o) if o.status.success() => {
201                    let state = String::from_utf8_lossy(&o.stdout).trim().to_string();
202                    match state.as_str() {
203                        "COMPLETED" => {
204                            ctx.set_output(TaskId(format!("slurm-wait:{}", self.job_name)), job_id);
205                            return TaskOutcome::Success;
206                        }
207                        "FAILED" | "CANCELLED" | "TIMEOUT" | "OUT_OF_MEMORY" | "NODE_FAIL" => {
208                            return TaskOutcome::Failed(format!(
209                                "job {} ended with state: {}",
210                                job_id, state
211                            ));
212                        }
213                        // PENDING, RUNNING, etc. — keep polling
214                        _ => {}
215                    }
216                }
217                _ => {} // SSH error — retry next poll
218            }
219
220            std::thread::sleep(self.poll_interval);
221        }
222    }
223
224    fn describe(&self) -> String {
225        format!("wait for slurm job '{}'", self.job_name)
226    }
227}
228
229/// Collect results from a completed slurm job.
230pub struct SlurmCollectTask {
231    pub job_name: String,
232    pub output_pattern: String,
233    pub submit_host: String,
234    pub submit_user: String,
235}
236
237impl DagTask for SlurmCollectTask {
238    fn execute(&self, ctx: &DagContext) -> TaskOutcome {
239        let _job_id: u64 = match ctx.get_output(&TaskId(format!("slurm-wait:{}", self.job_name))) {
240            Some(id) => id,
241            None => return TaskOutcome::Failed("job not completed".into()),
242        };
243
244        // Collect output files via SCP or SSH cat
245        let output = Command::new("ssh")
246            .args([
247                "-oStrictHostKeyChecking=no",
248                "-l",
249                &self.submit_user,
250                &self.submit_host,
251                &format!("cat {}", self.output_pattern),
252            ])
253            .output();
254
255        match output {
256            Ok(o) if o.status.success() => {
257                let content = String::from_utf8_lossy(&o.stdout).to_string();
258                ctx.set_output(TaskId(format!("slurm-collect:{}", self.job_name)), content);
259                TaskOutcome::Success
260            }
261            Ok(o) => {
262                let stderr = String::from_utf8_lossy(&o.stderr);
263                TaskOutcome::Failed(format!("collect failed: {}", stderr.trim()))
264            }
265            Err(e) => TaskOutcome::Failed(format!("collect failed: {}", e)),
266        }
267    }
268
269    fn describe(&self) -> String {
270        format!("collect results for '{}'", self.job_name)
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[test]
279    fn test_sbatch_output_parsing() {
280        // sbatch outputs: "Submitted batch job 12345"
281        let output = "Submitted batch job 12345";
282        let job_id: u64 = output
283            .trim()
284            .rsplit_once(' ')
285            .and_then(|(_, id)| id.parse().ok())
286            .unwrap_or(0);
287        assert_eq!(job_id, 12345);
288    }
289
290    #[test]
291    fn test_sbatch_output_parsing_with_whitespace() {
292        let output = "Submitted batch job 99999\n";
293        let job_id: u64 = output
294            .trim()
295            .rsplit_once(' ')
296            .and_then(|(_, id)| id.parse().ok())
297            .unwrap_or(0);
298        assert_eq!(job_id, 99999);
299    }
300
301    #[test]
302    fn test_sbatch_output_parsing_unexpected() {
303        let output = "Error: something went wrong";
304        let job_id: u64 = output
305            .trim()
306            .rsplit_once(' ')
307            .and_then(|(_, id)| id.parse().ok())
308            .unwrap_or(0);
309        assert_eq!(job_id, 0); // fallback to 0 on parse failure
310    }
311
312    #[test]
313    fn test_slurm_job_states() {
314        // Verify the wait task would recognize these terminal states
315        let terminal_failure = [
316            "FAILED",
317            "CANCELLED",
318            "TIMEOUT",
319            "OUT_OF_MEMORY",
320            "NODE_FAIL",
321        ];
322        let running = ["PENDING", "RUNNING", "COMPLETING"];
323
324        for state in &terminal_failure {
325            assert!(
326                matches!(
327                    state.as_ref(),
328                    "FAILED" | "CANCELLED" | "TIMEOUT" | "OUT_OF_MEMORY" | "NODE_FAIL"
329                ),
330                "{} should be terminal failure",
331                state
332            );
333        }
334
335        for state in &running {
336            assert!(
337                !matches!(
338                    state.as_ref(),
339                    "COMPLETED"
340                        | "FAILED"
341                        | "CANCELLED"
342                        | "TIMEOUT"
343                        | "OUT_OF_MEMORY"
344                        | "NODE_FAIL"
345                ),
346                "{} should continue polling",
347                state
348            );
349        }
350    }
351
352    #[test]
353    fn test_describe_methods() {
354        let build = NixBuildJobEnvTask::new("rnaseq", ".");
355        assert!(build.describe().contains("rnaseq"));
356
357        let submit = SlurmSubmitTask {
358            job_name: "test".to_string(),
359            script: "test.sh".to_string(),
360            partition: Some("gpu".to_string()),
361            submit_host: "ctrl".to_string(),
362            submit_user: "root".to_string(),
363        };
364        assert!(submit.describe().contains("test"));
365
366        let wait = SlurmWaitTask::new("myjob", "ctrl", "root");
367        assert!(wait.describe().contains("myjob"));
368        assert_eq!(wait.poll_interval, Duration::from_secs(10));
369    }
370}