Skip to main content

consortium_skypilot/
tasks.rs

1//! DagTask implementations for SkyPilot orchestration.
2
3use std::process::Command;
4
5use consortium::dag::{DagContext, DagTask, TaskId, TaskOutcome};
6use consortium_nix::build;
7
8/// Build a skypilot task environment via nix.
9pub struct NixBuildSkyEnvTask {
10    pub env_name: String,
11    pub flake_attr: String,
12}
13
14impl NixBuildSkyEnvTask {
15    pub fn new(env_name: &str, flake_uri: &str) -> Self {
16        Self {
17            env_name: env_name.to_string(),
18            flake_attr: format!("{}#skyEnvs.{}", flake_uri, env_name),
19        }
20    }
21}
22
23impl DagTask for NixBuildSkyEnvTask {
24    fn execute(&self, ctx: &DagContext) -> TaskOutcome {
25        match build::build_flake_attr(&self.flake_attr, None) {
26            Ok(path) => {
27                ctx.set_output(TaskId(format!("build-sky-env:{}", self.env_name)), path);
28                TaskOutcome::Success
29            }
30            Err(e) => TaskOutcome::Failed(format!("build sky env: {}", e)),
31        }
32    }
33
34    fn describe(&self) -> String {
35        format!("build skypilot environment '{}'", self.env_name)
36    }
37}
38
39/// Launch a SkyPilot cluster.
40pub struct SkyLaunchTask {
41    pub cluster_name: String,
42    pub task_yaml: String,
43    pub cloud: Option<String>,
44    pub region: Option<String>,
45}
46
47impl DagTask for SkyLaunchTask {
48    fn execute(&self, ctx: &DagContext) -> TaskOutcome {
49        let mut cmd = Command::new("sky");
50        cmd.args(["launch", "-c", &self.cluster_name, &self.task_yaml, "-y"]);
51
52        if let Some(ref cloud) = self.cloud {
53            cmd.args(["--cloud", cloud]);
54        }
55        if let Some(ref region) = self.region {
56            cmd.args(["--region", region]);
57        }
58
59        let output = match cmd.output() {
60            Ok(o) => o,
61            Err(e) => return TaskOutcome::Failed(format!("sky launch failed: {}", e)),
62        };
63
64        if output.status.success() {
65            ctx.set_output(
66                TaskId(format!("sky-launch:{}", self.cluster_name)),
67                self.cluster_name.clone(),
68            );
69            TaskOutcome::Success
70        } else {
71            let stderr = String::from_utf8_lossy(&output.stderr);
72            TaskOutcome::Failed(format!("sky launch failed: {}", stderr.trim()))
73        }
74    }
75
76    fn describe(&self) -> String {
77        format!("launch sky cluster '{}'", self.cluster_name)
78    }
79}
80
81/// Execute a command on a SkyPilot cluster.
82pub struct SkyExecTask {
83    pub cluster_name: String,
84    pub command: String,
85}
86
87impl DagTask for SkyExecTask {
88    fn execute(&self, ctx: &DagContext) -> TaskOutcome {
89        let output = Command::new("sky")
90            .args(["exec", &self.cluster_name, "--", &self.command])
91            .output();
92
93        match output {
94            Ok(o) if o.status.success() => {
95                let stdout = String::from_utf8_lossy(&o.stdout).to_string();
96                ctx.set_output(TaskId(format!("sky-exec:{}", self.cluster_name)), stdout);
97                TaskOutcome::Success
98            }
99            Ok(o) => {
100                let stderr = String::from_utf8_lossy(&o.stderr);
101                TaskOutcome::Failed(format!("sky exec failed: {}", stderr.trim()))
102            }
103            Err(e) => TaskOutcome::Failed(format!("sky exec failed: {}", e)),
104        }
105    }
106
107    fn describe(&self) -> String {
108        format!("exec on sky cluster '{}'", self.cluster_name)
109    }
110}
111
112/// Tear down a SkyPilot cluster.
113pub struct SkyDownTask {
114    pub cluster_name: String,
115}
116
117impl DagTask for SkyDownTask {
118    fn execute(&self, _ctx: &DagContext) -> TaskOutcome {
119        let output = Command::new("sky")
120            .args(["down", &self.cluster_name, "-y"])
121            .output();
122
123        match output {
124            Ok(o) if o.status.success() => TaskOutcome::Success,
125            Ok(o) => {
126                let stderr = String::from_utf8_lossy(&o.stderr);
127                TaskOutcome::Failed(format!("sky down failed: {}", stderr.trim()))
128            }
129            Err(e) => TaskOutcome::Failed(format!("sky down failed: {}", e)),
130        }
131    }
132
133    fn describe(&self) -> String {
134        format!("tear down sky cluster '{}'", self.cluster_name)
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[test]
143    fn test_describe_methods() {
144        let build = NixBuildSkyEnvTask::new("train", ".");
145        assert!(build.describe().contains("train"));
146        assert!(build.flake_attr.contains("skyEnvs.train"));
147
148        let launch = SkyLaunchTask {
149            cluster_name: "my-cluster".to_string(),
150            task_yaml: "task.yaml".to_string(),
151            cloud: Some("gcp".to_string()),
152            region: Some("us-central1".to_string()),
153        };
154        assert!(launch.describe().contains("my-cluster"));
155
156        let exec = SkyExecTask {
157            cluster_name: "my-cluster".to_string(),
158            command: "python train.py".to_string(),
159        };
160        assert!(exec.describe().contains("my-cluster"));
161
162        let down = SkyDownTask {
163            cluster_name: "my-cluster".to_string(),
164        };
165        assert!(down.describe().contains("my-cluster"));
166    }
167
168    #[test]
169    fn test_flake_attr_generation() {
170        let build = NixBuildSkyEnvTask::new("train-gpt", "github:user/repo");
171        assert_eq!(build.flake_attr, "github:user/repo#skyEnvs.train-gpt");
172    }
173}