consortium_skypilot/
tasks.rs1use std::process::Command;
4
5use consortium::dag::{DagContext, DagTask, TaskId, TaskOutcome};
6use consortium_nix::build;
7
8pub struct NixBuildSkyEnvTask {
10 pub env_name: String,
11 pub flake_attr: String,
12}
13
14impl NixBuildSkyEnvTask {
15 pub fn new(env_name: &str, flake_uri: &str) -> Self {
16 Self {
17 env_name: env_name.to_string(),
18 flake_attr: format!("{}#skyEnvs.{}", flake_uri, env_name),
19 }
20 }
21}
22
23impl DagTask for NixBuildSkyEnvTask {
24 fn execute(&self, ctx: &DagContext) -> TaskOutcome {
25 match build::build_flake_attr(&self.flake_attr, None) {
26 Ok(path) => {
27 ctx.set_output(TaskId(format!("build-sky-env:{}", self.env_name)), path);
28 TaskOutcome::Success
29 }
30 Err(e) => TaskOutcome::Failed(format!("build sky env: {}", e)),
31 }
32 }
33
34 fn describe(&self) -> String {
35 format!("build skypilot environment '{}'", self.env_name)
36 }
37}
38
39pub struct SkyLaunchTask {
41 pub cluster_name: String,
42 pub task_yaml: String,
43 pub cloud: Option<String>,
44 pub region: Option<String>,
45}
46
47impl DagTask for SkyLaunchTask {
48 fn execute(&self, ctx: &DagContext) -> TaskOutcome {
49 let mut cmd = Command::new("sky");
50 cmd.args(["launch", "-c", &self.cluster_name, &self.task_yaml, "-y"]);
51
52 if let Some(ref cloud) = self.cloud {
53 cmd.args(["--cloud", cloud]);
54 }
55 if let Some(ref region) = self.region {
56 cmd.args(["--region", region]);
57 }
58
59 let output = match cmd.output() {
60 Ok(o) => o,
61 Err(e) => return TaskOutcome::Failed(format!("sky launch failed: {}", e)),
62 };
63
64 if output.status.success() {
65 ctx.set_output(
66 TaskId(format!("sky-launch:{}", self.cluster_name)),
67 self.cluster_name.clone(),
68 );
69 TaskOutcome::Success
70 } else {
71 let stderr = String::from_utf8_lossy(&output.stderr);
72 TaskOutcome::Failed(format!("sky launch failed: {}", stderr.trim()))
73 }
74 }
75
76 fn describe(&self) -> String {
77 format!("launch sky cluster '{}'", self.cluster_name)
78 }
79}
80
81pub struct SkyExecTask {
83 pub cluster_name: String,
84 pub command: String,
85}
86
87impl DagTask for SkyExecTask {
88 fn execute(&self, ctx: &DagContext) -> TaskOutcome {
89 let output = Command::new("sky")
90 .args(["exec", &self.cluster_name, "--", &self.command])
91 .output();
92
93 match output {
94 Ok(o) if o.status.success() => {
95 let stdout = String::from_utf8_lossy(&o.stdout).to_string();
96 ctx.set_output(TaskId(format!("sky-exec:{}", self.cluster_name)), stdout);
97 TaskOutcome::Success
98 }
99 Ok(o) => {
100 let stderr = String::from_utf8_lossy(&o.stderr);
101 TaskOutcome::Failed(format!("sky exec failed: {}", stderr.trim()))
102 }
103 Err(e) => TaskOutcome::Failed(format!("sky exec failed: {}", e)),
104 }
105 }
106
107 fn describe(&self) -> String {
108 format!("exec on sky cluster '{}'", self.cluster_name)
109 }
110}
111
112pub struct SkyDownTask {
114 pub cluster_name: String,
115}
116
117impl DagTask for SkyDownTask {
118 fn execute(&self, _ctx: &DagContext) -> TaskOutcome {
119 let output = Command::new("sky")
120 .args(["down", &self.cluster_name, "-y"])
121 .output();
122
123 match output {
124 Ok(o) if o.status.success() => TaskOutcome::Success,
125 Ok(o) => {
126 let stderr = String::from_utf8_lossy(&o.stderr);
127 TaskOutcome::Failed(format!("sky down failed: {}", stderr.trim()))
128 }
129 Err(e) => TaskOutcome::Failed(format!("sky down failed: {}", e)),
130 }
131 }
132
133 fn describe(&self) -> String {
134 format!("tear down sky cluster '{}'", self.cluster_name)
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn test_describe_methods() {
144 let build = NixBuildSkyEnvTask::new("train", ".");
145 assert!(build.describe().contains("train"));
146 assert!(build.flake_attr.contains("skyEnvs.train"));
147
148 let launch = SkyLaunchTask {
149 cluster_name: "my-cluster".to_string(),
150 task_yaml: "task.yaml".to_string(),
151 cloud: Some("gcp".to_string()),
152 region: Some("us-central1".to_string()),
153 };
154 assert!(launch.describe().contains("my-cluster"));
155
156 let exec = SkyExecTask {
157 cluster_name: "my-cluster".to_string(),
158 command: "python train.py".to_string(),
159 };
160 assert!(exec.describe().contains("my-cluster"));
161
162 let down = SkyDownTask {
163 cluster_name: "my-cluster".to_string(),
164 };
165 assert!(down.describe().contains("my-cluster"));
166 }
167
168 #[test]
169 fn test_flake_attr_generation() {
170 let build = NixBuildSkyEnvTask::new("train-gpt", "github:user/repo");
171 assert_eq!(build.flake_attr, "github:user/repo#skyEnvs.train-gpt");
172 }
173}