1use std::process::Command;
4use std::time::{Duration, Instant};
5
6use consortium::dag::{DagContext, DagTask, TaskId, TaskOutcome};
7use consortium_nix::build;
8
9pub struct NixBuildRayEnvTask {
11 pub env_name: String,
12 pub flake_attr: String,
13}
14
15impl NixBuildRayEnvTask {
16 pub fn new(env_name: &str, flake_uri: &str) -> Self {
17 Self {
18 env_name: env_name.to_string(),
19 flake_attr: format!("{}#rayEnvs.{}", flake_uri, env_name),
20 }
21 }
22}
23
24impl DagTask for NixBuildRayEnvTask {
25 fn execute(&self, ctx: &DagContext) -> TaskOutcome {
26 match build::build_flake_attr(&self.flake_attr, None) {
27 Ok(path) => {
28 ctx.set_output(TaskId(format!("build-ray-env:{}", self.env_name)), path);
29 TaskOutcome::Success
30 }
31 Err(e) => TaskOutcome::Failed(format!("build ray env: {}", e)),
32 }
33 }
34
35 fn describe(&self) -> String {
36 format!("build ray environment '{}'", self.env_name)
37 }
38}
39
40pub struct RaySubmitTask {
42 pub job_name: String,
43 pub entrypoint: String,
44 pub head_address: String,
45 pub dashboard_port: u16,
46 pub working_dir: Option<String>,
47}
48
49impl DagTask for RaySubmitTask {
50 fn execute(&self, ctx: &DagContext) -> TaskOutcome {
51 let address = format!("http://{}:{}", self.head_address, self.dashboard_port);
52
53 let mut cmd = Command::new("ray");
54 cmd.args(["job", "submit", "--address", &address]);
55
56 if let Some(ref dir) = self.working_dir {
58 cmd.args(["--working-dir", dir]);
59 } else if let Some(env_path) =
60 ctx.get_output::<String>(&TaskId(format!("build-ray-env:{}", self.job_name)))
61 {
62 cmd.args(["--working-dir", &env_path]);
63 }
64
65 cmd.arg("--").arg(&self.entrypoint);
66
67 let output = match cmd.output() {
68 Ok(o) => o,
69 Err(e) => return TaskOutcome::Failed(format!("ray submit failed: {}", e)),
70 };
71
72 if output.status.success() {
73 let stdout = String::from_utf8_lossy(&output.stdout);
74 let job_id = stdout
76 .lines()
77 .find(|l| l.contains("raysubmit_"))
78 .unwrap_or("unknown")
79 .trim()
80 .to_string();
81
82 ctx.set_output(TaskId(format!("ray-submit:{}", self.job_name)), job_id);
83 TaskOutcome::Success
84 } else {
85 let stderr = String::from_utf8_lossy(&output.stderr);
86 TaskOutcome::Failed(format!("ray submit failed: {}", stderr.trim()))
87 }
88 }
89
90 fn describe(&self) -> String {
91 format!("submit ray job '{}'", self.job_name)
92 }
93}
94
95pub struct RayWaitTask {
97 pub job_name: String,
98 pub head_address: String,
99 pub dashboard_port: u16,
100 pub poll_interval: Duration,
101 pub timeout: Option<Duration>,
102}
103
104impl DagTask for RayWaitTask {
105 fn execute(&self, ctx: &DagContext) -> TaskOutcome {
106 let job_id: String = match ctx.get_output(&TaskId(format!("ray-submit:{}", self.job_name)))
107 {
108 Some(id) => id,
109 None => return TaskOutcome::Failed("no ray job ID from submit".into()),
110 };
111
112 let address = format!("http://{}:{}", self.head_address, self.dashboard_port);
113 let start = Instant::now();
114
115 loop {
116 if let Some(timeout) = self.timeout {
117 if start.elapsed() > timeout {
118 return TaskOutcome::Failed(format!("ray job {} timed out", job_id));
119 }
120 }
121
122 let output = Command::new("ray")
123 .args(["job", "status", "--address", &address, &job_id])
124 .output();
125
126 match output {
127 Ok(o) if o.status.success() => {
128 let stdout = String::from_utf8_lossy(&o.stdout).to_string();
129 if stdout.contains("SUCCEEDED") {
130 ctx.set_output(
131 TaskId(format!("ray-wait:{}", self.job_name)),
132 job_id.clone(),
133 );
134 return TaskOutcome::Success;
135 } else if stdout.contains("FAILED") || stdout.contains("STOPPED") {
136 return TaskOutcome::Failed(format!(
137 "ray job {} ended: {}",
138 job_id,
139 stdout.trim()
140 ));
141 }
142 }
143 _ => {}
144 }
145
146 std::thread::sleep(self.poll_interval);
147 }
148 }
149
150 fn describe(&self) -> String {
151 format!("wait for ray job '{}'", self.job_name)
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[test]
160 fn test_ray_submit_job_id_parsing() {
161 let stdout = "Job submitted successfully\nraysubmit_abc123def\nDone.";
163 let job_id = stdout
164 .lines()
165 .find(|l| l.contains("raysubmit_"))
166 .unwrap_or("unknown")
167 .trim()
168 .to_string();
169 assert_eq!(job_id, "raysubmit_abc123def");
170 }
171
172 #[test]
173 fn test_ray_submit_no_job_id() {
174 let stdout = "Some error output";
175 let job_id = stdout
176 .lines()
177 .find(|l| l.contains("raysubmit_"))
178 .unwrap_or("unknown")
179 .trim()
180 .to_string();
181 assert_eq!(job_id, "unknown");
182 }
183
184 #[test]
185 fn test_ray_job_status_detection() {
186 assert!("Status: SUCCEEDED".contains("SUCCEEDED"));
187 assert!("Status: FAILED".contains("FAILED"));
188 assert!("Status: STOPPED".contains("STOPPED"));
189 assert!(!"Status: RUNNING".contains("SUCCEEDED"));
190 assert!(!"Status: RUNNING".contains("FAILED"));
191 }
192
193 #[test]
194 fn test_describe_methods() {
195 let build = NixBuildRayEnvTask::new("train", ".");
196 assert!(build.describe().contains("train"));
197
198 let submit = RaySubmitTask {
199 job_name: "train".to_string(),
200 entrypoint: "python train.py".to_string(),
201 head_address: "localhost".to_string(),
202 dashboard_port: 8265,
203 working_dir: None,
204 };
205 assert!(submit.describe().contains("train"));
206 }
207}