1use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11#[serde(rename_all = "kebab-case")]
12pub enum ProfileType {
13 Nixos,
14 NixDarwin,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(rename_all = "camelCase")]
20pub struct DeploymentNode {
21 pub name: String,
23 pub target_host: String,
25 pub target_user: String,
27 pub target_port: Option<u16>,
29 pub system: String,
31 pub profile_type: ProfileType,
33 pub build_on_target: bool,
35 pub tags: Vec<String>,
37 #[serde(default)]
39 pub drv_path: Option<String>,
40 #[serde(default)]
42 pub toplevel: Option<String>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47#[serde(rename_all = "camelCase")]
48pub struct Builder {
49 pub host: String,
51 pub user: String,
53 pub max_jobs: u32,
55 pub speed_factor: u32,
57 pub systems: Vec<String>,
59 pub features: Vec<String>,
61 pub ssh_key: Option<String>,
63 #[serde(default = "default_protocol")]
65 pub protocol: String,
66}
67
68fn default_protocol() -> String {
69 "ssh-ng".to_string()
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct FleetConfig {
75 pub nodes: HashMap<String, DeploymentNode>,
77 #[serde(default)]
79 pub builders: HashMap<String, Builder>,
80 #[serde(default = "default_flake_uri")]
82 pub flake_uri: String,
83 #[serde(default, rename = "ansibleConfig")]
85 pub ansible_config: Option<AnsibleFleetConfig>,
86 #[serde(default, rename = "slurmConfig")]
88 pub slurm_config: Option<SlurmFleetConfig>,
89 #[serde(default, rename = "rayConfig")]
91 pub ray_config: Option<RayFleetConfig>,
92 #[serde(default, rename = "skypilotConfig")]
94 pub skypilot_config: Option<SkypilotFleetConfig>,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99#[serde(rename_all = "camelCase")]
100pub struct AnsibleFleetConfig {
101 pub control_node: String,
103 #[serde(default)]
105 pub ansible_version: Option<String>,
106 #[serde(default)]
108 pub collections: Vec<String>,
109 #[serde(default)]
111 pub playbook_dir: Option<String>,
112 #[serde(default)]
114 pub host_groups: HashMap<String, Vec<String>>,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119#[serde(rename_all = "camelCase")]
120pub struct SlurmFleetConfig {
121 pub submit_node: String,
123 pub submit_user: String,
125 pub control_node: String,
127 #[serde(default)]
129 pub partitions: HashMap<String, SlurmPartition>,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
134#[serde(rename_all = "camelCase")]
135pub struct SlurmPartition {
136 pub nodes: Vec<String>,
138 #[serde(default)]
140 pub default: bool,
141 #[serde(default)]
143 pub max_time: Option<String>,
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
148#[serde(rename_all = "camelCase")]
149pub struct RayFleetConfig {
150 pub head_address: String,
152 #[serde(default = "default_ray_port")]
154 pub dashboard_port: u16,
155 #[serde(default)]
157 pub kubernetes: bool,
158 #[serde(default)]
160 pub worker_groups: HashMap<String, RayWorkerGroup>,
161}
162
163fn default_ray_port() -> u16 {
164 8265
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
169#[serde(rename_all = "camelCase")]
170pub struct RayWorkerGroup {
171 pub nodes: Vec<String>,
173 #[serde(default)]
175 pub cpus: Option<u32>,
176 #[serde(default)]
178 pub gpus: Option<u32>,
179 #[serde(default)]
181 pub memory_mb: Option<u32>,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
186#[serde(rename_all = "camelCase")]
187pub struct SkypilotFleetConfig {
188 pub cloud: String,
190 #[serde(default)]
192 pub region: Option<String>,
193 #[serde(default)]
195 pub instance_type: Option<String>,
196}
197
198fn default_flake_uri() -> String {
199 ".".to_string()
200}
201
202impl FleetConfig {
203 pub fn from_file(path: &Path) -> Result<Self, ConfigError> {
205 let content =
206 std::fs::read_to_string(path).map_err(|e| ConfigError::Io(path.to_path_buf(), e))?;
207 serde_json::from_str(&content).map_err(|e| ConfigError::Parse(path.to_path_buf(), e))
208 }
209
210 pub fn from_json(json: &str) -> Result<Self, ConfigError> {
212 serde_json::from_str(json).map_err(|e| ConfigError::Parse(PathBuf::from("<string>"), e))
213 }
214
215 pub fn nodes_by_tags(&self, tags: &[String]) -> Vec<&DeploymentNode> {
217 self.nodes
218 .values()
219 .filter(|n| n.tags.iter().any(|t| tags.contains(t)))
220 .collect()
221 }
222
223 pub fn nodes_by_names(&self, names: &[String]) -> Vec<&DeploymentNode> {
225 self.nodes
226 .values()
227 .filter(|n| names.contains(&n.name))
228 .collect()
229 }
230
231 pub fn node_names(&self) -> Vec<String> {
233 let mut names: Vec<_> = self.nodes.keys().cloned().collect();
234 names.sort();
235 names
236 }
237
238 pub fn builder_names(&self) -> Vec<String> {
240 let mut names: Vec<_> = self.builders.keys().cloned().collect();
241 names.sort();
242 names
243 }
244
245 pub fn machines_file(&self) -> String {
247 self.builders
248 .values()
249 .map(|b| {
250 let key = b.ssh_key.as_deref().unwrap_or("-");
251 let features = b.features.join(",");
252 let systems = b.systems.join(",");
253 format!(
254 "{}://{}@{} {} {} {} {} {}",
255 b.protocol, b.user, b.host, systems, key, b.max_jobs, b.speed_factor, features
256 )
257 })
258 .collect::<Vec<_>>()
259 .join("\n")
260 }
261}
262
263#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
265#[serde(rename_all = "kebab-case")]
266pub enum DeployAction {
267 Switch,
269 Boot,
271 Test,
273 DryActivate,
275 Build,
277}
278
279impl std::fmt::Display for DeployAction {
280 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281 match self {
282 DeployAction::Switch => write!(f, "switch"),
283 DeployAction::Boot => write!(f, "boot"),
284 DeployAction::Test => write!(f, "test"),
285 DeployAction::DryActivate => write!(f, "dry-activate"),
286 DeployAction::Build => write!(f, "build"),
287 }
288 }
289}
290
291impl std::str::FromStr for DeployAction {
292 type Err = ConfigError;
293
294 fn from_str(s: &str) -> Result<Self, Self::Err> {
295 match s {
296 "switch" => Ok(DeployAction::Switch),
297 "boot" => Ok(DeployAction::Boot),
298 "test" => Ok(DeployAction::Test),
299 "dry-activate" => Ok(DeployAction::DryActivate),
300 "build" => Ok(DeployAction::Build),
301 _ => Err(ConfigError::InvalidAction(s.to_string())),
302 }
303 }
304}
305
306#[derive(Debug, Clone)]
308pub struct DeploymentTarget {
309 pub node: DeploymentNode,
311 pub toplevel_path: String,
313 pub current_system: Option<String>,
315 pub needs_build: bool,
317 pub needs_copy: bool,
319}
320
321#[derive(Debug, Clone)]
323pub struct DeploymentPlan {
324 pub targets: Vec<DeploymentTarget>,
326 pub action: DeployAction,
328 pub max_parallel: usize,
330}
331
332impl DeploymentPlan {
333 pub fn new(action: DeployAction, max_parallel: usize) -> Self {
335 Self {
336 targets: Vec::new(),
337 action,
338 max_parallel,
339 }
340 }
341
342 pub fn build_count(&self) -> usize {
344 self.targets.iter().filter(|t| t.needs_build).count()
345 }
346
347 pub fn copy_count(&self) -> usize {
349 self.targets.iter().filter(|t| t.needs_copy).count()
350 }
351
352 pub fn target_count(&self) -> usize {
354 self.targets.len()
355 }
356}
357
358#[derive(Debug, thiserror::Error)]
360pub enum ConfigError {
361 #[error("failed to read config file {0}: {1}")]
362 Io(PathBuf, std::io::Error),
363 #[error("failed to parse config file {0}: {1}")]
364 Parse(PathBuf, serde_json::Error),
365 #[error("invalid deploy action: {0}")]
366 InvalidAction(String),
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 fn sample_config_json() -> &'static str {
374 r#"{
375 "nodes": {
376 "hp01": {
377 "name": "hp01",
378 "targetHost": "192.168.1.121",
379 "targetUser": "root",
380 "targetPort": null,
381 "system": "x86_64-linux",
382 "profileType": "nixos",
383 "buildOnTarget": false,
384 "tags": ["build-host", "hpe"]
385 },
386 "mm01": {
387 "name": "mm01",
388 "targetHost": "192.168.1.111",
389 "targetUser": "root",
390 "targetPort": null,
391 "system": "x86_64-linux",
392 "profileType": "nixos",
393 "buildOnTarget": false,
394 "tags": ["ray"]
395 }
396 },
397 "builders": {
398 "hp01": {
399 "host": "192.168.1.121",
400 "user": "root",
401 "maxJobs": 16,
402 "speedFactor": 2,
403 "systems": ["x86_64-linux"],
404 "features": ["big-parallel", "kvm"],
405 "sshKey": null,
406 "protocol": "ssh-ng"
407 }
408 },
409 "flakeUri": "."
410 }"#
411 }
412
413 #[test]
414 fn test_parse_fleet_config() {
415 let config = FleetConfig::from_json(sample_config_json()).unwrap();
416 assert_eq!(config.nodes.len(), 2);
417 assert_eq!(config.builders.len(), 1);
418 assert_eq!(config.flake_uri, ".");
419 }
420
421 #[test]
422 fn test_node_fields() {
423 let config = FleetConfig::from_json(sample_config_json()).unwrap();
424 let hp01 = &config.nodes["hp01"];
425 assert_eq!(hp01.target_host, "192.168.1.121");
426 assert_eq!(hp01.target_user, "root");
427 assert_eq!(hp01.profile_type, ProfileType::Nixos);
428 assert!(!hp01.build_on_target);
429 assert_eq!(hp01.tags, vec!["build-host", "hpe"]);
430 }
431
432 #[test]
433 fn test_nodes_by_tags() {
434 let config = FleetConfig::from_json(sample_config_json()).unwrap();
435 let build_hosts = config.nodes_by_tags(&["build-host".to_string()]);
436 assert_eq!(build_hosts.len(), 1);
437 assert_eq!(build_hosts[0].name, "hp01");
438 }
439
440 #[test]
441 fn test_node_names_sorted() {
442 let config = FleetConfig::from_json(sample_config_json()).unwrap();
443 let names = config.node_names();
444 assert_eq!(names, vec!["hp01", "mm01"]);
445 }
446
447 #[test]
448 fn test_machines_file() {
449 let config = FleetConfig::from_json(sample_config_json()).unwrap();
450 let machines = config.machines_file();
451 assert!(machines.contains("ssh-ng://root@192.168.1.121"));
452 assert!(machines.contains("x86_64-linux"));
453 assert!(machines.contains("16"));
454 assert!(machines.contains("big-parallel,kvm"));
455 }
456
457 #[test]
458 fn test_deploy_action_display() {
459 assert_eq!(DeployAction::Switch.to_string(), "switch");
460 assert_eq!(DeployAction::DryActivate.to_string(), "dry-activate");
461 }
462
463 #[test]
464 fn test_deploy_action_parse() {
465 assert_eq!(
466 "switch".parse::<DeployAction>().unwrap(),
467 DeployAction::Switch
468 );
469 assert_eq!(
470 "dry-activate".parse::<DeployAction>().unwrap(),
471 DeployAction::DryActivate
472 );
473 assert!("invalid".parse::<DeployAction>().is_err());
474 }
475
476 #[test]
477 fn test_deployment_plan() {
478 let mut plan = DeploymentPlan::new(DeployAction::Switch, 4);
479 plan.targets.push(DeploymentTarget {
480 node: DeploymentNode {
481 name: "hp01".to_string(),
482 target_host: "192.168.1.121".to_string(),
483 target_user: "root".to_string(),
484 target_port: None,
485 system: "x86_64-linux".to_string(),
486 profile_type: ProfileType::Nixos,
487 build_on_target: false,
488 tags: vec![],
489 drv_path: None,
490 toplevel: None,
491 },
492 toplevel_path: "/nix/store/abc-nixos-system".to_string(),
493 current_system: Some("/nix/store/old-nixos-system".to_string()),
494 needs_build: true,
495 needs_copy: true,
496 });
497 assert_eq!(plan.build_count(), 1);
498 assert_eq!(plan.copy_count(), 1);
499 assert_eq!(plan.target_count(), 1);
500 }
501}