1pub mod error;
10pub mod tasks;
11
12pub use error::{RayError, Result};
13
14use std::time::Duration;
15
16use consortium::dag::{DagBuilder, DagContext, DagReport, ErrorPolicy};
17use consortium_nix::FleetConfig;
18
19pub fn submit_job(
21 config: &FleetConfig,
22 job_name: &str,
23 entrypoint: &str,
24 wait: bool,
25) -> Result<DagReport> {
26 let ray_config = config.ray_config.as_ref().ok_or(RayError::NoConfig)?;
27
28 let ctx = DagContext::new();
29 ctx.set_state("fleet_config", config.clone());
30
31 let mut dag = DagBuilder::new();
32
33 let build_id = format!("build-ray-env:{}", job_name);
35 dag.add_task(
36 &build_id,
37 tasks::NixBuildRayEnvTask::new(job_name, &config.flake_uri),
38 );
39
40 let submit_id = format!("ray-submit:{}", job_name);
42 dag.add_task(
43 &submit_id,
44 tasks::RaySubmitTask {
45 job_name: job_name.to_string(),
46 entrypoint: entrypoint.to_string(),
47 head_address: ray_config.head_address.clone(),
48 dashboard_port: ray_config.dashboard_port,
49 working_dir: None,
50 },
51 );
52 dag.add_dep(&submit_id, &build_id);
53
54 if wait {
56 let wait_id = format!("ray-wait:{}", job_name);
57 dag.add_task(
58 &wait_id,
59 tasks::RayWaitTask {
60 job_name: job_name.to_string(),
61 head_address: ray_config.head_address.clone(),
62 dashboard_port: ray_config.dashboard_port,
63 poll_interval: Duration::from_secs(10),
64 timeout: None,
65 },
66 );
67 dag.add_dep(&wait_id, &submit_id);
68 }
69
70 dag.error_policy(ErrorPolicy::FailFast);
71 dag.context(ctx);
72
73 let report = dag
74 .build()
75 .map_err(|e| RayError::Dag(e.to_string()))?
76 .run()
77 .map_err(|e| RayError::Dag(e.to_string()))?;
78
79 Ok(report)
80}