mirror of
https://gitlab.com/veilid/veilid.git
synced 2025-01-11 23:39:36 -05:00
[skip ci]
change config loading mechanism to allow loading multiple configs make it so predefined config can be loaded independently make it so default network/model are options and must be resolved by the time they are used
This commit is contained in:
parent
d93a5602d3
commit
9e7c41b635
@ -6,6 +6,7 @@ use parking_lot::*;
|
||||
use std::path::PathBuf;
|
||||
use stop_token::StopSource;
|
||||
use veilid_tools::*;
|
||||
use virtual_network::*;
|
||||
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
@ -40,9 +41,12 @@ struct CmdlineArgs {
|
||||
/// Turn off WS listener
|
||||
#[arg(long)]
|
||||
no_ws: bool,
|
||||
/// Specify a configuration file to use
|
||||
/// Specify an initial list of configuration files to use
|
||||
#[arg(short = 'c', long, value_name = "FILE")]
|
||||
config_file: Option<PathBuf>,
|
||||
config_file: Vec<PathBuf>,
|
||||
/// Specify to load configuration without a predefined config first
|
||||
#[arg(long)]
|
||||
no_predefined_config: bool,
|
||||
/// Instead of running the virtual router, print the configuration it would use to the console
|
||||
#[arg(long)]
|
||||
dump_config: bool,
|
||||
@ -64,17 +68,18 @@ fn main() -> Result<(), String> {
|
||||
|
||||
let args = CmdlineArgs::parse();
|
||||
|
||||
let config = config::Config::new(args.config_file)
|
||||
let initial_config = config::Config::new(&args.config_file, args.no_predefined_config)
|
||||
.map_err(|e| format!("Error loading config: {}", e))?;
|
||||
|
||||
if args.dump_config {
|
||||
let cfg_yaml = serde_yaml::to_string(&config)
|
||||
let cfg_yaml = serde_yaml::to_string(&initial_config)
|
||||
.map_err(|e| format!("Error serializing config: {}", e))?;
|
||||
println!("{}", cfg_yaml);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let router_server = virtual_network::RouterServer::new(config);
|
||||
let router_server = virtual_network::RouterServer::new(initial_config);
|
||||
|
||||
let _ss_tcp = if !args.no_tcp {
|
||||
Some(
|
||||
router_server
|
||||
|
@ -253,8 +253,6 @@ pub use timeout_or::*;
|
||||
pub use timestamp::*;
|
||||
#[doc(inline)]
|
||||
pub use tools::*;
|
||||
#[cfg(feature = "virtual-network")]
|
||||
pub use virtual_network::*;
|
||||
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
|
||||
pub use wasm::*;
|
||||
|
||||
|
@ -2,7 +2,7 @@ use super::*;
|
||||
use serde::*;
|
||||
use std::path::Path;
|
||||
|
||||
use validator::{Validate, ValidateArgs, ValidationError, ValidationErrors};
|
||||
use validator::{Validate, ValidationError, ValidationErrors};
|
||||
|
||||
const PREDEFINED_CONFIG: &str = include_str!("predefined_config.yml");
|
||||
const DEFAULT_CONFIG: &str = include_str!("default_config.yml");
|
||||
@ -13,15 +13,13 @@ pub enum ConfigError {
|
||||
ParseError(::config::ConfigError),
|
||||
#[error("validate error")]
|
||||
ValidateError(validator::ValidationErrors),
|
||||
#[error("no configuration files specified")]
|
||||
NoConfigFiles,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
|
||||
#[validate(context = "ValidateContext<'v_a>")]
|
||||
pub struct Profile {
|
||||
#[validate(
|
||||
length(min = 1),
|
||||
custom(function = "validate_instances_exist", use_context)
|
||||
)]
|
||||
#[validate(length(min = 1), nested)]
|
||||
pub instances: Vec<Instance>,
|
||||
}
|
||||
|
||||
@ -32,14 +30,20 @@ pub enum Instance {
|
||||
Template { template: WeightedList<String> },
|
||||
}
|
||||
|
||||
impl Validate for Instance {
|
||||
fn validate(&self) -> Result<(), ValidationErrors> {
|
||||
match self {
|
||||
Instance::Machine { machine } => machine.validate()?,
|
||||
Instance::Template { template } => template.validate()?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
|
||||
#[validate(
|
||||
context = "ValidateContext<'v_a>",
|
||||
schema(function = "validate_machine", use_context)
|
||||
)]
|
||||
pub struct Machine {
|
||||
#[serde(flatten)]
|
||||
#[validate(custom(function = "validate_machine_location", use_context))]
|
||||
#[validate(nested)]
|
||||
pub location: MachineLocation,
|
||||
#[serde(default)]
|
||||
pub disable_capabilities: Vec<String>,
|
||||
@ -47,14 +51,6 @@ pub struct Machine {
|
||||
pub bootstrap: bool,
|
||||
}
|
||||
|
||||
fn validate_machine(machine: &Machine, _context: &ValidateContext) -> Result<(), ValidationError> {
|
||||
if machine.disable_capabilities.contains(&("".to_string())) {
|
||||
return Err(ValidationError::new("badcap")
|
||||
.with_message("machine has empty disabled capability".into()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum MachineLocation {
|
||||
@ -70,53 +66,50 @@ pub enum MachineLocation {
|
||||
},
|
||||
}
|
||||
|
||||
fn validate_machine_location(
|
||||
value: &MachineLocation,
|
||||
context: &ValidateContext,
|
||||
) -> Result<(), ValidationError> {
|
||||
match value {
|
||||
MachineLocation::Network {
|
||||
network,
|
||||
address4,
|
||||
address6,
|
||||
} => {
|
||||
if address4.is_none() && address6.is_none() {
|
||||
return Err(ValidationError::new("badaddr")
|
||||
.with_message("machine must have at least one address".into()));
|
||||
impl Validate for MachineLocation {
|
||||
fn validate(&self) -> Result<(), ValidationErrors> {
|
||||
let mut errors = ValidationErrors::new();
|
||||
match self {
|
||||
MachineLocation::Network {
|
||||
network: _,
|
||||
address4,
|
||||
address6,
|
||||
} => {
|
||||
if address4.is_none() && address6.is_none() {
|
||||
errors.add(
|
||||
"MachineLocation",
|
||||
ValidationError::new("badaddr")
|
||||
.with_message("machine must have at least one address".into()),
|
||||
);
|
||||
}
|
||||
}
|
||||
validate_network_exists(network, context)?;
|
||||
MachineLocation::Blueprint { blueprint: _ } => {}
|
||||
}
|
||||
MachineLocation::Blueprint { blueprint } => {
|
||||
validate_blueprint_exists(blueprint, context)?;
|
||||
|
||||
if !errors.is_empty() {
|
||||
Err(errors)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
|
||||
#[validate(
|
||||
context = "ValidateContext<'v_a>",
|
||||
schema(function = "validate_template", use_context)
|
||||
)]
|
||||
pub struct Template {
|
||||
#[serde(flatten)]
|
||||
#[validate(custom(function = "validate_template_location", use_context))]
|
||||
#[validate(nested)]
|
||||
pub location: TemplateLocation,
|
||||
#[serde(flatten)]
|
||||
#[validate(nested)]
|
||||
pub limits: TemplateLimits,
|
||||
#[serde(default)]
|
||||
#[validate(custom(function = "validate_disable_capabilities"))]
|
||||
pub disable_capabilities: Vec<String>,
|
||||
}
|
||||
|
||||
fn validate_template(
|
||||
template: &Template,
|
||||
_context: &ValidateContext,
|
||||
) -> Result<(), ValidationError> {
|
||||
if template.disable_capabilities.contains(&("".to_string())) {
|
||||
return Err(ValidationError::new("badcap")
|
||||
.with_message("template has empty disabled capability".into()));
|
||||
fn validate_disable_capabilities(disable_capabilities: &[String]) -> Result<(), ValidationError> {
|
||||
if disable_capabilities.contains(&("".to_string())) {
|
||||
return Err(ValidationError::new("badcap").with_message("empty disabled capability".into()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@ -129,10 +122,12 @@ pub struct TemplateLimits {
|
||||
#[serde(default)]
|
||||
pub machine_count: Option<WeightedList<u32>>,
|
||||
#[validate(nested)]
|
||||
#[serde(default)]
|
||||
pub machines_per_network: Option<WeightedList<u32>>,
|
||||
}
|
||||
|
||||
fn validate_template_limits(limits: &TemplateLimits) -> Result<(), ValidationError> {
|
||||
let mut has_at_least_one_limit = false;
|
||||
if let Some(machine_count) = &limits.machine_count {
|
||||
machine_count.try_for_each(|x| {
|
||||
if *x == 0 {
|
||||
@ -141,6 +136,7 @@ fn validate_template_limits(limits: &TemplateLimits) -> Result<(), ValidationErr
|
||||
}
|
||||
Ok(())
|
||||
})?;
|
||||
has_at_least_one_limit = true;
|
||||
}
|
||||
if let Some(machines_per_network) = &limits.machines_per_network {
|
||||
machines_per_network.try_for_each(|x| {
|
||||
@ -150,6 +146,12 @@ fn validate_template_limits(limits: &TemplateLimits) -> Result<(), ValidationErr
|
||||
}
|
||||
Ok(())
|
||||
})?;
|
||||
has_at_least_one_limit = true;
|
||||
}
|
||||
|
||||
if !has_at_least_one_limit {
|
||||
return Err(ValidationError::new("nolimit")
|
||||
.with_message("template can not be unlimited per network".into()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@ -162,42 +164,32 @@ pub enum TemplateLocation {
|
||||
Blueprint { blueprint: WeightedList<String> },
|
||||
}
|
||||
|
||||
fn validate_template_location(
|
||||
value: &TemplateLocation,
|
||||
context: &ValidateContext,
|
||||
) -> Result<(), ValidationError> {
|
||||
match value {
|
||||
TemplateLocation::Network { network } => {
|
||||
network.try_for_each(|m| validate_network_exists(m, context))?;
|
||||
}
|
||||
TemplateLocation::Blueprint { blueprint } => {
|
||||
blueprint.try_for_each(|t| validate_blueprint_exists(t, context))?;
|
||||
impl Validate for TemplateLocation {
|
||||
fn validate(&self) -> Result<(), ValidationErrors> {
|
||||
match self {
|
||||
TemplateLocation::Network { network } => network.validate()?,
|
||||
TemplateLocation::Blueprint { blueprint } => blueprint.validate()?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
|
||||
#[validate(
|
||||
context = "ValidateContext<'v_a>",
|
||||
schema(function = "validate_network", use_context)
|
||||
)]
|
||||
#[validate(schema(function = "validate_network"))]
|
||||
pub struct Network {
|
||||
#[serde(default)]
|
||||
#[validate(custom(function = "validate_model_exists", use_context))]
|
||||
pub model: Option<String>,
|
||||
#[serde(default)]
|
||||
#[validate(custom(function = "validate_network_ipv4", use_context))]
|
||||
#[validate(nested)]
|
||||
pub ipv4: Option<NetworkIpv4>,
|
||||
#[serde(default)]
|
||||
#[validate(custom(function = "validate_network_ipv6", use_context))]
|
||||
#[validate(nested)]
|
||||
pub ipv6: Option<NetworkIpv6>,
|
||||
}
|
||||
|
||||
fn validate_network(network: &Network, _context: &ValidateContext) -> Result<(), ValidationError> {
|
||||
fn validate_network(network: &Network) -> Result<(), ValidationError> {
|
||||
if network.ipv4.is_none() && network.ipv6.is_none() {
|
||||
return Err(ValidationError::new("badaddr")
|
||||
.with_message("network must support at least one address type".into()));
|
||||
@ -205,83 +197,51 @@ fn validate_network(network: &Network, _context: &ValidateContext) -> Result<(),
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
|
||||
pub struct NetworkIpv4 {
|
||||
#[validate(length(min = 1))]
|
||||
pub allocation: String,
|
||||
#[serde(default)]
|
||||
#[validate(nested)]
|
||||
pub gateway: Option<NetworkGateway>,
|
||||
}
|
||||
|
||||
fn validate_network_ipv4(
|
||||
network_ipv4: &NetworkIpv4,
|
||||
context: &ValidateContext,
|
||||
) -> Result<(), ValidationError> {
|
||||
validate_allocation_exists(&network_ipv4.allocation, context)?;
|
||||
if let Some(gateway) = &network_ipv4.gateway {
|
||||
validate_network_gateway(gateway, context)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
|
||||
pub struct NetworkIpv6 {
|
||||
#[validate(length(min = 1))]
|
||||
pub allocation: String,
|
||||
#[serde(default)]
|
||||
#[validate(nested)]
|
||||
pub gateway: Option<NetworkGateway>,
|
||||
}
|
||||
|
||||
fn validate_network_ipv6(
|
||||
network_ipv6: &NetworkIpv6,
|
||||
context: &ValidateContext,
|
||||
) -> Result<(), ValidationError> {
|
||||
validate_allocation_exists(&network_ipv6.allocation, context)?;
|
||||
if let Some(gateway) = &network_ipv6.gateway {
|
||||
validate_network_gateway(gateway, context)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
|
||||
pub struct NetworkGateway {
|
||||
pub translation: Translation,
|
||||
pub upnp: bool,
|
||||
#[validate(length(min = 1))]
|
||||
pub network: Option<String>,
|
||||
}
|
||||
|
||||
fn validate_network_gateway(
|
||||
gateway: &NetworkGateway,
|
||||
context: &ValidateContext,
|
||||
) -> Result<(), ValidationError> {
|
||||
if let Some(network) = &gateway.network {
|
||||
validate_network_exists(network, context)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
////////////////////////////////////////////////////////////////
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
|
||||
#[validate(
|
||||
context = "ValidateContext<'v_a>",
|
||||
schema(function = "validate_blueprint", use_context)
|
||||
)]
|
||||
#[validate(schema(function = "validate_blueprint"))]
|
||||
pub struct Blueprint {
|
||||
#[serde(default)]
|
||||
#[validate(custom(function = "validate_models_exist", use_context))]
|
||||
#[validate(nested)]
|
||||
pub model: Option<WeightedList<String>>,
|
||||
#[validate(nested)]
|
||||
pub limits: BlueprintLimits,
|
||||
#[serde(default)]
|
||||
#[validate(custom(function = "validate_blueprint_ipv4", use_context))]
|
||||
#[validate(nested)]
|
||||
pub ipv4: Option<BlueprintIpv4>,
|
||||
#[serde(default)]
|
||||
#[validate(custom(function = "validate_blueprint_ipv6", use_context))]
|
||||
#[validate(nested)]
|
||||
pub ipv6: Option<BlueprintIpv6>,
|
||||
}
|
||||
|
||||
fn validate_blueprint(
|
||||
blueprint: &Blueprint,
|
||||
_context: &ValidateContext,
|
||||
) -> Result<(), ValidationError> {
|
||||
fn validate_blueprint(blueprint: &Blueprint) -> Result<(), ValidationError> {
|
||||
if blueprint.ipv4.is_none() && blueprint.ipv6.is_none() {
|
||||
return Err(ValidationError::new("badaddr")
|
||||
.with_message("blueprint must support at least one address type".into()));
|
||||
@ -324,39 +284,35 @@ pub enum BlueprintLocation {
|
||||
},
|
||||
}
|
||||
|
||||
fn validate_blueprint_location(
|
||||
value: &BlueprintLocation,
|
||||
context: &ValidateContext,
|
||||
) -> Result<(), ValidationError> {
|
||||
match value {
|
||||
BlueprintLocation::Allocation { allocation } => {
|
||||
allocation.try_for_each(|a| validate_allocation_exists(a, context))?;
|
||||
}
|
||||
BlueprintLocation::Network { network } => {
|
||||
if let Some(network) = network {
|
||||
network.try_for_each(|n| validate_network_exists(n, context))?;
|
||||
impl Validate for BlueprintLocation {
|
||||
fn validate(&self) -> Result<(), ValidationErrors> {
|
||||
match self {
|
||||
BlueprintLocation::Allocation { allocation } => allocation.validate()?,
|
||||
BlueprintLocation::Network { network } => {
|
||||
if let Some(network) = network {
|
||||
network.validate()?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
|
||||
#[validate(schema(function = "validate_blueprint_ipv4"))]
|
||||
pub struct BlueprintIpv4 {
|
||||
#[serde(flatten)]
|
||||
#[validate(nested)]
|
||||
pub location: BlueprintLocation,
|
||||
#[validate(nested)]
|
||||
pub prefix: WeightedList<u8>,
|
||||
#[serde(default)]
|
||||
#[validate(nested)]
|
||||
pub gateway: Option<BlueprintGateway>,
|
||||
}
|
||||
|
||||
fn validate_blueprint_ipv4(
|
||||
blueprint_ipv4: &BlueprintIpv4,
|
||||
context: &ValidateContext,
|
||||
) -> Result<(), ValidationError> {
|
||||
validate_blueprint_location(&blueprint_ipv4.location, context)?;
|
||||
blueprint_ipv4.prefix.validate_once()?;
|
||||
fn validate_blueprint_ipv4(blueprint_ipv4: &BlueprintIpv4) -> Result<(), ValidationError> {
|
||||
blueprint_ipv4.prefix.try_for_each(|x| {
|
||||
if *x > 32 {
|
||||
return Err(ValidationError::new("badprefix")
|
||||
@ -365,27 +321,23 @@ fn validate_blueprint_ipv4(
|
||||
Ok(())
|
||||
})?;
|
||||
|
||||
if let Some(gateway) = &blueprint_ipv4.gateway {
|
||||
validate_blueprint_gateway(gateway, context)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
|
||||
#[validate(schema(function = "validate_blueprint_ipv6"))]
|
||||
pub struct BlueprintIpv6 {
|
||||
#[serde(flatten)]
|
||||
#[validate(nested)]
|
||||
pub allocation: BlueprintLocation,
|
||||
#[validate(nested)]
|
||||
pub prefix: WeightedList<u8>,
|
||||
#[serde(default)]
|
||||
#[validate(nested)]
|
||||
pub gateway: Option<BlueprintGateway>,
|
||||
}
|
||||
|
||||
fn validate_blueprint_ipv6(
|
||||
blueprint_ipv6: &BlueprintIpv6,
|
||||
context: &ValidateContext,
|
||||
) -> Result<(), ValidationError> {
|
||||
validate_blueprint_location(&blueprint_ipv6.allocation, context)?;
|
||||
blueprint_ipv6.prefix.validate_once()?;
|
||||
fn validate_blueprint_ipv6(blueprint_ipv6: &BlueprintIpv6) -> Result<(), ValidationError> {
|
||||
blueprint_ipv6.prefix.try_for_each(|x| {
|
||||
if *x > 128 {
|
||||
return Err(ValidationError::new("badprefix")
|
||||
@ -393,37 +345,29 @@ fn validate_blueprint_ipv6(
|
||||
}
|
||||
Ok(())
|
||||
})?;
|
||||
|
||||
if let Some(gateway) = &blueprint_ipv6.gateway {
|
||||
validate_blueprint_gateway(gateway, context)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
|
||||
pub struct BlueprintGateway {
|
||||
#[validate(nested)]
|
||||
pub translation: WeightedList<Translation>,
|
||||
#[validate(range(min = 0.0, max = 1.0))]
|
||||
pub upnp: Probability,
|
||||
#[serde(flatten)]
|
||||
pub location: TemplateLocation,
|
||||
}
|
||||
|
||||
fn validate_blueprint_gateway(
|
||||
gateway: &BlueprintGateway,
|
||||
context: &ValidateContext,
|
||||
) -> Result<(), ValidationError> {
|
||||
gateway.translation.validate_once()?;
|
||||
validate_template_location(&gateway.location, context)?;
|
||||
Ok(())
|
||||
}
|
||||
////////////////////////////////////////////////////////////////
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
|
||||
#[validate(schema(function = "validate_subnets"))]
|
||||
pub struct Subnets {
|
||||
#[serde(default)]
|
||||
#[validate(nested)]
|
||||
pub subnet4: Option<WeightedList<Ipv4Net>>,
|
||||
#[serde(default)]
|
||||
#[validate(nested)]
|
||||
pub subnet6: Option<WeightedList<Ipv6Net>>,
|
||||
}
|
||||
|
||||
@ -480,7 +424,7 @@ fn validate_distribution(distribution: &Distribution) -> Result<(), ValidationEr
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Translation {
|
||||
None,
|
||||
@ -496,7 +440,6 @@ impl Default for Translation {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
|
||||
#[validate(context = "ValidateContext<'v_a>")]
|
||||
pub struct Model {
|
||||
#[validate(nested)]
|
||||
pub latency: Distribution,
|
||||
@ -509,33 +452,21 @@ pub struct Model {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
|
||||
#[validate(context = "ValidateContext<'v_a>")]
|
||||
pub struct Allocation {
|
||||
#[serde(flatten)]
|
||||
#[validate(nested)]
|
||||
pub subnets: Subnets,
|
||||
}
|
||||
|
||||
pub struct ValidateContext<'a> {
|
||||
config: &'a Config,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
|
||||
#[validate(context = "ValidateContext<'v_a>")]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct Config {
|
||||
pub seed: Option<u64>,
|
||||
#[validate(
|
||||
length(min = 1),
|
||||
custom(function = "validate_network_exists", use_context)
|
||||
)]
|
||||
pub default_network: String,
|
||||
#[validate(
|
||||
length(min = 1),
|
||||
custom(function = "validate_model_exists", use_context)
|
||||
)]
|
||||
pub default_model: String,
|
||||
#[serde(default)]
|
||||
#[validate(length(min = 1))]
|
||||
pub seed: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub default_network: Option<String>,
|
||||
#[serde(default)]
|
||||
pub default_model: Option<String>,
|
||||
#[serde(default)]
|
||||
pub profiles: HashMap<String, Profile>,
|
||||
#[serde(default)]
|
||||
pub machines: HashMap<String, Machine>,
|
||||
@ -551,140 +482,126 @@ pub struct Config {
|
||||
pub models: HashMap<String, Model>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn new<P: AsRef<Path>>(config_file: Option<P>) -> Result<Self, ConfigError> {
|
||||
let cfg = load_config(config_file).map_err(ConfigError::ParseError)?;
|
||||
|
||||
// Generate config
|
||||
let out: Self = cfg.try_deserialize().map_err(ConfigError::ParseError)?;
|
||||
|
||||
impl Validate for Config {
|
||||
fn validate(&self) -> Result<(), ValidationErrors> {
|
||||
// Validate config
|
||||
let context = ValidateContext { config: &out };
|
||||
let mut errors = ValidationErrors::new();
|
||||
|
||||
if let Err(e) = out.validate_with_args(&context) {
|
||||
errors = e;
|
||||
if let Some(default_network) = self.default_network.as_ref() {
|
||||
if default_network.is_empty() {
|
||||
errors.add(
|
||||
"default_network",
|
||||
ValidationError::new("badlen").with_message(
|
||||
"Config must have non-empty default network if specified".into(),
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
errors.merge_self("profiles", validate_all_with_args(&out.profiles, &context));
|
||||
errors.merge_self("machines", validate_all_with_args(&out.machines, &context));
|
||||
errors.merge_self(
|
||||
"templates",
|
||||
validate_all_with_args(&out.templates, &context),
|
||||
);
|
||||
errors.merge_self("networks", validate_all_with_args(&out.networks, &context));
|
||||
errors.merge_self(
|
||||
"blueprints",
|
||||
validate_all_with_args(&out.blueprints, &context),
|
||||
);
|
||||
errors.merge_self(
|
||||
"allocation",
|
||||
validate_all_with_args(&out.allocations, &context),
|
||||
);
|
||||
errors.merge_self("models", validate_all_with_args(&out.models, &context));
|
||||
if let Some(default_model) = self.default_model.as_ref() {
|
||||
if default_model.is_empty() {
|
||||
errors.add(
|
||||
"default_model",
|
||||
ValidationError::new("badlen").with_message(
|
||||
"Config must have non-empty default model if specified".into(),
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
errors.merge_self("profiles", validate_hash_map(&self.profiles));
|
||||
errors.merge_self("machines", validate_hash_map(&self.machines));
|
||||
errors.merge_self("templates", validate_hash_map(&self.templates));
|
||||
errors.merge_self("networks", validate_hash_map(&self.networks));
|
||||
errors.merge_self("blueprints", validate_hash_map(&self.blueprints));
|
||||
errors.merge_self("allocation", validate_hash_map(&self.allocations));
|
||||
errors.merge_self("models", validate_hash_map(&self.models));
|
||||
|
||||
if !errors.is_empty() {
|
||||
return Err(ConfigError::ValidateError(errors));
|
||||
return Err(errors);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn new<P: AsRef<Path>>(
|
||||
config_files: &[P],
|
||||
no_predefined_config: bool,
|
||||
) -> Result<Self, ConfigError> {
|
||||
let mut out = Self::default();
|
||||
|
||||
if !no_predefined_config {
|
||||
out = load_predefined_config()
|
||||
.map_err(ConfigError::ParseError)?
|
||||
.try_deserialize()
|
||||
.map_err(ConfigError::ParseError)?;
|
||||
out.validate().map_err(ConfigError::ValidateError)?;
|
||||
|
||||
// Load default config file
|
||||
if config_files.is_empty() {
|
||||
let cfg: Self = load_default_config()
|
||||
.map_err(ConfigError::ParseError)?
|
||||
.try_deserialize()
|
||||
.map_err(ConfigError::ParseError)?;
|
||||
cfg.validate().map_err(ConfigError::ValidateError)?;
|
||||
|
||||
out = out.combine(cfg)?;
|
||||
}
|
||||
} else {
|
||||
// There must be config files specified to use this option
|
||||
if config_files.is_empty() {
|
||||
return Err(ConfigError::NoConfigFiles);
|
||||
}
|
||||
}
|
||||
|
||||
// Load specified config files
|
||||
for config_file in config_files {
|
||||
let cfg: Self = load_config_file(config_file)
|
||||
.map_err(ConfigError::ParseError)?
|
||||
.try_deserialize()
|
||||
.map_err(ConfigError::ParseError)?;
|
||||
cfg.validate().map_err(ConfigError::ValidateError)?;
|
||||
|
||||
out = out.combine(cfg)?;
|
||||
}
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_instances_exist(
|
||||
value: &Vec<Instance>,
|
||||
context: &ValidateContext,
|
||||
) -> Result<(), ValidationError> {
|
||||
for v in value {
|
||||
match v {
|
||||
Instance::Machine { machine } => validate_machines_exist(machine, context)?,
|
||||
Instance::Template { template } => validate_templates_exist(template, context)?,
|
||||
}
|
||||
pub fn combine(self, other: Self) -> Result<Self, ConfigError> {
|
||||
let out = Config {
|
||||
seed: other.seed.or(self.seed),
|
||||
default_network: other.default_network.or(self.default_network),
|
||||
default_model: other.default_model.or(self.default_model),
|
||||
profiles: self.profiles.into_iter().chain(other.profiles).collect(),
|
||||
machines: self.machines.into_iter().chain(other.machines).collect(),
|
||||
templates: self.templates.into_iter().chain(other.templates).collect(),
|
||||
networks: self.networks.into_iter().chain(other.networks).collect(),
|
||||
blueprints: self
|
||||
.blueprints
|
||||
.into_iter()
|
||||
.chain(other.blueprints)
|
||||
.collect(),
|
||||
allocations: self
|
||||
.allocations
|
||||
.into_iter()
|
||||
.chain(other.allocations)
|
||||
.collect(),
|
||||
models: self.models.into_iter().chain(other.models).collect(),
|
||||
};
|
||||
|
||||
// Validate config
|
||||
out.validate().map_err(ConfigError::ValidateError)?;
|
||||
Ok(out)
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_network_exists(value: &str, context: &ValidateContext) -> Result<(), ValidationError> {
|
||||
if !context.config.networks.contains_key(value) {
|
||||
return Err(ValidationError::new("noexist").with_message("network does not exist".into()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_blueprint_exists(
|
||||
value: &str,
|
||||
context: &ValidateContext,
|
||||
) -> Result<(), ValidationError> {
|
||||
if !context.config.blueprints.contains_key(value) {
|
||||
return Err(ValidationError::new("noexist").with_message("blueprint does not exist".into()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_allocation_exists(
|
||||
value: &str,
|
||||
context: &ValidateContext,
|
||||
) -> Result<(), ValidationError> {
|
||||
if !context.config.allocations.contains_key(value) {
|
||||
return Err(
|
||||
ValidationError::new("noexist").with_message("allocation does not exist".into())
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_model_exists(value: &str, context: &ValidateContext) -> Result<(), ValidationError> {
|
||||
if !context.config.networks.contains_key(value) {
|
||||
return Err(ValidationError::new("noexist").with_message("model does not exist".into()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_models_exist(
|
||||
value: &WeightedList<String>,
|
||||
context: &ValidateContext,
|
||||
) -> Result<(), ValidationError> {
|
||||
value.try_for_each(|x| validate_model_exists(x, context))
|
||||
}
|
||||
|
||||
fn validate_machine_exists(value: &str, context: &ValidateContext) -> Result<(), ValidationError> {
|
||||
if !context.config.machines.contains_key(value) {
|
||||
return Err(ValidationError::new("noexist").with_message("machine does not exist".into()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_machines_exist(
|
||||
value: &WeightedList<String>,
|
||||
context: &ValidateContext,
|
||||
) -> Result<(), ValidationError> {
|
||||
value.try_for_each(|x| validate_machine_exists(x, context))
|
||||
}
|
||||
|
||||
fn validate_template_exists(value: &str, context: &ValidateContext) -> Result<(), ValidationError> {
|
||||
if !context.config.templates.contains_key(value) {
|
||||
return Err(ValidationError::new("noexist").with_message("template does not exist".into()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_templates_exist(
|
||||
value: &WeightedList<String>,
|
||||
context: &ValidateContext,
|
||||
) -> Result<(), ValidationError> {
|
||||
value.try_for_each(|x| validate_template_exists(x, context))
|
||||
}
|
||||
|
||||
fn validate_all_with_args<'v_a, T: ValidateArgs<'v_a, Args = &'v_a ValidateContext<'v_a>>>(
|
||||
value: &HashMap<String, T>,
|
||||
context: &'v_a ValidateContext,
|
||||
) -> Result<(), ValidationErrors> {
|
||||
fn validate_hash_map<T: Validate>(value: &HashMap<String, T>) -> Result<(), ValidationErrors> {
|
||||
let mut errors = ValidationErrors::new();
|
||||
for (n, x) in value.values().enumerate() {
|
||||
errors.merge_self(
|
||||
format!("[{n}]").to_static_str(),
|
||||
x.validate_with_args(context),
|
||||
);
|
||||
errors.merge_self(format!("[{n}]").to_static_str(), x.validate());
|
||||
}
|
||||
if !errors.is_empty() {
|
||||
return Err(errors);
|
||||
@ -692,6 +609,15 @@ fn validate_all_with_args<'v_a, T: ValidateArgs<'v_a, Args = &'v_a ValidateConte
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_predefined_config() -> Result<::config::Config, ::config::ConfigError> {
|
||||
::config::Config::builder()
|
||||
.add_source(::config::File::from_str(
|
||||
PREDEFINED_CONFIG,
|
||||
::config::FileFormat::Yaml,
|
||||
))
|
||||
.build()
|
||||
}
|
||||
|
||||
fn load_default_config() -> Result<::config::Config, ::config::ConfigError> {
|
||||
::config::Config::builder()
|
||||
.add_source(::config::File::from_str(
|
||||
@ -705,12 +631,9 @@ fn load_default_config() -> Result<::config::Config, ::config::ConfigError> {
|
||||
.build()
|
||||
}
|
||||
|
||||
fn load_config<P: AsRef<Path>>(
|
||||
opt_config_file: Option<P>,
|
||||
fn load_config_file<P: AsRef<Path>>(
|
||||
config_file: P,
|
||||
) -> Result<::config::Config, ::config::ConfigError> {
|
||||
let Some(config_file) = opt_config_file else {
|
||||
return load_default_config();
|
||||
};
|
||||
let config_path = config_file.as_ref();
|
||||
let Some(config_file_str) = config_path.to_str() else {
|
||||
return Err(::config::ConfigError::Message(
|
||||
@ -718,15 +641,10 @@ fn load_config<P: AsRef<Path>>(
|
||||
));
|
||||
};
|
||||
let config = ::config::Config::builder()
|
||||
.add_source(::config::File::from_str(
|
||||
PREDEFINED_CONFIG,
|
||||
::config::FileFormat::Yaml,
|
||||
))
|
||||
.add_source(::config::File::new(
|
||||
config_file_str,
|
||||
::config::FileFormat::Yaml,
|
||||
))
|
||||
.build()?;
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
@ -79,7 +79,6 @@ templates:
|
||||
bootrelay:
|
||||
network: "boot"
|
||||
machine_count: 4
|
||||
machines_per_network: 4
|
||||
# Servers on subnets within the 'internet' network
|
||||
relayserver:
|
||||
blueprint: "direct"
|
||||
|
@ -34,6 +34,8 @@ pub enum MachineRegistryError {
|
||||
TemplateNotFound,
|
||||
BlueprintNotFound,
|
||||
ModelNotFound,
|
||||
NoDefaultModel,
|
||||
NoDefaultNetwork,
|
||||
NoAllocation,
|
||||
ResourceInUse,
|
||||
}
|
||||
|
@ -189,7 +189,12 @@ impl BlueprintState {
|
||||
) -> MachineRegistryResult<()> {
|
||||
let model_name = match self.fields.model.as_ref() {
|
||||
Some(models) => (**machine_registry_inner.srng().weighted_choice_ref(models)).clone(),
|
||||
None => machine_registry_inner.config().default_model.clone(),
|
||||
None => machine_registry_inner
|
||||
.config()
|
||||
.default_model
|
||||
.as_ref()
|
||||
.ok_or(MachineRegistryError::NoDefaultModel)?
|
||||
.clone(),
|
||||
};
|
||||
let Some(model) = machine_registry_inner.config().models.get(&model_name) else {
|
||||
return Err(MachineRegistryError::ModelNotFound);
|
||||
|
@ -30,19 +30,6 @@ impl<T: fmt::Debug + Clone> WeightedList<T> {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
pub fn validate_once(&self) -> Result<(), ValidationError> {
|
||||
match self {
|
||||
Self::List(v) => {
|
||||
if v.is_empty() {
|
||||
return Err(ValidationError::new("len")
|
||||
.with_message("weighted list must not be empty".into()));
|
||||
}
|
||||
}
|
||||
Self::Single(_addr) => {}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn try_for_each<E, F: FnMut(&T) -> Result<(), E>>(&self, mut f: F) -> Result<(), E> {
|
||||
match self {
|
||||
WeightedList::Single(v) => f(v),
|
||||
@ -215,9 +202,21 @@ impl<'a, T: fmt::Debug + Clone> Iterator for WeightedListIter<'a, T> {
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
/// Validate
|
||||
|
||||
impl<T: fmt::Debug + Clone> Validate for WeightedList<T> {
|
||||
impl<T: core::hash::Hash + Eq + fmt::Debug + Clone> Validate for WeightedList<T> {
|
||||
fn validate(&self) -> Result<(), ValidationErrors> {
|
||||
let mut errors = ValidationErrors::new();
|
||||
|
||||
// Ensure weighted list does not have duplicates
|
||||
let items = self.iter().collect::<HashSet<_>>();
|
||||
if items.len() != self.len() {
|
||||
errors.add(
|
||||
"List",
|
||||
ValidationError::new("weightdup")
|
||||
.with_message("weighted list must not have duplicate items".into()),
|
||||
);
|
||||
}
|
||||
|
||||
// Make sure list is not empty
|
||||
match self {
|
||||
Self::List(v) => {
|
||||
if v.is_empty() {
|
||||
@ -240,6 +239,15 @@ impl<T: fmt::Debug + Clone> Validate for WeightedList<T> {
|
||||
}
|
||||
}
|
||||
|
||||
// impl<T: core::hash::Hash + Eq + fmt::Debug + Clone> WeightedList<T> {
|
||||
// pub fn validate_once(&self) -> Result<(), ValidationError> {
|
||||
// self.validate().map_err(|errs| {
|
||||
// ValidationError::new("multiple")
|
||||
// .with_message(format!("multiple validation errors: {}", errs).into())
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
/// Weighted
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user