[skip ci]

change config loading mechanism to allow loading multiple configs
make it so predefined config can be loaded independently
make it so default network/model are options and must be resolved by the time they are used
This commit is contained in:
Christien Rioux 2025-01-06 18:29:03 -05:00
parent d93a5602d3
commit 9e7c41b635
7 changed files with 262 additions and 327 deletions

View File

@ -6,6 +6,7 @@ use parking_lot::*;
use std::path::PathBuf; use std::path::PathBuf;
use stop_token::StopSource; use stop_token::StopSource;
use veilid_tools::*; use veilid_tools::*;
use virtual_network::*;
const VERSION: &str = env!("CARGO_PKG_VERSION"); const VERSION: &str = env!("CARGO_PKG_VERSION");
@ -40,9 +41,12 @@ struct CmdlineArgs {
/// Turn off WS listener /// Turn off WS listener
#[arg(long)] #[arg(long)]
no_ws: bool, no_ws: bool,
/// Specify a configuration file to use /// Specify an initial list of configuration files to use
#[arg(short = 'c', long, value_name = "FILE")] #[arg(short = 'c', long, value_name = "FILE")]
config_file: Option<PathBuf>, config_file: Vec<PathBuf>,
/// Specify to load configuration without a predefined config first
#[arg(long)]
no_predefined_config: bool,
/// Instead of running the virtual router, print the configuration it would use to the console /// Instead of running the virtual router, print the configuration it would use to the console
#[arg(long)] #[arg(long)]
dump_config: bool, dump_config: bool,
@ -64,17 +68,18 @@ fn main() -> Result<(), String> {
let args = CmdlineArgs::parse(); let args = CmdlineArgs::parse();
let config = config::Config::new(args.config_file) let initial_config = config::Config::new(&args.config_file, args.no_predefined_config)
.map_err(|e| format!("Error loading config: {}", e))?; .map_err(|e| format!("Error loading config: {}", e))?;
if args.dump_config { if args.dump_config {
let cfg_yaml = serde_yaml::to_string(&config) let cfg_yaml = serde_yaml::to_string(&initial_config)
.map_err(|e| format!("Error serializing config: {}", e))?; .map_err(|e| format!("Error serializing config: {}", e))?;
println!("{}", cfg_yaml); println!("{}", cfg_yaml);
return Ok(()); return Ok(());
} }
let router_server = virtual_network::RouterServer::new(config); let router_server = virtual_network::RouterServer::new(initial_config);
let _ss_tcp = if !args.no_tcp { let _ss_tcp = if !args.no_tcp {
Some( Some(
router_server router_server

View File

@ -253,8 +253,6 @@ pub use timeout_or::*;
pub use timestamp::*; pub use timestamp::*;
#[doc(inline)] #[doc(inline)]
pub use tools::*; pub use tools::*;
#[cfg(feature = "virtual-network")]
pub use virtual_network::*;
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))] #[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
pub use wasm::*; pub use wasm::*;

View File

@ -2,7 +2,7 @@ use super::*;
use serde::*; use serde::*;
use std::path::Path; use std::path::Path;
use validator::{Validate, ValidateArgs, ValidationError, ValidationErrors}; use validator::{Validate, ValidationError, ValidationErrors};
const PREDEFINED_CONFIG: &str = include_str!("predefined_config.yml"); const PREDEFINED_CONFIG: &str = include_str!("predefined_config.yml");
const DEFAULT_CONFIG: &str = include_str!("default_config.yml"); const DEFAULT_CONFIG: &str = include_str!("default_config.yml");
@ -13,15 +13,13 @@ pub enum ConfigError {
ParseError(::config::ConfigError), ParseError(::config::ConfigError),
#[error("validate error")] #[error("validate error")]
ValidateError(validator::ValidationErrors), ValidateError(validator::ValidationErrors),
#[error("no configuration files specified")]
NoConfigFiles,
} }
#[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[validate(context = "ValidateContext<'v_a>")]
pub struct Profile { pub struct Profile {
#[validate( #[validate(length(min = 1), nested)]
length(min = 1),
custom(function = "validate_instances_exist", use_context)
)]
pub instances: Vec<Instance>, pub instances: Vec<Instance>,
} }
@ -32,14 +30,20 @@ pub enum Instance {
Template { template: WeightedList<String> }, Template { template: WeightedList<String> },
} }
impl Validate for Instance {
fn validate(&self) -> Result<(), ValidationErrors> {
match self {
Instance::Machine { machine } => machine.validate()?,
Instance::Template { template } => template.validate()?,
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[validate(
context = "ValidateContext<'v_a>",
schema(function = "validate_machine", use_context)
)]
pub struct Machine { pub struct Machine {
#[serde(flatten)] #[serde(flatten)]
#[validate(custom(function = "validate_machine_location", use_context))] #[validate(nested)]
pub location: MachineLocation, pub location: MachineLocation,
#[serde(default)] #[serde(default)]
pub disable_capabilities: Vec<String>, pub disable_capabilities: Vec<String>,
@ -47,14 +51,6 @@ pub struct Machine {
pub bootstrap: bool, pub bootstrap: bool,
} }
fn validate_machine(machine: &Machine, _context: &ValidateContext) -> Result<(), ValidationError> {
if machine.disable_capabilities.contains(&("".to_string())) {
return Err(ValidationError::new("badcap")
.with_message("machine has empty disabled capability".into()));
}
Ok(())
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
pub enum MachineLocation { pub enum MachineLocation {
@ -70,53 +66,50 @@ pub enum MachineLocation {
}, },
} }
fn validate_machine_location( impl Validate for MachineLocation {
value: &MachineLocation, fn validate(&self) -> Result<(), ValidationErrors> {
context: &ValidateContext, let mut errors = ValidationErrors::new();
) -> Result<(), ValidationError> { match self {
match value { MachineLocation::Network {
MachineLocation::Network { network: _,
network, address4,
address4, address6,
address6, } => {
} => { if address4.is_none() && address6.is_none() {
if address4.is_none() && address6.is_none() { errors.add(
return Err(ValidationError::new("badaddr") "MachineLocation",
.with_message("machine must have at least one address".into())); ValidationError::new("badaddr")
.with_message("machine must have at least one address".into()),
);
}
} }
validate_network_exists(network, context)?; MachineLocation::Blueprint { blueprint: _ } => {}
} }
MachineLocation::Blueprint { blueprint } => {
validate_blueprint_exists(blueprint, context)?; if !errors.is_empty() {
Err(errors)
} else {
Ok(())
} }
} }
Ok(())
} }
#[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[validate(
context = "ValidateContext<'v_a>",
schema(function = "validate_template", use_context)
)]
pub struct Template { pub struct Template {
#[serde(flatten)] #[serde(flatten)]
#[validate(custom(function = "validate_template_location", use_context))] #[validate(nested)]
pub location: TemplateLocation, pub location: TemplateLocation,
#[serde(flatten)] #[serde(flatten)]
#[validate(nested)] #[validate(nested)]
pub limits: TemplateLimits, pub limits: TemplateLimits,
#[serde(default)] #[serde(default)]
#[validate(custom(function = "validate_disable_capabilities"))]
pub disable_capabilities: Vec<String>, pub disable_capabilities: Vec<String>,
} }
fn validate_template( fn validate_disable_capabilities(disable_capabilities: &[String]) -> Result<(), ValidationError> {
template: &Template, if disable_capabilities.contains(&("".to_string())) {
_context: &ValidateContext, return Err(ValidationError::new("badcap").with_message("empty disabled capability".into()));
) -> Result<(), ValidationError> {
if template.disable_capabilities.contains(&("".to_string())) {
return Err(ValidationError::new("badcap")
.with_message("template has empty disabled capability".into()));
} }
Ok(()) Ok(())
} }
@ -129,10 +122,12 @@ pub struct TemplateLimits {
#[serde(default)] #[serde(default)]
pub machine_count: Option<WeightedList<u32>>, pub machine_count: Option<WeightedList<u32>>,
#[validate(nested)] #[validate(nested)]
#[serde(default)]
pub machines_per_network: Option<WeightedList<u32>>, pub machines_per_network: Option<WeightedList<u32>>,
} }
fn validate_template_limits(limits: &TemplateLimits) -> Result<(), ValidationError> { fn validate_template_limits(limits: &TemplateLimits) -> Result<(), ValidationError> {
let mut has_at_least_one_limit = false;
if let Some(machine_count) = &limits.machine_count { if let Some(machine_count) = &limits.machine_count {
machine_count.try_for_each(|x| { machine_count.try_for_each(|x| {
if *x == 0 { if *x == 0 {
@ -141,6 +136,7 @@ fn validate_template_limits(limits: &TemplateLimits) -> Result<(), ValidationErr
} }
Ok(()) Ok(())
})?; })?;
has_at_least_one_limit = true;
} }
if let Some(machines_per_network) = &limits.machines_per_network { if let Some(machines_per_network) = &limits.machines_per_network {
machines_per_network.try_for_each(|x| { machines_per_network.try_for_each(|x| {
@ -150,6 +146,12 @@ fn validate_template_limits(limits: &TemplateLimits) -> Result<(), ValidationErr
} }
Ok(()) Ok(())
})?; })?;
has_at_least_one_limit = true;
}
if !has_at_least_one_limit {
return Err(ValidationError::new("nolimit")
.with_message("template can not be unlimited per network".into()));
} }
Ok(()) Ok(())
@ -162,42 +164,32 @@ pub enum TemplateLocation {
Blueprint { blueprint: WeightedList<String> }, Blueprint { blueprint: WeightedList<String> },
} }
fn validate_template_location( impl Validate for TemplateLocation {
value: &TemplateLocation, fn validate(&self) -> Result<(), ValidationErrors> {
context: &ValidateContext, match self {
) -> Result<(), ValidationError> { TemplateLocation::Network { network } => network.validate()?,
match value { TemplateLocation::Blueprint { blueprint } => blueprint.validate()?,
TemplateLocation::Network { network } => {
network.try_for_each(|m| validate_network_exists(m, context))?;
}
TemplateLocation::Blueprint { blueprint } => {
blueprint.try_for_each(|t| validate_blueprint_exists(t, context))?;
} }
Ok(())
} }
Ok(())
} }
//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
#[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[validate( #[validate(schema(function = "validate_network"))]
context = "ValidateContext<'v_a>",
schema(function = "validate_network", use_context)
)]
pub struct Network { pub struct Network {
#[serde(default)] #[serde(default)]
#[validate(custom(function = "validate_model_exists", use_context))]
pub model: Option<String>, pub model: Option<String>,
#[serde(default)] #[serde(default)]
#[validate(custom(function = "validate_network_ipv4", use_context))] #[validate(nested)]
pub ipv4: Option<NetworkIpv4>, pub ipv4: Option<NetworkIpv4>,
#[serde(default)] #[serde(default)]
#[validate(custom(function = "validate_network_ipv6", use_context))] #[validate(nested)]
pub ipv6: Option<NetworkIpv6>, pub ipv6: Option<NetworkIpv6>,
} }
fn validate_network(network: &Network, _context: &ValidateContext) -> Result<(), ValidationError> { fn validate_network(network: &Network) -> Result<(), ValidationError> {
if network.ipv4.is_none() && network.ipv6.is_none() { if network.ipv4.is_none() && network.ipv6.is_none() {
return Err(ValidationError::new("badaddr") return Err(ValidationError::new("badaddr")
.with_message("network must support at least one address type".into())); .with_message("network must support at least one address type".into()));
@ -205,83 +197,51 @@ fn validate_network(network: &Network, _context: &ValidateContext) -> Result<(),
Ok(()) Ok(())
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct NetworkIpv4 { pub struct NetworkIpv4 {
#[validate(length(min = 1))]
pub allocation: String, pub allocation: String,
#[serde(default)] #[serde(default)]
#[validate(nested)]
pub gateway: Option<NetworkGateway>, pub gateway: Option<NetworkGateway>,
} }
fn validate_network_ipv4( #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
network_ipv4: &NetworkIpv4,
context: &ValidateContext,
) -> Result<(), ValidationError> {
validate_allocation_exists(&network_ipv4.allocation, context)?;
if let Some(gateway) = &network_ipv4.gateway {
validate_network_gateway(gateway, context)?;
}
Ok(())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkIpv6 { pub struct NetworkIpv6 {
#[validate(length(min = 1))]
pub allocation: String, pub allocation: String,
#[serde(default)] #[serde(default)]
#[validate(nested)]
pub gateway: Option<NetworkGateway>, pub gateway: Option<NetworkGateway>,
} }
fn validate_network_ipv6(
network_ipv6: &NetworkIpv6,
context: &ValidateContext,
) -> Result<(), ValidationError> {
validate_allocation_exists(&network_ipv6.allocation, context)?;
if let Some(gateway) = &network_ipv6.gateway {
validate_network_gateway(gateway, context)?;
}
Ok(())
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct NetworkGateway { pub struct NetworkGateway {
pub translation: Translation, pub translation: Translation,
pub upnp: bool, pub upnp: bool,
#[validate(length(min = 1))]
pub network: Option<String>, pub network: Option<String>,
} }
fn validate_network_gateway(
gateway: &NetworkGateway,
context: &ValidateContext,
) -> Result<(), ValidationError> {
if let Some(network) = &gateway.network {
validate_network_exists(network, context)?;
}
Ok(())
}
//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
#[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[validate( #[validate(schema(function = "validate_blueprint"))]
context = "ValidateContext<'v_a>",
schema(function = "validate_blueprint", use_context)
)]
pub struct Blueprint { pub struct Blueprint {
#[serde(default)] #[serde(default)]
#[validate(custom(function = "validate_models_exist", use_context))] #[validate(nested)]
pub model: Option<WeightedList<String>>, pub model: Option<WeightedList<String>>,
#[validate(nested)] #[validate(nested)]
pub limits: BlueprintLimits, pub limits: BlueprintLimits,
#[serde(default)] #[serde(default)]
#[validate(custom(function = "validate_blueprint_ipv4", use_context))] #[validate(nested)]
pub ipv4: Option<BlueprintIpv4>, pub ipv4: Option<BlueprintIpv4>,
#[serde(default)] #[serde(default)]
#[validate(custom(function = "validate_blueprint_ipv6", use_context))] #[validate(nested)]
pub ipv6: Option<BlueprintIpv6>, pub ipv6: Option<BlueprintIpv6>,
} }
fn validate_blueprint( fn validate_blueprint(blueprint: &Blueprint) -> Result<(), ValidationError> {
blueprint: &Blueprint,
_context: &ValidateContext,
) -> Result<(), ValidationError> {
if blueprint.ipv4.is_none() && blueprint.ipv6.is_none() { if blueprint.ipv4.is_none() && blueprint.ipv6.is_none() {
return Err(ValidationError::new("badaddr") return Err(ValidationError::new("badaddr")
.with_message("blueprint must support at least one address type".into())); .with_message("blueprint must support at least one address type".into()));
@ -324,39 +284,35 @@ pub enum BlueprintLocation {
}, },
} }
fn validate_blueprint_location( impl Validate for BlueprintLocation {
value: &BlueprintLocation, fn validate(&self) -> Result<(), ValidationErrors> {
context: &ValidateContext, match self {
) -> Result<(), ValidationError> { BlueprintLocation::Allocation { allocation } => allocation.validate()?,
match value { BlueprintLocation::Network { network } => {
BlueprintLocation::Allocation { allocation } => { if let Some(network) = network {
allocation.try_for_each(|a| validate_allocation_exists(a, context))?; network.validate()?;
} }
BlueprintLocation::Network { network } => {
if let Some(network) = network {
network.try_for_each(|n| validate_network_exists(n, context))?;
} }
} }
}
Ok(()) Ok(())
}
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[validate(schema(function = "validate_blueprint_ipv4"))]
pub struct BlueprintIpv4 { pub struct BlueprintIpv4 {
#[serde(flatten)] #[serde(flatten)]
#[validate(nested)]
pub location: BlueprintLocation, pub location: BlueprintLocation,
#[validate(nested)]
pub prefix: WeightedList<u8>, pub prefix: WeightedList<u8>,
#[serde(default)] #[serde(default)]
#[validate(nested)]
pub gateway: Option<BlueprintGateway>, pub gateway: Option<BlueprintGateway>,
} }
fn validate_blueprint_ipv4( fn validate_blueprint_ipv4(blueprint_ipv4: &BlueprintIpv4) -> Result<(), ValidationError> {
blueprint_ipv4: &BlueprintIpv4,
context: &ValidateContext,
) -> Result<(), ValidationError> {
validate_blueprint_location(&blueprint_ipv4.location, context)?;
blueprint_ipv4.prefix.validate_once()?;
blueprint_ipv4.prefix.try_for_each(|x| { blueprint_ipv4.prefix.try_for_each(|x| {
if *x > 32 { if *x > 32 {
return Err(ValidationError::new("badprefix") return Err(ValidationError::new("badprefix")
@ -365,27 +321,23 @@ fn validate_blueprint_ipv4(
Ok(()) Ok(())
})?; })?;
if let Some(gateway) = &blueprint_ipv4.gateway {
validate_blueprint_gateway(gateway, context)?;
}
Ok(()) Ok(())
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[validate(schema(function = "validate_blueprint_ipv6"))]
pub struct BlueprintIpv6 { pub struct BlueprintIpv6 {
#[serde(flatten)] #[serde(flatten)]
#[validate(nested)]
pub allocation: BlueprintLocation, pub allocation: BlueprintLocation,
#[validate(nested)]
pub prefix: WeightedList<u8>, pub prefix: WeightedList<u8>,
#[serde(default)] #[serde(default)]
#[validate(nested)]
pub gateway: Option<BlueprintGateway>, pub gateway: Option<BlueprintGateway>,
} }
fn validate_blueprint_ipv6( fn validate_blueprint_ipv6(blueprint_ipv6: &BlueprintIpv6) -> Result<(), ValidationError> {
blueprint_ipv6: &BlueprintIpv6,
context: &ValidateContext,
) -> Result<(), ValidationError> {
validate_blueprint_location(&blueprint_ipv6.allocation, context)?;
blueprint_ipv6.prefix.validate_once()?;
blueprint_ipv6.prefix.try_for_each(|x| { blueprint_ipv6.prefix.try_for_each(|x| {
if *x > 128 { if *x > 128 {
return Err(ValidationError::new("badprefix") return Err(ValidationError::new("badprefix")
@ -393,37 +345,29 @@ fn validate_blueprint_ipv6(
} }
Ok(()) Ok(())
})?; })?;
if let Some(gateway) = &blueprint_ipv6.gateway {
validate_blueprint_gateway(gateway, context)?;
}
Ok(()) Ok(())
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct BlueprintGateway { pub struct BlueprintGateway {
#[validate(nested)]
pub translation: WeightedList<Translation>, pub translation: WeightedList<Translation>,
#[validate(range(min = 0.0, max = 1.0))]
pub upnp: Probability, pub upnp: Probability,
#[serde(flatten)] #[serde(flatten)]
pub location: TemplateLocation, pub location: TemplateLocation,
} }
fn validate_blueprint_gateway(
gateway: &BlueprintGateway,
context: &ValidateContext,
) -> Result<(), ValidationError> {
gateway.translation.validate_once()?;
validate_template_location(&gateway.location, context)?;
Ok(())
}
//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
#[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[validate(schema(function = "validate_subnets"))] #[validate(schema(function = "validate_subnets"))]
pub struct Subnets { pub struct Subnets {
#[serde(default)] #[serde(default)]
#[validate(nested)]
pub subnet4: Option<WeightedList<Ipv4Net>>, pub subnet4: Option<WeightedList<Ipv4Net>>,
#[serde(default)] #[serde(default)]
#[validate(nested)]
pub subnet6: Option<WeightedList<Ipv6Net>>, pub subnet6: Option<WeightedList<Ipv6Net>>,
} }
@ -480,7 +424,7 @@ fn validate_distribution(distribution: &Distribution) -> Result<(), ValidationEr
Ok(()) Ok(())
} }
#[derive(Debug, Copy, Clone, Serialize, Deserialize)] #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum Translation { pub enum Translation {
None, None,
@ -496,7 +440,6 @@ impl Default for Translation {
} }
#[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[validate(context = "ValidateContext<'v_a>")]
pub struct Model { pub struct Model {
#[validate(nested)] #[validate(nested)]
pub latency: Distribution, pub latency: Distribution,
@ -509,33 +452,21 @@ pub struct Model {
} }
#[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[validate(context = "ValidateContext<'v_a>")]
pub struct Allocation { pub struct Allocation {
#[serde(flatten)] #[serde(flatten)]
#[validate(nested)] #[validate(nested)]
pub subnets: Subnets, pub subnets: Subnets,
} }
pub struct ValidateContext<'a> { #[derive(Debug, Clone, Serialize, Deserialize, Default)]
config: &'a Config,
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[validate(context = "ValidateContext<'v_a>")]
pub struct Config { pub struct Config {
pub seed: Option<u64>,
#[validate(
length(min = 1),
custom(function = "validate_network_exists", use_context)
)]
pub default_network: String,
#[validate(
length(min = 1),
custom(function = "validate_model_exists", use_context)
)]
pub default_model: String,
#[serde(default)] #[serde(default)]
#[validate(length(min = 1))] pub seed: Option<u64>,
#[serde(default)]
pub default_network: Option<String>,
#[serde(default)]
pub default_model: Option<String>,
#[serde(default)]
pub profiles: HashMap<String, Profile>, pub profiles: HashMap<String, Profile>,
#[serde(default)] #[serde(default)]
pub machines: HashMap<String, Machine>, pub machines: HashMap<String, Machine>,
@ -551,140 +482,126 @@ pub struct Config {
pub models: HashMap<String, Model>, pub models: HashMap<String, Model>,
} }
impl Config { impl Validate for Config {
pub fn new<P: AsRef<Path>>(config_file: Option<P>) -> Result<Self, ConfigError> { fn validate(&self) -> Result<(), ValidationErrors> {
let cfg = load_config(config_file).map_err(ConfigError::ParseError)?;
// Generate config
let out: Self = cfg.try_deserialize().map_err(ConfigError::ParseError)?;
// Validate config // Validate config
let context = ValidateContext { config: &out };
let mut errors = ValidationErrors::new(); let mut errors = ValidationErrors::new();
if let Err(e) = out.validate_with_args(&context) { if let Some(default_network) = self.default_network.as_ref() {
errors = e; if default_network.is_empty() {
errors.add(
"default_network",
ValidationError::new("badlen").with_message(
"Config must have non-empty default network if specified".into(),
),
);
}
} }
errors.merge_self("profiles", validate_all_with_args(&out.profiles, &context)); if let Some(default_model) = self.default_model.as_ref() {
errors.merge_self("machines", validate_all_with_args(&out.machines, &context)); if default_model.is_empty() {
errors.merge_self( errors.add(
"templates", "default_model",
validate_all_with_args(&out.templates, &context), ValidationError::new("badlen").with_message(
); "Config must have non-empty default model if specified".into(),
errors.merge_self("networks", validate_all_with_args(&out.networks, &context)); ),
errors.merge_self( );
"blueprints", }
validate_all_with_args(&out.blueprints, &context), }
);
errors.merge_self( errors.merge_self("profiles", validate_hash_map(&self.profiles));
"allocation", errors.merge_self("machines", validate_hash_map(&self.machines));
validate_all_with_args(&out.allocations, &context), errors.merge_self("templates", validate_hash_map(&self.templates));
); errors.merge_self("networks", validate_hash_map(&self.networks));
errors.merge_self("models", validate_all_with_args(&out.models, &context)); errors.merge_self("blueprints", validate_hash_map(&self.blueprints));
errors.merge_self("allocation", validate_hash_map(&self.allocations));
errors.merge_self("models", validate_hash_map(&self.models));
if !errors.is_empty() { if !errors.is_empty() {
return Err(ConfigError::ValidateError(errors)); return Err(errors);
}
Ok(())
}
}
impl Config {
pub fn new<P: AsRef<Path>>(
config_files: &[P],
no_predefined_config: bool,
) -> Result<Self, ConfigError> {
let mut out = Self::default();
if !no_predefined_config {
out = load_predefined_config()
.map_err(ConfigError::ParseError)?
.try_deserialize()
.map_err(ConfigError::ParseError)?;
out.validate().map_err(ConfigError::ValidateError)?;
// Load default config file
if config_files.is_empty() {
let cfg: Self = load_default_config()
.map_err(ConfigError::ParseError)?
.try_deserialize()
.map_err(ConfigError::ParseError)?;
cfg.validate().map_err(ConfigError::ValidateError)?;
out = out.combine(cfg)?;
}
} else {
// There must be config files specified to use this option
if config_files.is_empty() {
return Err(ConfigError::NoConfigFiles);
}
}
// Load specified config files
for config_file in config_files {
let cfg: Self = load_config_file(config_file)
.map_err(ConfigError::ParseError)?
.try_deserialize()
.map_err(ConfigError::ParseError)?;
cfg.validate().map_err(ConfigError::ValidateError)?;
out = out.combine(cfg)?;
} }
Ok(out) Ok(out)
} }
}
fn validate_instances_exist( pub fn combine(self, other: Self) -> Result<Self, ConfigError> {
value: &Vec<Instance>, let out = Config {
context: &ValidateContext, seed: other.seed.or(self.seed),
) -> Result<(), ValidationError> { default_network: other.default_network.or(self.default_network),
for v in value { default_model: other.default_model.or(self.default_model),
match v { profiles: self.profiles.into_iter().chain(other.profiles).collect(),
Instance::Machine { machine } => validate_machines_exist(machine, context)?, machines: self.machines.into_iter().chain(other.machines).collect(),
Instance::Template { template } => validate_templates_exist(template, context)?, templates: self.templates.into_iter().chain(other.templates).collect(),
} networks: self.networks.into_iter().chain(other.networks).collect(),
blueprints: self
.blueprints
.into_iter()
.chain(other.blueprints)
.collect(),
allocations: self
.allocations
.into_iter()
.chain(other.allocations)
.collect(),
models: self.models.into_iter().chain(other.models).collect(),
};
// Validate config
out.validate().map_err(ConfigError::ValidateError)?;
Ok(out)
} }
Ok(())
} }
fn validate_network_exists(value: &str, context: &ValidateContext) -> Result<(), ValidationError> { fn validate_hash_map<T: Validate>(value: &HashMap<String, T>) -> Result<(), ValidationErrors> {
if !context.config.networks.contains_key(value) {
return Err(ValidationError::new("noexist").with_message("network does not exist".into()));
}
Ok(())
}
fn validate_blueprint_exists(
value: &str,
context: &ValidateContext,
) -> Result<(), ValidationError> {
if !context.config.blueprints.contains_key(value) {
return Err(ValidationError::new("noexist").with_message("blueprint does not exist".into()));
}
Ok(())
}
fn validate_allocation_exists(
value: &str,
context: &ValidateContext,
) -> Result<(), ValidationError> {
if !context.config.allocations.contains_key(value) {
return Err(
ValidationError::new("noexist").with_message("allocation does not exist".into())
);
}
Ok(())
}
fn validate_model_exists(value: &str, context: &ValidateContext) -> Result<(), ValidationError> {
if !context.config.networks.contains_key(value) {
return Err(ValidationError::new("noexist").with_message("model does not exist".into()));
}
Ok(())
}
fn validate_models_exist(
value: &WeightedList<String>,
context: &ValidateContext,
) -> Result<(), ValidationError> {
value.try_for_each(|x| validate_model_exists(x, context))
}
fn validate_machine_exists(value: &str, context: &ValidateContext) -> Result<(), ValidationError> {
if !context.config.machines.contains_key(value) {
return Err(ValidationError::new("noexist").with_message("machine does not exist".into()));
}
Ok(())
}
fn validate_machines_exist(
value: &WeightedList<String>,
context: &ValidateContext,
) -> Result<(), ValidationError> {
value.try_for_each(|x| validate_machine_exists(x, context))
}
fn validate_template_exists(value: &str, context: &ValidateContext) -> Result<(), ValidationError> {
if !context.config.templates.contains_key(value) {
return Err(ValidationError::new("noexist").with_message("template does not exist".into()));
}
Ok(())
}
fn validate_templates_exist(
value: &WeightedList<String>,
context: &ValidateContext,
) -> Result<(), ValidationError> {
value.try_for_each(|x| validate_template_exists(x, context))
}
fn validate_all_with_args<'v_a, T: ValidateArgs<'v_a, Args = &'v_a ValidateContext<'v_a>>>(
value: &HashMap<String, T>,
context: &'v_a ValidateContext,
) -> Result<(), ValidationErrors> {
let mut errors = ValidationErrors::new(); let mut errors = ValidationErrors::new();
for (n, x) in value.values().enumerate() { for (n, x) in value.values().enumerate() {
errors.merge_self( errors.merge_self(format!("[{n}]").to_static_str(), x.validate());
format!("[{n}]").to_static_str(),
x.validate_with_args(context),
);
} }
if !errors.is_empty() { if !errors.is_empty() {
return Err(errors); return Err(errors);
@ -692,6 +609,15 @@ fn validate_all_with_args<'v_a, T: ValidateArgs<'v_a, Args = &'v_a ValidateConte
Ok(()) Ok(())
} }
fn load_predefined_config() -> Result<::config::Config, ::config::ConfigError> {
::config::Config::builder()
.add_source(::config::File::from_str(
PREDEFINED_CONFIG,
::config::FileFormat::Yaml,
))
.build()
}
fn load_default_config() -> Result<::config::Config, ::config::ConfigError> { fn load_default_config() -> Result<::config::Config, ::config::ConfigError> {
::config::Config::builder() ::config::Config::builder()
.add_source(::config::File::from_str( .add_source(::config::File::from_str(
@ -705,12 +631,9 @@ fn load_default_config() -> Result<::config::Config, ::config::ConfigError> {
.build() .build()
} }
fn load_config<P: AsRef<Path>>( fn load_config_file<P: AsRef<Path>>(
opt_config_file: Option<P>, config_file: P,
) -> Result<::config::Config, ::config::ConfigError> { ) -> Result<::config::Config, ::config::ConfigError> {
let Some(config_file) = opt_config_file else {
return load_default_config();
};
let config_path = config_file.as_ref(); let config_path = config_file.as_ref();
let Some(config_file_str) = config_path.to_str() else { let Some(config_file_str) = config_path.to_str() else {
return Err(::config::ConfigError::Message( return Err(::config::ConfigError::Message(
@ -718,15 +641,10 @@ fn load_config<P: AsRef<Path>>(
)); ));
}; };
let config = ::config::Config::builder() let config = ::config::Config::builder()
.add_source(::config::File::from_str(
PREDEFINED_CONFIG,
::config::FileFormat::Yaml,
))
.add_source(::config::File::new( .add_source(::config::File::new(
config_file_str, config_file_str,
::config::FileFormat::Yaml, ::config::FileFormat::Yaml,
)) ))
.build()?; .build()?;
Ok(config) Ok(config)
} }

View File

@ -79,7 +79,6 @@ templates:
bootrelay: bootrelay:
network: "boot" network: "boot"
machine_count: 4 machine_count: 4
machines_per_network: 4
# Servers on subnets within the 'internet' network # Servers on subnets within the 'internet' network
relayserver: relayserver:
blueprint: "direct" blueprint: "direct"

View File

@ -34,6 +34,8 @@ pub enum MachineRegistryError {
TemplateNotFound, TemplateNotFound,
BlueprintNotFound, BlueprintNotFound,
ModelNotFound, ModelNotFound,
NoDefaultModel,
NoDefaultNetwork,
NoAllocation, NoAllocation,
ResourceInUse, ResourceInUse,
} }

View File

@ -189,7 +189,12 @@ impl BlueprintState {
) -> MachineRegistryResult<()> { ) -> MachineRegistryResult<()> {
let model_name = match self.fields.model.as_ref() { let model_name = match self.fields.model.as_ref() {
Some(models) => (**machine_registry_inner.srng().weighted_choice_ref(models)).clone(), Some(models) => (**machine_registry_inner.srng().weighted_choice_ref(models)).clone(),
None => machine_registry_inner.config().default_model.clone(), None => machine_registry_inner
.config()
.default_model
.as_ref()
.ok_or(MachineRegistryError::NoDefaultModel)?
.clone(),
}; };
let Some(model) = machine_registry_inner.config().models.get(&model_name) else { let Some(model) = machine_registry_inner.config().models.get(&model_name) else {
return Err(MachineRegistryError::ModelNotFound); return Err(MachineRegistryError::ModelNotFound);

View File

@ -30,19 +30,6 @@ impl<T: fmt::Debug + Clone> WeightedList<T> {
self.len() == 0 self.len() == 0
} }
pub fn validate_once(&self) -> Result<(), ValidationError> {
match self {
Self::List(v) => {
if v.is_empty() {
return Err(ValidationError::new("len")
.with_message("weighted list must not be empty".into()));
}
}
Self::Single(_addr) => {}
}
Ok(())
}
pub fn try_for_each<E, F: FnMut(&T) -> Result<(), E>>(&self, mut f: F) -> Result<(), E> { pub fn try_for_each<E, F: FnMut(&T) -> Result<(), E>>(&self, mut f: F) -> Result<(), E> {
match self { match self {
WeightedList::Single(v) => f(v), WeightedList::Single(v) => f(v),
@ -215,9 +202,21 @@ impl<'a, T: fmt::Debug + Clone> Iterator for WeightedListIter<'a, T> {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
/// Validate /// Validate
impl<T: fmt::Debug + Clone> Validate for WeightedList<T> { impl<T: core::hash::Hash + Eq + fmt::Debug + Clone> Validate for WeightedList<T> {
fn validate(&self) -> Result<(), ValidationErrors> { fn validate(&self) -> Result<(), ValidationErrors> {
let mut errors = ValidationErrors::new(); let mut errors = ValidationErrors::new();
// Ensure weighted list does not have duplicates
let items = self.iter().collect::<HashSet<_>>();
if items.len() != self.len() {
errors.add(
"List",
ValidationError::new("weightdup")
.with_message("weighted list must not have duplicate items".into()),
);
}
// Make sure list is not empty
match self { match self {
Self::List(v) => { Self::List(v) => {
if v.is_empty() { if v.is_empty() {
@ -240,6 +239,15 @@ impl<T: fmt::Debug + Clone> Validate for WeightedList<T> {
} }
} }
// impl<T: core::hash::Hash + Eq + fmt::Debug + Clone> WeightedList<T> {
// pub fn validate_once(&self) -> Result<(), ValidationError> {
// self.validate().map_err(|errs| {
// ValidationError::new("multiple")
// .with_message(format!("multiple validation errors: {}", errs).into())
// })
// }
// }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
/// Weighted /// Weighted