1use std::process::Command;
6use std::time::{Duration, Instant};
7
8use consortium::dag::{DagContext, DagTask, TaskId, TaskOutcome};
9use consortium_nix::{build, copy};
10
11pub struct NixBuildJobEnvTask {
13 pub job_name: String,
14 pub flake_attr: String,
15}
16
17impl NixBuildJobEnvTask {
18 pub fn new(job_name: &str, flake_uri: &str) -> Self {
19 Self {
20 job_name: job_name.to_string(),
21 flake_attr: format!("{}#slurmEnvs.{}", flake_uri, job_name),
22 }
23 }
24}
25
26impl DagTask for NixBuildJobEnvTask {
27 fn execute(&self, ctx: &DagContext) -> TaskOutcome {
28 match build::build_flake_attr(&self.flake_attr, None) {
29 Ok(path) => {
30 ctx.set_output(TaskId(format!("build-job-env:{}", self.job_name)), path);
31 TaskOutcome::Success
32 }
33 Err(e) => TaskOutcome::Failed(format!("build job env: {}", e)),
34 }
35 }
36
37 fn describe(&self) -> String {
38 format!("build slurm job env '{}'", self.job_name)
39 }
40}
41
42pub struct NixCopyToSubmitTask {
44 pub job_name: String,
45 pub submit_host: String,
46 pub submit_user: String,
47}
48
49impl DagTask for NixCopyToSubmitTask {
50 fn execute(&self, ctx: &DagContext) -> TaskOutcome {
51 let store_path: String =
52 match ctx.get_output(&TaskId(format!("build-job-env:{}", self.job_name))) {
53 Some(p) => p,
54 None => return TaskOutcome::Failed("no job env build output".into()),
55 };
56
57 let store_uri = format!("ssh-ng://{}@{}", self.submit_user, self.submit_host);
58 match copy::copy_closure(&store_path, &store_uri) {
59 Ok(()) => {
60 ctx.set_output(
61 TaskId(format!("copy-job-env:{}", self.job_name)),
62 store_path,
63 );
64 TaskOutcome::Success
65 }
66 Err(e) => TaskOutcome::Failed(format!("copy job env: {}", e)),
67 }
68 }
69
70 fn describe(&self) -> String {
71 format!("copy job env '{}' to {}", self.job_name, self.submit_host)
72 }
73}
74
75pub struct SlurmSubmitTask {
77 pub job_name: String,
78 pub script: String,
79 pub partition: Option<String>,
80 pub submit_host: String,
81 pub submit_user: String,
82}
83
84impl DagTask for SlurmSubmitTask {
85 fn execute(&self, ctx: &DagContext) -> TaskOutcome {
86 let mut sbatch_args = vec!["sbatch".to_string()];
88 sbatch_args.push(format!("--job-name={}", self.job_name));
89
90 if let Some(ref partition) = self.partition {
91 sbatch_args.push(format!("--partition={}", partition));
92 }
93
94 if let Some(env_path) =
96 ctx.get_output::<String>(&TaskId(format!("copy-job-env:{}", self.job_name)))
97 {
98 sbatch_args.push(format!("--export=ALL,PATH={}/bin:$PATH", env_path));
99 }
100
101 sbatch_args.push(self.script.clone());
102
103 let ssh_cmd = sbatch_args.join(" ");
104 let output = Command::new("ssh")
105 .args([
106 "-oStrictHostKeyChecking=no",
107 "-oPasswordAuthentication=no",
108 "-l",
109 &self.submit_user,
110 &self.submit_host,
111 &ssh_cmd,
112 ])
113 .output();
114
115 match output {
116 Ok(o) if o.status.success() => {
117 let stdout = String::from_utf8_lossy(&o.stdout);
118 let job_id: u64 = stdout
120 .trim()
121 .rsplit_once(' ')
122 .and_then(|(_, id)| id.parse().ok())
123 .unwrap_or(0);
124
125 ctx.set_output(TaskId(format!("slurm-submit:{}", self.job_name)), job_id);
126 TaskOutcome::Success
127 }
128 Ok(o) => {
129 let stderr = String::from_utf8_lossy(&o.stderr);
130 TaskOutcome::Failed(format!("sbatch failed: {}", stderr.trim()))
131 }
132 Err(e) => TaskOutcome::Failed(format!("sbatch exec failed: {}", e)),
133 }
134 }
135
136 fn describe(&self) -> String {
137 format!("submit slurm job '{}'", self.job_name)
138 }
139}
140
141pub struct SlurmWaitTask {
143 pub job_name: String,
144 pub submit_host: String,
145 pub submit_user: String,
146 pub poll_interval: Duration,
147 pub timeout: Option<Duration>,
148}
149
150impl SlurmWaitTask {
151 pub fn new(job_name: &str, submit_host: &str, submit_user: &str) -> Self {
152 Self {
153 job_name: job_name.to_string(),
154 submit_host: submit_host.to_string(),
155 submit_user: submit_user.to_string(),
156 poll_interval: Duration::from_secs(10),
157 timeout: None,
158 }
159 }
160}
161
162impl DagTask for SlurmWaitTask {
163 fn execute(&self, ctx: &DagContext) -> TaskOutcome {
164 let job_id: u64 = match ctx.get_output(&TaskId(format!("slurm-submit:{}", self.job_name))) {
165 Some(id) => id,
166 None => return TaskOutcome::Failed("no job ID from submit".into()),
167 };
168
169 let start = Instant::now();
170
171 loop {
172 if let Some(timeout) = self.timeout {
174 if start.elapsed() > timeout {
175 return TaskOutcome::Failed(format!(
176 "job {} timed out after {}s",
177 job_id,
178 start.elapsed().as_secs()
179 ));
180 }
181 }
182
183 let sacct_cmd = format!(
185 "sacct -j {} --format=State --noheader --parsable2 | head -1",
186 job_id
187 );
188 let output = Command::new("ssh")
189 .args([
190 "-oStrictHostKeyChecking=no",
191 "-oPasswordAuthentication=no",
192 "-l",
193 &self.submit_user,
194 &self.submit_host,
195 &sacct_cmd,
196 ])
197 .output();
198
199 match output {
200 Ok(o) if o.status.success() => {
201 let state = String::from_utf8_lossy(&o.stdout).trim().to_string();
202 match state.as_str() {
203 "COMPLETED" => {
204 ctx.set_output(TaskId(format!("slurm-wait:{}", self.job_name)), job_id);
205 return TaskOutcome::Success;
206 }
207 "FAILED" | "CANCELLED" | "TIMEOUT" | "OUT_OF_MEMORY" | "NODE_FAIL" => {
208 return TaskOutcome::Failed(format!(
209 "job {} ended with state: {}",
210 job_id, state
211 ));
212 }
213 _ => {}
215 }
216 }
217 _ => {} }
219
220 std::thread::sleep(self.poll_interval);
221 }
222 }
223
224 fn describe(&self) -> String {
225 format!("wait for slurm job '{}'", self.job_name)
226 }
227}
228
229pub struct SlurmCollectTask {
231 pub job_name: String,
232 pub output_pattern: String,
233 pub submit_host: String,
234 pub submit_user: String,
235}
236
237impl DagTask for SlurmCollectTask {
238 fn execute(&self, ctx: &DagContext) -> TaskOutcome {
239 let _job_id: u64 = match ctx.get_output(&TaskId(format!("slurm-wait:{}", self.job_name))) {
240 Some(id) => id,
241 None => return TaskOutcome::Failed("job not completed".into()),
242 };
243
244 let output = Command::new("ssh")
246 .args([
247 "-oStrictHostKeyChecking=no",
248 "-l",
249 &self.submit_user,
250 &self.submit_host,
251 &format!("cat {}", self.output_pattern),
252 ])
253 .output();
254
255 match output {
256 Ok(o) if o.status.success() => {
257 let content = String::from_utf8_lossy(&o.stdout).to_string();
258 ctx.set_output(TaskId(format!("slurm-collect:{}", self.job_name)), content);
259 TaskOutcome::Success
260 }
261 Ok(o) => {
262 let stderr = String::from_utf8_lossy(&o.stderr);
263 TaskOutcome::Failed(format!("collect failed: {}", stderr.trim()))
264 }
265 Err(e) => TaskOutcome::Failed(format!("collect failed: {}", e)),
266 }
267 }
268
269 fn describe(&self) -> String {
270 format!("collect results for '{}'", self.job_name)
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn test_sbatch_output_parsing() {
280 let output = "Submitted batch job 12345";
282 let job_id: u64 = output
283 .trim()
284 .rsplit_once(' ')
285 .and_then(|(_, id)| id.parse().ok())
286 .unwrap_or(0);
287 assert_eq!(job_id, 12345);
288 }
289
290 #[test]
291 fn test_sbatch_output_parsing_with_whitespace() {
292 let output = "Submitted batch job 99999\n";
293 let job_id: u64 = output
294 .trim()
295 .rsplit_once(' ')
296 .and_then(|(_, id)| id.parse().ok())
297 .unwrap_or(0);
298 assert_eq!(job_id, 99999);
299 }
300
301 #[test]
302 fn test_sbatch_output_parsing_unexpected() {
303 let output = "Error: something went wrong";
304 let job_id: u64 = output
305 .trim()
306 .rsplit_once(' ')
307 .and_then(|(_, id)| id.parse().ok())
308 .unwrap_or(0);
309 assert_eq!(job_id, 0); }
311
312 #[test]
313 fn test_slurm_job_states() {
314 let terminal_failure = [
316 "FAILED",
317 "CANCELLED",
318 "TIMEOUT",
319 "OUT_OF_MEMORY",
320 "NODE_FAIL",
321 ];
322 let running = ["PENDING", "RUNNING", "COMPLETING"];
323
324 for state in &terminal_failure {
325 assert!(
326 matches!(
327 state.as_ref(),
328 "FAILED" | "CANCELLED" | "TIMEOUT" | "OUT_OF_MEMORY" | "NODE_FAIL"
329 ),
330 "{} should be terminal failure",
331 state
332 );
333 }
334
335 for state in &running {
336 assert!(
337 !matches!(
338 state.as_ref(),
339 "COMPLETED"
340 | "FAILED"
341 | "CANCELLED"
342 | "TIMEOUT"
343 | "OUT_OF_MEMORY"
344 | "NODE_FAIL"
345 ),
346 "{} should continue polling",
347 state
348 );
349 }
350 }
351
352 #[test]
353 fn test_describe_methods() {
354 let build = NixBuildJobEnvTask::new("rnaseq", ".");
355 assert!(build.describe().contains("rnaseq"));
356
357 let submit = SlurmSubmitTask {
358 job_name: "test".to_string(),
359 script: "test.sh".to_string(),
360 partition: Some("gpu".to_string()),
361 submit_host: "ctrl".to_string(),
362 submit_user: "root".to_string(),
363 };
364 assert!(submit.describe().contains("test"));
365
366 let wait = SlurmWaitTask::new("myjob", "ctrl", "root");
367 assert!(wait.describe().contains("myjob"));
368 assert_eq!(wait.poll_interval, Duration::from_secs(10));
369 }
370}