validation WIP

This commit is contained in:
Christien Rioux 2024-11-18 23:23:43 -05:00
parent 5c0e973e4a
commit 981674b870
9 changed files with 670 additions and 222 deletions

54
Cargo.lock generated
View File

@ -1491,6 +1491,7 @@ dependencies = [
"ident_case", "ident_case",
"proc-macro2", "proc-macro2",
"quote", "quote",
"strsim 0.11.1",
"syn 2.0.87", "syn 2.0.87",
] ]
@ -4362,6 +4363,28 @@ dependencies = [
"toml_edit 0.19.15", "toml_edit 0.19.15",
] ]
[[package]]
name = "proc-macro-error-attr2"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5"
dependencies = [
"proc-macro2",
"quote",
]
[[package]]
name = "proc-macro-error2"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802"
dependencies = [
"proc-macro-error-attr2",
"proc-macro2",
"quote",
"syn 2.0.87",
]
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.89" version = "1.0.89"
@ -6267,6 +6290,36 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "validator"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0b4a29d8709210980a09379f27ee31549b73292c87ab9899beee1c0d3be6303"
dependencies = [
"idna 1.0.3",
"once_cell",
"regex",
"serde",
"serde_derive",
"serde_json",
"url",
"validator_derive",
]
[[package]]
name = "validator_derive"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bac855a2ce6f843beb229757e6e570a42e837bcb15e5f449dd48d5747d41bf77"
dependencies = [
"darling 0.20.10",
"once_cell",
"proc-macro-error2",
"proc-macro2",
"quote",
"syn 2.0.87",
]
[[package]] [[package]]
name = "valuable" name = "valuable"
version = "0.1.0" version = "0.1.0"
@ -6619,6 +6672,7 @@ dependencies = [
"tracing-oslog", "tracing-oslog",
"tracing-subscriber", "tracing-subscriber",
"tracing-wasm", "tracing-wasm",
"validator",
"wasm-bindgen", "wasm-bindgen",
"wasm-bindgen-futures", "wasm-bindgen-futures",
"wasm-bindgen-test", "wasm-bindgen-test",

View File

@ -50,6 +50,7 @@ virtual-network-server = [
"dep:config", "dep:config",
"dep:ipnet", "dep:ipnet",
"dep:serde_yaml", "dep:serde_yaml",
"dep:validator",
"dep:ws_stream_tungstenite", "dep:ws_stream_tungstenite",
] ]
@ -111,6 +112,7 @@ config = { version = "^0", default-features = false, features = [
], optional = true } ], optional = true }
ipnet = { version = "2", features = ["serde"], optional = true } ipnet = { version = "2", features = ["serde"], optional = true }
serde_yaml = { package = "serde_yaml_ng", version = "^0.10.0", optional = true } serde_yaml = { package = "serde_yaml_ng", version = "^0.10.0", optional = true }
validator = { version = "0.19.0", features = ["derive"], optional = true }
# Dependencies for WASM builds only # Dependencies for WASM builds only
[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies] [target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies]

View File

@ -202,17 +202,19 @@ pub enum ServerProcessorReplyValue {
} }
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub enum ServerProcessorReplyResult { pub enum ServerProcessorReplyStatus {
Value(ServerProcessorReplyValue), Value(ServerProcessorReplyValue),
InvalidMachineId, InvalidMachineId,
InvalidSocketId, InvalidSocketId,
MissingProfile,
ProfileComplete,
IoError(#[serde(with = "serde_io_error::SerdeIoErrorKindDef")] io::ErrorKind), IoError(#[serde(with = "serde_io_error::SerdeIoErrorKindDef")] io::ErrorKind),
} }
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct ServerProcessorReply { pub struct ServerProcessorReply {
pub message_id: MessageId, pub message_id: MessageId,
pub status: ServerProcessorReplyResult, pub status: ServerProcessorReplyStatus,
} }
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]

View File

@ -26,7 +26,7 @@ impl fmt::Debug for RouterClientInner {
struct RouterClientUnlockedInner { struct RouterClientUnlockedInner {
sender: flume::Sender<ServerProcessorCommand>, sender: flume::Sender<ServerProcessorCommand>,
next_message_id: AtomicU64, next_message_id: AtomicU64,
router_op_waiter: RouterOpWaiter<ServerProcessorReplyResult, ()>, router_op_waiter: RouterOpWaiter<ServerProcessorReplyStatus, ()>,
} }
impl fmt::Debug for RouterClientUnlockedInner { impl fmt::Debug for RouterClientUnlockedInner {
@ -480,7 +480,7 @@ impl RouterClient {
fn new( fn new(
sender: flume::Sender<ServerProcessorCommand>, sender: flume::Sender<ServerProcessorCommand>,
router_op_waiter: RouterOpWaiter<ServerProcessorReplyResult, ()>, router_op_waiter: RouterOpWaiter<ServerProcessorReplyStatus, ()>,
jh_handler: MustJoinHandle<()>, jh_handler: MustJoinHandle<()>,
stop_source: StopSource, stop_source: StopSource,
) -> RouterClient { ) -> RouterClient {
@ -557,16 +557,20 @@ impl RouterClient {
.map_err(|_| VirtualNetworkError::WaitError)?; .map_err(|_| VirtualNetworkError::WaitError)?;
match status { match status {
ServerProcessorReplyResult::Value(server_processor_response) => { ServerProcessorReplyStatus::Value(server_processor_response) => {
Ok(server_processor_response) Ok(server_processor_response)
} }
ServerProcessorReplyResult::InvalidMachineId => { ServerProcessorReplyStatus::InvalidMachineId => {
Err(VirtualNetworkError::InvalidMachineId) Err(VirtualNetworkError::InvalidMachineId)
} }
ServerProcessorReplyResult::InvalidSocketId => { ServerProcessorReplyStatus::InvalidSocketId => {
Err(VirtualNetworkError::InvalidSocketId) Err(VirtualNetworkError::InvalidSocketId)
} }
ServerProcessorReplyResult::IoError(k) => Err(VirtualNetworkError::IoError(k)), ServerProcessorReplyStatus::MissingProfile => Err(VirtualNetworkError::MissingProfile),
ServerProcessorReplyStatus::ProfileComplete => {
Err(VirtualNetworkError::ProfileComplete)
}
ServerProcessorReplyStatus::IoError(k) => Err(VirtualNetworkError::IoError(k)),
} }
} }
@ -574,7 +578,7 @@ impl RouterClient {
reader: R, reader: R,
writer: W, writer: W,
receiver: flume::Receiver<ServerProcessorCommand>, receiver: flume::Receiver<ServerProcessorCommand>,
router_op_waiter: RouterOpWaiter<ServerProcessorReplyResult, ()>, router_op_waiter: RouterOpWaiter<ServerProcessorReplyStatus, ()>,
stop_token: StopToken, stop_token: StopToken,
) where ) where
R: AsyncReadExt + Unpin + Send, R: AsyncReadExt + Unpin + Send,
@ -619,7 +623,7 @@ impl RouterClient {
async fn run_local_processor( async fn run_local_processor(
receiver: flume::Receiver<ServerProcessorEvent>, receiver: flume::Receiver<ServerProcessorEvent>,
router_op_waiter: RouterOpWaiter<ServerProcessorReplyResult, ()>, router_op_waiter: RouterOpWaiter<ServerProcessorReplyStatus, ()>,
stop_token: StopToken, stop_token: StopToken,
) { ) {
let mut unord = FuturesUnordered::new(); let mut unord = FuturesUnordered::new();
@ -640,7 +644,7 @@ impl RouterClient {
async fn process_event( async fn process_event(
evt: ServerProcessorEvent, evt: ServerProcessorEvent,
router_op_waiter: RouterOpWaiter<ServerProcessorReplyResult, ()>, router_op_waiter: RouterOpWaiter<ServerProcessorReplyStatus, ()>,
) -> io::Result<()> { ) -> io::Result<()> {
match evt { match evt {
ServerProcessorEvent::Reply(reply) => { ServerProcessorEvent::Reply(reply) => {

View File

@ -3,11 +3,17 @@ use ipnet::*;
use serde::*; use serde::*;
use std::path::Path; use std::path::Path;
pub use ::config::ConfigError; use validator::{Validate, ValidateArgs, ValidationError, ValidationErrors};
const PREDEFINED_CONFIG: &str = include_str!("predefined_config.yml"); const PREDEFINED_CONFIG: &str = include_str!("predefined_config.yml");
const DEFAULT_CONFIG: &str = include_str!("default_config.yml"); const DEFAULT_CONFIG: &str = include_str!("default_config.yml");
#[derive(Debug)]
pub enum ConfigError {
ParseError(::config::ConfigError),
ValidateError(validator::ValidationErrors),
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
pub enum WeightedList<T: fmt::Debug + Clone> { pub enum WeightedList<T: fmt::Debug + Clone> {
@ -19,6 +25,26 @@ impl<T: fmt::Debug + Clone> Default for WeightedList<T> {
Self::List(Vec::new()) Self::List(Vec::new())
} }
} }
impl<T: fmt::Debug + Clone> Validate for WeightedList<T> {
fn validate(&self) -> Result<(), ValidationErrors> {
let mut errors = ValidationErrors::new();
if let Self::List(v) = self {
if v.is_empty() {
errors.add(
"List",
ValidationError::new("len")
.with_message("weighted list must not be empty".into()),
)
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
pub type Probability = f32; pub type Probability = f32;
@ -29,9 +55,35 @@ pub enum Weighted<T: fmt::Debug + Clone> {
Unweighted(T), Unweighted(T),
} }
#[derive(Debug, Clone, Serialize, Deserialize)] impl<T: fmt::Debug + Clone> Validate for Weighted<T> {
fn validate(&self) -> Result<(), ValidationErrors> {
let mut errors = ValidationErrors::new();
if let Self::Weighted { item: _, weight } = self {
if *weight <= 0.0 {
errors.add(
"Weighted",
ValidationError::new("len")
.with_message("weight must be a positive value".into()),
)
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[validate(context = "ValidateContext<'v_a>")]
pub struct Profile { pub struct Profile {
instances: Vec<Instance>, #[validate(
length(min = 1),
custom(function = "validate_instances_exist", use_context)
)]
pub instances: Vec<Instance>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -41,33 +93,37 @@ pub enum Instance {
Template { template: WeightedList<String> }, Template { template: WeightedList<String> },
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[validate(context = "ValidateContext<'v_a>")]
pub struct Machine { pub struct Machine {
#[serde(flatten)] #[serde(flatten)]
location: Location, #[validate(custom(function = "validate_location_exists", use_context))]
pub location: Location,
#[serde(default)] #[serde(default)]
address4: Option<Ipv4Addr>, pub address4: Option<Ipv4Addr>,
#[serde(default)] #[serde(default)]
address6: Option<Ipv6Addr>, pub address6: Option<Ipv6Addr>,
#[serde(default)] #[serde(default)]
disable_capabilities: Vec<String>, pub disable_capabilities: Vec<String>,
#[serde(default)] #[serde(default)]
bootstrap: bool, pub bootstrap: bool,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[validate(context = "ValidateContext<'v_a>")]
pub struct Template { pub struct Template {
#[serde(flatten)] #[serde(flatten)]
location: Location, #[validate(custom(function = "validate_location_exists", use_context))]
pub location: Location,
#[serde(flatten)] #[serde(flatten)]
limits: Limits, pub limits: Limits,
#[serde(default)] #[serde(default)]
disable_capabilities: Vec<String>, pub disable_capabilities: Vec<String>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct Limits { pub struct Limits {
machine_count: WeightedList<u32>, pub machine_count: WeightedList<u32>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -79,95 +135,95 @@ pub enum Location {
//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct Network { pub struct Network {
#[serde(default)] #[serde(default)]
model: Option<String>, pub model: Option<String>,
#[serde(default)] #[serde(default)]
ipv4: Option<NetworkIpv4>, pub ipv4: Option<NetworkIpv4>,
#[serde(default)] #[serde(default)]
ipv6: Option<NetworkIpv6>, pub ipv6: Option<NetworkIpv6>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct NetworkIpv4 { pub struct NetworkIpv4 {
allocation: String, pub allocation: String,
#[serde(default)] #[serde(default)]
gateway: Option<NetworkGateway>, pub gateway: Option<NetworkGateway>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct NetworkIpv6 { pub struct NetworkIpv6 {
allocation: String, pub allocation: String,
#[serde(default)] #[serde(default)]
gateway: Option<NetworkGateway>, pub gateway: Option<NetworkGateway>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct NetworkGateway { pub struct NetworkGateway {
translation: Translation, pub translation: Translation,
upnp: bool, pub upnp: bool,
network: Option<String>, pub network: Option<String>,
} }
//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct Blueprint { pub struct Blueprint {
#[serde(default)] #[serde(default)]
model: WeightedList<String>, pub model: WeightedList<String>,
#[serde(default)] #[serde(default)]
ipv4: Option<BlueprintIpv4>, pub ipv4: Option<BlueprintIpv4>,
#[serde(default)] #[serde(default)]
ipv6: Option<BlueprintIpv6>, pub ipv6: Option<BlueprintIpv6>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct BlueprintIpv4 { pub struct BlueprintIpv4 {
#[serde(default)] #[serde(default)]
allocation: Option<String>, pub allocation: Option<String>,
prefix: u8, pub prefix: u8,
#[serde(default)] #[serde(default)]
gateway: Option<BlueprintGateway>, pub gateway: Option<BlueprintGateway>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct BlueprintIpv6 { pub struct BlueprintIpv6 {
#[serde(default)] #[serde(default)]
allocation: Option<String>, pub allocation: Option<String>,
prefix: u8, pub prefix: u8,
#[serde(default)] #[serde(default)]
gateway: Option<BlueprintGateway>, pub gateway: Option<BlueprintGateway>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct BlueprintGateway { pub struct BlueprintGateway {
translation: WeightedList<Translation>, pub translation: WeightedList<Translation>,
upnp: Probability, pub upnp: Probability,
network: Option<String>, pub network: Option<String>,
} }
//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct Subnets { pub struct Subnets {
#[serde(default)] #[serde(default)]
subnet4: Vec<Ipv4Net>, pub subnet4: Vec<Ipv4Net>,
#[serde(default)] #[serde(default)]
subnet6: Vec<Ipv6Net>, pub subnet6: Vec<Ipv6Net>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct Distance { pub struct Distance {
min: f32, pub min: f32,
max: f32, pub max: f32,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct Distribution { pub struct Distribution {
mean: f32, pub mean: f32,
sigma: f32, pub sigma: f32,
skew: f32, pub skew: f32,
min: f32, pub min: f32,
max: f32, pub max: f32,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -185,52 +241,166 @@ impl Default for Translation {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct Model { pub struct Model {
latency: Distribution, pub latency: Distribution,
#[serde(default)] #[serde(default)]
distance: Option<Distance>, pub distance: Option<Distance>,
#[serde(default)] #[serde(default)]
loss: Probability, pub loss: Probability,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct Allocation { pub struct Allocation {
#[serde(flatten)] #[serde(flatten)]
subnets: Subnets, pub subnets: Subnets,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] struct ValidateContext<'a> {
config: &'a Config,
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[validate(context = "ValidateContext<'v_a>")]
pub struct Config { pub struct Config {
seed: Option<u32>, pub seed: Option<u32>,
default_network: String, #[validate(
default_model: String, length(min = 1),
custom(function = "validate_network_exists", use_context)
)]
pub default_network: String,
#[validate(
length(min = 1),
custom(function = "validate_model_exists", use_context)
)]
pub default_model: String,
#[serde(default)] #[serde(default)]
profiles: HashMap<String, Profile>, #[validate(length(min = 1))]
pub profiles: HashMap<String, Profile>,
#[serde(default)] #[serde(default)]
machines: HashMap<String, Machine>, pub machines: HashMap<String, Machine>,
#[serde(default)] #[serde(default)]
templates: HashMap<String, Template>, pub templates: HashMap<String, Template>,
#[serde(default)] #[serde(default)]
networks: HashMap<String, Network>, pub networks: HashMap<String, Network>,
#[serde(default)] #[serde(default)]
blueprints: HashMap<String, Blueprint>, pub blueprints: HashMap<String, Blueprint>,
#[serde(default)] #[serde(default)]
allocations: HashMap<String, Allocation>, pub allocations: HashMap<String, Allocation>,
#[serde(default)] #[serde(default)]
models: HashMap<String, Model>, pub models: HashMap<String, Model>,
} }
impl Config { impl Config {
pub fn new<P: AsRef<Path>>(config_file: Option<P>) -> Result<Self, ConfigError> { pub fn new<P: AsRef<Path>>(config_file: Option<P>) -> Result<Self, ConfigError> {
let cfg = load_config(config_file)?; let cfg = load_config(config_file).map_err(ConfigError::ParseError)?;
// Generate config // Generate config
cfg.try_deserialize() let out: Self = cfg.try_deserialize().map_err(ConfigError::ParseError)?;
// Validate config
let context = ValidateContext { config: &out };
let mut errors = ValidationErrors::new();
if let Err(e) = out.validate_with_args(&context) {
errors = e;
}
errors.merge_self("profiles", validate_all_profiles(&out.profiles, &context));
errors.merge_self("machines", validate_all_machines(&out.machines, &context));
errors.merge_self(
"templates",
validate_all_templates(&out.templates, &context),
);
errors.merge_self("networks", validate_all_networks(&out.networks, &context));
errors.merge_self(
"blueprints",
validate_all_blueprints(&out.blueprints, &context),
);
errors.merge_self(
"allocation",
validate_all_allocations(&out.allocations, &context),
);
errors.merge_self("models", validate_all_models(&out.models, &context));
if !errors.is_empty() {
return Err(ConfigError::ValidateError(errors));
}
Ok(out)
} }
} }
fn load_default_config() -> Result<::config::Config, ConfigError> { fn validate_instances_exist(
value: &Vec<Instance>,
context: &ValidateContext,
) -> Result<(), ValidationError> {
Ok(())
}
fn validate_location_exists(
value: &Location,
context: &ValidateContext,
) -> Result<(), ValidationError> {
Ok(())
}
fn validate_network_exists(value: &str, context: &ValidateContext) -> Result<(), ValidationError> {
Ok(())
}
fn validate_model_exists(value: &str, context: &ValidateContext) -> Result<(), ValidationError> {
Ok(())
}
fn validate_all_profiles(
value: &HashMap<String, Profile>,
context: &ValidateContext,
) -> Result<(), ValidationErrors> {
for x in value.values() {
x.validate_with_args(context)?
}
Ok(())
}
fn validate_all_machines(
value: &HashMap<String, Machine>,
context: &ValidateContext,
) -> Result<(), ValidationErrors> {
Ok(())
}
fn validate_all_templates(
value: &HashMap<String, Template>,
context: &ValidateContext,
) -> Result<(), ValidationErrors> {
Ok(())
}
fn validate_all_networks(
value: &HashMap<String, Network>,
context: &ValidateContext,
) -> Result<(), ValidationErrors> {
Ok(())
}
fn validate_all_blueprints(
value: &HashMap<String, Blueprint>,
context: &ValidateContext,
) -> Result<(), ValidationErrors> {
Ok(())
}
fn validate_all_allocations(
value: &HashMap<String, Allocation>,
context: &ValidateContext,
) -> Result<(), ValidationErrors> {
Ok(())
}
fn validate_all_models(
value: &HashMap<String, Model>,
context: &ValidateContext,
) -> Result<(), ValidationErrors> {
Ok(())
}
fn load_default_config() -> Result<::config::Config, ::config::ConfigError> {
::config::Config::builder() ::config::Config::builder()
.add_source(::config::File::from_str( .add_source(::config::File::from_str(
PREDEFINED_CONFIG, PREDEFINED_CONFIG,
@ -245,17 +415,17 @@ fn load_default_config() -> Result<::config::Config, ConfigError> {
fn load_config<P: AsRef<Path>>( fn load_config<P: AsRef<Path>>(
opt_config_file: Option<P>, opt_config_file: Option<P>,
) -> Result<::config::Config, ConfigError> { ) -> Result<::config::Config, ::config::ConfigError> {
let Some(config_file) = opt_config_file else { let Some(config_file) = opt_config_file else {
return load_default_config(); return load_default_config();
}; };
let config_path = config_file.as_ref(); let config_path = config_file.as_ref();
let Some(config_file_str) = config_path.to_str() else { let Some(config_file_str) = config_path.to_str() else {
return Err(ConfigError::Message( return Err(::config::ConfigError::Message(
"config file path is not valid UTF-8".to_owned(), "config file path is not valid UTF-8".to_owned(),
)); ));
}; };
::config::Config::builder() let config = ::config::Config::builder()
.add_source(::config::File::from_str( .add_source(::config::File::from_str(
PREDEFINED_CONFIG, PREDEFINED_CONFIG,
::config::FileFormat::Yaml, ::config::FileFormat::Yaml,
@ -264,5 +434,7 @@ fn load_config<P: AsRef<Path>>(
config_file_str, config_file_str,
::config::FileFormat::Yaml, ::config::FileFormat::Yaml,
)) ))
.build() .build()?;
Ok(config)
} }

View File

@ -0,0 +1,111 @@
use super::*;
#[derive(Debug)]
struct Machine {}
#[derive(Debug)]
struct MachineRegistryUnlockedInner {
config: config::Config,
}
#[derive(Debug, Default)]
struct ProfileState {
next_instance_index: usize,
}
#[derive(Debug)]
struct MachineRegistryInner {
machines_by_id: HashMap<MachineId, Machine>,
current_profile_state: HashMap<String, ProfileState>,
}
#[derive(Debug, Clone)]
pub enum MachineRegistryError {
InvalidMachineId,
ProfileNotFound,
ProfileComplete,
}
pub type MachineRegistryResult<T> = Result<T, MachineRegistryError>;
#[derive(Debug, Clone)]
pub struct MachineRegistry {
unlocked_inner: Arc<MachineRegistryUnlockedInner>,
inner: Arc<Mutex<MachineRegistryInner>>,
}
impl MachineRegistry {
///////////////////////////////////////////////////////////
/// Public Interface
pub fn new(config: config::Config) -> Self {
Self {
unlocked_inner: Arc::new(MachineRegistryUnlockedInner { config }),
inner: Arc::new(Mutex::new(MachineRegistryInner {
machines_by_id: HashMap::new(),
current_profile_state: HashMap::new(),
})),
}
}
pub async fn allocate(&self, profile: String) -> MachineRegistryResult<MachineId> {
// Get profile definition
let Some(profile_def) = self.unlocked_inner.config.profiles.get(&profile) else {
return Err(MachineRegistryError::ProfileNotFound);
};
// Get current profile state, creating one if we have not yet started executing the profile
let mut inner = self.inner.lock();
let current_profile_state = inner
.current_profile_state
.entry(profile)
.or_insert_with(|| ProfileState::default());
// Get the next instance from the definition
let Some(instance_def) = profile_def
.instances
.get(current_profile_state.next_instance_index)
else {
//
return Err(MachineRegistryError::ProfileComplete);
};
match instance_def {
config::Instance::Machine { machine } => {
self.create_machine(machine);
}
config::Instance::Template { template } => todo!(),
}
Ok(machine_id)
}
pub async fn release(&self, machine_id: MachineId) -> MachineRegistryResult<()> {}
///////////////////////////////////////////////////////////
/// Private Implementation
async fn create_machine(
&self,
machine_def: config::Machine,
) -> MachineRegistryResult<MachineId> {
//
}
fn weighted_choice<T: fmt::Debug + Clone>(
&self,
weighted_list: &config::WeightedList<T>,
) -> &T {
match weighted_list {
config::WeightedList::Single(x) => x,
config::WeightedList::List(vec) => {
let total_weight = vec
.iter()
.map(|x| match x {
config::Weighted::Weighted { item, weight } => weight,
config::Weighted::Unweighted(item) => 1.0,
})
.reduce(|acc, x| acc + x);
}
}
}
}

View File

@ -1,8 +1,12 @@
mod config; pub mod config;
mod machine_registry;
pub use config::*; mod server_processor;
use super::*; use super::*;
use machine_registry::*;
use server_processor::*;
use async_tungstenite::accept_async; use async_tungstenite::accept_async;
use futures_codec::{Bytes, BytesCodec, FramedRead, FramedWrite}; use futures_codec::{Bytes, BytesCodec, FramedRead, FramedWrite};
use futures_util::{stream::FuturesUnordered, AsyncReadExt, StreamExt, TryStreamExt}; use futures_util::{stream::FuturesUnordered, AsyncReadExt, StreamExt, TryStreamExt};
@ -31,9 +35,10 @@ enum RunLoopEvent {
#[derive(Debug)] #[derive(Debug)]
struct RouterServerUnlockedInner { struct RouterServerUnlockedInner {
config: Config, config: config::Config,
new_client_sender: flume::Sender<SendPinBoxFuture<RunLoopEvent>>, new_client_sender: flume::Sender<SendPinBoxFuture<RunLoopEvent>>,
new_client_receiver: flume::Receiver<SendPinBoxFuture<RunLoopEvent>>, new_client_receiver: flume::Receiver<SendPinBoxFuture<RunLoopEvent>>,
server_processor: ServerProcessor,
} }
#[derive(Debug)] #[derive(Debug)]
@ -57,15 +62,19 @@ impl RouterServer {
// Public Interface // Public Interface
/// Create a router server for virtual networking /// Create a router server for virtual networking
pub fn new(config: Config) -> Self { pub fn new(config: config::Config) -> Self {
// Make a channel to receive new clients // Make a channel to receive new clients
let (new_client_sender, new_client_receiver) = flume::unbounded(); let (new_client_sender, new_client_receiver) = flume::unbounded();
// Make a server processor to handle messages
let server_processor = ServerProcessor::new(config.clone());
Self { Self {
unlocked_inner: Arc::new(RouterServerUnlockedInner { unlocked_inner: Arc::new(RouterServerUnlockedInner {
config, config,
new_client_sender, new_client_sender,
new_client_receiver, new_client_receiver,
server_processor,
}), }),
inner: Arc::new(Mutex::new(RouterServerInner {})), inner: Arc::new(Mutex::new(RouterServerInner {})),
} }
@ -95,10 +104,11 @@ impl RouterServer {
let x = x; let x = x;
let cmd = from_bytes::<ServerProcessorCommand>(&x).map_err(io::Error::other)?; let cmd = from_bytes::<ServerProcessorCommand>(&x).map_err(io::Error::other)?;
self.clone() self.unlocked_inner
.process_command(cmd, outbound_sender.clone()) .server_processor
.await .enqueue_command(cmd, outbound_sender.clone());
.map_err(io::Error::other)
Ok(())
})); }));
let mut unord = FuturesUnordered::new(); let mut unord = FuturesUnordered::new();
@ -244,16 +254,14 @@ impl RouterServer {
let this = self.clone(); let this = self.clone();
let inbound_receiver_fut = system_boxed(async move { let inbound_receiver_fut = system_boxed(async move {
let fut = local_inbound_receiver local_inbound_receiver
.into_stream() .into_stream()
.map(Ok) .for_each(|cmd| async {
.try_for_each(|cmd| { this.unlocked_inner
this.clone() .server_processor
.process_command(cmd, local_outbound_sender.clone()) .enqueue_command(cmd, local_outbound_sender.clone());
}); })
if let Err(e) = fut.await { .await;
error!("{}", e);
}
RunLoopEvent::Done RunLoopEvent::Done
}); });
@ -273,6 +281,13 @@ impl RouterServer {
let mut need_new_client_fut = true; let mut need_new_client_fut = true;
// Add server processor to run loop
unord.push(
self.unlocked_inner
.server_processor
.run_loop_process_commands(),
);
loop { loop {
if need_new_client_fut { if need_new_client_fut {
let new_client_receiver = self.unlocked_inner.new_client_receiver.clone(); let new_client_receiver = self.unlocked_inner.new_client_receiver.clone();
@ -311,116 +326,4 @@ impl RouterServer {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
// Private Implementation // Private Implementation
async fn process_command(
self,
cmd: ServerProcessorCommand,
outbound_sender: flume::Sender<ServerProcessorEvent>,
) -> RouterServerResult<()> {
match cmd {
ServerProcessorCommand::Message(server_processor_message) => {
self.process_message(
server_processor_message.message_id,
server_processor_message.request,
outbound_sender,
)
.await
}
ServerProcessorCommand::CloseSocket {
machine_id,
socket_id,
} => {
self.process_close_socket(machine_id, socket_id, outbound_sender)
.await
}
}
}
async fn process_close_socket(
self,
machine_id: MachineId,
socket_id: SocketId,
outbound_sender: flume::Sender<ServerProcessorEvent>,
) -> RouterServerResult<()> {
//
Ok(())
}
async fn process_message(
self,
message_id: MessageId,
request: ServerProcessorRequest,
outbound_sender: flume::Sender<ServerProcessorEvent>,
) -> RouterServerResult<()> {
match request {
ServerProcessorRequest::AllocateMachine { profile } => todo!(),
ServerProcessorRequest::ReleaseMachine { machine_id } => todo!(),
ServerProcessorRequest::GetInterfaces { machine_id } => todo!(),
ServerProcessorRequest::TcpConnect {
machine_id,
local_address,
remote_address,
timeout_ms,
options,
} => todo!(),
ServerProcessorRequest::TcpBind {
machine_id,
local_address,
options,
} => todo!(),
ServerProcessorRequest::TcpAccept {
machine_id,
listen_socket_id,
} => todo!(),
ServerProcessorRequest::TcpShutdown {
machine_id,
socket_id,
} => todo!(),
ServerProcessorRequest::UdpBind {
machine_id,
local_address,
options,
} => todo!(),
ServerProcessorRequest::Send {
machine_id,
socket_id,
data,
} => todo!(),
ServerProcessorRequest::SendTo {
machine_id,
socket_id,
remote_address,
data,
} => todo!(),
ServerProcessorRequest::Recv {
machine_id,
socket_id,
len,
} => todo!(),
ServerProcessorRequest::RecvFrom {
machine_id,
socket_id,
len,
} => todo!(),
ServerProcessorRequest::GetRoutedLocalAddress {
machine_id,
address_type,
} => todo!(),
ServerProcessorRequest::FindGateway { machine_id } => todo!(),
ServerProcessorRequest::GetExternalAddress { gateway_id } => todo!(),
ServerProcessorRequest::AddPort {
gateway_id,
protocol,
external_port,
local_address,
lease_duration_ms,
description,
} => todo!(),
ServerProcessorRequest::RemovePort {
gateway_id,
protocol,
external_port,
} => todo!(),
ServerProcessorRequest::TXTQuery { name } => todo!(),
}
}
} }

View File

@ -0,0 +1,196 @@
use super::*;
struct ServerProcessorCommandRecord {
cmd: ServerProcessorCommand,
outbound_sender: flume::Sender<ServerProcessorEvent>,
}
#[derive(Debug)]
struct ServerProcessorInner {
//
}
#[derive(Debug)]
struct ServerProcessorUnlockedInner {
config: config::Config,
receiver: flume::Receiver<ServerProcessorCommandRecord>,
sender: flume::Sender<ServerProcessorCommandRecord>,
machine_registry: MachineRegistry,
}
#[derive(Debug, Clone)]
pub struct ServerProcessor {
unlocked_inner: Arc<ServerProcessorUnlockedInner>,
inner: Arc<Mutex<ServerProcessorInner>>,
}
impl ServerProcessor {
////////////////////////////////////////////////////////////////////////
// Public Interface
pub fn new(config: config::Config) -> Self {
let (sender, receiver) = flume::unbounded();
Self {
unlocked_inner: Arc::new(ServerProcessorUnlockedInner {
config: config.clone(),
sender,
receiver,
machine_registry: MachineRegistry::new(config),
}),
inner: Arc::new(Mutex::new(ServerProcessorInner {})),
}
}
pub fn enqueue_command(
&self,
cmd: ServerProcessorCommand,
outbound_sender: flume::Sender<ServerProcessorEvent>,
) {
if let Err(e) = self
.unlocked_inner
.sender
.send(ServerProcessorCommandRecord {
cmd,
outbound_sender,
})
{
eprintln!("Dropped command: {}", e);
}
}
pub fn run_loop_process_commands(&self) -> SendPinBoxFuture<RunLoopEvent> {
let receiver_stream = self.unlocked_inner.receiver.clone().into_stream();
let this = self.clone();
Box::pin(async move {
receiver_stream
.for_each_concurrent(None, |x| {
let this = this.clone();
async move {
if let Err(e) = this.process_command(x.cmd, x.outbound_sender).await {
eprintln!("Failed to process command: {}", e);
}
}
})
.await;
RunLoopEvent::Done
})
}
////////////////////////////////////////////////////////////////////////
// Private Implementation
async fn process_command(
self,
cmd: ServerProcessorCommand,
outbound_sender: flume::Sender<ServerProcessorEvent>,
) -> RouterServerResult<()> {
match cmd {
ServerProcessorCommand::Message(server_processor_message) => {
self.process_message(
server_processor_message.message_id,
server_processor_message.request,
outbound_sender,
)
.await
}
ServerProcessorCommand::CloseSocket {
machine_id,
socket_id,
} => {
self.process_close_socket(machine_id, socket_id, outbound_sender)
.await
}
}
}
async fn process_close_socket(
self,
machine_id: MachineId,
socket_id: SocketId,
outbound_sender: flume::Sender<ServerProcessorEvent>,
) -> RouterServerResult<()> {
//
Ok(())
}
async fn process_message(
self,
message_id: MessageId,
request: ServerProcessorRequest,
outbound_sender: flume::Sender<ServerProcessorEvent>,
) -> RouterServerResult<()> {
match request {
ServerProcessorRequest::AllocateMachine { profile } => todo!(),
ServerProcessorRequest::ReleaseMachine { machine_id } => todo!(),
ServerProcessorRequest::GetInterfaces { machine_id } => todo!(),
ServerProcessorRequest::TcpConnect {
machine_id,
local_address,
remote_address,
timeout_ms,
options,
} => todo!(),
ServerProcessorRequest::TcpBind {
machine_id,
local_address,
options,
} => todo!(),
ServerProcessorRequest::TcpAccept {
machine_id,
listen_socket_id,
} => todo!(),
ServerProcessorRequest::TcpShutdown {
machine_id,
socket_id,
} => todo!(),
ServerProcessorRequest::UdpBind {
machine_id,
local_address,
options,
} => todo!(),
ServerProcessorRequest::Send {
machine_id,
socket_id,
data,
} => todo!(),
ServerProcessorRequest::SendTo {
machine_id,
socket_id,
remote_address,
data,
} => todo!(),
ServerProcessorRequest::Recv {
machine_id,
socket_id,
len,
} => todo!(),
ServerProcessorRequest::RecvFrom {
machine_id,
socket_id,
len,
} => todo!(),
ServerProcessorRequest::GetRoutedLocalAddress {
machine_id,
address_type,
} => todo!(),
ServerProcessorRequest::FindGateway { machine_id } => todo!(),
ServerProcessorRequest::GetExternalAddress { gateway_id } => todo!(),
ServerProcessorRequest::AddPort {
gateway_id,
protocol,
external_port,
local_address,
lease_duration_ms,
description,
} => todo!(),
ServerProcessorRequest::RemovePort {
gateway_id,
protocol,
external_port,
} => todo!(),
ServerProcessorRequest::TXTQuery { name } => todo!(),
}
}
}

View File

@ -13,6 +13,10 @@ pub enum VirtualNetworkError {
InvalidMachineId, InvalidMachineId,
#[error("Invalid socket id")] #[error("Invalid socket id")]
InvalidSocketId, InvalidSocketId,
#[error("Missing profile")]
MissingProfile,
#[error("Profile complete")]
ProfileComplete,
#[error("Io error: {0}")] #[error("Io error: {0}")]
IoError(io::ErrorKind), IoError(io::ErrorKind),
} }