checkpoint merge of network-shim branch

This commit is contained in:
Christien Rioux 2025-02-10 03:06:41 +00:00
parent 079b665230
commit a2b0214b8e
276 changed files with 17493 additions and 7193 deletions

1912
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,13 @@ members = [
] ]
resolver = "2" resolver = "2"
[workspace.package]
repository = "https://gitlab.com/veilid/veilid"
authors = ["Veilid Team <contact@veilid.com>"]
license = "MPL-2.0"
edition = "2021"
rust-version = "1.81.0"
[patch.crates-io] [patch.crates-io]
cursive = { git = "https://gitlab.com/veilid/cursive.git" } cursive = { git = "https://gitlab.com/veilid/cursive.git" }
cursive_core = { git = "https://gitlab.com/veilid/cursive.git" } cursive_core = { git = "https://gitlab.com/veilid/cursive.git" }
@ -26,3 +33,31 @@ lto = true
[profile.dev.package.backtrace] [profile.dev.package.backtrace]
opt-level = 3 opt-level = 3
[profile.dev.package.argon2]
opt-level = 3
debug-assertions = false
[profile.dev.package.ed25519-dalek]
opt-level = 3
debug-assertions = false
[profile.dev.package.x25519-dalek]
opt-level = 3
debug-assertions = false
[profile.dev.package.curve25519-dalek]
opt-level = 3
debug-assertions = false
[profile.dev.package.chacha20poly1305]
opt-level = 3
debug-assertions = false
[profile.dev.package.blake3]
opt-level = 3
debug-assertions = false
[profile.dev.package.chacha20]
opt-level = 3
debug-assertions = false

View File

@ -36,7 +36,7 @@ command line without it. If you do so, you may skip to
[Run Veilid setup script](#Run Veilid setup script). [Run Veilid setup script](#Run Veilid setup script).
- build-tools;34.0.0 - build-tools;34.0.0
- ndk;26.3.11579264 - ndk;27.0.12077973
- cmake;3.22.1 - cmake;3.22.1
- platform-tools - platform-tools
- platforms;android-34 - platforms;android-34
@ -58,7 +58,7 @@ the command line to install the requisite package versions:
sdkmanager --install "platform-tools" sdkmanager --install "platform-tools"
sdkmanager --install "platforms;android-34" sdkmanager --install "platforms;android-34"
sdkmanager --install "build-tools;34.0.0" sdkmanager --install "build-tools;34.0.0"
sdkmanager --install "ndk;26.3.11579264" sdkmanager --install "ndk;27.0.12077973"
sdkmanager --install "cmake;3.22.1" sdkmanager --install "cmake;3.22.1"
``` ```
@ -110,7 +110,7 @@ You will need to use Android Studio [here](https://developer.android.com/studio)
to maintain your Android dependencies. Use the SDK Manager in the IDE to install the following packages (use package details view to select version): to maintain your Android dependencies. Use the SDK Manager in the IDE to install the following packages (use package details view to select version):
- Android SDK Build Tools (34.0.0) - Android SDK Build Tools (34.0.0)
- NDK (Side-by-side) (26.3.11579264) - NDK (Side-by-side) (27.0.12077973)
- Cmake (3.22.1) - Cmake (3.22.1)
- Android SDK 34 - Android SDK 34
- Android SDK Command Line Tools (latest) (7.0/latest) - Android SDK Command Line Tools (latest) (7.0/latest)

View File

@ -18,7 +18,7 @@ FROM ubuntu:18.04
ENV ZIG_VERSION=0.13.0 ENV ZIG_VERSION=0.13.0
ENV CMAKE_VERSION_MINOR=3.30 ENV CMAKE_VERSION_MINOR=3.30
ENV CMAKE_VERSION_PATCH=3.30.1 ENV CMAKE_VERSION_PATCH=3.30.1
ENV WASM_BINDGEN_CLI_VERSION=0.2.93 ENV WASM_BINDGEN_CLI_VERSION=0.2.100
ENV RUST_VERSION=1.81.0 ENV RUST_VERSION=1.81.0
ENV RUSTUP_HOME=/usr/local/rustup ENV RUSTUP_HOME=/usr/local/rustup
ENV RUSTUP_DIST_SERVER=https://static.rust-lang.org ENV RUSTUP_DIST_SERVER=https://static.rust-lang.org
@ -82,7 +82,7 @@ deps-android:
RUN mkdir /Android; mkdir /Android/Sdk RUN mkdir /Android; mkdir /Android/Sdk
RUN curl -o /Android/cmdline-tools.zip https://dl.google.com/android/repository/commandlinetools-linux-9123335_latest.zip RUN curl -o /Android/cmdline-tools.zip https://dl.google.com/android/repository/commandlinetools-linux-9123335_latest.zip
RUN cd /Android; unzip /Android/cmdline-tools.zip RUN cd /Android; unzip /Android/cmdline-tools.zip
RUN yes | /Android/cmdline-tools/bin/sdkmanager --sdk_root=/Android/Sdk build-tools\;34.0.0 ndk\;26.3.11579264 cmake\;3.22.1 platform-tools platforms\;android-34 cmdline-tools\;latest RUN yes | /Android/cmdline-tools/bin/sdkmanager --sdk_root=/Android/Sdk build-tools\;34.0.0 ndk\;27.0.12077973 cmake\;3.22.1 platform-tools platforms\;android-34 cmdline-tools\;latest
RUN rm -rf /Android/cmdline-tools RUN rm -rf /Android/cmdline-tools
RUN apt-get clean RUN apt-get clean
@ -170,7 +170,7 @@ build-linux-arm64:
build-android: build-android:
FROM +code-android FROM +code-android
WORKDIR /veilid/veilid-core WORKDIR /veilid/veilid-core
ENV PATH=$PATH:/Android/Sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/bin/ ENV PATH=$PATH:/Android/Sdk/ndk/27.0.12077973/toolchains/llvm/prebuilt/linux-x86_64/bin/
RUN cargo build --target aarch64-linux-android --release RUN cargo build --target aarch64-linux-android --release
RUN cargo build --target armv7-linux-androideabi --release RUN cargo build --target armv7-linux-androideabi --release
RUN cargo build --target i686-linux-android --release RUN cargo build --target i686-linux-android --release

View File

@ -43,14 +43,14 @@ while true; do
curl -o $HOME/Android/cmdline-tools.zip https://dl.google.com/android/repository/commandlinetools-linux-9123335_latest.zip curl -o $HOME/Android/cmdline-tools.zip https://dl.google.com/android/repository/commandlinetools-linux-9123335_latest.zip
cd $HOME/Android cd $HOME/Android
unzip $HOME/Android/cmdline-tools.zip unzip $HOME/Android/cmdline-tools.zip
$HOME/Android/cmdline-tools/bin/sdkmanager --sdk_root=$HOME/Android/Sdk build-tools\;34.0.0 ndk\;26.3.11579264 cmake\;3.22.1 platform-tools platforms\;android-34 cmdline-tools\;latest emulator $HOME/Android/cmdline-tools/bin/sdkmanager --sdk_root=$HOME/Android/Sdk build-tools\;34.0.0 ndk\;27.0.12077973 cmake\;3.22.1 platform-tools platforms\;android-34 cmdline-tools\;latest emulator
cd $HOME cd $HOME
rm -rf $HOME/Android/cmdline-tools $HOME/Android/cmdline-tools.zip rm -rf $HOME/Android/cmdline-tools $HOME/Android/cmdline-tools.zip
# Add environment variables # Add environment variables
cat >>$HOME/.profile <<END cat >>$HOME/.profile <<END
source "\$HOME/.cargo/env" source "\$HOME/.cargo/env"
export PATH=\$PATH:\$HOME/Android/Sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/bin:\$HOME/Android/Sdk/platform-tools:\$HOME/Android/Sdk/cmdline-tools/latest/bin export PATH=\$PATH:\$HOME/Android/Sdk/ndk/27.0.12077973/toolchains/llvm/prebuilt/linux-x86_64/bin:\$HOME/Android/Sdk/platform-tools:\$HOME/Android/Sdk/cmdline-tools/latest/bin
export ANDROID_HOME=\$HOME/Android/Sdk export ANDROID_HOME=\$HOME/Android/Sdk
END END
break break

View File

@ -42,7 +42,7 @@ while true; do
fi fi
# ensure ndk is installed # ensure ndk is installed
ANDROID_NDK_HOME="$ANDROID_HOME/ndk/26.3.11579264" ANDROID_NDK_HOME="$ANDROID_HOME/ndk/27.0.12077973"
if [ -f "$ANDROID_NDK_HOME/ndk-build" ]; then if [ -f "$ANDROID_NDK_HOME/ndk-build" ]; then
echo '[X] Android NDK is installed at the location $ANDROID_NDK_HOME' echo '[X] Android NDK is installed at the location $ANDROID_NDK_HOME'
else else

View File

@ -31,10 +31,10 @@ while true; do
fi fi
# ensure Android SDK packages are installed # ensure Android SDK packages are installed
$ANDROID_HOME/cmdline-tools/latest/bin/sdkmanager build-tools\;34.0.0 ndk\;26.3.11579264 cmake\;3.22.1 platform-tools platforms\;android-34 $ANDROID_HOME/cmdline-tools/latest/bin/sdkmanager build-tools\;34.0.0 ndk\;27.0.12077973 cmake\;3.22.1 platform-tools platforms\;android-34
# ensure ndk is installed # ensure ndk is installed
ANDROID_NDK_HOME="$ANDROID_HOME/ndk/26.3.11579264" ANDROID_NDK_HOME="$ANDROID_HOME/ndk/27.0.12077973"
if [ -f "$ANDROID_NDK_HOME/ndk-build" ]; then if [ -f "$ANDROID_NDK_HOME/ndk-build" ]; then
echo '[X] Android NDK is installed at the location $ANDROID_NDK_HOME' echo '[X] Android NDK is installed at the location $ANDROID_NDK_HOME'
else else

View File

@ -27,6 +27,7 @@ logging:
enabled: false enabled: false
testing: testing:
subnode_index: 0 subnode_index: 0
subnode_count: 1
core: core:
protected_store: protected_store:
allow_insecure_fallback: true allow_insecure_fallback: true

View File

@ -138,6 +138,7 @@ otlp:
```yaml ```yaml
testing: testing:
subnode_index: 0 subnode_index: 0
subnode_count: 1
``` ```
### core ### core

View File

@ -1,8 +1,8 @@
[target.aarch64-linux-android] [target.aarch64-linux-android]
linker = "/Android/Sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android34-clang" linker = "/Android/Sdk/ndk/27.0.12077973/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android34-clang"
[target.armv7-linux-androideabi] [target.armv7-linux-androideabi]
linker = "/Android/Sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/bin/armv7a-linux-androideabi33-clang" linker = "/Android/Sdk/ndk/27.0.12077973/toolchains/llvm/prebuilt/linux-x86_64/bin/armv7a-linux-androideabi33-clang"
[target.x86_64-linux-android] [target.x86_64-linux-android]
linker = "/Android/Sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/bin/x86_64-linux-android34-clang" linker = "/Android/Sdk/ndk/27.0.12077973/toolchains/llvm/prebuilt/linux-x86_64/bin/x86_64-linux-android34-clang"
[target.i686-linux-android] [target.i686-linux-android]
linker = "/Android/Sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/bin/i686-linux-android34-clang" linker = "/Android/Sdk/ndk/27.0.12077973/toolchains/llvm/prebuilt/linux-x86_64/bin/i686-linux-android34-clang"

View File

@ -4,12 +4,12 @@ name = "veilid-cli"
version = "0.4.1" version = "0.4.1"
# --- # ---
description = "Client application for connecting to a Veilid headless node" description = "Client application for connecting to a Veilid headless node"
repository = "https://gitlab.com/veilid/veilid"
authors = ["Veilid Team <contact@veilid.com>"]
edition = "2021"
license = "MPL-2.0"
resolver = "2" resolver = "2"
rust-version = "1.81.0" repository.workspace = true
authors.workspace = true
license.workspace = true
edition.workspace = true
rust-version.workspace = true
[[bin]] [[bin]]
name = "veilid-cli" name = "veilid-cli"
@ -17,6 +17,8 @@ path = "src/main.rs"
[features] [features]
default = ["rt-tokio"] default = ["rt-tokio"]
default-async-std = ["rt-async-std"]
rt-async-std = [ rt-async-std = [
"async-std", "async-std",
"veilid-tools/rt-async-std", "veilid-tools/rt-async-std",

View File

@ -223,13 +223,12 @@ impl ClientApiConnection {
trace!("ClientApiConnection::handle_tcp_connection"); trace!("ClientApiConnection::handle_tcp_connection");
// Connect the TCP socket // Connect the TCP socket
let stream = TcpStream::connect(connect_addr) let stream = connect_async_tcp_stream(None, connect_addr, 10_000)
.await .await
.map_err(map_to_string)?
.into_timeout_error()
.map_err(map_to_string)?; .map_err(map_to_string)?;
// If it succeed, disable nagle algorithm
stream.set_nodelay(true).map_err(map_to_string)?;
// State we connected // State we connected
let comproc = self.inner.lock().comproc.clone(); let comproc = self.inner.lock().comproc.clone();
comproc.set_connection_state(ConnectionState::ConnectedTCP( comproc.set_connection_state(ConnectionState::ConnectedTCP(
@ -239,16 +238,8 @@ impl ClientApiConnection {
// Split into reader and writer halves // Split into reader and writer halves
// with line buffering on the reader // with line buffering on the reader
cfg_if! { let (reader, writer) = split_async_tcp_stream(stream);
if #[cfg(feature="rt-async-std")] { let reader = BufReader::new(reader);
use futures::AsyncReadExt;
let (reader, writer) = stream.split();
let reader = BufReader::new(reader);
} else {
let (reader, writer) = stream.into_split();
let reader = BufReader::new(reader);
}
}
self.clone().run_json_api_processor(reader, writer).await self.clone().run_json_api_processor(reader, writer).await
} }

View File

@ -57,6 +57,7 @@ struct CommandProcessorInner {
#[derive(Clone)] #[derive(Clone)]
pub struct CommandProcessor { pub struct CommandProcessor {
inner: Arc<Mutex<CommandProcessorInner>>, inner: Arc<Mutex<CommandProcessorInner>>,
settings: Arc<Settings>,
} }
impl CommandProcessor { impl CommandProcessor {
@ -75,6 +76,7 @@ impl CommandProcessor {
last_call_id: None, last_call_id: None,
enable_app_messages: false, enable_app_messages: false,
})), })),
settings: Arc::new(settings.clone()),
} }
} }
pub fn set_client_api_connection(&self, capi: ClientApiConnection) { pub fn set_client_api_connection(&self, capi: ClientApiConnection) {
@ -186,6 +188,54 @@ Core Debug Commands:
Ok(()) Ok(())
} }
pub fn cmd_connect(&self, rest: Option<String>, callback: UICallback) -> Result<(), String> {
trace!("CommandProcessor::cmd_connect");
let capi = self.capi();
let ui = self.ui_sender();
let this = self.clone();
spawn_detached_local("cmd connect", async move {
capi.disconnect().await;
if let Some(rest) = rest {
if let Ok(subnode_index) = u16::from_str(&rest) {
let ipc_path = this
.settings
.resolve_ipc_path(this.settings.ipc_path.clone(), subnode_index);
this.set_ipc_path(ipc_path);
this.set_network_address(None);
} else if let Some(ipc_path) =
this.settings.resolve_ipc_path(Some(rest.clone().into()), 0)
{
this.set_ipc_path(Some(ipc_path));
this.set_network_address(None);
} else if let Ok(Some(network_address)) =
this.settings.resolve_network_address(Some(rest.clone()))
{
if let Some(addr) = network_address.first() {
this.set_network_address(Some(*addr));
this.set_ipc_path(None);
} else {
ui.add_node_event(
Level::Error,
&format!("Invalid network address: {}", rest),
);
}
} else {
ui.add_node_event(
Level::Error,
&format!("Invalid connection string: {}", rest),
);
}
}
this.start_connection();
ui.send_callback(callback);
});
Ok(())
}
pub fn cmd_debug(&self, command_line: String, callback: UICallback) -> Result<(), String> { pub fn cmd_debug(&self, command_line: String, callback: UICallback) -> Result<(), String> {
trace!("CommandProcessor::cmd_debug"); trace!("CommandProcessor::cmd_debug");
let capi = self.capi(); let capi = self.capi();
@ -331,6 +381,7 @@ Core Debug Commands:
"exit" => self.cmd_exit(callback), "exit" => self.cmd_exit(callback),
"quit" => self.cmd_exit(callback), "quit" => self.cmd_exit(callback),
"disconnect" => self.cmd_disconnect(callback), "disconnect" => self.cmd_disconnect(callback),
"connect" => self.cmd_connect(rest, callback),
"shutdown" => self.cmd_shutdown(callback), "shutdown" => self.cmd_shutdown(callback),
"change_log_level" => self.cmd_change_log_level(rest, callback), "change_log_level" => self.cmd_change_log_level(rest, callback),
"change_log_ignore" => self.cmd_change_log_ignore(rest, callback), "change_log_ignore" => self.cmd_change_log_ignore(rest, callback),

View File

@ -28,10 +28,11 @@ pub struct InteractiveUIInner {
#[derive(Clone)] #[derive(Clone)]
pub struct InteractiveUI { pub struct InteractiveUI {
inner: Arc<Mutex<InteractiveUIInner>>, inner: Arc<Mutex<InteractiveUIInner>>,
_settings: Arc<Settings>,
} }
impl InteractiveUI { impl InteractiveUI {
pub fn new(_settings: &Settings) -> (Self, InteractiveUISender) { pub fn new(settings: &Settings) -> (Self, InteractiveUISender) {
let (cssender, csreceiver) = flume::unbounded::<ConnectionState>(); let (cssender, csreceiver) = flume::unbounded::<ConnectionState>();
let term = Term::stdout(); let term = Term::stdout();
@ -45,9 +46,10 @@ impl InteractiveUI {
error: None, error: None,
done: Some(StopSource::new()), done: Some(StopSource::new()),
connection_state_receiver: csreceiver, connection_state_receiver: csreceiver,
log_enabled: false, log_enabled: true,
enable_color, enable_color,
})), })),
_settings: Arc::new(settings.clone()),
}; };
let ui_sender = InteractiveUISender { let ui_sender = InteractiveUISender {
@ -169,7 +171,6 @@ impl InteractiveUI {
eprintln!("Error: {:?}", e); eprintln!("Error: {:?}", e);
self.inner.lock().done.take(); self.inner.lock().done.take();
} }
self.inner.lock().log_enabled = true;
} }
} else if line == "log warn" { } else if line == "log warn" {
let opt_cmdproc = self.inner.lock().cmdproc.clone(); let opt_cmdproc = self.inner.lock().cmdproc.clone();
@ -181,7 +182,6 @@ impl InteractiveUI {
eprintln!("Error: {:?}", e); eprintln!("Error: {:?}", e);
self.inner.lock().done.take(); self.inner.lock().done.take();
} }
self.inner.lock().log_enabled = true;
} }
} else if line == "log info" { } else if line == "log info" {
let opt_cmdproc = self.inner.lock().cmdproc.clone(); let opt_cmdproc = self.inner.lock().cmdproc.clone();
@ -193,7 +193,6 @@ impl InteractiveUI {
eprintln!("Error: {:?}", e); eprintln!("Error: {:?}", e);
self.inner.lock().done.take(); self.inner.lock().done.take();
} }
self.inner.lock().log_enabled = true;
} }
} else if line == "log debug" || line == "log" { } else if line == "log debug" || line == "log" {
let opt_cmdproc = self.inner.lock().cmdproc.clone(); let opt_cmdproc = self.inner.lock().cmdproc.clone();
@ -205,6 +204,8 @@ impl InteractiveUI {
eprintln!("Error: {:?}", e); eprintln!("Error: {:?}", e);
self.inner.lock().done.take(); self.inner.lock().done.take();
} }
}
if line == "log" {
self.inner.lock().log_enabled = true; self.inner.lock().log_enabled = true;
} }
} else if line == "log trace" { } else if line == "log trace" {
@ -217,7 +218,6 @@ impl InteractiveUI {
eprintln!("Error: {:?}", e); eprintln!("Error: {:?}", e);
self.inner.lock().done.take(); self.inner.lock().done.take();
} }
self.inner.lock().log_enabled = true;
} }
} else if line == "log off" { } else if line == "log off" {
let opt_cmdproc = self.inner.lock().cmdproc.clone(); let opt_cmdproc = self.inner.lock().cmdproc.clone();
@ -229,9 +229,27 @@ impl InteractiveUI {
eprintln!("Error: {:?}", e); eprintln!("Error: {:?}", e);
self.inner.lock().done.take(); self.inner.lock().done.take();
} }
self.inner.lock().log_enabled = false;
} }
} else if line == "log hide" || line == "log disable" {
self.inner.lock().log_enabled = false;
} else if line == "log show" || line == "log enable" {
self.inner.lock().log_enabled = true;
} else if !line.is_empty() { } else if !line.is_empty() {
if line == "help" {
let _ = writeln!(
stdout,
r#"
Interactive Mode Commands:
help - Display this help
clear - Clear the screen
log [level] - Set the client api log level for the node to one of: error,warn,info,debug,trace,off
hide|disable - Turn off viewing the log without changing the log level for the node
show|enable - Turn on viewing the log without changing the log level for the node
- With no option, 'log' turns on viewing the log and sets the level to 'debug'
"#
);
}
let cmdproc = self.inner.lock().cmdproc.clone(); let cmdproc = self.inner.lock().cmdproc.clone();
if let Some(cmdproc) = &cmdproc { if let Some(cmdproc) = &cmdproc {
if let Err(e) = cmdproc.run_command( if let Err(e) = cmdproc.run_command(

View File

@ -3,7 +3,7 @@
#![deny(unused_must_use)] #![deny(unused_must_use)]
#![recursion_limit = "256"] #![recursion_limit = "256"]
use crate::{settings::NamedSocketAddrs, tools::*, ui::*}; use crate::{tools::*, ui::*};
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
use flexi_logger::*; use flexi_logger::*;
@ -37,7 +37,7 @@ struct CmdlineArgs {
ipc_path: Option<PathBuf>, ipc_path: Option<PathBuf>,
/// Subnode index to use when connecting /// Subnode index to use when connecting
#[arg(short('n'), long, default_value = "0")] #[arg(short('n'), long, default_value = "0")]
subnode_index: usize, subnode_index: u16,
/// Address to connect to /// Address to connect to
#[arg(long, short = 'a')] #[arg(long, short = 'a')]
address: Option<String>, address: Option<String>,
@ -47,9 +47,9 @@ struct CmdlineArgs {
/// Specify a configuration file to use /// Specify a configuration file to use
#[arg(short = 'c', long, value_name = "FILE")] #[arg(short = 'c', long, value_name = "FILE")]
config_file: Option<PathBuf>, config_file: Option<PathBuf>,
/// log level /// Log level for the CLI itself (not for the Veilid node)
#[arg(value_enum)] #[arg(long, value_enum)]
log_level: Option<LogLevel>, cli_log_level: Option<LogLevel>,
/// interactive /// interactive
#[arg(long, short = 'i', group = "execution_mode")] #[arg(long, short = 'i', group = "execution_mode")]
interactive: bool, interactive: bool,
@ -93,11 +93,11 @@ fn main() -> Result<(), String> {
.map_err(|e| format!("configuration is invalid: {}", e))?; .map_err(|e| format!("configuration is invalid: {}", e))?;
// Set config from command line // Set config from command line
if let Some(LogLevel::Debug) = args.log_level { if let Some(LogLevel::Debug) = args.cli_log_level {
settings.logging.level = settings::LogLevel::Debug; settings.logging.level = settings::LogLevel::Debug;
settings.logging.terminal.enabled = true; settings.logging.terminal.enabled = true;
} }
if let Some(LogLevel::Trace) = args.log_level { if let Some(LogLevel::Trace) = args.cli_log_level {
settings.logging.level = settings::LogLevel::Trace; settings.logging.level = settings::LogLevel::Trace;
settings.logging.terminal.enabled = true; settings.logging.terminal.enabled = true;
} }
@ -248,59 +248,14 @@ fn main() -> Result<(), String> {
// Determine IPC path to try // Determine IPC path to try
let mut client_api_ipc_path = None; let mut client_api_ipc_path = None;
if enable_ipc { if enable_ipc {
cfg_if::cfg_if! { client_api_ipc_path = settings.resolve_ipc_path(args.ipc_path, args.subnode_index);
if #[cfg(windows)] { if client_api_ipc_path.is_some() {
if let Some(ipc_path) = args.ipc_path.or(settings.ipc_path.clone()) { enable_network = false;
if is_ipc_socket_path(&ipc_path) {
// try direct path
enable_network = false;
client_api_ipc_path = Some(ipc_path);
} else {
// try subnode index inside path
let ipc_path = ipc_path.join(args.subnode_index.to_string());
if is_ipc_socket_path(&ipc_path) {
// subnode indexed path exists
enable_network = false;
client_api_ipc_path = Some(ipc_path);
}
}
}
} else {
if let Some(ipc_path) = args.ipc_path.or(settings.ipc_path.clone()) {
if is_ipc_socket_path(&ipc_path) {
// try direct path
enable_network = false;
client_api_ipc_path = Some(ipc_path);
} else if ipc_path.exists() && ipc_path.is_dir() {
// try subnode index inside path
let ipc_path = ipc_path.join(args.subnode_index.to_string());
if is_ipc_socket_path(&ipc_path) {
// subnode indexed path exists
enable_network = false;
client_api_ipc_path = Some(ipc_path);
}
}
}
}
} }
} }
let mut client_api_network_addresses = None; let mut client_api_network_addresses = None;
if enable_network { if enable_network {
let args_address = if let Some(args_address) = args.address { client_api_network_addresses = settings.resolve_network_address(args.address)?;
match NamedSocketAddrs::try_from(args_address) {
Ok(v) => Some(v),
Err(e) => {
return Err(format!("Invalid server address: {}", e));
}
}
} else {
None
};
if let Some(address_arg) = args_address.or(settings.address.clone()) {
client_api_network_addresses = Some(address_arg.addrs);
} else if let Some(address) = settings.address.clone() {
client_api_network_addresses = Some(address.addrs.clone());
}
} }
// Create command processor // Create command processor

View File

@ -1,5 +1,6 @@
use directories::*; use directories::*;
use crate::tools::*;
use serde_derive::*; use serde_derive::*;
use std::ffi::OsStr; use std::ffi::OsStr;
use std::net::{SocketAddr, ToSocketAddrs}; use std::net::{SocketAddr, ToSocketAddrs};
@ -118,7 +119,7 @@ pub fn convert_loglevel(log_level: LogLevel) -> log::LevelFilter {
} }
} }
#[derive(Debug, Clone)] #[derive(Clone, Debug)]
pub struct NamedSocketAddrs { pub struct NamedSocketAddrs {
pub _name: String, pub _name: String,
pub addrs: Vec<SocketAddr>, pub addrs: Vec<SocketAddr>,
@ -148,26 +149,26 @@ impl<'de> serde::Deserialize<'de> for NamedSocketAddrs {
} }
} }
#[derive(Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct Terminal { pub struct Terminal {
pub enabled: bool, pub enabled: bool,
} }
#[derive(Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct File { pub struct File {
pub enabled: bool, pub enabled: bool,
pub directory: String, pub directory: String,
pub append: bool, pub append: bool,
} }
#[derive(Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct Logging { pub struct Logging {
pub terminal: Terminal, pub terminal: Terminal,
pub file: File, pub file: File,
pub level: LogLevel, pub level: LogLevel,
} }
#[derive(Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct Colors { pub struct Colors {
pub background: String, pub background: String,
pub shadow: String, pub shadow: String,
@ -182,7 +183,7 @@ pub struct Colors {
pub highlight_text: String, pub highlight_text: String,
} }
#[derive(Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct LogColors { pub struct LogColors {
pub trace: String, pub trace: String,
pub debug: String, pub debug: String,
@ -191,7 +192,7 @@ pub struct LogColors {
pub error: String, pub error: String,
} }
#[derive(Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct Theme { pub struct Theme {
pub shadow: bool, pub shadow: bool,
pub borders: String, pub borders: String,
@ -199,24 +200,24 @@ pub struct Theme {
pub log_colors: LogColors, pub log_colors: LogColors,
} }
#[derive(Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct NodeLog { pub struct NodeLog {
pub scrollback: usize, pub scrollback: usize,
} }
#[derive(Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct CommandLine { pub struct CommandLine {
pub history_size: usize, pub history_size: usize,
} }
#[derive(Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct Interface { pub struct Interface {
pub theme: Theme, pub theme: Theme,
pub node_log: NodeLog, pub node_log: NodeLog,
pub command_line: CommandLine, pub command_line: CommandLine,
} }
#[derive(Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct Settings { pub struct Settings {
pub enable_ipc: bool, pub enable_ipc: bool,
pub ipc_path: Option<PathBuf>, pub ipc_path: Option<PathBuf>,
@ -229,6 +230,90 @@ pub struct Settings {
} }
impl Settings { impl Settings {
//////////////////////////////////////////////////////////////////////////////////
pub fn new(config_file: Option<&OsStr>) -> Result<Self, config::ConfigError> {
// Load the default config
let mut cfg = load_default_config()?;
// Merge in the config file if we have one
if let Some(config_file) = config_file {
let config_file_path = Path::new(config_file);
// If the user specifies a config file on the command line then it must exist
cfg = load_config(cfg, config_file_path)?;
}
// Generate config
cfg.try_deserialize()
}
pub fn resolve_ipc_path(
&self,
ipc_path: Option<PathBuf>,
subnode_index: u16,
) -> Option<PathBuf> {
let mut client_api_ipc_path = None;
// Determine IPC path to try
cfg_if::cfg_if! {
if #[cfg(windows)] {
if let Some(ipc_path) = ipc_path.or(self.ipc_path.clone()) {
if is_ipc_socket_path(&ipc_path) {
// try direct path
enable_network = false;
client_api_ipc_path = Some(ipc_path);
} else {
// try subnode index inside path
let ipc_path = ipc_path.join(subnode_index.to_string());
if is_ipc_socket_path(&ipc_path) {
// subnode indexed path exists
client_api_ipc_path = Some(ipc_path);
}
}
}
} else {
if let Some(ipc_path) = ipc_path.or(self.ipc_path.clone()) {
if is_ipc_socket_path(&ipc_path) {
// try direct path
client_api_ipc_path = Some(ipc_path);
} else if ipc_path.exists() && ipc_path.is_dir() {
// try subnode index inside path
let ipc_path = ipc_path.join(subnode_index.to_string());
if is_ipc_socket_path(&ipc_path) {
// subnode indexed path exists
client_api_ipc_path = Some(ipc_path);
}
}
}
}
}
client_api_ipc_path
}
pub fn resolve_network_address(
&self,
address: Option<String>,
) -> Result<Option<Vec<SocketAddr>>, String> {
let mut client_api_network_addresses = None;
let args_address = if let Some(args_address) = address {
match NamedSocketAddrs::try_from(args_address) {
Ok(v) => Some(v),
Err(e) => {
return Err(format!("Invalid server address: {}", e));
}
}
} else {
None
};
if let Some(address_arg) = args_address.or(self.address.clone()) {
client_api_network_addresses = Some(address_arg.addrs);
} else if let Some(address) = self.address.clone() {
client_api_network_addresses = Some(address.addrs.clone());
}
Ok(client_api_network_addresses)
}
////////////////////////////////////////////////////////////////////////////
#[cfg_attr(windows, expect(dead_code))] #[cfg_attr(windows, expect(dead_code))]
fn get_server_default_directory(subpath: &str) -> PathBuf { fn get_server_default_directory(subpath: &str) -> PathBuf {
#[cfg(unix)] #[cfg(unix)]
@ -284,21 +369,6 @@ impl Settings {
default_log_directory default_log_directory
} }
pub fn new(config_file: Option<&OsStr>) -> Result<Self, config::ConfigError> {
// Load the default config
let mut cfg = load_default_config()?;
// Merge in the config file if we have one
if let Some(config_file) = config_file {
let config_file_path = Path::new(config_file);
// If the user specifies a config file on the command line then it must exist
cfg = load_config(cfg, config_file_path)?;
}
// Generate config
cfg.try_deserialize()
}
} }
#[test] #[test]

View File

@ -8,12 +8,10 @@ use core::str::FromStr;
cfg_if! { cfg_if! {
if #[cfg(feature="rt-async-std")] { if #[cfg(feature="rt-async-std")] {
pub use async_std::net::TcpStream;
pub fn block_on<F: Future<Output = T>, T>(f: F) -> T { pub fn block_on<F: Future<Output = T>, T>(f: F) -> T {
async_std::task::block_on(f) async_std::task::block_on(f)
} }
} else if #[cfg(feature="rt-tokio")] { } else if #[cfg(feature="rt-tokio")] {
pub use tokio::net::TcpStream;
pub fn block_on<F: Future<Output = T>, T>(f: F) -> T { pub fn block_on<F: Future<Output = T>, T>(f: F) -> T {
let rt = tokio::runtime::Runtime::new().unwrap(); let rt = tokio::runtime::Runtime::new().unwrap();
let local = tokio::task::LocalSet::new(); let local = tokio::task::LocalSet::new();

View File

@ -4,13 +4,13 @@ name = "veilid-core"
version = "0.4.1" version = "0.4.1"
# --- # ---
description = "Core library used to create a Veilid node and operate it as part of an application" description = "Core library used to create a Veilid node and operate it as part of an application"
repository = "https://gitlab.com/veilid/veilid"
authors = ["Veilid Team <contact@veilid.com>"]
edition = "2021"
build = "build.rs" build = "build.rs"
license = "MPL-2.0"
resolver = "2" resolver = "2"
rust-version = "1.81.0" repository.workspace = true
authors.workspace = true
license.workspace = true
edition.workspace = true
rust-version.workspace = true
[lib] [lib]
crate-type = ["cdylib", "staticlib", "rlib"] crate-type = ["cdylib", "staticlib", "rlib"]
@ -56,6 +56,8 @@ veilid_core_ios_tests = ["dep:tracing-oslog"]
debug-locks = ["veilid-tools/debug-locks"] debug-locks = ["veilid-tools/debug-locks"]
unstable-blockstore = [] unstable-blockstore = []
unstable-tunnels = [] unstable-tunnels = []
virtual-network = ["veilid-tools/virtual-network"]
virtual-network-server = ["veilid-tools/virtual-network-server"]
# GeoIP # GeoIP
geolocation = ["maxminddb", "reqwest"] geolocation = ["maxminddb", "reqwest"]
@ -133,8 +135,8 @@ hickory-resolver = { version = "0.24.1", optional = true }
# Serialization # Serialization
capnp = { version = "0.19.6", default-features = false, features = ["alloc"] } capnp = { version = "0.19.6", default-features = false, features = ["alloc"] }
serde = { version = "1.0.204", features = ["derive", "rc"] } serde = { version = "1.0.214", features = ["derive", "rc"] }
serde_json = { version = "1.0.120" } serde_json = { version = "1.0.132" }
serde-big-array = "0.5.1" serde-big-array = "0.5.1"
json = "0.12.4" json = "0.12.4"
data-encoding = { version = "2.6.0" } data-encoding = { version = "2.6.0" }
@ -148,7 +150,7 @@ sanitize-filename = "0.5.0"
# Dependencies for native builds only # Dependencies for native builds only
# Linux, Windows, Mac, iOS, Android # Linux, Windows, Mac, iOS, Android
[target.'cfg(not(target_arch = "wasm32"))'.dependencies] [target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies]
# Tools # Tools
config = { version = "0.13.4", default-features = false, features = ["yaml"] } config = { version = "0.13.4", default-features = false, features = ["yaml"] }
@ -164,7 +166,6 @@ sysinfo = { version = "^0.30.13", default-features = false }
tokio = { version = "1.38.1", features = ["full"], optional = true } tokio = { version = "1.38.1", features = ["full"], optional = true }
tokio-util = { version = "0.7.11", features = ["compat"], optional = true } tokio-util = { version = "0.7.11", features = ["compat"], optional = true }
tokio-stream = { version = "0.1.15", features = ["net"], optional = true } tokio-stream = { version = "0.1.15", features = ["net"], optional = true }
async-io = { version = "1.13.0" }
futures-util = { version = "0.3.30", default-features = false, features = [ futures-util = { version = "0.3.30", default-features = false, features = [
"async-await", "async-await",
"sink", "sink",
@ -184,10 +185,9 @@ webpki = "0.22.4"
webpki-roots = "0.25.4" webpki-roots = "0.25.4"
rustls = "0.21.12" rustls = "0.21.12"
rustls-pemfile = "1.0.4" rustls-pemfile = "1.0.4"
socket2 = { version = "0.5.7", features = ["all"] }
# Dependencies for WASM builds only # Dependencies for WASM builds only
[target.'cfg(target_arch = "wasm32")'.dependencies] [target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies]
veilid-tools = { version = "0.4.1", path = "../veilid-tools", default-features = false, features = [ veilid-tools = { version = "0.4.1", path = "../veilid-tools", default-features = false, features = [
"rt-wasm-bindgen", "rt-wasm-bindgen",
@ -222,7 +222,7 @@ tracing-wasm = "0.2.1"
keyvaluedb-web = "0.1.2" keyvaluedb-web = "0.1.2"
### Configuration for WASM32 'web-sys' crate ### Configuration for WASM32 'web-sys' crate
[target.'cfg(target_arch = "wasm32")'.dependencies.web-sys] [target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies.web-sys]
version = "0.3.69" version = "0.3.69"
features = [ features = [
'Document', 'Document',
@ -260,12 +260,12 @@ tracing-oslog = { version = "0.1.2", optional = true }
### DEV DEPENDENCIES ### DEV DEPENDENCIES
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] [target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dev-dependencies]
simplelog = { version = "0.12.2", features = ["test"] } simplelog = { version = "0.12.2", features = ["test"] }
serial_test = "2.0.0" serial_test = "2.0.0"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
[target.'cfg(target_arch = "wasm32")'.dev-dependencies] [target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dev-dependencies]
serial_test = { version = "2.0.0", default-features = false, features = [ serial_test = { version = "2.0.0", default-features = false, features = [
"async", "async",
] } ] }

View File

@ -178,7 +178,7 @@ fn fix_android_emulator() {
.or(env::var("ANDROID_SDK_ROOT")) .or(env::var("ANDROID_SDK_ROOT"))
.expect("ANDROID_HOME or ANDROID_SDK_ROOT not set"); .expect("ANDROID_HOME or ANDROID_SDK_ROOT not set");
let lib_path = glob(&format!( let lib_path = glob(&format!(
"{android_home}/ndk/26.3.11579264/**/lib{missing_library}.a" "{android_home}/ndk/27.0.12077973/**/lib{missing_library}.a"
)) ))
.expect("failed to glob") .expect("failed to glob")
.next() .next()

View File

@ -1,54 +1,43 @@
use crate::*; use crate::{network_manager::StartupDisposition, *};
use crypto::Crypto; use routing_table::RoutingTableHealth;
use network_manager::*;
use routing_table::*;
use storage_manager::*;
#[derive(Debug, Clone)]
pub struct AttachmentManagerStartupContext {
pub startup_lock: Arc<StartupLock>,
}
impl AttachmentManagerStartupContext {
pub fn new() -> Self {
Self {
startup_lock: Arc::new(StartupLock::new()),
}
}
}
impl Default for AttachmentManagerStartupContext {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
struct AttachmentManagerInner { struct AttachmentManagerInner {
last_attachment_state: AttachmentState, last_attachment_state: AttachmentState,
last_routing_table_health: Option<RoutingTableHealth>, last_routing_table_health: Option<RoutingTableHealth>,
maintain_peers: bool, maintain_peers: bool,
started_ts: Timestamp, started_ts: Timestamp,
attach_ts: Option<Timestamp>, attach_ts: Option<Timestamp>,
update_callback: Option<UpdateCallback>,
attachment_maintainer_jh: Option<MustJoinHandle<()>>, attachment_maintainer_jh: Option<MustJoinHandle<()>>,
} }
struct AttachmentManagerUnlockedInner { #[derive(Debug)]
_event_bus: EventBus, pub struct AttachmentManager {
config: VeilidConfig, registry: VeilidComponentRegistry,
network_manager: NetworkManager, inner: Mutex<AttachmentManagerInner>,
startup_context: AttachmentManagerStartupContext,
} }
#[derive(Clone)] impl_veilid_component!(AttachmentManager);
pub struct AttachmentManager {
inner: Arc<Mutex<AttachmentManagerInner>>,
unlocked_inner: Arc<AttachmentManagerUnlockedInner>,
}
impl AttachmentManager { impl AttachmentManager {
fn new_unlocked_inner(
event_bus: EventBus,
config: VeilidConfig,
storage_manager: StorageManager,
table_store: TableStore,
#[cfg(feature = "unstable-blockstore")] block_store: BlockStore,
crypto: Crypto,
) -> AttachmentManagerUnlockedInner {
AttachmentManagerUnlockedInner {
_event_bus: event_bus.clone(),
config: config.clone(),
network_manager: NetworkManager::new(
event_bus,
config,
storage_manager,
table_store,
#[cfg(feature = "unstable-blockstore")]
block_store,
crypto,
),
}
}
fn new_inner() -> AttachmentManagerInner { fn new_inner() -> AttachmentManagerInner {
AttachmentManagerInner { AttachmentManagerInner {
last_attachment_state: AttachmentState::Detached, last_attachment_state: AttachmentState::Detached,
@ -56,52 +45,35 @@ impl AttachmentManager {
maintain_peers: false, maintain_peers: false,
started_ts: Timestamp::now(), started_ts: Timestamp::now(),
attach_ts: None, attach_ts: None,
update_callback: None,
attachment_maintainer_jh: None, attachment_maintainer_jh: None,
} }
} }
pub fn new( pub fn new(
event_bus: EventBus, registry: VeilidComponentRegistry,
config: VeilidConfig, startup_context: AttachmentManagerStartupContext,
storage_manager: StorageManager,
table_store: TableStore,
#[cfg(feature = "unstable-blockstore")] block_store: BlockStore,
crypto: Crypto,
) -> Self { ) -> Self {
Self { Self {
inner: Arc::new(Mutex::new(Self::new_inner())), registry,
unlocked_inner: Arc::new(Self::new_unlocked_inner( inner: Mutex::new(Self::new_inner()),
event_bus, startup_context,
config,
storage_manager,
table_store,
#[cfg(feature = "unstable-blockstore")]
block_store,
crypto,
)),
} }
} }
pub fn config(&self) -> VeilidConfig { pub fn is_attached(&self) -> bool {
self.unlocked_inner.config.clone() let s = self.inner.lock().last_attachment_state;
!matches!(s, AttachmentState::Detached | AttachmentState::Detaching)
} }
pub fn network_manager(&self) -> NetworkManager { #[allow(dead_code)]
self.unlocked_inner.network_manager.clone() pub fn is_detached(&self) -> bool {
let s = self.inner.lock().last_attachment_state;
matches!(s, AttachmentState::Detached)
} }
// pub fn is_attached(&self) -> bool { #[allow(dead_code)]
// let s = self.inner.lock().last_attachment_state; pub fn get_attach_timestamp(&self) -> Option<Timestamp> {
// !matches!(s, AttachmentState::Detached | AttachmentState::Detaching) self.inner.lock().attach_ts
// } }
// pub fn is_detached(&self) -> bool {
// let s = self.inner.lock().last_attachment_state;
// matches!(s, AttachmentState::Detached)
// }
// pub fn get_attach_timestamp(&self) -> Option<Timestamp> {
// self.inner.lock().attach_ts
// }
fn translate_routing_table_health( fn translate_routing_table_health(
health: &RoutingTableHealth, health: &RoutingTableHealth,
@ -155,11 +127,6 @@ impl AttachmentManager {
inner.last_attachment_state = inner.last_attachment_state =
AttachmentManager::translate_routing_table_health(&health, routing_table_config); AttachmentManager::translate_routing_table_health(&health, routing_table_config);
// If we don't have an update callback yet for some reason, just return now
let Some(update_callback) = inner.update_callback.clone() else {
return;
};
// Send update if one of: // Send update if one of:
// * the attachment state has changed // * the attachment state has changed
// * routing domain readiness has changed // * routing domain readiness has changed
@ -172,7 +139,7 @@ impl AttachmentManager {
}) })
.unwrap_or(true); .unwrap_or(true);
if send_update { if send_update {
Some((update_callback, Self::get_veilid_state_inner(&inner))) Some(Self::get_veilid_state_inner(&inner))
} else { } else {
None None
} }
@ -180,15 +147,14 @@ impl AttachmentManager {
// Send the update outside of the lock // Send the update outside of the lock
if let Some(update) = opt_update { if let Some(update) = opt_update {
(update.0)(VeilidUpdate::Attachment(update.1)); (self.update_callback())(VeilidUpdate::Attachment(update));
} }
} }
fn update_attaching_detaching_state(&self, state: AttachmentState) { fn update_attaching_detaching_state(&self, state: AttachmentState) {
let uptime; let uptime;
let attached_uptime; let attached_uptime;
{
let update_callback = {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
// Clear routing table health so when we start measuring it we start from scratch // Clear routing table health so when we start measuring it we start from scratch
@ -211,29 +177,98 @@ impl AttachmentManager {
let now = Timestamp::now(); let now = Timestamp::now();
uptime = now - inner.started_ts; uptime = now - inner.started_ts;
attached_uptime = inner.attach_ts.map(|ts| now - ts); attached_uptime = inner.attach_ts.map(|ts| now - ts);
// Get callback
inner.update_callback.clone()
}; };
// Send update // Send update
if let Some(update_callback) = update_callback { (self.update_callback())(VeilidUpdate::Attachment(Box::new(VeilidStateAttachment {
update_callback(VeilidUpdate::Attachment(Box::new(VeilidStateAttachment { state,
state, public_internet_ready: false,
public_internet_ready: false, local_network_ready: false,
local_network_ready: false, uptime,
uptime, attached_uptime,
attached_uptime, })))
}))) }
async fn startup(&self) -> EyreResult<StartupDisposition> {
let guard = self.startup_context.startup_lock.startup()?;
let rpc_processor = self.rpc_processor();
let network_manager = self.network_manager();
// Startup network manager
network_manager.startup().await?;
// Startup rpc processor
if let Err(e) = rpc_processor.startup().await {
network_manager.shutdown().await;
return Err(e);
} }
// Startup routing table
let routing_table = self.routing_table();
if let Err(e) = routing_table.startup().await {
rpc_processor.shutdown().await;
network_manager.shutdown().await;
return Err(e);
}
// Startup successful
guard.success();
// Inform api clients that things have changed
log_net!(debug "sending network state update to api clients");
network_manager.send_network_update();
Ok(StartupDisposition::Success)
}
async fn shutdown(&self) {
let guard = self
.startup_context
.startup_lock
.shutdown()
.await
.expect("should be started up");
let routing_table = self.routing_table();
let rpc_processor = self.rpc_processor();
let network_manager = self.network_manager();
// Shutdown RoutingTable
routing_table.shutdown().await;
// Shutdown NetworkManager
network_manager.shutdown().await;
// Shutdown RPCProcessor
rpc_processor.shutdown().await;
// Shutdown successful
guard.success();
// send update
log_net!(debug "sending network state update to api clients");
network_manager.send_network_update();
}
async fn tick(&self) -> EyreResult<()> {
// Run the network manager tick
let network_manager = self.network_manager();
network_manager.tick().await?;
// Run the routing table tick
let routing_table = self.routing_table();
routing_table.tick().await?;
Ok(())
} }
#[instrument(parent = None, level = "debug", skip_all)] #[instrument(parent = None, level = "debug", skip_all)]
async fn attachment_maintainer(self) { async fn attachment_maintainer(&self) {
log_net!(debug "attachment starting"); log_net!(debug "attachment starting");
self.update_attaching_detaching_state(AttachmentState::Attaching); self.update_attaching_detaching_state(AttachmentState::Attaching);
let netman = self.network_manager(); let network_manager = self.network_manager();
let mut restart; let mut restart;
let mut restart_delay; let mut restart_delay;
@ -241,9 +276,9 @@ impl AttachmentManager {
restart = false; restart = false;
restart_delay = 1; restart_delay = 1;
match netman.startup().await { match self.startup().await {
Err(err) => { Err(err) => {
error!("network startup failed: {}", err); error!("attachment startup failed: {}", err);
restart = true; restart = true;
} }
Ok(StartupDisposition::BindRetry) => { Ok(StartupDisposition::BindRetry) => {
@ -257,15 +292,15 @@ impl AttachmentManager {
while self.inner.lock().maintain_peers { while self.inner.lock().maintain_peers {
// tick network manager // tick network manager
let next_tick_ts = get_timestamp() + 1_000_000u64; let next_tick_ts = get_timestamp() + 1_000_000u64;
if let Err(err) = netman.tick().await { if let Err(err) = self.tick().await {
error!("Error in network manager: {}", err); error!("Error in attachment tick: {}", err);
self.inner.lock().maintain_peers = false; self.inner.lock().maintain_peers = false;
restart = true; restart = true;
break; break;
} }
// see if we need to restart the network // see if we need to restart the network
if netman.network_needs_restart() { if network_manager.network_needs_restart() {
info!("Restarting network"); info!("Restarting network");
restart = true; restart = true;
break; break;
@ -288,8 +323,8 @@ impl AttachmentManager {
log_net!(debug "attachment stopping"); log_net!(debug "attachment stopping");
} }
log_net!(debug "stopping network"); log_net!(debug "shutting down attachment");
netman.shutdown().await; self.shutdown().await;
} }
} }
@ -313,25 +348,24 @@ impl AttachmentManager {
} }
#[instrument(level = "debug", skip_all, err)] #[instrument(level = "debug", skip_all, err)]
pub async fn init(&self, update_callback: UpdateCallback) -> EyreResult<()> { pub async fn init_async(&self) -> EyreResult<()> {
{ Ok(())
let mut inner = self.inner.lock(); }
inner.update_callback = Some(update_callback.clone());
}
self.network_manager().init(update_callback).await?;
#[instrument(level = "debug", skip_all, err)]
pub async fn post_init_async(&self) -> EyreResult<()> {
Ok(()) Ok(())
} }
#[instrument(level = "debug", skip_all)] #[instrument(level = "debug", skip_all)]
pub async fn terminate(&self) { pub async fn pre_terminate_async(&self) {
// Ensure we detached // Ensure we detached
self.detach().await; self.detach().await;
self.network_manager().terminate().await;
self.inner.lock().update_callback = None;
} }
#[instrument(level = "debug", skip_all)]
pub async fn terminate_async(&self) {}
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub async fn attach(&self) -> bool { pub async fn attach(&self) -> bool {
// Create long-running connection maintenance routine // Create long-running connection maintenance routine
@ -340,10 +374,11 @@ impl AttachmentManager {
return false; return false;
} }
inner.maintain_peers = true; inner.maintain_peers = true;
inner.attachment_maintainer_jh = Some(spawn( let registry = self.registry();
"attachment maintainer", inner.attachment_maintainer_jh = Some(spawn("attachment maintainer", async move {
self.clone().attachment_maintainer(), let this = registry.attachment_manager();
)); this.attachment_maintainer().await;
}));
true true
} }

View File

@ -0,0 +1,336 @@
use std::marker::PhantomData;
use super::*;
pub trait AsAnyArcSendSync {
fn as_any_arc_send_sync(self: Arc<Self>) -> Arc<dyn core::any::Any + Send + Sync>;
}
impl<T: Send + Sync + 'static> AsAnyArcSendSync for T {
fn as_any_arc_send_sync(self: Arc<Self>) -> Arc<dyn core::any::Any + Send + Sync> {
self
}
}
pub trait VeilidComponent:
AsAnyArcSendSync + VeilidComponentRegistryAccessor + core::fmt::Debug
{
fn init(&self) -> SendPinBoxFutureLifetime<'_, EyreResult<()>>;
fn post_init(&self) -> SendPinBoxFutureLifetime<'_, EyreResult<()>>;
fn pre_terminate(&self) -> SendPinBoxFutureLifetime<'_, ()>;
fn terminate(&self) -> SendPinBoxFutureLifetime<'_, ()>;
}
pub trait VeilidComponentRegistryAccessor {
fn registry(&self) -> VeilidComponentRegistry;
fn config(&self) -> VeilidConfig {
self.registry().config.clone()
}
fn update_callback(&self) -> UpdateCallback {
self.registry().config.update_callback()
}
fn event_bus(&self) -> EventBus {
self.registry().event_bus.clone()
}
}
pub struct VeilidComponentGuard<'a, T: VeilidComponent + Send + Sync + 'static> {
component: Arc<T>,
_phantom: core::marker::PhantomData<&'a T>,
}
impl<'a, T> core::ops::Deref for VeilidComponentGuard<'a, T>
where
T: VeilidComponent + Send + Sync + 'static,
{
type Target = T;
fn deref(&self) -> &Self::Target {
&self.component
}
}
#[derive(Debug)]
struct VeilidComponentRegistryInner {
type_map: HashMap<core::any::TypeId, Arc<dyn VeilidComponent + Send + Sync>>,
init_order: Vec<core::any::TypeId>,
mock: bool,
}
#[derive(Clone, Debug)]
pub struct VeilidComponentRegistry {
inner: Arc<Mutex<VeilidComponentRegistryInner>>,
config: VeilidConfig,
event_bus: EventBus,
init_lock: Arc<AsyncMutex<bool>>,
}
impl VeilidComponentRegistry {
pub fn new(config: VeilidConfig) -> Self {
Self {
inner: Arc::new(Mutex::new(VeilidComponentRegistryInner {
type_map: HashMap::new(),
init_order: Vec::new(),
mock: false,
})),
config,
event_bus: EventBus::new(),
init_lock: Arc::new(AsyncMutex::new(false)),
}
}
pub fn enable_mock(&self) {
let mut inner = self.inner.lock();
inner.mock = true;
}
pub fn register<
T: VeilidComponent + Send + Sync + 'static,
F: FnOnce(VeilidComponentRegistry) -> T,
>(
&self,
component_constructor: F,
) {
let component = Arc::new(component_constructor(self.clone()));
let component_type_id = core::any::TypeId::of::<T>();
let mut inner = self.inner.lock();
assert!(
inner
.type_map
.insert(component_type_id, component)
.is_none(),
"should not register same component twice"
);
inner.init_order.push(component_type_id);
}
pub fn register_with_context<
C,
T: VeilidComponent + Send + Sync + 'static,
F: FnOnce(VeilidComponentRegistry, C) -> T,
>(
&self,
component_constructor: F,
context: C,
) {
let component = Arc::new(component_constructor(self.clone(), context));
let component_type_id = core::any::TypeId::of::<T>();
let mut inner = self.inner.lock();
assert!(
inner
.type_map
.insert(component_type_id, component)
.is_none(),
"should not register same component twice"
);
inner.init_order.push(component_type_id);
}
pub async fn init(&self) -> EyreResult<()> {
let Some(mut _init_guard) = asyncmutex_try_lock!(self.init_lock) else {
bail!("init should only happen one at a time");
};
if *_init_guard {
bail!("already initialized");
}
// Event bus starts up early
self.event_bus.startup().await?;
// Process components in initialization order
let init_order = self.get_init_order();
let mut initialized = vec![];
for component in init_order {
if let Err(e) = component.init().await {
self.terminate_inner(initialized).await;
self.event_bus.shutdown().await;
return Err(e);
}
initialized.push(component);
}
*_init_guard = true;
Ok(())
}
pub async fn post_init(&self) -> EyreResult<()> {
let Some(mut _init_guard) = asyncmutex_try_lock!(self.init_lock) else {
bail!("init should only happen one at a time");
};
if !*_init_guard {
bail!("not initialized");
}
let init_order = self.get_init_order();
let mut post_initialized = vec![];
for component in init_order {
if let Err(e) = component.post_init().await {
self.pre_terminate_inner(post_initialized).await;
return Err(e);
}
post_initialized.push(component)
}
Ok(())
}
pub async fn pre_terminate(&self) {
let Some(mut _init_guard) = asyncmutex_try_lock!(self.init_lock) else {
panic!("terminate should only happen one at a time");
};
if !*_init_guard {
panic!("not initialized");
}
let init_order = self.get_init_order();
self.pre_terminate_inner(init_order).await;
}
pub async fn terminate(&self) {
let Some(mut _init_guard) = asyncmutex_try_lock!(self.init_lock) else {
panic!("terminate should only happen one at a time");
};
if !*_init_guard {
panic!("not initialized");
}
// Terminate components in reverse initialization order
let init_order = self.get_init_order();
self.terminate_inner(init_order).await;
// Event bus shuts down last
self.event_bus.shutdown().await;
*_init_guard = false;
}
async fn pre_terminate_inner(
&self,
pre_initialized: Vec<Arc<dyn VeilidComponent + Send + Sync>>,
) {
for component in pre_initialized.iter().rev() {
component.pre_terminate().await;
}
}
async fn terminate_inner(&self, initialized: Vec<Arc<dyn VeilidComponent + Send + Sync>>) {
for component in initialized.iter().rev() {
component.terminate().await;
}
}
fn get_init_order(&self) -> Vec<Arc<dyn VeilidComponent + Send + Sync>> {
let inner = self.inner.lock();
inner
.init_order
.iter()
.map(|id| inner.type_map.get(id).unwrap().clone())
.collect::<Vec<_>>()
}
//////////////////////////////////////////////////////////////
pub fn lookup<'a, T: VeilidComponent + Send + Sync + 'static>(
&self,
) -> Option<VeilidComponentGuard<'a, T>> {
let inner = self.inner.lock();
let component_type_id = core::any::TypeId::of::<T>();
let component_dyn = inner.type_map.get(&component_type_id)?.clone();
let component = component_dyn
.as_any_arc_send_sync()
.downcast::<T>()
.unwrap();
Some(VeilidComponentGuard {
component,
_phantom: PhantomData {},
})
}
}
impl VeilidComponentRegistryAccessor for VeilidComponentRegistry {
fn registry(&self) -> VeilidComponentRegistry {
self.clone()
}
}
////////////////////////////////////////////////////////////////////
macro_rules! impl_veilid_component_registry_accessor {
($struct_name:ident) => {
impl VeilidComponentRegistryAccessor for $struct_name {
fn registry(&self) -> VeilidComponentRegistry {
self.registry.clone()
}
}
};
}
pub(crate) use impl_veilid_component_registry_accessor;
/////////////////////////////////////////////////////////////////////
macro_rules! impl_veilid_component {
($component_name:ident) => {
impl_veilid_component_registry_accessor!($component_name);
impl VeilidComponent for $component_name {
fn init(&self) -> SendPinBoxFutureLifetime<'_, EyreResult<()>> {
Box::pin(async { self.init_async().await })
}
fn post_init(&self) -> SendPinBoxFutureLifetime<'_, EyreResult<()>> {
Box::pin(async { self.post_init_async().await })
}
fn pre_terminate(&self) -> SendPinBoxFutureLifetime<'_, ()> {
Box::pin(async { self.pre_terminate_async().await })
}
fn terminate(&self) -> SendPinBoxFutureLifetime<'_, ()> {
Box::pin(async { self.terminate_async().await })
}
}
};
}
pub(crate) use impl_veilid_component;
/////////////////////////////////////////////////////////////////////
// Utility macro for setting up a background TickTask
// Should be called during new/construction of a component with background tasks
// and before any post-init 'tick' operations are started
macro_rules! impl_setup_task {
($this:expr, $this_type:ty, $task_name:ident, $task_routine:ident ) => {{
let registry = $this.registry();
$this.$task_name.set_routine(move |s, l, t| {
let registry = registry.clone();
Box::pin(async move {
let this = registry.lookup::<$this_type>().unwrap();
this.$task_routine(s, Timestamp::new(l), Timestamp::new(t))
.await
})
});
}};
}
pub(crate) use impl_setup_task;
// Utility macro for setting up an event bus handler
// Should be called after init, during post-init or later
// Subscription should be unsubscribed before termination
macro_rules! impl_subscribe_event_bus {
($this:expr, $this_type:ty, $event_handler:ident ) => {{
let registry = $this.registry();
$this.event_bus().subscribe(move |evt| {
let registry = registry.clone();
Box::pin(async move {
let this = registry.lookup::<$this_type>().unwrap();
this.$event_handler(evt);
})
})
}};
}
pub(crate) use impl_subscribe_event_bus;

View File

@ -1,242 +1,26 @@
use crate::attachment_manager::*; use crate::attachment_manager::{AttachmentManager, AttachmentManagerStartupContext};
use crate::crypto::Crypto; use crate::crypto::Crypto;
use crate::logging::*; use crate::logging::*;
use crate::storage_manager::*; use crate::network_manager::{NetworkManager, NetworkManagerStartupContext};
use crate::routing_table::RoutingTable;
use crate::rpc_processor::{RPCProcessor, RPCProcessorStartupContext};
use crate::storage_manager::StorageManager;
use crate::veilid_api::*; use crate::veilid_api::*;
use crate::veilid_config::*; use crate::veilid_config::*;
use crate::*; use crate::*;
pub type UpdateCallback = Arc<dyn Fn(VeilidUpdate) + Send + Sync>; pub type UpdateCallback = Arc<dyn Fn(VeilidUpdate) + Send + Sync>;
/// Internal services startup mechanism. type InitKey = (String, String);
/// Ensures that everything is started up, and shut down in the right order
/// and provides an atomic state for if the system is properly operational.
struct StartupShutdownContext {
pub config: VeilidConfig,
pub update_callback: UpdateCallback,
pub event_bus: Option<EventBus>,
pub protected_store: Option<ProtectedStore>,
pub table_store: Option<TableStore>,
#[cfg(feature = "unstable-blockstore")]
pub block_store: Option<BlockStore>,
pub crypto: Option<Crypto>,
pub attachment_manager: Option<AttachmentManager>,
pub storage_manager: Option<StorageManager>,
}
impl StartupShutdownContext {
pub fn new_empty(config: VeilidConfig, update_callback: UpdateCallback) -> Self {
Self {
config,
update_callback,
event_bus: None,
protected_store: None,
table_store: None,
#[cfg(feature = "unstable-blockstore")]
block_store: None,
crypto: None,
attachment_manager: None,
storage_manager: None,
}
}
#[allow(clippy::too_many_arguments)]
pub fn new_full(
config: VeilidConfig,
update_callback: UpdateCallback,
event_bus: EventBus,
protected_store: ProtectedStore,
table_store: TableStore,
#[cfg(feature = "unstable-blockstore")] block_store: BlockStore,
crypto: Crypto,
attachment_manager: AttachmentManager,
storage_manager: StorageManager,
) -> Self {
Self {
config,
update_callback,
event_bus: Some(event_bus),
protected_store: Some(protected_store),
table_store: Some(table_store),
#[cfg(feature = "unstable-blockstore")]
block_store: Some(block_store),
crypto: Some(crypto),
attachment_manager: Some(attachment_manager),
storage_manager: Some(storage_manager),
}
}
#[instrument(level = "trace", target = "core_context", err, skip_all)]
pub async fn startup(&mut self) -> EyreResult<()> {
info!("Veilid API starting up");
info!("init api tracing");
let (program_name, namespace) = {
let config = self.config.get();
(config.program_name.clone(), config.namespace.clone())
};
ApiTracingLayer::add_callback(program_name, namespace, self.update_callback.clone())
.await?;
// Add the event bus
let event_bus = EventBus::new();
if let Err(e) = event_bus.startup().await {
error!("failed to start up event bus: {}", e);
self.shutdown().await;
return Err(e.into());
}
self.event_bus = Some(event_bus.clone());
// Set up protected store
let protected_store = ProtectedStore::new(event_bus.clone(), self.config.clone());
if let Err(e) = protected_store.init().await {
error!("failed to init protected store: {}", e);
self.shutdown().await;
return Err(e);
}
self.protected_store = Some(protected_store.clone());
// Set up tablestore and crypto system
let table_store = TableStore::new(
event_bus.clone(),
self.config.clone(),
protected_store.clone(),
);
let crypto = Crypto::new(event_bus.clone(), self.config.clone(), table_store.clone());
table_store.set_crypto(crypto.clone());
// Initialize table store first, so crypto code can load caches
// Tablestore can use crypto during init, just not any cached operations or things
// that require flushing back to the tablestore
if let Err(e) = table_store.init().await {
error!("failed to init table store: {}", e);
self.shutdown().await;
return Err(e);
}
self.table_store = Some(table_store.clone());
// Set up crypto
if let Err(e) = crypto.init().await {
error!("failed to init crypto: {}", e);
self.shutdown().await;
return Err(e);
}
self.crypto = Some(crypto.clone());
// Set up block store
#[cfg(feature = "unstable-blockstore")]
{
let block_store = BlockStore::new(event_bus.clone(), self.config.clone());
if let Err(e) = block_store.init().await {
error!("failed to init block store: {}", e);
self.shutdown().await;
return Err(e);
}
self.block_store = Some(block_store.clone());
}
// Set up storage manager
let update_callback = self.update_callback.clone();
let storage_manager = StorageManager::new(
event_bus.clone(),
self.config.clone(),
self.crypto.clone().unwrap(),
self.table_store.clone().unwrap(),
#[cfg(feature = "unstable-blockstore")]
self.block_store.clone().unwrap(),
);
if let Err(e) = storage_manager.init(update_callback).await {
error!("failed to init storage manager: {}", e);
self.shutdown().await;
return Err(e);
}
self.storage_manager = Some(storage_manager.clone());
// Set up attachment manager
let update_callback = self.update_callback.clone();
let attachment_manager = AttachmentManager::new(
event_bus.clone(),
self.config.clone(),
storage_manager,
table_store,
#[cfg(feature = "unstable-blockstore")]
block_store,
crypto,
);
if let Err(e) = attachment_manager.init(update_callback).await {
error!("failed to init attachment manager: {}", e);
self.shutdown().await;
return Err(e);
}
self.attachment_manager = Some(attachment_manager);
info!("Veilid API startup complete");
Ok(())
}
#[instrument(level = "trace", target = "core_context", skip_all)]
pub async fn shutdown(&mut self) {
info!("Veilid API shutting down");
if let Some(attachment_manager) = &mut self.attachment_manager {
attachment_manager.terminate().await;
}
if let Some(storage_manager) = &mut self.storage_manager {
storage_manager.terminate().await;
}
#[cfg(feature = "unstable-blockstore")]
if let Some(block_store) = &mut self.block_store {
block_store.terminate().await;
}
if let Some(crypto) = &mut self.crypto {
crypto.terminate().await;
}
if let Some(table_store) = &mut self.table_store {
table_store.terminate().await;
}
if let Some(protected_store) = &mut self.protected_store {
protected_store.terminate().await;
}
if let Some(event_bus) = &mut self.event_bus {
event_bus.shutdown().await;
}
info!("Veilid API shutdown complete");
// api logger terminate is idempotent
let (program_name, namespace) = {
let config = self.config.get();
(config.program_name.clone(), config.namespace.clone())
};
if let Err(e) = ApiTracingLayer::remove_callback(program_name, namespace).await {
error!("Error removing callback from ApiTracingLayer: {}", e);
}
// send final shutdown update
(self.update_callback)(VeilidUpdate::Shutdown);
}
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
pub struct VeilidCoreContext { #[derive(Clone, Debug)]
pub config: VeilidConfig, pub(crate) struct VeilidCoreContext {
pub update_callback: UpdateCallback, registry: VeilidComponentRegistry,
// Event bus
pub event_bus: EventBus,
// Services
pub storage_manager: StorageManager,
pub protected_store: ProtectedStore,
pub table_store: TableStore,
#[cfg(feature = "unstable-blockstore")]
pub block_store: BlockStore,
pub crypto: Crypto,
pub attachment_manager: AttachmentManager,
} }
impl_veilid_component_registry_accessor!(VeilidCoreContext);
impl VeilidCoreContext { impl VeilidCoreContext {
#[instrument(level = "trace", target = "core_context", err, skip_all)] #[instrument(level = "trace", target = "core_context", err, skip_all)]
async fn new_with_config_callback( async fn new_with_config_callback(
@ -244,10 +28,9 @@ impl VeilidCoreContext {
config_callback: ConfigCallback, config_callback: ConfigCallback,
) -> VeilidAPIResult<VeilidCoreContext> { ) -> VeilidAPIResult<VeilidCoreContext> {
// Set up config from callback // Set up config from callback
let mut config = VeilidConfig::new(); let config = VeilidConfig::new_from_callback(config_callback, update_callback)?;
config.setup(config_callback, update_callback.clone())?;
Self::new_common(update_callback, config).await Self::new_common(config).await
} }
#[instrument(level = "trace", target = "core_context", err, skip_all)] #[instrument(level = "trace", target = "core_context", err, skip_all)]
@ -256,16 +39,12 @@ impl VeilidCoreContext {
config_inner: VeilidConfigInner, config_inner: VeilidConfigInner,
) -> VeilidAPIResult<VeilidCoreContext> { ) -> VeilidAPIResult<VeilidCoreContext> {
// Set up config from json // Set up config from json
let mut config = VeilidConfig::new(); let config = VeilidConfig::new_from_config(config_inner, update_callback);
config.setup_from_config(config_inner, update_callback.clone())?; Self::new_common(config).await
Self::new_common(update_callback, config).await
} }
#[instrument(level = "trace", target = "core_context", err, skip_all)] #[instrument(level = "trace", target = "core_context", err, skip_all)]
async fn new_common( async fn new_common(config: VeilidConfig) -> VeilidAPIResult<VeilidCoreContext> {
update_callback: UpdateCallback,
config: VeilidConfig,
) -> VeilidAPIResult<VeilidCoreContext> {
cfg_if! { cfg_if! {
if #[cfg(target_os = "android")] { if #[cfg(target_os = "android")] {
if !crate::intf::android::is_android_ready() { if !crate::intf::android::is_android_ready() {
@ -274,45 +53,134 @@ impl VeilidCoreContext {
} }
} }
let mut sc = StartupShutdownContext::new_empty(config.clone(), update_callback); info!("Veilid API starting up");
sc.startup().await.map_err(VeilidAPIError::generic)?;
Ok(VeilidCoreContext { let (program_name, namespace, update_callback) = {
config: sc.config, let cfginner = config.get();
update_callback: sc.update_callback, (
event_bus: sc.event_bus.unwrap(), cfginner.program_name.clone(),
storage_manager: sc.storage_manager.unwrap(), cfginner.namespace.clone(),
protected_store: sc.protected_store.unwrap(), config.update_callback(),
table_store: sc.table_store.unwrap(), )
#[cfg(feature = "unstable-blockstore")] };
block_store: sc.block_store.unwrap(),
crypto: sc.crypto.unwrap(), ApiTracingLayer::add_callback(program_name, namespace, update_callback.clone()).await?;
attachment_manager: sc.attachment_manager.unwrap(),
}) // Create component registry
let registry = VeilidComponentRegistry::new(config);
// Register all components
registry.register(ProtectedStore::new);
registry.register(Crypto::new);
registry.register(TableStore::new);
#[cfg(feature = "unstable-blockstore")]
registry.register(BlockStore::new);
registry.register(StorageManager::new);
registry.register(RoutingTable::new);
registry
.register_with_context(NetworkManager::new, NetworkManagerStartupContext::default());
registry.register_with_context(RPCProcessor::new, RPCProcessorStartupContext::default());
registry.register_with_context(
AttachmentManager::new,
AttachmentManagerStartupContext::default(),
);
// Run initialization
// This should make the majority of subsystems functional
registry.init().await.map_err(VeilidAPIError::internal)?;
// Run post-initialization
// This should resolve any inter-subsystem dependencies
// required for background processes that utilize multiple subsystems
// Background processes also often require registry lookup of the
// current subsystem, which is not available until after init succeeds
if let Err(e) = registry.post_init().await {
registry.terminate().await;
return Err(VeilidAPIError::internal(e));
}
info!("Veilid API startup complete");
Ok(Self { registry })
} }
#[instrument(level = "trace", target = "core_context", skip_all)] #[instrument(level = "trace", target = "core_context", skip_all)]
async fn shutdown(self) { async fn shutdown(self) {
let mut sc = StartupShutdownContext::new_full( info!("Veilid API shutdown complete");
self.config.clone(),
self.update_callback.clone(), let (program_name, namespace, update_callback) = {
self.event_bus, let config = self.registry.config();
self.protected_store, let cfginner = config.get();
self.table_store, (
#[cfg(feature = "unstable-blockstore")] cfginner.program_name.clone(),
self.block_store, cfginner.namespace.clone(),
self.crypto, config.update_callback(),
self.attachment_manager, )
self.storage_manager, };
);
sc.shutdown().await; // Run pre-termination
// This should shut down background processes that may require the existence of
// other subsystems that may not exist during final termination
self.registry.pre_terminate().await;
// Run termination
// This should finish any shutdown operations for the subsystems
self.registry.terminate().await;
if let Err(e) = ApiTracingLayer::remove_callback(program_name, namespace).await {
error!("Error removing callback from ApiTracingLayer: {}", e);
}
// send final shutdown update
update_callback(VeilidUpdate::Shutdown);
}
}
/////////////////////////////////////////////////////////////////////////////
pub trait RegisteredComponents {
fn protected_store<'a>(&self) -> VeilidComponentGuard<'a, ProtectedStore>;
fn crypto<'a>(&self) -> VeilidComponentGuard<'a, Crypto>;
fn table_store<'a>(&self) -> VeilidComponentGuard<'a, TableStore>;
fn storage_manager<'a>(&self) -> VeilidComponentGuard<'a, StorageManager>;
fn routing_table<'a>(&self) -> VeilidComponentGuard<'a, RoutingTable>;
fn network_manager<'a>(&self) -> VeilidComponentGuard<'a, NetworkManager>;
fn rpc_processor<'a>(&self) -> VeilidComponentGuard<'a, RPCProcessor>;
fn attachment_manager<'a>(&self) -> VeilidComponentGuard<'a, AttachmentManager>;
}
impl<T: VeilidComponentRegistryAccessor> RegisteredComponents for T {
fn protected_store<'a>(&self) -> VeilidComponentGuard<'a, ProtectedStore> {
self.registry().lookup::<ProtectedStore>().unwrap()
}
fn crypto<'a>(&self) -> VeilidComponentGuard<'a, Crypto> {
self.registry().lookup::<Crypto>().unwrap()
}
fn table_store<'a>(&self) -> VeilidComponentGuard<'a, TableStore> {
self.registry().lookup::<TableStore>().unwrap()
}
fn storage_manager<'a>(&self) -> VeilidComponentGuard<'a, StorageManager> {
self.registry().lookup::<StorageManager>().unwrap()
}
fn routing_table<'a>(&self) -> VeilidComponentGuard<'a, RoutingTable> {
self.registry().lookup::<RoutingTable>().unwrap()
}
fn network_manager<'a>(&self) -> VeilidComponentGuard<'a, NetworkManager> {
self.registry().lookup::<NetworkManager>().unwrap()
}
fn rpc_processor<'a>(&self) -> VeilidComponentGuard<'a, RPCProcessor> {
self.registry().lookup::<RPCProcessor>().unwrap()
}
fn attachment_manager<'a>(&self) -> VeilidComponentGuard<'a, AttachmentManager> {
self.registry().lookup::<AttachmentManager>().unwrap()
} }
} }
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
lazy_static::lazy_static! { lazy_static::lazy_static! {
static ref INITIALIZED: AsyncMutex<HashSet<(String,String)>> = AsyncMutex::new(HashSet::new()); static ref INITIALIZED: Mutex<HashSet<InitKey>> = Mutex::new(HashSet::new());
static ref STARTUP_TABLE: AsyncTagLockTable<InitKey> = AsyncTagLockTable::new();
} }
/// Initialize a Veilid node. /// Initialize a Veilid node.
@ -345,9 +213,11 @@ pub async fn api_startup(
})?; })?;
let init_key = (program_name, namespace); let init_key = (program_name, namespace);
// Only allow one startup/shutdown per program_name+namespace combination simultaneously
let _tag_guard = STARTUP_TABLE.lock_tag(init_key.clone()).await;
// See if we have an API started up already // See if we have an API started up already
let mut initialized_lock = INITIALIZED.lock().await; if INITIALIZED.lock().contains(&init_key) {
if initialized_lock.contains(&init_key) {
apibail_already_initialized!(); apibail_already_initialized!();
} }
@ -358,7 +228,8 @@ pub async fn api_startup(
// Return an API object around our context // Return an API object around our context
let veilid_api = VeilidAPI::new(context); let veilid_api = VeilidAPI::new(context);
initialized_lock.insert(init_key); // Add to the initialized set
INITIALIZED.lock().insert(init_key);
Ok(veilid_api) Ok(veilid_api)
} }
@ -403,12 +274,13 @@ pub async fn api_startup_config(
// Get the program_name and namespace we're starting up in // Get the program_name and namespace we're starting up in
let program_name = config.program_name.clone(); let program_name = config.program_name.clone();
let namespace = config.namespace.clone(); let namespace = config.namespace.clone();
let init_key = (program_name, namespace); let init_key = (program_name, namespace);
// Only allow one startup/shutdown per program_name+namespace combination simultaneously
let _tag_guard = STARTUP_TABLE.lock_tag(init_key.clone()).await;
// See if we have an API started up already // See if we have an API started up already
let mut initialized_lock = INITIALIZED.lock().await; if INITIALIZED.lock().contains(&init_key) {
if initialized_lock.contains(&init_key) {
apibail_already_initialized!(); apibail_already_initialized!();
} }
@ -418,20 +290,32 @@ pub async fn api_startup_config(
// Return an API object around our context // Return an API object around our context
let veilid_api = VeilidAPI::new(context); let veilid_api = VeilidAPI::new(context);
initialized_lock.insert(init_key); // Add to the initialized set
INITIALIZED.lock().insert(init_key);
Ok(veilid_api) Ok(veilid_api)
} }
#[instrument(level = "trace", target = "core_context", skip_all)] #[instrument(level = "trace", target = "core_context", skip_all)]
pub async fn api_shutdown(context: VeilidCoreContext) { pub(crate) async fn api_shutdown(context: VeilidCoreContext) {
let mut initialized_lock = INITIALIZED.lock().await;
let init_key = { let init_key = {
let config = context.config.get(); let registry = context.registry();
(config.program_name.clone(), config.namespace.clone()) let config = registry.config();
let cfginner = config.get();
(cfginner.program_name.clone(), cfginner.namespace.clone())
}; };
// Only allow one startup/shutdown per program_name+namespace combination simultaneously
let _tag_guard = STARTUP_TABLE.lock_tag(init_key.clone()).await;
// See if we have an API started up already
if !INITIALIZED.lock().contains(&init_key) {
return;
}
// Shutdown the context
context.shutdown().await; context.shutdown().await;
initialized_lock.remove(&init_key);
// Remove from the initialized set
INITIALIZED.lock().remove(&init_key);
} }

View File

@ -1,11 +1,11 @@
use super::*; use super::*;
const VEILID_DOMAIN_API: &[u8] = b"VEILID_API"; pub(crate) const VEILID_DOMAIN_API: &[u8] = b"VEILID_API";
pub trait CryptoSystem { pub trait CryptoSystem {
// Accessors // Accessors
fn kind(&self) -> CryptoKind; fn kind(&self) -> CryptoKind;
fn crypto(&self) -> Crypto; fn crypto(&self) -> VeilidComponentGuard<'_, Crypto>;
// Cached Operations // Cached Operations
fn cached_dh(&self, key: &PublicKey, secret: &SecretKey) -> VeilidAPIResult<SharedSecret>; fn cached_dh(&self, key: &PublicKey, secret: &SecretKey) -> VeilidAPIResult<SharedSecret>;

View File

@ -67,7 +67,7 @@ impl Envelope {
#[instrument(level = "trace", target = "envelope", skip_all)] #[instrument(level = "trace", target = "envelope", skip_all)]
pub fn from_signed_data( pub fn from_signed_data(
crypto: Crypto, crypto: &Crypto,
data: &[u8], data: &[u8],
network_key: &Option<SharedSecret>, network_key: &Option<SharedSecret>,
) -> VeilidAPIResult<Envelope> { ) -> VeilidAPIResult<Envelope> {
@ -193,7 +193,7 @@ impl Envelope {
#[instrument(level = "trace", target = "envelope", skip_all)] #[instrument(level = "trace", target = "envelope", skip_all)]
pub fn decrypt_body( pub fn decrypt_body(
&self, &self,
crypto: Crypto, crypto: &Crypto,
data: &[u8], data: &[u8],
node_id_secret: &SecretKey, node_id_secret: &SecretKey,
network_key: &Option<SharedSecret>, network_key: &Option<SharedSecret>,
@ -226,7 +226,7 @@ impl Envelope {
#[instrument(level = "trace", target = "envelope", skip_all, err)] #[instrument(level = "trace", target = "envelope", skip_all, err)]
pub fn to_encrypted_data( pub fn to_encrypted_data(
&self, &self,
crypto: Crypto, crypto: &Crypto,
body: &[u8], body: &[u8],
node_id_secret: &SecretKey, node_id_secret: &SecretKey,
network_key: &Option<SharedSecret>, network_key: &Option<SharedSecret>,

View File

@ -0,0 +1,276 @@
use super::*;
/// Guard to access a particular cryptosystem
pub struct CryptoSystemGuard<'a> {
crypto_system: Arc<dyn CryptoSystem + Send + Sync>,
_phantom: core::marker::PhantomData<&'a (dyn CryptoSystem + Send + Sync)>,
}
impl<'a> CryptoSystemGuard<'a> {
pub(super) fn new(crypto_system: Arc<dyn CryptoSystem + Send + Sync>) -> Self {
Self {
crypto_system,
_phantom: PhantomData,
}
}
pub fn as_async(self) -> AsyncCryptoSystemGuard<'a> {
AsyncCryptoSystemGuard { guard: self }
}
}
impl<'a> core::ops::Deref for CryptoSystemGuard<'a> {
type Target = dyn CryptoSystem + Send + Sync;
fn deref(&self) -> &Self::Target {
self.crypto_system.as_ref()
}
}
/// Async cryptosystem guard to help break up heavy blocking operations
pub struct AsyncCryptoSystemGuard<'a> {
guard: CryptoSystemGuard<'a>,
}
async fn yielding<R, T: FnOnce() -> R>(x: T) -> R {
let out = x();
sleep(0).await;
out
}
impl<'a> AsyncCryptoSystemGuard<'a> {
// Accessors
pub fn kind(&self) -> CryptoKind {
self.guard.kind()
}
pub fn crypto(&self) -> VeilidComponentGuard<'_, Crypto> {
self.guard.crypto()
}
// Cached Operations
pub async fn cached_dh(
&self,
key: &PublicKey,
secret: &SecretKey,
) -> VeilidAPIResult<SharedSecret> {
yielding(|| self.guard.cached_dh(key, secret)).await
}
// Generation
pub async fn random_bytes(&self, len: u32) -> Vec<u8> {
yielding(|| self.guard.random_bytes(len)).await
}
pub fn default_salt_length(&self) -> u32 {
self.guard.default_salt_length()
}
pub async fn hash_password(&self, password: &[u8], salt: &[u8]) -> VeilidAPIResult<String> {
yielding(|| self.guard.hash_password(password, salt)).await
}
pub async fn verify_password(
&self,
password: &[u8],
password_hash: &str,
) -> VeilidAPIResult<bool> {
yielding(|| self.guard.verify_password(password, password_hash)).await
}
pub async fn derive_shared_secret(
&self,
password: &[u8],
salt: &[u8],
) -> VeilidAPIResult<SharedSecret> {
yielding(|| self.guard.derive_shared_secret(password, salt)).await
}
pub async fn random_nonce(&self) -> Nonce {
yielding(|| self.guard.random_nonce()).await
}
pub async fn random_shared_secret(&self) -> SharedSecret {
yielding(|| self.guard.random_shared_secret()).await
}
pub async fn compute_dh(
&self,
key: &PublicKey,
secret: &SecretKey,
) -> VeilidAPIResult<SharedSecret> {
yielding(|| self.guard.compute_dh(key, secret)).await
}
pub async fn generate_shared_secret(
&self,
key: &PublicKey,
secret: &SecretKey,
domain: &[u8],
) -> VeilidAPIResult<SharedSecret> {
let dh = self.compute_dh(key, secret).await?;
Ok(self
.generate_hash(&[&dh.bytes, domain, VEILID_DOMAIN_API].concat())
.await)
}
pub async fn generate_keypair(&self) -> KeyPair {
yielding(|| self.guard.generate_keypair()).await
}
pub async fn generate_hash(&self, data: &[u8]) -> HashDigest {
yielding(|| self.guard.generate_hash(data)).await
}
pub async fn generate_hash_reader(
&self,
reader: &mut dyn std::io::Read,
) -> VeilidAPIResult<HashDigest> {
yielding(|| self.guard.generate_hash_reader(reader)).await
}
// Validation
pub async fn validate_keypair(&self, key: &PublicKey, secret: &SecretKey) -> bool {
yielding(|| self.guard.validate_keypair(key, secret)).await
}
pub async fn validate_hash(&self, data: &[u8], hash: &HashDigest) -> bool {
yielding(|| self.guard.validate_hash(data, hash)).await
}
pub async fn validate_hash_reader(
&self,
reader: &mut dyn std::io::Read,
hash: &HashDigest,
) -> VeilidAPIResult<bool> {
yielding(|| self.guard.validate_hash_reader(reader, hash)).await
}
// Distance Metric
pub async fn distance(&self, key1: &CryptoKey, key2: &CryptoKey) -> CryptoKeyDistance {
yielding(|| self.guard.distance(key1, key2)).await
}
// Authentication
pub async fn sign(
&self,
key: &PublicKey,
secret: &SecretKey,
data: &[u8],
) -> VeilidAPIResult<Signature> {
yielding(|| self.guard.sign(key, secret, data)).await
}
pub async fn verify(
&self,
key: &PublicKey,
data: &[u8],
signature: &Signature,
) -> VeilidAPIResult<bool> {
yielding(|| self.guard.verify(key, data, signature)).await
}
// AEAD Encrypt/Decrypt
pub fn aead_overhead(&self) -> usize {
self.guard.aead_overhead()
}
pub async fn decrypt_in_place_aead(
&self,
body: &mut Vec<u8>,
nonce: &Nonce,
shared_secret: &SharedSecret,
associated_data: Option<&[u8]>,
) -> VeilidAPIResult<()> {
yielding(|| {
self.guard
.decrypt_in_place_aead(body, nonce, shared_secret, associated_data)
})
.await
}
pub async fn decrypt_aead(
&self,
body: &[u8],
nonce: &Nonce,
shared_secret: &SharedSecret,
associated_data: Option<&[u8]>,
) -> VeilidAPIResult<Vec<u8>> {
yielding(|| {
self.guard
.decrypt_aead(body, nonce, shared_secret, associated_data)
})
.await
}
pub async fn encrypt_in_place_aead(
&self,
body: &mut Vec<u8>,
nonce: &Nonce,
shared_secret: &SharedSecret,
associated_data: Option<&[u8]>,
) -> VeilidAPIResult<()> {
yielding(|| {
self.guard
.encrypt_in_place_aead(body, nonce, shared_secret, associated_data)
})
.await
}
pub async fn encrypt_aead(
&self,
body: &[u8],
nonce: &Nonce,
shared_secret: &SharedSecret,
associated_data: Option<&[u8]>,
) -> VeilidAPIResult<Vec<u8>> {
yielding(|| {
self.guard
.encrypt_aead(body, nonce, shared_secret, associated_data)
})
.await
}
// NoAuth Encrypt/Decrypt
pub async fn crypt_in_place_no_auth(
&self,
body: &mut [u8],
nonce: &[u8; NONCE_LENGTH],
shared_secret: &SharedSecret,
) {
yielding(|| {
self.guard
.crypt_in_place_no_auth(body, nonce, shared_secret)
})
.await
}
pub async fn crypt_b2b_no_auth(
&self,
in_buf: &[u8],
out_buf: &mut [u8],
nonce: &[u8; NONCE_LENGTH],
shared_secret: &SharedSecret,
) {
yielding(|| {
self.guard
.crypt_b2b_no_auth(in_buf, out_buf, nonce, shared_secret)
})
.await
}
pub async fn crypt_no_auth_aligned_8(
&self,
body: &[u8],
nonce: &[u8; NONCE_LENGTH],
shared_secret: &SharedSecret,
) -> Vec<u8> {
yielding(|| {
self.guard
.crypt_no_auth_aligned_8(body, nonce, shared_secret)
})
.await
}
pub async fn crypt_no_auth_unaligned(
&self,
body: &[u8],
nonce: &[u8; NONCE_LENGTH],
shared_secret: &SharedSecret,
) -> Vec<u8> {
yielding(|| {
self.guard
.crypt_no_auth_unaligned(body, nonce, shared_secret)
})
.await
}
}

View File

@ -1,6 +1,7 @@
mod blake3digest512; mod blake3digest512;
mod dh_cache; mod dh_cache;
mod envelope; mod envelope;
mod guard;
mod receipt; mod receipt;
mod types; mod types;
@ -16,6 +17,7 @@ pub use blake3digest512::*;
pub use crypto_system::*; pub use crypto_system::*;
pub use envelope::*; pub use envelope::*;
pub use guard::*;
pub use receipt::*; pub use receipt::*;
pub use types::*; pub use types::*;
@ -29,9 +31,7 @@ use core::convert::TryInto;
use dh_cache::*; use dh_cache::*;
use hashlink::linked_hash_map::Entry; use hashlink::linked_hash_map::Entry;
use hashlink::LruCache; use hashlink::LruCache;
use std::marker::PhantomData;
/// Handle to a particular cryptosystem
pub type CryptoSystemVersion = Arc<dyn CryptoSystem + Send + Sync>;
cfg_if! { cfg_if! {
if #[cfg(all(feature = "enable-crypto-none", feature = "enable-crypto-vld0"))] { if #[cfg(all(feature = "enable-crypto-none", feature = "enable-crypto-vld0"))] {
@ -72,23 +72,40 @@ pub fn best_envelope_version() -> EnvelopeVersion {
struct CryptoInner { struct CryptoInner {
dh_cache: DHCache, dh_cache: DHCache,
flush_future: Option<SendPinBoxFuture<()>>, flush_future: Option<SendPinBoxFuture<()>>,
#[cfg(feature = "enable-crypto-vld0")]
crypto_vld0: Option<Arc<dyn CryptoSystem + Send + Sync>>,
#[cfg(feature = "enable-crypto-none")]
crypto_none: Option<Arc<dyn CryptoSystem + Send + Sync>>,
} }
struct CryptoUnlockedInner { impl fmt::Debug for CryptoInner {
_event_bus: EventBus, fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
config: VeilidConfig, f.debug_struct("CryptoInner")
table_store: TableStore, //.field("dh_cache", &self.dh_cache)
// .field("flush_future", &self.flush_future)
// .field("crypto_vld0", &self.crypto_vld0)
// .field("crypto_none", &self.crypto_none)
.finish()
}
} }
/// Crypto factory implementation /// Crypto factory implementation
#[derive(Clone)]
pub struct Crypto { pub struct Crypto {
unlocked_inner: Arc<CryptoUnlockedInner>, registry: VeilidComponentRegistry,
inner: Arc<Mutex<CryptoInner>>, inner: Mutex<CryptoInner>,
#[cfg(feature = "enable-crypto-vld0")]
crypto_vld0: Arc<dyn CryptoSystem + Send + Sync>,
#[cfg(feature = "enable-crypto-none")]
crypto_none: Arc<dyn CryptoSystem + Send + Sync>,
}
impl_veilid_component!(Crypto);
impl fmt::Debug for Crypto {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Crypto")
//.field("registry", &self.registry)
.field("inner", &self.inner)
// .field("crypto_vld0", &self.crypto_vld0)
// .field("crypto_none", &self.crypto_none)
.finish()
}
} }
impl Crypto { impl Crypto {
@ -96,63 +113,43 @@ impl Crypto {
CryptoInner { CryptoInner {
dh_cache: DHCache::new(DH_CACHE_SIZE), dh_cache: DHCache::new(DH_CACHE_SIZE),
flush_future: None, flush_future: None,
}
}
pub fn new(registry: VeilidComponentRegistry) -> Self {
Self {
registry: registry.clone(),
inner: Mutex::new(Self::new_inner()),
#[cfg(feature = "enable-crypto-vld0")] #[cfg(feature = "enable-crypto-vld0")]
crypto_vld0: None, crypto_vld0: Arc::new(vld0::CryptoSystemVLD0::new(registry.clone())),
#[cfg(feature = "enable-crypto-none")] #[cfg(feature = "enable-crypto-none")]
crypto_none: None, crypto_none: Arc::new(none::CryptoSystemNONE::new(registry.clone())),
} }
} }
pub fn new(event_bus: EventBus, config: VeilidConfig, table_store: TableStore) -> Self {
let out = Self {
unlocked_inner: Arc::new(CryptoUnlockedInner {
_event_bus: event_bus,
config,
table_store,
}),
inner: Arc::new(Mutex::new(Self::new_inner())),
};
#[cfg(feature = "enable-crypto-vld0")]
{
out.inner.lock().crypto_vld0 = Some(Arc::new(vld0::CryptoSystemVLD0::new(out.clone())));
}
#[cfg(feature = "enable-crypto-none")]
{
out.inner.lock().crypto_none = Some(Arc::new(none::CryptoSystemNONE::new(out.clone())));
}
out
}
pub fn config(&self) -> VeilidConfig {
self.unlocked_inner.config.clone()
}
#[instrument(level = "trace", target = "crypto", skip_all, err)] #[instrument(level = "trace", target = "crypto", skip_all, err)]
pub async fn init(&self) -> EyreResult<()> { async fn init_async(&self) -> EyreResult<()> {
let table_store = self.unlocked_inner.table_store.clone(); // Nothing to initialize at this time
Ok(())
}
// Setup called by table store after it get initialized
#[instrument(level = "trace", target = "crypto", skip_all, err)]
pub(crate) async fn table_store_setup(&self, table_store: &TableStore) -> EyreResult<()> {
// Init node id from config // Init node id from config
if let Err(e) = self if let Err(e) = self.setup_node_ids(table_store).await {
.unlocked_inner
.config
.init_node_ids(self.clone(), table_store.clone())
.await
{
return Err(e).wrap_err("init node id failed"); return Err(e).wrap_err("init node id failed");
} }
// make local copy of node id for easy access // make local copy of node id for easy access
let mut cache_validity_key: Vec<u8> = Vec::new(); let mut cache_validity_key: Vec<u8> = Vec::new();
{ self.config().with(|c| {
let c = self.unlocked_inner.config.get();
for ck in VALID_CRYPTO_KINDS { for ck in VALID_CRYPTO_KINDS {
if let Some(nid) = c.network.routing_table.node_id.get(ck) { if let Some(nid) = c.network.routing_table.node_id.get(ck) {
cache_validity_key.append(&mut nid.value.bytes.to_vec()); cache_validity_key.append(&mut nid.value.bytes.to_vec());
} }
} }
}; });
// load caches if they are valid for this node id // load caches if they are valid for this node id
let mut db = table_store let mut db = table_store
@ -175,13 +172,17 @@ impl Crypto {
db.store(0, b"cache_validity_key", &cache_validity_key) db.store(0, b"cache_validity_key", &cache_validity_key)
.await?; .await?;
} }
Ok(())
}
#[instrument(level = "trace", target = "crypto", skip_all, err)]
async fn post_init_async(&self) -> EyreResult<()> {
// Schedule flushing // Schedule flushing
let this = self.clone(); let registry = self.registry();
let flush_future = interval("crypto flush", 60000, move || { let flush_future = interval("crypto flush", 60000, move || {
let this = this.clone(); let crypto = registry.crypto();
async move { async move {
if let Err(e) = this.flush().await { if let Err(e) = crypto.flush().await {
warn!("flush failed: {}", e); warn!("flush failed: {}", e);
} }
} }
@ -197,16 +198,12 @@ impl Crypto {
cache_to_bytes(&inner.dh_cache) cache_to_bytes(&inner.dh_cache)
}; };
let db = self let db = self.table_store().open("crypto_caches", 1).await?;
.unlocked_inner
.table_store
.open("crypto_caches", 1)
.await?;
db.store(0, b"dh_cache", &cache_bytes).await?; db.store(0, b"dh_cache", &cache_bytes).await?;
Ok(()) Ok(())
} }
pub async fn terminate(&self) { async fn pre_terminate_async(&self) {
let flush_future = self.inner.lock().flush_future.take(); let flush_future = self.inner.lock().flush_future.take();
if let Some(f) = flush_future { if let Some(f) = flush_future {
f.await; f.await;
@ -222,23 +219,36 @@ impl Crypto {
}; };
} }
async fn terminate_async(&self) {
// Nothing to terminate at this time
}
/// Factory method to get a specific crypto version /// Factory method to get a specific crypto version
pub fn get(&self, kind: CryptoKind) -> Option<CryptoSystemVersion> { pub fn get(&self, kind: CryptoKind) -> Option<CryptoSystemGuard<'_>> {
let inner = self.inner.lock();
match kind { match kind {
#[cfg(feature = "enable-crypto-vld0")] #[cfg(feature = "enable-crypto-vld0")]
CRYPTO_KIND_VLD0 => Some(inner.crypto_vld0.clone().unwrap()), CRYPTO_KIND_VLD0 => Some(CryptoSystemGuard::new(self.crypto_vld0.clone())),
#[cfg(feature = "enable-crypto-none")] #[cfg(feature = "enable-crypto-none")]
CRYPTO_KIND_NONE => Some(inner.crypto_none.clone().unwrap()), CRYPTO_KIND_NONE => Some(CryptoSystemGuard::new(self.crypto_none.clone())),
_ => None, _ => None,
} }
} }
/// Factory method to get a specific crypto version for async use
pub fn get_async(&self, kind: CryptoKind) -> Option<AsyncCryptoSystemGuard<'_>> {
self.get(kind).map(|x| x.as_async())
}
// Factory method to get the best crypto version // Factory method to get the best crypto version
pub fn best(&self) -> CryptoSystemVersion { pub fn best(&self) -> CryptoSystemGuard<'_> {
self.get(best_crypto_kind()).unwrap() self.get(best_crypto_kind()).unwrap()
} }
// Factory method to get the best crypto version for async use
pub fn best_async(&self) -> AsyncCryptoSystemGuard<'_> {
self.get_async(best_crypto_kind()).unwrap()
}
/// Signature set verification /// Signature set verification
/// Returns Some() the set of signature cryptokinds that validate and are supported /// Returns Some() the set of signature cryptokinds that validate and are supported
/// Returns None if any cryptokinds are supported and do not validate /// Returns None if any cryptokinds are supported and do not validate
@ -331,4 +341,120 @@ impl Crypto {
} }
Ok(()) Ok(())
} }
#[cfg(not(test))]
async fn setup_node_id(
&self,
vcrypto: AsyncCryptoSystemGuard<'_>,
table_store: &TableStore,
) -> VeilidAPIResult<(TypedKey, TypedSecret)> {
let config = self.config();
let ck = vcrypto.kind();
let (mut node_id, mut node_id_secret) = config.with(|c| {
(
c.network.routing_table.node_id.get(ck),
c.network.routing_table.node_id_secret.get(ck),
)
});
// See if node id was previously stored in the table store
let config_table = table_store.open("__veilid_config", 1).await?;
let table_key_node_id = format!("node_id_{}", ck);
let table_key_node_id_secret = format!("node_id_secret_{}", ck);
if node_id.is_none() {
log_crypto!(debug "pulling {} from storage", table_key_node_id);
if let Ok(Some(stored_node_id)) = config_table
.load_json::<TypedKey>(0, table_key_node_id.as_bytes())
.await
{
log_crypto!(debug "{} found in storage", table_key_node_id);
node_id = Some(stored_node_id);
} else {
log_crypto!(debug "{} not found in storage", table_key_node_id);
}
}
// See if node id secret was previously stored in the protected store
if node_id_secret.is_none() {
log_crypto!(debug "pulling {} from storage", table_key_node_id_secret);
if let Ok(Some(stored_node_id_secret)) = config_table
.load_json::<TypedSecret>(0, table_key_node_id_secret.as_bytes())
.await
{
log_crypto!(debug "{} found in storage", table_key_node_id_secret);
node_id_secret = Some(stored_node_id_secret);
} else {
log_crypto!(debug "{} not found in storage", table_key_node_id_secret);
}
}
// If we have a node id from storage, check it
let (node_id, node_id_secret) =
if let (Some(node_id), Some(node_id_secret)) = (node_id, node_id_secret) {
// Validate node id
if !vcrypto
.validate_keypair(&node_id.value, &node_id_secret.value)
.await
{
apibail_generic!(format!(
"node_id_secret_{} and node_id_key_{} don't match",
ck, ck
));
}
(node_id, node_id_secret)
} else {
// If we still don't have a valid node id, generate one
log_crypto!(debug "generating new node_id_{}", ck);
let kp = vcrypto.generate_keypair().await;
(TypedKey::new(ck, kp.key), TypedSecret::new(ck, kp.secret))
};
info!("Node Id: {}", node_id);
// Save the node id / secret in storage
config_table
.store_json(0, table_key_node_id.as_bytes(), &node_id)
.await?;
config_table
.store_json(0, table_key_node_id_secret.as_bytes(), &node_id_secret)
.await?;
Ok((node_id, node_id_secret))
}
/// Get the node id from config if one is specified.
/// Must be done -after- protected store is initialized, during table store init
#[cfg_attr(test, allow(unused_variables))]
async fn setup_node_ids(&self, table_store: &TableStore) -> VeilidAPIResult<()> {
let mut out_node_id = TypedKeyGroup::new();
let mut out_node_id_secret = TypedSecretGroup::new();
for ck in VALID_CRYPTO_KINDS {
let vcrypto = self
.get_async(ck)
.expect("Valid crypto kind is not actually valid.");
#[cfg(test)]
let (node_id, node_id_secret) = {
let kp = vcrypto.generate_keypair().await;
(TypedKey::new(ck, kp.key), TypedSecret::new(ck, kp.secret))
};
#[cfg(not(test))]
let (node_id, node_id_secret) = self.setup_node_id(vcrypto, table_store).await?;
// Save for config
out_node_id.add(node_id);
out_node_id_secret.add(node_id_secret);
}
// Commit back to config
self.config().try_with_mut(|c| {
c.network.routing_table.node_id = out_node_id;
c.network.routing_table.node_id_secret = out_node_id_secret;
Ok(())
})?;
Ok(())
}
} }

View File

@ -49,14 +49,13 @@ fn is_bytes_eq_32(a: &[u8], v: u8) -> bool {
} }
/// None CryptoSystem /// None CryptoSystem
#[derive(Clone)]
pub struct CryptoSystemNONE { pub struct CryptoSystemNONE {
crypto: Crypto, registry: VeilidComponentRegistry,
} }
impl CryptoSystemNONE { impl CryptoSystemNONE {
pub fn new(crypto: Crypto) -> Self { pub fn new(registry: VeilidComponentRegistry) -> Self {
Self { crypto } Self { registry }
} }
} }
@ -66,13 +65,13 @@ impl CryptoSystem for CryptoSystemNONE {
CRYPTO_KIND_NONE CRYPTO_KIND_NONE
} }
fn crypto(&self) -> Crypto { fn crypto(&self) -> VeilidComponentGuard<'_, Crypto> {
self.crypto.clone() self.registry().lookup::<Crypto>().unwrap()
} }
// Cached Operations // Cached Operations
fn cached_dh(&self, key: &PublicKey, secret: &SecretKey) -> VeilidAPIResult<SharedSecret> { fn cached_dh(&self, key: &PublicKey, secret: &SecretKey) -> VeilidAPIResult<SharedSecret> {
self.crypto self.crypto()
.cached_dh_internal::<CryptoSystemNONE>(self, key, secret) .cached_dh_internal::<CryptoSystemNONE>(self, key, secret)
} }

View File

@ -68,7 +68,7 @@ impl Receipt {
} }
#[instrument(level = "trace", target = "receipt", skip_all, err)] #[instrument(level = "trace", target = "receipt", skip_all, err)]
pub fn from_signed_data(crypto: Crypto, data: &[u8]) -> VeilidAPIResult<Receipt> { pub fn from_signed_data(crypto: &Crypto, data: &[u8]) -> VeilidAPIResult<Receipt> {
// Ensure we are at least the length of the envelope // Ensure we are at least the length of the envelope
if data.len() < MIN_RECEIPT_SIZE { if data.len() < MIN_RECEIPT_SIZE {
apibail_parse_error!("receipt too small", data.len()); apibail_parse_error!("receipt too small", data.len());
@ -157,7 +157,7 @@ impl Receipt {
} }
#[instrument(level = "trace", target = "receipt", skip_all, err)] #[instrument(level = "trace", target = "receipt", skip_all, err)]
pub fn to_signed_data(&self, crypto: Crypto, secret: &SecretKey) -> VeilidAPIResult<Vec<u8>> { pub fn to_signed_data(&self, crypto: &Crypto, secret: &SecretKey) -> VeilidAPIResult<Vec<u8>> {
// Ensure extra data isn't too long // Ensure extra data isn't too long
let receipt_size: usize = self.extra_data.len() + MIN_RECEIPT_SIZE; let receipt_size: usize = self.extra_data.len() + MIN_RECEIPT_SIZE;
if receipt_size > MAX_RECEIPT_SIZE { if receipt_size > MAX_RECEIPT_SIZE {

View File

@ -2,20 +2,20 @@ use super::*;
static LOREM_IPSUM:&[u8] = b"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. "; static LOREM_IPSUM:&[u8] = b"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. ";
pub async fn test_aead(vcrypto: CryptoSystemVersion) { pub async fn test_aead(vcrypto: &AsyncCryptoSystemGuard<'_>) {
trace!("test_aead"); trace!("test_aead");
let n1 = vcrypto.random_nonce(); let n1 = vcrypto.random_nonce().await;
let n2 = loop { let n2 = loop {
let n = vcrypto.random_nonce(); let n = vcrypto.random_nonce().await;
if n != n1 { if n != n1 {
break n; break n;
} }
}; };
let ss1 = vcrypto.random_shared_secret(); let ss1 = vcrypto.random_shared_secret().await;
let ss2 = loop { let ss2 = loop {
let ss = vcrypto.random_shared_secret(); let ss = vcrypto.random_shared_secret().await;
if ss != ss1 { if ss != ss1 {
break ss; break ss;
} }
@ -27,6 +27,7 @@ pub async fn test_aead(vcrypto: CryptoSystemVersion) {
assert!( assert!(
vcrypto vcrypto
.encrypt_in_place_aead(&mut body, &n1, &ss1, None) .encrypt_in_place_aead(&mut body, &n1, &ss1, None)
.await
.is_ok(), .is_ok(),
"encrypt should succeed" "encrypt should succeed"
); );
@ -41,6 +42,7 @@ pub async fn test_aead(vcrypto: CryptoSystemVersion) {
assert!( assert!(
vcrypto vcrypto
.decrypt_in_place_aead(&mut body, &n1, &ss1, None) .decrypt_in_place_aead(&mut body, &n1, &ss1, None)
.await
.is_ok(), .is_ok(),
"decrypt should succeed" "decrypt should succeed"
); );
@ -49,6 +51,7 @@ pub async fn test_aead(vcrypto: CryptoSystemVersion) {
assert!( assert!(
vcrypto vcrypto
.decrypt_in_place_aead(&mut body3, &n2, &ss1, None) .decrypt_in_place_aead(&mut body3, &n2, &ss1, None)
.await
.is_err(), .is_err(),
"decrypt with wrong nonce should fail" "decrypt with wrong nonce should fail"
); );
@ -57,6 +60,7 @@ pub async fn test_aead(vcrypto: CryptoSystemVersion) {
assert!( assert!(
vcrypto vcrypto
.decrypt_in_place_aead(&mut body4, &n1, &ss2, None) .decrypt_in_place_aead(&mut body4, &n1, &ss2, None)
.await
.is_err(), .is_err(),
"decrypt with wrong secret should fail" "decrypt with wrong secret should fail"
); );
@ -65,37 +69,47 @@ pub async fn test_aead(vcrypto: CryptoSystemVersion) {
assert!( assert!(
vcrypto vcrypto
.decrypt_in_place_aead(&mut body5, &n1, &ss2, Some(b"foobar")) .decrypt_in_place_aead(&mut body5, &n1, &ss2, Some(b"foobar"))
.await
.is_err(), .is_err(),
"decrypt with wrong associated data should fail" "decrypt with wrong associated data should fail"
); );
assert_ne!(body5, body, "failure changes data"); assert_ne!(body5, body, "failure changes data");
assert!( assert!(
vcrypto.decrypt_aead(LOREM_IPSUM, &n1, &ss1, None).is_err(), vcrypto
.decrypt_aead(LOREM_IPSUM, &n1, &ss1, None)
.await
.is_err(),
"should fail authentication" "should fail authentication"
); );
let body5 = vcrypto.encrypt_aead(LOREM_IPSUM, &n1, &ss1, None).unwrap(); let body5 = vcrypto
let body6 = vcrypto.decrypt_aead(&body5, &n1, &ss1, None).unwrap(); .encrypt_aead(LOREM_IPSUM, &n1, &ss1, None)
let body7 = vcrypto.encrypt_aead(LOREM_IPSUM, &n1, &ss1, None).unwrap(); .await
.unwrap();
let body6 = vcrypto.decrypt_aead(&body5, &n1, &ss1, None).await.unwrap();
let body7 = vcrypto
.encrypt_aead(LOREM_IPSUM, &n1, &ss1, None)
.await
.unwrap();
assert_eq!(body6, LOREM_IPSUM); assert_eq!(body6, LOREM_IPSUM);
assert_eq!(body5, body7); assert_eq!(body5, body7);
} }
pub async fn test_no_auth(vcrypto: CryptoSystemVersion) { pub async fn test_no_auth(vcrypto: &AsyncCryptoSystemGuard<'_>) {
trace!("test_no_auth"); trace!("test_no_auth");
let n1 = vcrypto.random_nonce(); let n1 = vcrypto.random_nonce().await;
let n2 = loop { let n2 = loop {
let n = vcrypto.random_nonce(); let n = vcrypto.random_nonce().await;
if n != n1 { if n != n1 {
break n; break n;
} }
}; };
let ss1 = vcrypto.random_shared_secret(); let ss1 = vcrypto.random_shared_secret().await;
let ss2 = loop { let ss2 = loop {
let ss = vcrypto.random_shared_secret(); let ss = vcrypto.random_shared_secret().await;
if ss != ss1 { if ss != ss1 {
break ss; break ss;
} }
@ -104,7 +118,7 @@ pub async fn test_no_auth(vcrypto: CryptoSystemVersion) {
let mut body = LOREM_IPSUM.to_vec(); let mut body = LOREM_IPSUM.to_vec();
let body2 = body.clone(); let body2 = body.clone();
let size_before_encrypt = body.len(); let size_before_encrypt = body.len();
vcrypto.crypt_in_place_no_auth(&mut body, &n1, &ss1); vcrypto.crypt_in_place_no_auth(&mut body, &n1, &ss1).await;
let size_after_encrypt = body.len(); let size_after_encrypt = body.len();
assert_eq!( assert_eq!(
@ -114,49 +128,69 @@ pub async fn test_no_auth(vcrypto: CryptoSystemVersion) {
let mut body3 = body.clone(); let mut body3 = body.clone();
let mut body4 = body.clone(); let mut body4 = body.clone();
vcrypto.crypt_in_place_no_auth(&mut body, &n1, &ss1); vcrypto.crypt_in_place_no_auth(&mut body, &n1, &ss1).await;
assert_eq!(body, body2, "result after decrypt should be the same"); assert_eq!(body, body2, "result after decrypt should be the same");
vcrypto.crypt_in_place_no_auth(&mut body3, &n2, &ss1); vcrypto.crypt_in_place_no_auth(&mut body3, &n2, &ss1).await;
assert_ne!(body3, body, "decrypt should not be equal with wrong nonce"); assert_ne!(body3, body, "decrypt should not be equal with wrong nonce");
vcrypto.crypt_in_place_no_auth(&mut body4, &n1, &ss2); vcrypto.crypt_in_place_no_auth(&mut body4, &n1, &ss2).await;
assert_ne!(body4, body, "decrypt should not be equal with wrong secret"); assert_ne!(body4, body, "decrypt should not be equal with wrong secret");
let body5 = vcrypto.crypt_no_auth_unaligned(LOREM_IPSUM, &n1, &ss1); let body5 = vcrypto
let body6 = vcrypto.crypt_no_auth_unaligned(&body5, &n1, &ss1); .crypt_no_auth_unaligned(LOREM_IPSUM, &n1, &ss1)
let body7 = vcrypto.crypt_no_auth_unaligned(LOREM_IPSUM, &n1, &ss1); .await;
let body6 = vcrypto.crypt_no_auth_unaligned(&body5, &n1, &ss1).await;
let body7 = vcrypto
.crypt_no_auth_unaligned(LOREM_IPSUM, &n1, &ss1)
.await;
assert_eq!(body6, LOREM_IPSUM); assert_eq!(body6, LOREM_IPSUM);
assert_eq!(body5, body7); assert_eq!(body5, body7);
let body5 = vcrypto.crypt_no_auth_aligned_8(LOREM_IPSUM, &n1, &ss1); let body5 = vcrypto
let body6 = vcrypto.crypt_no_auth_aligned_8(&body5, &n1, &ss1); .crypt_no_auth_aligned_8(LOREM_IPSUM, &n1, &ss1)
let body7 = vcrypto.crypt_no_auth_aligned_8(LOREM_IPSUM, &n1, &ss1); .await;
let body6 = vcrypto.crypt_no_auth_aligned_8(&body5, &n1, &ss1).await;
let body7 = vcrypto
.crypt_no_auth_aligned_8(LOREM_IPSUM, &n1, &ss1)
.await;
assert_eq!(body6, LOREM_IPSUM); assert_eq!(body6, LOREM_IPSUM);
assert_eq!(body5, body7); assert_eq!(body5, body7);
} }
pub async fn test_dh(vcrypto: CryptoSystemVersion) { pub async fn test_dh(vcrypto: &AsyncCryptoSystemGuard<'_>) {
trace!("test_dh"); trace!("test_dh");
let (dht_key, dht_key_secret) = vcrypto.generate_keypair().into_split(); let (dht_key, dht_key_secret) = vcrypto.generate_keypair().await.into_split();
assert!(vcrypto.validate_keypair(&dht_key, &dht_key_secret)); assert!(vcrypto.validate_keypair(&dht_key, &dht_key_secret).await);
let (dht_key2, dht_key_secret2) = vcrypto.generate_keypair().into_split(); let (dht_key2, dht_key_secret2) = vcrypto.generate_keypair().await.into_split();
assert!(vcrypto.validate_keypair(&dht_key2, &dht_key_secret2)); assert!(vcrypto.validate_keypair(&dht_key2, &dht_key_secret2).await);
let r1 = vcrypto.compute_dh(&dht_key, &dht_key_secret2).unwrap(); let r1 = vcrypto
let r2 = vcrypto.compute_dh(&dht_key2, &dht_key_secret).unwrap(); .compute_dh(&dht_key, &dht_key_secret2)
let r3 = vcrypto.compute_dh(&dht_key, &dht_key_secret2).unwrap(); .await
let r4 = vcrypto.compute_dh(&dht_key2, &dht_key_secret).unwrap(); .unwrap();
let r2 = vcrypto
.compute_dh(&dht_key2, &dht_key_secret)
.await
.unwrap();
let r3 = vcrypto
.compute_dh(&dht_key, &dht_key_secret2)
.await
.unwrap();
let r4 = vcrypto
.compute_dh(&dht_key2, &dht_key_secret)
.await
.unwrap();
assert_eq!(r1, r2); assert_eq!(r1, r2);
assert_eq!(r3, r4); assert_eq!(r3, r4);
assert_eq!(r2, r3); assert_eq!(r2, r3);
trace!("dh: {:?}", r1); trace!("dh: {:?}", r1);
// test cache // test cache
let r5 = vcrypto.cached_dh(&dht_key, &dht_key_secret2).unwrap(); let r5 = vcrypto.cached_dh(&dht_key, &dht_key_secret2).await.unwrap();
let r6 = vcrypto.cached_dh(&dht_key2, &dht_key_secret).unwrap(); let r6 = vcrypto.cached_dh(&dht_key2, &dht_key_secret).await.unwrap();
let r7 = vcrypto.cached_dh(&dht_key, &dht_key_secret2).unwrap(); let r7 = vcrypto.cached_dh(&dht_key, &dht_key_secret2).await.unwrap();
let r8 = vcrypto.cached_dh(&dht_key2, &dht_key_secret).unwrap(); let r8 = vcrypto.cached_dh(&dht_key2, &dht_key_secret).await.unwrap();
assert_eq!(r1, r5); assert_eq!(r1, r5);
assert_eq!(r2, r6); assert_eq!(r2, r6);
assert_eq!(r3, r7); assert_eq!(r3, r7);
@ -164,63 +198,67 @@ pub async fn test_dh(vcrypto: CryptoSystemVersion) {
trace!("cached_dh: {:?}", r5); trace!("cached_dh: {:?}", r5);
} }
pub async fn test_generation(vcrypto: CryptoSystemVersion) { pub async fn test_generation(vcrypto: &AsyncCryptoSystemGuard<'_>) {
let b1 = vcrypto.random_bytes(32); let b1 = vcrypto.random_bytes(32).await;
let b2 = vcrypto.random_bytes(32); let b2 = vcrypto.random_bytes(32).await;
assert_ne!(b1, b2); assert_ne!(b1, b2);
assert_eq!(b1.len(), 32); assert_eq!(b1.len(), 32);
assert_eq!(b2.len(), 32); assert_eq!(b2.len(), 32);
let b3 = vcrypto.random_bytes(0); let b3 = vcrypto.random_bytes(0).await;
let b4 = vcrypto.random_bytes(0); let b4 = vcrypto.random_bytes(0).await;
assert_eq!(b3, b4); assert_eq!(b3, b4);
assert_eq!(b3.len(), 0); assert_eq!(b3.len(), 0);
assert_ne!(vcrypto.default_salt_length(), 0); assert_ne!(vcrypto.default_salt_length(), 0);
let pstr1 = vcrypto.hash_password(b"abc123", b"qwerasdf").unwrap(); let pstr1 = vcrypto.hash_password(b"abc123", b"qwerasdf").await.unwrap();
let pstr2 = vcrypto.hash_password(b"abc123", b"qwerasdf").unwrap(); let pstr2 = vcrypto.hash_password(b"abc123", b"qwerasdf").await.unwrap();
assert_eq!(pstr1, pstr2); assert_eq!(pstr1, pstr2);
let pstr3 = vcrypto.hash_password(b"abc123", b"qwerasdg").unwrap(); let pstr3 = vcrypto.hash_password(b"abc123", b"qwerasdg").await.unwrap();
assert_ne!(pstr1, pstr3); assert_ne!(pstr1, pstr3);
let pstr4 = vcrypto.hash_password(b"abc124", b"qwerasdf").unwrap(); let pstr4 = vcrypto.hash_password(b"abc124", b"qwerasdf").await.unwrap();
assert_ne!(pstr1, pstr4); assert_ne!(pstr1, pstr4);
let pstr5 = vcrypto.hash_password(b"abc124", b"qwerasdg").unwrap(); let pstr5 = vcrypto.hash_password(b"abc124", b"qwerasdg").await.unwrap();
assert_ne!(pstr3, pstr5); assert_ne!(pstr3, pstr5);
vcrypto vcrypto
.hash_password(b"abc123", b"qwe") .hash_password(b"abc123", b"qwe")
.await
.expect_err("should reject short salt"); .expect_err("should reject short salt");
vcrypto vcrypto
.hash_password( .hash_password(
b"abc123", b"abc123",
b"qwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerz", b"qwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerz",
) )
.await
.expect_err("should reject long salt"); .expect_err("should reject long salt");
assert!(vcrypto.verify_password(b"abc123", &pstr1).unwrap()); assert!(vcrypto.verify_password(b"abc123", &pstr1).await.unwrap());
assert!(vcrypto.verify_password(b"abc123", &pstr2).unwrap()); assert!(vcrypto.verify_password(b"abc123", &pstr2).await.unwrap());
assert!(vcrypto.verify_password(b"abc123", &pstr3).unwrap()); assert!(vcrypto.verify_password(b"abc123", &pstr3).await.unwrap());
assert!(!vcrypto.verify_password(b"abc123", &pstr4).unwrap()); assert!(!vcrypto.verify_password(b"abc123", &pstr4).await.unwrap());
assert!(!vcrypto.verify_password(b"abc123", &pstr5).unwrap()); assert!(!vcrypto.verify_password(b"abc123", &pstr5).await.unwrap());
let ss1 = vcrypto.derive_shared_secret(b"abc123", b"qwerasdf"); let ss1 = vcrypto.derive_shared_secret(b"abc123", b"qwerasdf").await;
let ss2 = vcrypto.derive_shared_secret(b"abc123", b"qwerasdf"); let ss2 = vcrypto.derive_shared_secret(b"abc123", b"qwerasdf").await;
assert_eq!(ss1, ss2); assert_eq!(ss1, ss2);
let ss3 = vcrypto.derive_shared_secret(b"abc123", b"qwerasdg"); let ss3 = vcrypto.derive_shared_secret(b"abc123", b"qwerasdg").await;
assert_ne!(ss1, ss3); assert_ne!(ss1, ss3);
let ss4 = vcrypto.derive_shared_secret(b"abc124", b"qwerasdf"); let ss4 = vcrypto.derive_shared_secret(b"abc124", b"qwerasdf").await;
assert_ne!(ss1, ss4); assert_ne!(ss1, ss4);
let ss5 = vcrypto.derive_shared_secret(b"abc124", b"qwerasdg"); let ss5 = vcrypto.derive_shared_secret(b"abc124", b"qwerasdg").await;
assert_ne!(ss3, ss5); assert_ne!(ss3, ss5);
vcrypto vcrypto
.derive_shared_secret(b"abc123", b"qwe") .derive_shared_secret(b"abc123", b"qwe")
.await
.expect_err("should reject short salt"); .expect_err("should reject short salt");
vcrypto vcrypto
.derive_shared_secret( .derive_shared_secret(
b"abc123", b"abc123",
b"qwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerz", b"qwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerz",
) )
.await
.expect_err("should reject long salt"); .expect_err("should reject long salt");
} }
@ -230,11 +268,11 @@ pub async fn test_all() {
// Test versions // Test versions
for v in VALID_CRYPTO_KINDS { for v in VALID_CRYPTO_KINDS {
let vcrypto = crypto.get(v).unwrap(); let vcrypto = crypto.get_async(v).unwrap();
test_aead(vcrypto.clone()).await; test_aead(&vcrypto).await;
test_no_auth(vcrypto.clone()).await; test_no_auth(&vcrypto).await;
test_dh(vcrypto.clone()).await; test_dh(&vcrypto).await;
test_generation(vcrypto).await; test_generation(&vcrypto).await;
} }
crypto_tests_shutdown(api.clone()).await; crypto_tests_shutdown(api.clone()).await;

View File

@ -2,9 +2,10 @@ use super::*;
pub async fn test_envelope_round_trip( pub async fn test_envelope_round_trip(
envelope_version: EnvelopeVersion, envelope_version: EnvelopeVersion,
vcrypto: CryptoSystemVersion, vcrypto: &AsyncCryptoSystemGuard<'_>,
network_key: Option<SharedSecret>, network_key: Option<SharedSecret>,
) { ) {
let crypto = vcrypto.crypto();
if network_key.is_some() { if network_key.is_some() {
info!( info!(
"--- test envelope round trip {} w/network key ---", "--- test envelope round trip {} w/network key ---",
@ -16,9 +17,9 @@ pub async fn test_envelope_round_trip(
// Create envelope // Create envelope
let ts = Timestamp::from(0x12345678ABCDEF69u64); let ts = Timestamp::from(0x12345678ABCDEF69u64);
let nonce = vcrypto.random_nonce(); let nonce = vcrypto.random_nonce().await;
let (sender_id, sender_secret) = vcrypto.generate_keypair().into_split(); let (sender_id, sender_secret) = vcrypto.generate_keypair().await.into_split();
let (recipient_id, recipient_secret) = vcrypto.generate_keypair().into_split(); let (recipient_id, recipient_secret) = vcrypto.generate_keypair().await.into_split();
let envelope = Envelope::new( let envelope = Envelope::new(
envelope_version, envelope_version,
vcrypto.kind(), vcrypto.kind(),
@ -33,15 +34,15 @@ pub async fn test_envelope_round_trip(
// Serialize to bytes // Serialize to bytes
let enc_data = envelope let enc_data = envelope
.to_encrypted_data(vcrypto.crypto(), body, &sender_secret, &network_key) .to_encrypted_data(&crypto, body, &sender_secret, &network_key)
.expect("failed to encrypt data"); .expect("failed to encrypt data");
// Deserialize from bytes // Deserialize from bytes
let envelope2 = Envelope::from_signed_data(vcrypto.crypto(), &enc_data, &network_key) let envelope2 = Envelope::from_signed_data(&crypto, &enc_data, &network_key)
.expect("failed to deserialize envelope from data"); .expect("failed to deserialize envelope from data");
let body2 = envelope2 let body2 = envelope2
.decrypt_body(vcrypto.crypto(), &enc_data, &recipient_secret, &network_key) .decrypt_body(&crypto, &enc_data, &recipient_secret, &network_key)
.expect("failed to decrypt envelope body"); .expect("failed to decrypt envelope body");
// Compare envelope and body // Compare envelope and body
@ -53,43 +54,44 @@ pub async fn test_envelope_round_trip(
let mut mod_enc_data = enc_data.clone(); let mut mod_enc_data = enc_data.clone();
mod_enc_data[enc_data_len - 1] ^= 0x80u8; mod_enc_data[enc_data_len - 1] ^= 0x80u8;
assert!( assert!(
Envelope::from_signed_data(vcrypto.crypto(), &mod_enc_data, &network_key).is_err(), Envelope::from_signed_data(&crypto, &mod_enc_data, &network_key).is_err(),
"should have failed to decode envelope with modified signature" "should have failed to decode envelope with modified signature"
); );
let mut mod_enc_data2 = enc_data.clone(); let mut mod_enc_data2 = enc_data.clone();
mod_enc_data2[enc_data_len - 65] ^= 0x80u8; mod_enc_data2[enc_data_len - 65] ^= 0x80u8;
assert!( assert!(
Envelope::from_signed_data(vcrypto.crypto(), &mod_enc_data2, &network_key).is_err(), Envelope::from_signed_data(&crypto, &mod_enc_data2, &network_key).is_err(),
"should have failed to decode envelope with modified data" "should have failed to decode envelope with modified data"
); );
} }
pub async fn test_receipt_round_trip( pub async fn test_receipt_round_trip(
envelope_version: EnvelopeVersion, envelope_version: EnvelopeVersion,
vcrypto: CryptoSystemVersion, vcrypto: &AsyncCryptoSystemGuard<'_>,
) { ) {
let crypto = vcrypto.crypto();
info!("--- test receipt round trip ---"); info!("--- test receipt round trip ---");
// Create arbitrary body // Create arbitrary body
let body = b"This is an arbitrary body"; let body = b"This is an arbitrary body";
// Create receipt // Create receipt
let nonce = vcrypto.random_nonce(); let nonce = vcrypto.random_nonce().await;
let (sender_id, sender_secret) = vcrypto.generate_keypair().into_split(); let (sender_id, sender_secret) = vcrypto.generate_keypair().await.into_split();
let receipt = Receipt::try_new(envelope_version, vcrypto.kind(), nonce, sender_id, body) let receipt = Receipt::try_new(envelope_version, vcrypto.kind(), nonce, sender_id, body)
.expect("should not fail"); .expect("should not fail");
// Serialize to bytes // Serialize to bytes
let mut enc_data = receipt let mut enc_data = receipt
.to_signed_data(vcrypto.crypto(), &sender_secret) .to_signed_data(&crypto, &sender_secret)
.expect("failed to make signed data"); .expect("failed to make signed data");
// Deserialize from bytes // Deserialize from bytes
let receipt2 = Receipt::from_signed_data(vcrypto.crypto(), &enc_data) let receipt2 = Receipt::from_signed_data(&crypto, &enc_data)
.expect("failed to deserialize envelope from data"); .expect("failed to deserialize envelope from data");
// Should not validate even when a single bit is changed // Should not validate even when a single bit is changed
enc_data[5] = 0x01; enc_data[5] = 0x01;
Receipt::from_signed_data(vcrypto.crypto(), &enc_data) Receipt::from_signed_data(&crypto, &enc_data)
.expect_err("should have failed to decrypt using wrong secret"); .expect_err("should have failed to decrypt using wrong secret");
// Compare receipts // Compare receipts
@ -103,12 +105,12 @@ pub async fn test_all() {
// Test versions // Test versions
for ev in VALID_ENVELOPE_VERSIONS { for ev in VALID_ENVELOPE_VERSIONS {
for v in VALID_CRYPTO_KINDS { for v in VALID_CRYPTO_KINDS {
let vcrypto = crypto.get(v).unwrap(); let vcrypto = crypto.get_async(v).unwrap();
test_envelope_round_trip(ev, vcrypto.clone(), None).await; test_envelope_round_trip(ev, &vcrypto, None).await;
test_envelope_round_trip(ev, vcrypto.clone(), Some(vcrypto.random_shared_secret())) test_envelope_round_trip(ev, &vcrypto, Some(vcrypto.random_shared_secret().await))
.await; .await;
test_receipt_round_trip(ev, vcrypto).await; test_receipt_round_trip(ev, &vcrypto).await;
} }
} }

View File

@ -6,10 +6,10 @@ static CHEEZBURGER: &str = "I can has cheezburger";
static EMPTY_KEY: [u8; PUBLIC_KEY_LENGTH] = [0u8; PUBLIC_KEY_LENGTH]; static EMPTY_KEY: [u8; PUBLIC_KEY_LENGTH] = [0u8; PUBLIC_KEY_LENGTH];
static EMPTY_KEY_SECRET: [u8; SECRET_KEY_LENGTH] = [0u8; SECRET_KEY_LENGTH]; static EMPTY_KEY_SECRET: [u8; SECRET_KEY_LENGTH] = [0u8; SECRET_KEY_LENGTH];
pub async fn test_generate_secret(vcrypto: CryptoSystemVersion) { pub async fn test_generate_secret(vcrypto: &AsyncCryptoSystemGuard<'_>) {
// Verify keys generate // Verify keys generate
let (dht_key, dht_key_secret) = vcrypto.generate_keypair().into_split(); let (dht_key, dht_key_secret) = vcrypto.generate_keypair().await.into_split();
let (dht_key2, dht_key_secret2) = vcrypto.generate_keypair().into_split(); let (dht_key2, dht_key_secret2) = vcrypto.generate_keypair().await.into_split();
// Verify byte patterns are different between public and secret // Verify byte patterns are different between public and secret
assert_ne!(dht_key.bytes, dht_key_secret.bytes); assert_ne!(dht_key.bytes, dht_key_secret.bytes);
@ -20,21 +20,24 @@ pub async fn test_generate_secret(vcrypto: CryptoSystemVersion) {
assert_ne!(dht_key_secret, dht_key_secret2); assert_ne!(dht_key_secret, dht_key_secret2);
} }
pub async fn test_sign_and_verify(vcrypto: CryptoSystemVersion) { pub async fn test_sign_and_verify(vcrypto: &AsyncCryptoSystemGuard<'_>) {
// Make two keys // Make two keys
let (dht_key, dht_key_secret) = vcrypto.generate_keypair().into_split(); let (dht_key, dht_key_secret) = vcrypto.generate_keypair().await.into_split();
let (dht_key2, dht_key_secret2) = vcrypto.generate_keypair().into_split(); let (dht_key2, dht_key_secret2) = vcrypto.generate_keypair().await.into_split();
// Sign the same message twice // Sign the same message twice
let dht_sig = vcrypto let dht_sig = vcrypto
.sign(&dht_key, &dht_key_secret, LOREM_IPSUM.as_bytes()) .sign(&dht_key, &dht_key_secret, LOREM_IPSUM.as_bytes())
.await
.unwrap(); .unwrap();
trace!("dht_sig: {:?}", dht_sig); trace!("dht_sig: {:?}", dht_sig);
let dht_sig_b = vcrypto let dht_sig_b = vcrypto
.sign(&dht_key, &dht_key_secret, LOREM_IPSUM.as_bytes()) .sign(&dht_key, &dht_key_secret, LOREM_IPSUM.as_bytes())
.await
.unwrap(); .unwrap();
// Sign a second message // Sign a second message
let dht_sig_c = vcrypto let dht_sig_c = vcrypto
.sign(&dht_key, &dht_key_secret, CHEEZBURGER.as_bytes()) .sign(&dht_key, &dht_key_secret, CHEEZBURGER.as_bytes())
.await
.unwrap(); .unwrap();
trace!("dht_sig_c: {:?}", dht_sig_c); trace!("dht_sig_c: {:?}", dht_sig_c);
// Verify they are the same signature // Verify they are the same signature
@ -42,6 +45,7 @@ pub async fn test_sign_and_verify(vcrypto: CryptoSystemVersion) {
// Sign the same message with a different key // Sign the same message with a different key
let dht_sig2 = vcrypto let dht_sig2 = vcrypto
.sign(&dht_key2, &dht_key_secret2, LOREM_IPSUM.as_bytes()) .sign(&dht_key2, &dht_key_secret2, LOREM_IPSUM.as_bytes())
.await
.unwrap(); .unwrap();
// Verify a different key gives a different signature // Verify a different key gives a different signature
assert_ne!(dht_sig2, dht_sig_b); assert_ne!(dht_sig2, dht_sig_b);
@ -49,73 +53,93 @@ pub async fn test_sign_and_verify(vcrypto: CryptoSystemVersion) {
// Try using the wrong secret to sign // Try using the wrong secret to sign
let a1 = vcrypto let a1 = vcrypto
.sign(&dht_key, &dht_key_secret, LOREM_IPSUM.as_bytes()) .sign(&dht_key, &dht_key_secret, LOREM_IPSUM.as_bytes())
.await
.unwrap(); .unwrap();
let a2 = vcrypto let a2 = vcrypto
.sign(&dht_key2, &dht_key_secret2, LOREM_IPSUM.as_bytes()) .sign(&dht_key2, &dht_key_secret2, LOREM_IPSUM.as_bytes())
.await
.unwrap(); .unwrap();
let _b1 = vcrypto let _b1 = vcrypto
.sign(&dht_key, &dht_key_secret2, LOREM_IPSUM.as_bytes()) .sign(&dht_key, &dht_key_secret2, LOREM_IPSUM.as_bytes())
.await
.unwrap_err(); .unwrap_err();
let _b2 = vcrypto let _b2 = vcrypto
.sign(&dht_key2, &dht_key_secret, LOREM_IPSUM.as_bytes()) .sign(&dht_key2, &dht_key_secret, LOREM_IPSUM.as_bytes())
.await
.unwrap_err(); .unwrap_err();
assert_ne!(a1, a2); assert_ne!(a1, a2);
assert_eq!( assert_eq!(
vcrypto.verify(&dht_key, LOREM_IPSUM.as_bytes(), &a1), vcrypto.verify(&dht_key, LOREM_IPSUM.as_bytes(), &a1).await,
Ok(true) Ok(true)
); );
assert_eq!( assert_eq!(
vcrypto.verify(&dht_key2, LOREM_IPSUM.as_bytes(), &a2), vcrypto.verify(&dht_key2, LOREM_IPSUM.as_bytes(), &a2).await,
Ok(true) Ok(true)
); );
assert_eq!( assert_eq!(
vcrypto.verify(&dht_key, LOREM_IPSUM.as_bytes(), &a2), vcrypto.verify(&dht_key, LOREM_IPSUM.as_bytes(), &a2).await,
Ok(false) Ok(false)
); );
assert_eq!( assert_eq!(
vcrypto.verify(&dht_key2, LOREM_IPSUM.as_bytes(), &a1), vcrypto.verify(&dht_key2, LOREM_IPSUM.as_bytes(), &a1).await,
Ok(false) Ok(false)
); );
// Try verifications that should work // Try verifications that should work
assert_eq!( assert_eq!(
vcrypto.verify(&dht_key, LOREM_IPSUM.as_bytes(), &dht_sig), vcrypto
.verify(&dht_key, LOREM_IPSUM.as_bytes(), &dht_sig)
.await,
Ok(true) Ok(true)
); );
assert_eq!( assert_eq!(
vcrypto.verify(&dht_key, LOREM_IPSUM.as_bytes(), &dht_sig_b), vcrypto
.verify(&dht_key, LOREM_IPSUM.as_bytes(), &dht_sig_b)
.await,
Ok(true) Ok(true)
); );
assert_eq!( assert_eq!(
vcrypto.verify(&dht_key2, LOREM_IPSUM.as_bytes(), &dht_sig2), vcrypto
.verify(&dht_key2, LOREM_IPSUM.as_bytes(), &dht_sig2)
.await,
Ok(true) Ok(true)
); );
assert_eq!( assert_eq!(
vcrypto.verify(&dht_key, CHEEZBURGER.as_bytes(), &dht_sig_c), vcrypto
.verify(&dht_key, CHEEZBURGER.as_bytes(), &dht_sig_c)
.await,
Ok(true) Ok(true)
); );
// Try verifications that shouldn't work // Try verifications that shouldn't work
assert_eq!( assert_eq!(
vcrypto.verify(&dht_key2, LOREM_IPSUM.as_bytes(), &dht_sig), vcrypto
.verify(&dht_key2, LOREM_IPSUM.as_bytes(), &dht_sig)
.await,
Ok(false) Ok(false)
); );
assert_eq!( assert_eq!(
vcrypto.verify(&dht_key, LOREM_IPSUM.as_bytes(), &dht_sig2), vcrypto
.verify(&dht_key, LOREM_IPSUM.as_bytes(), &dht_sig2)
.await,
Ok(false) Ok(false)
); );
assert_eq!( assert_eq!(
vcrypto.verify(&dht_key2, CHEEZBURGER.as_bytes(), &dht_sig_c), vcrypto
.verify(&dht_key2, CHEEZBURGER.as_bytes(), &dht_sig_c)
.await,
Ok(false) Ok(false)
); );
assert_eq!( assert_eq!(
vcrypto.verify(&dht_key, CHEEZBURGER.as_bytes(), &dht_sig), vcrypto
.verify(&dht_key, CHEEZBURGER.as_bytes(), &dht_sig)
.await,
Ok(false) Ok(false)
); );
} }
pub async fn test_key_conversions(vcrypto: CryptoSystemVersion) { pub async fn test_key_conversions(vcrypto: &AsyncCryptoSystemGuard<'_>) {
// Test default key // Test default key
let (dht_key, dht_key_secret) = (PublicKey::default(), SecretKey::default()); let (dht_key, dht_key_secret) = (PublicKey::default(), SecretKey::default());
assert_eq!(dht_key.bytes, EMPTY_KEY); assert_eq!(dht_key.bytes, EMPTY_KEY);
@ -131,10 +155,10 @@ pub async fn test_key_conversions(vcrypto: CryptoSystemVersion) {
assert_eq!(dht_key_secret_string, dht_key_string); assert_eq!(dht_key_secret_string, dht_key_string);
// Make different keys // Make different keys
let (dht_key2, dht_key_secret2) = vcrypto.generate_keypair().into_split(); let (dht_key2, dht_key_secret2) = vcrypto.generate_keypair().await.into_split();
trace!("dht_key2: {:?}", dht_key2); trace!("dht_key2: {:?}", dht_key2);
trace!("dht_key_secret2: {:?}", dht_key_secret2); trace!("dht_key_secret2: {:?}", dht_key_secret2);
let (dht_key3, _dht_key_secret3) = vcrypto.generate_keypair().into_split(); let (dht_key3, _dht_key_secret3) = vcrypto.generate_keypair().await.into_split();
trace!("dht_key3: {:?}", dht_key3); trace!("dht_key3: {:?}", dht_key3);
trace!("_dht_key_secret3: {:?}", _dht_key_secret3); trace!("_dht_key_secret3: {:?}", _dht_key_secret3);
@ -185,7 +209,7 @@ pub async fn test_key_conversions(vcrypto: CryptoSystemVersion) {
.is_err()); .is_err());
} }
pub async fn test_encode_decode(vcrypto: CryptoSystemVersion) { pub async fn test_encode_decode(vcrypto: &AsyncCryptoSystemGuard<'_>) {
let dht_key = PublicKey::try_decode("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA").unwrap(); let dht_key = PublicKey::try_decode("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA").unwrap();
let dht_key_secret = let dht_key_secret =
SecretKey::try_decode("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA").unwrap(); SecretKey::try_decode("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA").unwrap();
@ -194,7 +218,7 @@ pub async fn test_encode_decode(vcrypto: CryptoSystemVersion) {
assert_eq!(dht_key, dht_key_b); assert_eq!(dht_key, dht_key_b);
assert_eq!(dht_key_secret, dht_key_secret_b); assert_eq!(dht_key_secret, dht_key_secret_b);
let (dht_key2, dht_key_secret2) = vcrypto.generate_keypair().into_split(); let (dht_key2, dht_key_secret2) = vcrypto.generate_keypair().await.into_split();
let e1 = dht_key.encode(); let e1 = dht_key.encode();
trace!("e1: {:?}", e1); trace!("e1: {:?}", e1);
@ -229,7 +253,7 @@ pub async fn test_encode_decode(vcrypto: CryptoSystemVersion) {
assert!(f2.is_err()); assert!(f2.is_err());
} }
pub async fn test_typed_convert(vcrypto: CryptoSystemVersion) { pub async fn test_typed_convert(vcrypto: &AsyncCryptoSystemGuard<'_>) {
let tks1 = format!( let tks1 = format!(
"{}:7lxDEabK_qgjbe38RtBa3IZLrud84P6NhGP-pRTZzdQ", "{}:7lxDEabK_qgjbe38RtBa3IZLrud84P6NhGP-pRTZzdQ",
vcrypto.kind() vcrypto.kind()
@ -261,15 +285,15 @@ pub async fn test_typed_convert(vcrypto: CryptoSystemVersion) {
assert!(tks6x.ends_with(&tks6)); assert!(tks6x.ends_with(&tks6));
} }
async fn test_hash(vcrypto: CryptoSystemVersion) { async fn test_hash(vcrypto: &AsyncCryptoSystemGuard<'_>) {
let mut s = BTreeSet::<PublicKey>::new(); let mut s = BTreeSet::<PublicKey>::new();
let k1 = vcrypto.generate_hash("abc".as_bytes()); let k1 = vcrypto.generate_hash("abc".as_bytes()).await;
let k2 = vcrypto.generate_hash("abcd".as_bytes()); let k2 = vcrypto.generate_hash("abcd".as_bytes()).await;
let k3 = vcrypto.generate_hash("".as_bytes()); let k3 = vcrypto.generate_hash("".as_bytes()).await;
let k4 = vcrypto.generate_hash(" ".as_bytes()); let k4 = vcrypto.generate_hash(" ".as_bytes()).await;
let k5 = vcrypto.generate_hash(LOREM_IPSUM.as_bytes()); let k5 = vcrypto.generate_hash(LOREM_IPSUM.as_bytes()).await;
let k6 = vcrypto.generate_hash(CHEEZBURGER.as_bytes()); let k6 = vcrypto.generate_hash(CHEEZBURGER.as_bytes()).await;
s.insert(k1); s.insert(k1);
s.insert(k2); s.insert(k2);
@ -279,12 +303,12 @@ async fn test_hash(vcrypto: CryptoSystemVersion) {
s.insert(k6); s.insert(k6);
assert_eq!(s.len(), 6); assert_eq!(s.len(), 6);
let v1 = vcrypto.generate_hash("abc".as_bytes()); let v1 = vcrypto.generate_hash("abc".as_bytes()).await;
let v2 = vcrypto.generate_hash("abcd".as_bytes()); let v2 = vcrypto.generate_hash("abcd".as_bytes()).await;
let v3 = vcrypto.generate_hash("".as_bytes()); let v3 = vcrypto.generate_hash("".as_bytes()).await;
let v4 = vcrypto.generate_hash(" ".as_bytes()); let v4 = vcrypto.generate_hash(" ".as_bytes()).await;
let v5 = vcrypto.generate_hash(LOREM_IPSUM.as_bytes()); let v5 = vcrypto.generate_hash(LOREM_IPSUM.as_bytes()).await;
let v6 = vcrypto.generate_hash(CHEEZBURGER.as_bytes()); let v6 = vcrypto.generate_hash(CHEEZBURGER.as_bytes()).await;
assert_eq!(k1, v1); assert_eq!(k1, v1);
assert_eq!(k2, v2); assert_eq!(k2, v2);
@ -293,24 +317,24 @@ async fn test_hash(vcrypto: CryptoSystemVersion) {
assert_eq!(k5, v5); assert_eq!(k5, v5);
assert_eq!(k6, v6); assert_eq!(k6, v6);
vcrypto.validate_hash("abc".as_bytes(), &v1); vcrypto.validate_hash("abc".as_bytes(), &v1).await;
vcrypto.validate_hash("abcd".as_bytes(), &v2); vcrypto.validate_hash("abcd".as_bytes(), &v2).await;
vcrypto.validate_hash("".as_bytes(), &v3); vcrypto.validate_hash("".as_bytes(), &v3).await;
vcrypto.validate_hash(" ".as_bytes(), &v4); vcrypto.validate_hash(" ".as_bytes(), &v4).await;
vcrypto.validate_hash(LOREM_IPSUM.as_bytes(), &v5); vcrypto.validate_hash(LOREM_IPSUM.as_bytes(), &v5).await;
vcrypto.validate_hash(CHEEZBURGER.as_bytes(), &v6); vcrypto.validate_hash(CHEEZBURGER.as_bytes(), &v6).await;
} }
async fn test_operations(vcrypto: CryptoSystemVersion) { async fn test_operations(vcrypto: &AsyncCryptoSystemGuard<'_>) {
let k1 = vcrypto.generate_hash(LOREM_IPSUM.as_bytes()); let k1 = vcrypto.generate_hash(LOREM_IPSUM.as_bytes()).await;
let k2 = vcrypto.generate_hash(CHEEZBURGER.as_bytes()); let k2 = vcrypto.generate_hash(CHEEZBURGER.as_bytes()).await;
let k3 = vcrypto.generate_hash("abc".as_bytes()); let k3 = vcrypto.generate_hash("abc".as_bytes()).await;
// Get distance // Get distance
let d1 = vcrypto.distance(&k1, &k2); let d1 = vcrypto.distance(&k1, &k2).await;
let d2 = vcrypto.distance(&k2, &k1); let d2 = vcrypto.distance(&k2, &k1).await;
let d3 = vcrypto.distance(&k1, &k3); let d3 = vcrypto.distance(&k1, &k3).await;
let d4 = vcrypto.distance(&k2, &k3); let d4 = vcrypto.distance(&k2, &k3).await;
trace!("d1={:?}", d1); trace!("d1={:?}", d1);
trace!("d2={:?}", d2); trace!("d2={:?}", d2);
@ -393,15 +417,15 @@ pub async fn test_all() {
// Test versions // Test versions
for v in VALID_CRYPTO_KINDS { for v in VALID_CRYPTO_KINDS {
let vcrypto = crypto.get(v).unwrap(); let vcrypto = crypto.get_async(v).unwrap();
test_generate_secret(vcrypto.clone()).await; test_generate_secret(&vcrypto).await;
test_sign_and_verify(vcrypto.clone()).await; test_sign_and_verify(&vcrypto).await;
test_key_conversions(vcrypto.clone()).await; test_key_conversions(&vcrypto).await;
test_encode_decode(vcrypto.clone()).await; test_encode_decode(&vcrypto).await;
test_typed_convert(vcrypto.clone()).await; test_typed_convert(&vcrypto).await;
test_hash(vcrypto.clone()).await; test_hash(&vcrypto).await;
test_operations(vcrypto).await; test_operations(&vcrypto).await;
} }
crypto_tests_shutdown(api.clone()).await; crypto_tests_shutdown(api.clone()).await;

View File

@ -78,7 +78,11 @@ where
macro_rules! byte_array_type { macro_rules! byte_array_type {
($name:ident, $size:expr, $encoded_size:expr) => { ($name:ident, $size:expr, $encoded_size:expr) => {
#[derive(Clone, Copy, Hash, PartialOrd, Ord, PartialEq, Eq)] #[derive(Clone, Copy, Hash, PartialOrd, Ord, PartialEq, Eq)]
#[cfg_attr(target_arch = "wasm32", derive(Tsify), tsify(into_wasm_abi))] #[cfg_attr(
all(target_arch = "wasm32", target_os = "unknown"),
derive(Tsify),
tsify(into_wasm_abi)
)]
pub struct $name { pub struct $name {
pub bytes: [u8; $size], pub bytes: [u8; $size],
} }
@ -280,17 +284,17 @@ macro_rules! byte_array_type {
byte_array_type!(CryptoKey, CRYPTO_KEY_LENGTH, CRYPTO_KEY_LENGTH_ENCODED); byte_array_type!(CryptoKey, CRYPTO_KEY_LENGTH, CRYPTO_KEY_LENGTH_ENCODED);
#[cfg_attr(target_arch = "wasm32", declare)] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type PublicKey = CryptoKey; pub type PublicKey = CryptoKey;
#[cfg_attr(target_arch = "wasm32", declare)] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type SecretKey = CryptoKey; pub type SecretKey = CryptoKey;
#[cfg_attr(target_arch = "wasm32", declare)] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type HashDigest = CryptoKey; pub type HashDigest = CryptoKey;
#[cfg_attr(target_arch = "wasm32", declare)] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type SharedSecret = CryptoKey; pub type SharedSecret = CryptoKey;
#[cfg_attr(target_arch = "wasm32", declare)] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type RouteId = CryptoKey; pub type RouteId = CryptoKey;
#[cfg_attr(target_arch = "wasm32", declare)] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type CryptoKeyDistance = CryptoKey; pub type CryptoKeyDistance = CryptoKey;
byte_array_type!(Signature, SIGNATURE_LENGTH, SIGNATURE_LENGTH_ENCODED); byte_array_type!(Signature, SIGNATURE_LENGTH, SIGNATURE_LENGTH_ENCODED);

View File

@ -2,7 +2,7 @@ use super::*;
#[derive(Clone, Copy, Default, PartialOrd, Ord, PartialEq, Eq, Hash)] #[derive(Clone, Copy, Default, PartialOrd, Ord, PartialEq, Eq, Hash)]
#[cfg_attr( #[cfg_attr(
target_arch = "wasm32", all(target_arch = "wasm32", target_os = "unknown"),
derive(Tsify), derive(Tsify),
tsify(from_wasm_abi, into_wasm_abi) tsify(from_wasm_abi, into_wasm_abi)
)] )]

View File

@ -6,7 +6,7 @@ use core::fmt;
use core::hash::Hash; use core::hash::Hash;
/// Cryptography version fourcc code /// Cryptography version fourcc code
#[cfg_attr(target_arch = "wasm32", declare)] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type CryptoKind = FourCC; pub type CryptoKind = FourCC;
/// Sort best crypto kinds first /// Sort best crypto kinds first
@ -52,24 +52,24 @@ pub use crypto_typed::*;
pub use crypto_typed_group::*; pub use crypto_typed_group::*;
pub use keypair::*; pub use keypair::*;
#[cfg_attr(target_arch = "wasm32", declare)] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type TypedKey = CryptoTyped<PublicKey>; pub type TypedKey = CryptoTyped<PublicKey>;
#[cfg_attr(target_arch = "wasm32", declare)] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type TypedSecret = CryptoTyped<SecretKey>; pub type TypedSecret = CryptoTyped<SecretKey>;
#[cfg_attr(target_arch = "wasm32", declare)] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type TypedKeyPair = CryptoTyped<KeyPair>; pub type TypedKeyPair = CryptoTyped<KeyPair>;
#[cfg_attr(target_arch = "wasm32", declare)] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type TypedSignature = CryptoTyped<Signature>; pub type TypedSignature = CryptoTyped<Signature>;
#[cfg_attr(target_arch = "wasm32", declare)] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type TypedSharedSecret = CryptoTyped<SharedSecret>; pub type TypedSharedSecret = CryptoTyped<SharedSecret>;
#[cfg_attr(target_arch = "wasm32", declare)] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type TypedKeyGroup = CryptoTypedGroup<PublicKey>; pub type TypedKeyGroup = CryptoTypedGroup<PublicKey>;
#[cfg_attr(target_arch = "wasm32", declare)] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type TypedSecretGroup = CryptoTypedGroup<SecretKey>; pub type TypedSecretGroup = CryptoTypedGroup<SecretKey>;
#[cfg_attr(target_arch = "wasm32", declare)] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type TypedKeyPairGroup = CryptoTypedGroup<KeyPair>; pub type TypedKeyPairGroup = CryptoTypedGroup<KeyPair>;
#[cfg_attr(target_arch = "wasm32", declare)] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type TypedSignatureGroup = CryptoTypedGroup<Signature>; pub type TypedSignatureGroup = CryptoTypedGroup<Signature>;
#[cfg_attr(target_arch = "wasm32", declare)] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type TypedSharedSecretGroup = CryptoTypedGroup<SharedSecret>; pub type TypedSharedSecretGroup = CryptoTypedGroup<SharedSecret>;

View File

@ -47,14 +47,13 @@ pub fn vld0_generate_keypair() -> KeyPair {
} }
/// V0 CryptoSystem /// V0 CryptoSystem
#[derive(Clone)]
pub struct CryptoSystemVLD0 { pub struct CryptoSystemVLD0 {
crypto: Crypto, registry: VeilidComponentRegistry,
} }
impl CryptoSystemVLD0 { impl CryptoSystemVLD0 {
pub fn new(crypto: Crypto) -> Self { pub fn new(registry: VeilidComponentRegistry) -> Self {
Self { crypto } Self { registry }
} }
} }
@ -64,14 +63,14 @@ impl CryptoSystem for CryptoSystemVLD0 {
CRYPTO_KIND_VLD0 CRYPTO_KIND_VLD0
} }
fn crypto(&self) -> Crypto { fn crypto(&self) -> VeilidComponentGuard<'_, Crypto> {
self.crypto.clone() self.registry.lookup::<Crypto>().unwrap()
} }
// Cached Operations // Cached Operations
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
fn cached_dh(&self, key: &PublicKey, secret: &SecretKey) -> VeilidAPIResult<SharedSecret> { fn cached_dh(&self, key: &PublicKey, secret: &SecretKey) -> VeilidAPIResult<SharedSecret> {
self.crypto self.crypto()
.cached_dh_internal::<CryptoSystemVLD0>(self, key, secret) .cached_dh_internal::<CryptoSystemVLD0>(self, key, secret)
} }

View File

@ -1,12 +1,12 @@
use super::*; use super::*;
#[cfg(target_arch = "wasm32")] #[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
mod wasm; mod wasm;
#[cfg(target_arch = "wasm32")] #[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
pub use wasm::*; pub use wasm::*;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod native; mod native;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
pub use native::*; pub use native::*;
pub static KNOWN_PROTECTED_STORE_KEYS: [&str; 2] = ["device_encryption_key", "_test_key"]; pub static KNOWN_PROTECTED_STORE_KEYS: [&str; 2] = ["device_encryption_key", "_test_key"];

View File

@ -4,31 +4,37 @@ struct BlockStoreInner {
// //
} }
#[derive(Clone)] impl fmt::Debug for BlockStoreInner {
pub struct BlockStore { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
event_bus: EventBus, f.debug_struct("BlockStoreInner").finish()
config: VeilidConfig, }
inner: Arc<Mutex<BlockStoreInner>>,
} }
#[derive(Debug)]
pub struct BlockStore {
registry: VeilidComponentRegistry,
inner: Mutex<BlockStoreInner>,
}
impl_veilid_component!(BlockStore);
impl BlockStore { impl BlockStore {
fn new_inner() -> BlockStoreInner { fn new_inner() -> BlockStoreInner {
BlockStoreInner {} BlockStoreInner {}
} }
pub fn new(event_bus: EventBus, config: VeilidConfig) -> Self { pub fn new(registry: VeilidComponentRegistry) -> Self {
Self { Self {
event_bus, registry,
config, inner: Mutex::new(Self::new_inner()),
inner: Arc::new(Mutex::new(Self::new_inner())),
} }
} }
pub async fn init(&self) -> EyreResult<()> { async fn init_async(&self) -> EyreResult<()> {
// Ensure permissions are correct // Ensure permissions are correct
// ensure_file_private_owner(&dbpath)?; // ensure_file_private_owner(&dbpath)?;
Ok(()) Ok(())
} }
pub async fn terminate(&self) {} async fn terminate_async(&self) {}
} }

View File

@ -6,14 +6,20 @@ use std::path::Path;
pub struct ProtectedStoreInner { pub struct ProtectedStoreInner {
keyring_manager: Option<KeyringManager>, keyring_manager: Option<KeyringManager>,
} }
impl fmt::Debug for ProtectedStoreInner {
#[derive(Clone)] fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
pub struct ProtectedStore { f.debug_struct("ProtectedStoreInner").finish()
_event_bus: EventBus, }
config: VeilidConfig,
inner: Arc<Mutex<ProtectedStoreInner>>,
} }
#[derive(Debug)]
pub struct ProtectedStore {
registry: VeilidComponentRegistry,
inner: Mutex<ProtectedStoreInner>,
}
impl_veilid_component!(ProtectedStore);
impl ProtectedStore { impl ProtectedStore {
fn new_inner() -> ProtectedStoreInner { fn new_inner() -> ProtectedStoreInner {
ProtectedStoreInner { ProtectedStoreInner {
@ -21,11 +27,10 @@ impl ProtectedStore {
} }
} }
pub fn new(event_bus: EventBus, config: VeilidConfig) -> Self { pub fn new(registry: VeilidComponentRegistry) -> Self {
Self { Self {
_event_bus: event_bus, registry,
config, inner: Mutex::new(Self::new_inner()),
inner: Arc::new(Mutex::new(Self::new_inner())),
} }
} }
@ -42,9 +47,10 @@ impl ProtectedStore {
} }
#[instrument(level = "debug", skip(self), err)] #[instrument(level = "debug", skip(self), err)]
pub async fn init(&self) -> EyreResult<()> { async fn init_async(&self) -> EyreResult<()> {
let delete = { let delete = {
let c = self.config.get(); let config = self.config();
let c = config.get();
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
if !c.protected_store.always_use_insecure_storage { if !c.protected_store.always_use_insecure_storage {
// Attempt to open the secure keyring // Attempt to open the secure keyring
@ -101,13 +107,22 @@ impl ProtectedStore {
Ok(()) Ok(())
} }
#[instrument(level = "debug", skip(self), err)]
async fn post_init_async(&self) -> EyreResult<()> {
Ok(())
}
#[instrument(level = "debug", skip(self))] #[instrument(level = "debug", skip(self))]
pub async fn terminate(&self) { async fn pre_terminate_async(&self) {}
#[instrument(level = "debug", skip(self))]
async fn terminate_async(&self) {
*self.inner.lock() = Self::new_inner(); *self.inner.lock() = Self::new_inner();
} }
fn service_name(&self) -> String { fn service_name(&self) -> String {
let c = self.config.get(); let config = self.config();
let c = config.get();
if c.namespace.is_empty() { if c.namespace.is_empty() {
"veilid_protected_store".to_owned() "veilid_protected_store".to_owned()
} else { } else {

View File

@ -4,28 +4,34 @@ struct BlockStoreInner {
// //
} }
#[derive(Clone)] impl fmt::Debug for BlockStoreInner {
pub struct BlockStore { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
event_bus: EventBus, f.debug_struct("BlockStoreInner").finish()
config: VeilidConfig, }
inner: Arc<Mutex<BlockStoreInner>>,
} }
#[derive(Debug)]
pub struct BlockStore {
registry: VeilidComponentRegistry,
inner: Mutex<BlockStoreInner>,
}
impl_veilid_component!(BlockStore);
impl BlockStore { impl BlockStore {
fn new_inner() -> BlockStoreInner { fn new_inner() -> BlockStoreInner {
BlockStoreInner {} BlockStoreInner {}
} }
pub fn new(event_bus: EventBus, config: VeilidConfig) -> Self { pub fn new(registry: VeilidComponentRegistry) -> Self {
Self { Self {
event_bus, registry,
config, inner: Mutex::new(Self::new_inner()),
inner: Arc::new(Mutex::new(Self::new_inner())),
} }
} }
pub async fn init(&self) -> EyreResult<()> { async fn init_async(&self) -> EyreResult<()> {
Ok(()) Ok(())
} }
pub async fn terminate(&self) {} async fn terminate_async(&self) {}
} }

View File

@ -3,18 +3,16 @@ use data_encoding::BASE64URL_NOPAD;
use web_sys::*; use web_sys::*;
#[derive(Clone)] #[derive(Debug)]
pub struct ProtectedStore { pub struct ProtectedStore {
_event_bus: EventBus, registry: VeilidComponentRegistry,
config: VeilidConfig,
} }
impl_veilid_component!(ProtectedStore);
impl ProtectedStore { impl ProtectedStore {
pub fn new(event_bus: EventBus, config: VeilidConfig) -> Self { pub fn new(registry: VeilidComponentRegistry) -> Self {
Self { Self { registry }
_event_bus: event_bus,
config,
}
} }
#[instrument(level = "trace", skip(self), err)] #[instrument(level = "trace", skip(self), err)]
@ -30,15 +28,24 @@ impl ProtectedStore {
} }
#[instrument(level = "debug", skip(self), err)] #[instrument(level = "debug", skip(self), err)]
pub async fn init(&self) -> EyreResult<()> { pub(crate) async fn init_async(&self) -> EyreResult<()> {
Ok(())
}
#[instrument(level = "debug", skip(self), err)]
pub(crate) async fn post_init_async(&self) -> EyreResult<()> {
Ok(()) Ok(())
} }
#[instrument(level = "debug", skip(self))] #[instrument(level = "debug", skip(self))]
pub async fn terminate(&self) {} pub(crate) async fn pre_terminate_async(&self) {}
#[instrument(level = "debug", skip(self))]
pub(crate) async fn terminate_async(&self) {}
fn browser_key_name(&self, key: &str) -> String { fn browser_key_name(&self, key: &str) -> String {
let c = self.config.get(); let config = self.config();
let c = config.get();
if c.namespace.is_empty() { if c.namespace.is_empty() {
format!("__veilid_protected_store_{}", key) format!("__veilid_protected_store_{}", key)
} else { } else {

View File

@ -28,7 +28,7 @@
#![recursion_limit = "256"] #![recursion_limit = "256"]
cfg_if::cfg_if! { cfg_if::cfg_if! {
if #[cfg(target_arch = "wasm32")] { if #[cfg(all(target_arch = "wasm32", target_os = "unknown"))] {
#[cfg(any(feature = "rt-async-std", feature = "rt-tokio"))] #[cfg(any(feature = "rt-async-std", feature = "rt-tokio"))]
compile_error!("features \"rt-async-std\" and \"rt-tokio\" can not be specified for WASM"); compile_error!("features \"rt-async-std\" and \"rt-tokio\" can not be specified for WASM");
} else { } else {
@ -45,6 +45,7 @@ cfg_if::cfg_if! {
extern crate alloc; extern crate alloc;
mod attachment_manager; mod attachment_manager;
mod component;
mod core_context; mod core_context;
mod crypto; mod crypto;
mod intf; mod intf;
@ -58,6 +59,8 @@ mod veilid_api;
mod veilid_config; mod veilid_config;
mod wasm_helpers; mod wasm_helpers;
pub(crate) use self::component::*;
pub(crate) use self::core_context::RegisteredComponents;
pub use self::core_context::{api_startup, api_startup_config, api_startup_json, UpdateCallback}; pub use self::core_context::{api_startup, api_startup_config, api_startup_json, UpdateCallback};
pub use self::logging::{ pub use self::logging::{
ApiTracingLayer, VeilidLayerFilter, DEFAULT_LOG_FACILITIES_ENABLED_LIST, ApiTracingLayer, VeilidLayerFilter, DEFAULT_LOG_FACILITIES_ENABLED_LIST,

View File

@ -2,7 +2,6 @@ use crate::core_context::*;
use crate::veilid_api::*; use crate::veilid_api::*;
use crate::*; use crate::*;
use core::fmt::Write; use core::fmt::Write;
use once_cell::sync::OnceCell;
use tracing_subscriber::*; use tracing_subscriber::*;
struct ApiTracingLayerInner { struct ApiTracingLayerInner {
@ -21,11 +20,10 @@ struct ApiTracingLayerInner {
/// with many copies of Veilid running. /// with many copies of Veilid running.
#[derive(Clone)] #[derive(Clone)]
pub struct ApiTracingLayer { pub struct ApiTracingLayer {}
inner: Arc<Mutex<Option<ApiTracingLayerInner>>>,
}
static API_LOGGER: OnceCell<ApiTracingLayer> = OnceCell::new(); static API_LOGGER_INNER: Mutex<Option<ApiTracingLayerInner>> = Mutex::new(None);
static API_LOGGER_ENABLED: AtomicBool = AtomicBool::new(false);
impl ApiTracingLayer { impl ApiTracingLayer {
/// Initialize an ApiTracingLayer singleton /// Initialize an ApiTracingLayer singleton
@ -33,11 +31,7 @@ impl ApiTracingLayer {
/// This must be inserted into your tracing subscriber before you /// This must be inserted into your tracing subscriber before you
/// call api_startup() or api_startup_json() if you are going to use api tracing. /// call api_startup() or api_startup_json() if you are going to use api tracing.
pub fn init() -> ApiTracingLayer { pub fn init() -> ApiTracingLayer {
API_LOGGER ApiTracingLayer {}
.get_or_init(|| ApiTracingLayer {
inner: Arc::new(Mutex::new(None)),
})
.clone()
} }
fn new_inner() -> ApiTracingLayerInner { fn new_inner() -> ApiTracingLayerInner {
@ -52,12 +46,7 @@ impl ApiTracingLayer {
namespace: String, namespace: String,
update_callback: UpdateCallback, update_callback: UpdateCallback,
) -> VeilidAPIResult<()> { ) -> VeilidAPIResult<()> {
let Some(api_logger) = API_LOGGER.get() else { let mut inner = API_LOGGER_INNER.lock();
// Did not init, so skip this
return Ok(());
};
let mut inner = api_logger.inner.lock();
if inner.is_none() { if inner.is_none() {
*inner = Some(Self::new_inner()); *inner = Some(Self::new_inner());
} }
@ -70,6 +59,9 @@ impl ApiTracingLayer {
.unwrap() .unwrap()
.update_callbacks .update_callbacks
.insert(key, update_callback); .insert(key, update_callback);
API_LOGGER_ENABLED.store(true, Ordering::Release);
return Ok(()); return Ok(());
} }
@ -79,28 +71,29 @@ impl ApiTracingLayer {
namespace: String, namespace: String,
) -> VeilidAPIResult<()> { ) -> VeilidAPIResult<()> {
let key = (program_name, namespace); let key = (program_name, namespace);
if let Some(api_logger) = API_LOGGER.get() {
let mut inner = api_logger.inner.lock(); let mut inner = API_LOGGER_INNER.lock();
if inner.is_none() { if inner.is_none() {
apibail_not_initialized!(); apibail_not_initialized!();
}
if inner
.as_mut()
.unwrap()
.update_callbacks
.remove(&key)
.is_none()
{
apibail_not_initialized!();
}
if inner.as_mut().unwrap().update_callbacks.is_empty() {
*inner = None;
}
} }
if inner
.as_mut()
.unwrap()
.update_callbacks
.remove(&key)
.is_none()
{
apibail_not_initialized!();
}
if inner.as_mut().unwrap().update_callbacks.is_empty() {
*inner = None;
API_LOGGER_ENABLED.store(false, Ordering::Release);
}
Ok(()) Ok(())
} }
fn emit_log(&self, inner: &mut ApiTracingLayerInner, meta: &Metadata<'_>, message: String) { fn emit_log(&self, meta: &'static Metadata<'static>, message: String) {
let level = *meta.level(); let level = *meta.level();
let target = meta.target(); let target = meta.target();
let log_level = VeilidLogLevel::from_tracing_level(level); let log_level = VeilidLogLevel::from_tracing_level(level);
@ -148,8 +141,10 @@ impl ApiTracingLayer {
backtrace, backtrace,
})); }));
for cb in inner.update_callbacks.values() { if let Some(inner) = &mut *API_LOGGER_INNER.lock() {
(cb)(log_update.clone()); for cb in inner.update_callbacks.values() {
(cb)(log_update.clone());
}
} }
} }
} }
@ -159,17 +154,23 @@ pub struct SpanDuration {
end: Timestamp, end: Timestamp,
} }
fn simplify_file(file: &str) -> String { fn simplify_file(file: &'static str) -> &'static str {
let path = std::path::Path::new(file); file.static_transform(|file| {
let path_component_count = path.iter().count(); let out = {
if path.ends_with("mod.rs") && path_component_count >= 2 { let path = std::path::Path::new(file);
let outpath: std::path::PathBuf = path.iter().skip(path_component_count - 2).collect(); let path_component_count = path.iter().count();
outpath.to_string_lossy().to_string() if path.ends_with("mod.rs") && path_component_count >= 2 {
} else if let Some(filename) = path.file_name() { let outpath: std::path::PathBuf =
filename.to_string_lossy().to_string() path.iter().skip(path_component_count - 2).collect();
} else { outpath.to_string_lossy().to_string()
file.to_string() } else if let Some(filename) = path.file_name() {
} filename.to_string_lossy().to_string()
} else {
file.to_string()
}
};
out.to_static_str()
})
} }
impl<S: Subscriber + for<'a> registry::LookupSpan<'a>> Layer<S> for ApiTracingLayer { impl<S: Subscriber + for<'a> registry::LookupSpan<'a>> Layer<S> for ApiTracingLayer {
@ -179,47 +180,51 @@ impl<S: Subscriber + for<'a> registry::LookupSpan<'a>> Layer<S> for ApiTracingLa
id: &tracing::Id, id: &tracing::Id,
ctx: layer::Context<'_, S>, ctx: layer::Context<'_, S>,
) { ) {
if let Some(_inner) = &mut *self.inner.lock() { if !API_LOGGER_ENABLED.load(Ordering::Acquire) {
let mut new_debug_record = StringRecorder::new(); // Optimization if api logger has no callbacks
attrs.record(&mut new_debug_record); return;
}
if let Some(span_ref) = ctx.span(id) { let mut new_debug_record = StringRecorder::new();
attrs.record(&mut new_debug_record);
if let Some(span_ref) = ctx.span(id) {
span_ref
.extensions_mut()
.insert::<StringRecorder>(new_debug_record);
if crate::DURATION_LOG_FACILITIES.contains(&attrs.metadata().target()) {
span_ref span_ref
.extensions_mut() .extensions_mut()
.insert::<StringRecorder>(new_debug_record); .insert::<SpanDuration>(SpanDuration {
if crate::DURATION_LOG_FACILITIES.contains(&attrs.metadata().target()) { start: Timestamp::now(),
span_ref end: Timestamp::default(),
.extensions_mut() });
.insert::<SpanDuration>(SpanDuration {
start: Timestamp::now(),
end: Timestamp::default(),
});
}
} }
} }
} }
fn on_close(&self, id: span::Id, ctx: layer::Context<'_, S>) { fn on_close(&self, id: span::Id, ctx: layer::Context<'_, S>) {
if let Some(inner) = &mut *self.inner.lock() { if !API_LOGGER_ENABLED.load(Ordering::Acquire) {
if let Some(span_ref) = ctx.span(&id) { // Optimization if api logger has no callbacks
if let Some(span_duration) = span_ref.extensions_mut().get_mut::<SpanDuration>() { return;
span_duration.end = Timestamp::now(); }
let duration = span_duration.end.saturating_sub(span_duration.start); if let Some(span_ref) = ctx.span(&id) {
let meta = span_ref.metadata(); if let Some(span_duration) = span_ref.extensions_mut().get_mut::<SpanDuration>() {
self.emit_log( span_duration.end = Timestamp::now();
inner, let duration = span_duration.end.saturating_sub(span_duration.start);
meta, let meta = span_ref.metadata();
format!( self.emit_log(
" {}{}: duration={}", meta,
span_ref format!(
.parent() " {}{}: duration={}",
.map(|p| format!("{}::", p.name())) span_ref
.unwrap_or_default(), .parent()
span_ref.name(), .map(|p| format!("{}::", p.name()))
format_opt_ts(Some(duration)) .unwrap_or_default(),
), span_ref.name(),
); format_opt_ts(Some(duration))
} ),
);
} }
} }
} }
@ -230,22 +235,26 @@ impl<S: Subscriber + for<'a> registry::LookupSpan<'a>> Layer<S> for ApiTracingLa
values: &tracing::span::Record<'_>, values: &tracing::span::Record<'_>,
ctx: layer::Context<'_, S>, ctx: layer::Context<'_, S>,
) { ) {
if let Some(_inner) = &mut *self.inner.lock() { if !API_LOGGER_ENABLED.load(Ordering::Acquire) {
if let Some(span_ref) = ctx.span(id) { // Optimization if api logger has no callbacks
if let Some(debug_record) = span_ref.extensions_mut().get_mut::<StringRecorder>() { return;
values.record(debug_record); }
} if let Some(span_ref) = ctx.span(id) {
if let Some(debug_record) = span_ref.extensions_mut().get_mut::<StringRecorder>() {
values.record(debug_record);
} }
} }
} }
fn on_event(&self, event: &tracing::Event<'_>, _ctx: layer::Context<'_, S>) { fn on_event(&self, event: &tracing::Event<'_>, _ctx: layer::Context<'_, S>) {
if let Some(inner) = &mut *self.inner.lock() { if !API_LOGGER_ENABLED.load(Ordering::Acquire) {
let mut recorder = StringRecorder::new(); // Optimization if api logger has no callbacks
event.record(&mut recorder); return;
let meta = event.metadata();
self.emit_log(inner, meta, recorder.to_string());
} }
let mut recorder = StringRecorder::new();
event.record(&mut recorder);
let meta = event.metadata();
self.emit_log(meta, recorder.to_string());
} }
} }

View File

@ -151,42 +151,6 @@ macro_rules! log_client_api {
} }
} }
#[macro_export]
macro_rules! log_network_result {
(error $text:expr) => {error!(
target: "network_result",
"{}",
$text,
)};
(error $fmt:literal, $($arg:expr),+) => {
error!(target: "network_result", $fmt, $($arg),+);
};
(warn $text:expr) => {warn!(
target: "network_result",
"{}",
$text,
)};
(warn $fmt:literal, $($arg:expr),+) => {
warn!(target:"network_result", $fmt, $($arg),+);
};
(debug $text:expr) => {debug!(
target: "network_result",
"{}",
$text,
)};
(debug $fmt:literal, $($arg:expr),+) => {
debug!(target:"network_result", $fmt, $($arg),+);
};
($text:expr) => {trace!(
target: "network_result",
"{}",
$text,
)};
($fmt:literal, $($arg:expr),+) => {
trace!(target:"network_result", $fmt, $($arg),+);
}
}
#[macro_export] #[macro_export]
macro_rules! log_rpc { macro_rules! log_rpc {
(error $text:expr) => { error!( (error $text:expr) => { error!(
@ -421,6 +385,14 @@ macro_rules! log_crypto {
(warn $fmt:literal, $($arg:expr),+) => { (warn $fmt:literal, $($arg:expr),+) => {
warn!(target:"crypto", $fmt, $($arg),+); warn!(target:"crypto", $fmt, $($arg),+);
}; };
(debug $text:expr) => { debug!(
target: "crypto",
"{}",
$text,
)};
(debug $fmt:literal, $($arg:expr),+) => {
debug!(target:"crypto", $fmt, $($arg),+);
};
($text:expr) => {trace!( ($text:expr) => {trace!(
target: "crypto", target: "crypto",
"{}", "{}",

View File

@ -23,6 +23,7 @@ pub const ADDRESS_CHECK_CACHE_SIZE: usize = 10;
// TimestampDuration::new(3_600_000_000_u64); // 60 minutes // TimestampDuration::new(3_600_000_000_u64); // 60 minutes
/// Address checker config /// Address checker config
#[derive(Debug)]
pub struct AddressCheckConfig { pub struct AddressCheckConfig {
pub detect_address_changes: bool, pub detect_address_changes: bool,
pub ip6_prefix_size: usize, pub ip6_prefix_size: usize,
@ -44,6 +45,22 @@ pub struct AddressCheck {
address_consistency_table: BTreeMap<AddressCheckCacheKey, LruCache<IpAddr, SocketAddress>>, address_consistency_table: BTreeMap<AddressCheckCacheKey, LruCache<IpAddr, SocketAddress>>,
} }
impl fmt::Debug for AddressCheck {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AddressCheck")
.field("config", &self.config)
//.field("net", &self.net)
.field("current_network_class", &self.current_network_class)
.field("current_addresses", &self.current_addresses)
.field(
"address_inconsistency_table",
&self.address_inconsistency_table,
)
.field("address_consistency_table", &self.address_consistency_table)
.finish()
}
}
impl AddressCheck { impl AddressCheck {
pub fn new(config: AddressCheckConfig, net: Network) -> Self { pub fn new(config: AddressCheckConfig, net: Network) -> Self {
Self { Self {

View File

@ -32,63 +32,27 @@ struct AddressFilterInner {
dial_info_failures: BTreeMap<DialInfo, Timestamp>, dial_info_failures: BTreeMap<DialInfo, Timestamp>,
} }
struct AddressFilterUnlockedInner { #[derive(Debug)]
pub(crate) struct AddressFilter {
registry: VeilidComponentRegistry,
inner: Mutex<AddressFilterInner>,
max_connections_per_ip4: usize, max_connections_per_ip4: usize,
max_connections_per_ip6_prefix: usize, max_connections_per_ip6_prefix: usize,
max_connections_per_ip6_prefix_size: usize, max_connections_per_ip6_prefix_size: usize,
max_connection_frequency_per_min: usize, max_connection_frequency_per_min: usize,
punishment_duration_min: usize, punishment_duration_min: usize,
dial_info_failure_duration_min: usize, dial_info_failure_duration_min: usize,
routing_table: RoutingTable,
} }
impl fmt::Debug for AddressFilterUnlockedInner { impl_veilid_component_registry_accessor!(AddressFilter);
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AddressFilterUnlockedInner")
.field("max_connections_per_ip4", &self.max_connections_per_ip4)
.field(
"max_connections_per_ip6_prefix",
&self.max_connections_per_ip6_prefix,
)
.field(
"max_connections_per_ip6_prefix_size",
&self.max_connections_per_ip6_prefix_size,
)
.field(
"max_connection_frequency_per_min",
&self.max_connection_frequency_per_min,
)
.field("punishment_duration_min", &self.punishment_duration_min)
.field(
"dial_info_failure_duration_min",
&self.dial_info_failure_duration_min,
)
.finish()
}
}
#[derive(Clone, Debug)]
pub(crate) struct AddressFilter {
unlocked_inner: Arc<AddressFilterUnlockedInner>,
inner: Arc<Mutex<AddressFilterInner>>,
}
impl AddressFilter { impl AddressFilter {
pub fn new(config: VeilidConfig, routing_table: RoutingTable) -> Self { pub fn new(registry: VeilidComponentRegistry) -> Self {
let config = registry.config();
let c = config.get(); let c = config.get();
Self { Self {
unlocked_inner: Arc::new(AddressFilterUnlockedInner { registry,
max_connections_per_ip4: c.network.max_connections_per_ip4 as usize, inner: Mutex::new(AddressFilterInner {
max_connections_per_ip6_prefix: c.network.max_connections_per_ip6_prefix as usize,
max_connections_per_ip6_prefix_size: c.network.max_connections_per_ip6_prefix_size
as usize,
max_connection_frequency_per_min: c.network.max_connection_frequency_per_min
as usize,
punishment_duration_min: PUNISHMENT_DURATION_MIN,
dial_info_failure_duration_min: DIAL_INFO_FAILURE_DURATION_MIN,
routing_table,
}),
inner: Arc::new(Mutex::new(AddressFilterInner {
conn_count_by_ip4: BTreeMap::new(), conn_count_by_ip4: BTreeMap::new(),
conn_count_by_ip6_prefix: BTreeMap::new(), conn_count_by_ip6_prefix: BTreeMap::new(),
conn_timestamps_by_ip4: BTreeMap::new(), conn_timestamps_by_ip4: BTreeMap::new(),
@ -97,7 +61,14 @@ impl AddressFilter {
punishments_by_ip6_prefix: BTreeMap::new(), punishments_by_ip6_prefix: BTreeMap::new(),
punishments_by_node_id: BTreeMap::new(), punishments_by_node_id: BTreeMap::new(),
dial_info_failures: BTreeMap::new(), dial_info_failures: BTreeMap::new(),
})), }),
max_connections_per_ip4: c.network.max_connections_per_ip4 as usize,
max_connections_per_ip6_prefix: c.network.max_connections_per_ip6_prefix as usize,
max_connections_per_ip6_prefix_size: c.network.max_connections_per_ip6_prefix_size
as usize,
max_connection_frequency_per_min: c.network.max_connection_frequency_per_min as usize,
punishment_duration_min: PUNISHMENT_DURATION_MIN,
dial_info_failure_duration_min: DIAL_INFO_FAILURE_DURATION_MIN,
} }
} }
@ -109,7 +80,7 @@ impl AddressFilter {
inner.dial_info_failures.clear(); inner.dial_info_failures.clear();
} }
fn purge_old_timestamps(&self, inner: &mut AddressFilterInner, cur_ts: Timestamp) { fn purge_old_timestamps_inner(&self, inner: &mut AddressFilterInner, cur_ts: Timestamp) {
// v4 // v4
{ {
let mut dead_keys = Vec::<Ipv4Addr>::new(); let mut dead_keys = Vec::<Ipv4Addr>::new();
@ -151,7 +122,7 @@ impl AddressFilter {
for (key, value) in &mut inner.punishments_by_ip4 { for (key, value) in &mut inner.punishments_by_ip4 {
// Drop punishments older than the punishment duration // Drop punishments older than the punishment duration
if cur_ts.as_u64().saturating_sub(value.timestamp.as_u64()) if cur_ts.as_u64().saturating_sub(value.timestamp.as_u64())
> self.unlocked_inner.punishment_duration_min as u64 * 60_000_000u64 > self.punishment_duration_min as u64 * 60_000_000u64
{ {
dead_keys.push(*key); dead_keys.push(*key);
} }
@ -167,7 +138,7 @@ impl AddressFilter {
for (key, value) in &mut inner.punishments_by_ip6_prefix { for (key, value) in &mut inner.punishments_by_ip6_prefix {
// Drop punishments older than the punishment duration // Drop punishments older than the punishment duration
if cur_ts.as_u64().saturating_sub(value.timestamp.as_u64()) if cur_ts.as_u64().saturating_sub(value.timestamp.as_u64())
> self.unlocked_inner.punishment_duration_min as u64 * 60_000_000u64 > self.punishment_duration_min as u64 * 60_000_000u64
{ {
dead_keys.push(*key); dead_keys.push(*key);
} }
@ -183,7 +154,7 @@ impl AddressFilter {
for (key, value) in &mut inner.punishments_by_node_id { for (key, value) in &mut inner.punishments_by_node_id {
// Drop punishments older than the punishment duration // Drop punishments older than the punishment duration
if cur_ts.as_u64().saturating_sub(value.timestamp.as_u64()) if cur_ts.as_u64().saturating_sub(value.timestamp.as_u64())
> self.unlocked_inner.punishment_duration_min as u64 * 60_000_000u64 > self.punishment_duration_min as u64 * 60_000_000u64
{ {
dead_keys.push(*key); dead_keys.push(*key);
} }
@ -192,7 +163,7 @@ impl AddressFilter {
warn!("Forgiving: {}", key); warn!("Forgiving: {}", key);
inner.punishments_by_node_id.remove(&key); inner.punishments_by_node_id.remove(&key);
// make the entry alive again if it's still here // make the entry alive again if it's still here
if let Ok(Some(nr)) = self.unlocked_inner.routing_table.lookup_node_ref(key) { if let Ok(Some(nr)) = self.routing_table().lookup_node_ref(key) {
nr.operate_mut(|_rti, e| e.set_punished(None)); nr.operate_mut(|_rti, e| e.set_punished(None));
} }
} }
@ -203,7 +174,7 @@ impl AddressFilter {
for (key, value) in &mut inner.dial_info_failures { for (key, value) in &mut inner.dial_info_failures {
// Drop failures older than the failure duration // Drop failures older than the failure duration
if cur_ts.as_u64().saturating_sub(value.as_u64()) if cur_ts.as_u64().saturating_sub(value.as_u64())
> self.unlocked_inner.dial_info_failure_duration_min as u64 * 60_000_000u64 > self.dial_info_failure_duration_min as u64 * 60_000_000u64
{ {
dead_keys.push(key.clone()); dead_keys.push(key.clone());
} }
@ -241,10 +212,7 @@ impl AddressFilter {
pub fn is_ip_addr_punished(&self, addr: IpAddr) -> bool { pub fn is_ip_addr_punished(&self, addr: IpAddr) -> bool {
let inner = self.inner.lock(); let inner = self.inner.lock();
let ipblock = ip_to_ipblock( let ipblock = ip_to_ipblock(self.max_connections_per_ip6_prefix_size, addr);
self.unlocked_inner.max_connections_per_ip6_prefix_size,
addr,
);
self.is_ip_addr_punished_inner(&inner, ipblock) self.is_ip_addr_punished_inner(&inner, ipblock)
} }
@ -273,8 +241,9 @@ impl AddressFilter {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
inner.punishments_by_ip4.clear(); inner.punishments_by_ip4.clear();
inner.punishments_by_ip6_prefix.clear(); inner.punishments_by_ip6_prefix.clear();
self.unlocked_inner.routing_table.clear_punishments();
inner.punishments_by_node_id.clear(); inner.punishments_by_node_id.clear();
self.routing_table().clear_punishments();
} }
pub fn punish_ip_addr(&self, addr: IpAddr, reason: PunishmentReason) { pub fn punish_ip_addr(&self, addr: IpAddr, reason: PunishmentReason) {
@ -282,10 +251,7 @@ impl AddressFilter {
let timestamp = Timestamp::now(); let timestamp = Timestamp::now();
let punishment = Punishment { reason, timestamp }; let punishment = Punishment { reason, timestamp };
let ipblock = ip_to_ipblock( let ipblock = ip_to_ipblock(self.max_connections_per_ip6_prefix_size, addr);
self.unlocked_inner.max_connections_per_ip6_prefix_size,
addr,
);
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
match ipblock { match ipblock {
@ -315,7 +281,7 @@ impl AddressFilter {
} }
pub fn punish_node_id(&self, node_id: TypedKey, reason: PunishmentReason) { pub fn punish_node_id(&self, node_id: TypedKey, reason: PunishmentReason) {
if let Ok(Some(nr)) = self.unlocked_inner.routing_table.lookup_node_ref(node_id) { if let Ok(Some(nr)) = self.routing_table().lookup_node_ref(node_id) {
// make the entry dead if it's punished // make the entry dead if it's punished
nr.operate_mut(|_rti, e| e.set_punished(Some(reason))); nr.operate_mut(|_rti, e| e.set_punished(Some(reason)));
} }
@ -338,14 +304,14 @@ impl AddressFilter {
#[instrument(parent = None, level = "trace", skip_all, err)] #[instrument(parent = None, level = "trace", skip_all, err)]
pub async fn address_filter_task_routine( pub async fn address_filter_task_routine(
self, &self,
_stop_token: StopToken, _stop_token: StopToken,
_last_ts: Timestamp, _last_ts: Timestamp,
cur_ts: Timestamp, cur_ts: Timestamp,
) -> EyreResult<()> { ) -> EyreResult<()> {
// //
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
self.purge_old_timestamps(&mut inner, cur_ts); self.purge_old_timestamps_inner(&mut inner, cur_ts);
self.purge_old_punishments(&mut inner, cur_ts); self.purge_old_punishments(&mut inner, cur_ts);
Ok(()) Ok(())
@ -354,23 +320,20 @@ impl AddressFilter {
pub fn add_connection(&self, addr: IpAddr) -> Result<(), AddressFilterError> { pub fn add_connection(&self, addr: IpAddr) -> Result<(), AddressFilterError> {
let inner = &mut *self.inner.lock(); let inner = &mut *self.inner.lock();
let ipblock = ip_to_ipblock( let ipblock = ip_to_ipblock(self.max_connections_per_ip6_prefix_size, addr);
self.unlocked_inner.max_connections_per_ip6_prefix_size,
addr,
);
if self.is_ip_addr_punished_inner(inner, ipblock) { if self.is_ip_addr_punished_inner(inner, ipblock) {
return Err(AddressFilterError::Punished); return Err(AddressFilterError::Punished);
} }
let ts = Timestamp::now(); let ts = Timestamp::now();
self.purge_old_timestamps(inner, ts); self.purge_old_timestamps_inner(inner, ts);
match ipblock { match ipblock {
IpAddr::V4(v4) => { IpAddr::V4(v4) => {
// See if we have too many connections from this ip block // See if we have too many connections from this ip block
let cnt = inner.conn_count_by_ip4.entry(v4).or_default(); let cnt = inner.conn_count_by_ip4.entry(v4).or_default();
assert!(*cnt <= self.unlocked_inner.max_connections_per_ip4); assert!(*cnt <= self.max_connections_per_ip4);
if *cnt == self.unlocked_inner.max_connections_per_ip4 { if *cnt == self.max_connections_per_ip4 {
warn!("Address filter count exceeded: {:?}", v4); warn!("Address filter count exceeded: {:?}", v4);
return Err(AddressFilterError::CountExceeded); return Err(AddressFilterError::CountExceeded);
} }
@ -380,8 +343,8 @@ impl AddressFilter {
// keep timestamps that are less than a minute away // keep timestamps that are less than a minute away
ts.saturating_sub(*v) < TimestampDuration::new(60_000_000u64) ts.saturating_sub(*v) < TimestampDuration::new(60_000_000u64)
}); });
assert!(tstamps.len() <= self.unlocked_inner.max_connection_frequency_per_min); assert!(tstamps.len() <= self.max_connection_frequency_per_min);
if tstamps.len() == self.unlocked_inner.max_connection_frequency_per_min { if tstamps.len() == self.max_connection_frequency_per_min {
warn!("Address filter rate exceeded: {:?}", v4); warn!("Address filter rate exceeded: {:?}", v4);
return Err(AddressFilterError::RateExceeded); return Err(AddressFilterError::RateExceeded);
} }
@ -393,15 +356,15 @@ impl AddressFilter {
IpAddr::V6(v6) => { IpAddr::V6(v6) => {
// See if we have too many connections from this ip block // See if we have too many connections from this ip block
let cnt = inner.conn_count_by_ip6_prefix.entry(v6).or_default(); let cnt = inner.conn_count_by_ip6_prefix.entry(v6).or_default();
assert!(*cnt <= self.unlocked_inner.max_connections_per_ip6_prefix); assert!(*cnt <= self.max_connections_per_ip6_prefix);
if *cnt == self.unlocked_inner.max_connections_per_ip6_prefix { if *cnt == self.max_connections_per_ip6_prefix {
warn!("Address filter count exceeded: {:?}", v6); warn!("Address filter count exceeded: {:?}", v6);
return Err(AddressFilterError::CountExceeded); return Err(AddressFilterError::CountExceeded);
} }
// See if this ip block has connected too frequently // See if this ip block has connected too frequently
let tstamps = inner.conn_timestamps_by_ip6_prefix.entry(v6).or_default(); let tstamps = inner.conn_timestamps_by_ip6_prefix.entry(v6).or_default();
assert!(tstamps.len() <= self.unlocked_inner.max_connection_frequency_per_min); assert!(tstamps.len() <= self.max_connection_frequency_per_min);
if tstamps.len() == self.unlocked_inner.max_connection_frequency_per_min { if tstamps.len() == self.max_connection_frequency_per_min {
warn!("Address filter rate exceeded: {:?}", v6); warn!("Address filter rate exceeded: {:?}", v6);
return Err(AddressFilterError::RateExceeded); return Err(AddressFilterError::RateExceeded);
} }
@ -414,16 +377,13 @@ impl AddressFilter {
Ok(()) Ok(())
} }
pub fn remove_connection(&mut self, addr: IpAddr) -> Result<(), AddressNotInTableError> { pub fn remove_connection(&self, addr: IpAddr) -> Result<(), AddressNotInTableError> {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
let ipblock = ip_to_ipblock( let ipblock = ip_to_ipblock(self.max_connections_per_ip6_prefix_size, addr);
self.unlocked_inner.max_connections_per_ip6_prefix_size,
addr,
);
let ts = Timestamp::now(); let ts = Timestamp::now();
self.purge_old_timestamps(&mut inner, ts); self.purge_old_timestamps_inner(&mut inner, ts);
match ipblock { match ipblock {
IpAddr::V4(v4) => { IpAddr::V4(v4) => {

View File

@ -57,17 +57,16 @@ struct ConnectionManagerInner {
async_processor_jh: Option<MustJoinHandle<()>>, async_processor_jh: Option<MustJoinHandle<()>>,
stop_source: Option<StopSource>, stop_source: Option<StopSource>,
protected_addresses: HashMap<SocketAddress, ProtectedAddress>, protected_addresses: HashMap<SocketAddress, ProtectedAddress>,
reconnection_processor: DeferredStreamProcessor,
} }
struct ConnectionManagerArc { struct ConnectionManagerArc {
network_manager: NetworkManager,
connection_initial_timeout_ms: u32, connection_initial_timeout_ms: u32,
connection_inactivity_timeout_ms: u32, connection_inactivity_timeout_ms: u32,
connection_table: ConnectionTable, connection_table: ConnectionTable,
address_lock_table: AsyncTagLockTable<SocketAddr>, address_lock_table: AsyncTagLockTable<SocketAddr>,
startup_lock: StartupLock, startup_lock: StartupLock,
inner: Mutex<Option<ConnectionManagerInner>>, inner: Mutex<Option<ConnectionManagerInner>>,
reconnection_processor: DeferredStreamProcessor,
} }
impl core::fmt::Debug for ConnectionManagerArc { impl core::fmt::Debug for ConnectionManagerArc {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
@ -79,15 +78,17 @@ impl core::fmt::Debug for ConnectionManagerArc {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ConnectionManager { pub struct ConnectionManager {
registry: VeilidComponentRegistry,
arc: Arc<ConnectionManagerArc>, arc: Arc<ConnectionManagerArc>,
} }
impl_veilid_component_registry_accessor!(ConnectionManager);
impl ConnectionManager { impl ConnectionManager {
fn new_inner( fn new_inner(
stop_source: StopSource, stop_source: StopSource,
sender: flume::Sender<ConnectionManagerEvent>, sender: flume::Sender<ConnectionManagerEvent>,
async_processor_jh: MustJoinHandle<()>, async_processor_jh: MustJoinHandle<()>,
reconnection_processor: DeferredStreamProcessor,
) -> ConnectionManagerInner { ) -> ConnectionManagerInner {
ConnectionManagerInner { ConnectionManagerInner {
next_id: 0.into(), next_id: 0.into(),
@ -95,11 +96,10 @@ impl ConnectionManager {
sender, sender,
async_processor_jh: Some(async_processor_jh), async_processor_jh: Some(async_processor_jh),
protected_addresses: HashMap::new(), protected_addresses: HashMap::new(),
reconnection_processor,
} }
} }
fn new_arc(network_manager: NetworkManager) -> ConnectionManagerArc { fn new_arc(registry: VeilidComponentRegistry) -> ConnectionManagerArc {
let config = network_manager.config(); let config = registry.config();
let (connection_initial_timeout_ms, connection_inactivity_timeout_ms) = { let (connection_initial_timeout_ms, connection_inactivity_timeout_ms) = {
let c = config.get(); let c = config.get();
( (
@ -107,28 +107,24 @@ impl ConnectionManager {
c.network.connection_inactivity_timeout_ms, c.network.connection_inactivity_timeout_ms,
) )
}; };
let address_filter = network_manager.address_filter();
ConnectionManagerArc { ConnectionManagerArc {
network_manager, reconnection_processor: DeferredStreamProcessor::new(),
connection_initial_timeout_ms, connection_initial_timeout_ms,
connection_inactivity_timeout_ms, connection_inactivity_timeout_ms,
connection_table: ConnectionTable::new(config, address_filter), connection_table: ConnectionTable::new(registry),
address_lock_table: AsyncTagLockTable::new(), address_lock_table: AsyncTagLockTable::new(),
startup_lock: StartupLock::new(), startup_lock: StartupLock::new(),
inner: Mutex::new(None), inner: Mutex::new(None),
} }
} }
pub fn new(network_manager: NetworkManager) -> Self { pub fn new(registry: VeilidComponentRegistry) -> Self {
Self { Self {
arc: Arc::new(Self::new_arc(network_manager)), arc: Arc::new(Self::new_arc(registry.clone())),
registry,
} }
} }
pub fn network_manager(&self) -> NetworkManager {
self.arc.network_manager.clone()
}
pub fn connection_inactivity_timeout_ms(&self) -> u32 { pub fn connection_inactivity_timeout_ms(&self) -> u32 {
self.arc.connection_inactivity_timeout_ms self.arc.connection_inactivity_timeout_ms
} }
@ -150,21 +146,17 @@ impl ConnectionManager {
self.clone().async_processor(stop_source.token(), receiver), self.clone().async_processor(stop_source.token(), receiver),
); );
// Spawn the reconnection processor
let mut reconnection_processor = DeferredStreamProcessor::new();
reconnection_processor.init().await;
// Store in the inner object // Store in the inner object
let mut inner = self.arc.inner.lock(); {
if inner.is_some() { let mut inner = self.arc.inner.lock();
panic!("shouldn't start connection manager twice without shutting it down first"); if inner.is_some() {
panic!("shouldn't start connection manager twice without shutting it down first");
}
*inner = Some(Self::new_inner(stop_source, sender, async_processor));
} }
*inner = Some(Self::new_inner(
stop_source, // Spawn the reconnection processor
sender, self.arc.reconnection_processor.init().await;
async_processor,
reconnection_processor,
));
guard.success(); guard.success();
@ -178,6 +170,10 @@ impl ConnectionManager {
return; return;
}; };
// Stop the reconnection processor
log_net!(debug "stopping reconnection processor task");
self.arc.reconnection_processor.terminate().await;
// Remove the inner from the lock // Remove the inner from the lock
let mut inner = { let mut inner = {
let mut inner_lock = self.arc.inner.lock(); let mut inner_lock = self.arc.inner.lock();
@ -188,9 +184,6 @@ impl ConnectionManager {
} }
} }
}; };
// Stop the reconnection processor
log_net!(debug "stopping reconnection processor task");
inner.reconnection_processor.terminate().await;
// Stop all the connections and the async processor // Stop all the connections and the async processor
log_net!(debug "stopping async processor task"); log_net!(debug "stopping async processor task");
drop(inner.stop_source.take()); drop(inner.stop_source.take());
@ -452,13 +445,14 @@ impl ConnectionManager {
// Attempt new connection // Attempt new connection
let mut retry_count = NEW_CONNECTION_RETRY_COUNT; let mut retry_count = NEW_CONNECTION_RETRY_COUNT;
let network_manager = self.network_manager();
let prot_conn = network_result_try!(loop { let prot_conn = network_result_try!(loop {
let result_net_res = ProtocolNetworkConnection::connect( let result_net_res = ProtocolNetworkConnection::connect(
preferred_local_address, preferred_local_address,
&dial_info, &dial_info,
self.arc.connection_initial_timeout_ms, self.arc.connection_initial_timeout_ms,
self.network_manager().address_filter(), network_manager.address_filter(),
) )
.await; .await;
match result_net_res { match result_net_res {
@ -574,7 +568,7 @@ impl ConnectionManager {
// Called by low-level network when any connection-oriented protocol connection appears // Called by low-level network when any connection-oriented protocol connection appears
// either from incoming connections. // either from incoming connections.
#[cfg_attr(target_arch = "wasm32", expect(dead_code))] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), expect(dead_code))]
pub(super) async fn on_accepted_protocol_network_connection( pub(super) async fn on_accepted_protocol_network_connection(
&self, &self,
protocol_connection: ProtocolNetworkConnection, protocol_connection: ProtocolNetworkConnection,
@ -660,7 +654,7 @@ impl ConnectionManager {
// Reconnect the protected connection immediately // Reconnect the protected connection immediately
if reconnect { if reconnect {
if let Some(dial_info) = conn.dial_info() { if let Some(dial_info) = conn.dial_info() {
self.spawn_reconnector_inner(inner, dial_info); self.spawn_reconnector(dial_info);
} else { } else {
log_net!(debug "Can't reconnect to accepted protected connection: {} -> {} for node {}", conn.connection_id(), conn.debug_print(Timestamp::now()), protect_nr); log_net!(debug "Can't reconnect to accepted protected connection: {} -> {} for node {}", conn.connection_id(), conn.debug_print(Timestamp::now()), protect_nr);
} }
@ -675,9 +669,9 @@ impl ConnectionManager {
} }
} }
fn spawn_reconnector_inner(&self, inner: &mut ConnectionManagerInner, dial_info: DialInfo) { fn spawn_reconnector(&self, dial_info: DialInfo) {
let this = self.clone(); let this = self.clone();
inner.reconnection_processor.add( self.arc.reconnection_processor.add(
Box::pin(futures_util::stream::once(async { dial_info })), Box::pin(futures_util::stream::once(async { dial_info })),
move |dial_info| { move |dial_info| {
let this = this.clone(); let this = this.clone();

View File

@ -44,17 +44,20 @@ struct ConnectionTableInner {
protocol_index_by_id: BTreeMap<NetworkConnectionId, usize>, protocol_index_by_id: BTreeMap<NetworkConnectionId, usize>,
id_by_flow: BTreeMap<Flow, NetworkConnectionId>, id_by_flow: BTreeMap<Flow, NetworkConnectionId>,
ids_by_remote: BTreeMap<PeerAddress, Vec<NetworkConnectionId>>, ids_by_remote: BTreeMap<PeerAddress, Vec<NetworkConnectionId>>,
address_filter: AddressFilter,
priority_flows: Vec<LruCache<Flow, ()>>, priority_flows: Vec<LruCache<Flow, ()>>,
} }
#[derive(Debug)] #[derive(Debug)]
pub struct ConnectionTable { pub struct ConnectionTable {
inner: Arc<Mutex<ConnectionTableInner>>, registry: VeilidComponentRegistry,
inner: Mutex<ConnectionTableInner>,
} }
impl_veilid_component_registry_accessor!(ConnectionTable);
impl ConnectionTable { impl ConnectionTable {
pub fn new(config: VeilidConfig, address_filter: AddressFilter) -> Self { pub fn new(registry: VeilidComponentRegistry) -> Self {
let config = registry.config();
let max_connections = { let max_connections = {
let c = config.get(); let c = config.get();
vec![ vec![
@ -64,7 +67,8 @@ impl ConnectionTable {
] ]
}; };
Self { Self {
inner: Arc::new(Mutex::new(ConnectionTableInner { registry,
inner: Mutex::new(ConnectionTableInner {
conn_by_id: max_connections conn_by_id: max_connections
.iter() .iter()
.map(|_| LruCache::new_unbounded()) .map(|_| LruCache::new_unbounded())
@ -72,13 +76,12 @@ impl ConnectionTable {
protocol_index_by_id: BTreeMap::new(), protocol_index_by_id: BTreeMap::new(),
id_by_flow: BTreeMap::new(), id_by_flow: BTreeMap::new(),
ids_by_remote: BTreeMap::new(), ids_by_remote: BTreeMap::new(),
address_filter,
priority_flows: max_connections priority_flows: max_connections
.iter() .iter()
.map(|x| LruCache::new(x * PRIORITY_FLOW_PERCENTAGE / 100)) .map(|x| LruCache::new(x * PRIORITY_FLOW_PERCENTAGE / 100))
.collect(), .collect(),
max_connections, max_connections,
})), }),
} }
} }
@ -168,6 +171,7 @@ impl ConnectionTable {
/// when it is getting full while adding a new connection. /// when it is getting full while adding a new connection.
/// Factored out into its own function for clarity. /// Factored out into its own function for clarity.
fn lru_out_connection_inner( fn lru_out_connection_inner(
&self,
inner: &mut ConnectionTableInner, inner: &mut ConnectionTableInner,
protocol_index: usize, protocol_index: usize,
) -> Result<Option<NetworkConnection>, ()> { ) -> Result<Option<NetworkConnection>, ()> {
@ -198,7 +202,7 @@ impl ConnectionTable {
lruk lruk
}; };
let dead_conn = Self::remove_connection_records(inner, dead_k); let dead_conn = self.remove_connection_records_inner(inner, dead_k);
Ok(Some(dead_conn)) Ok(Some(dead_conn))
} }
@ -235,20 +239,21 @@ impl ConnectionTable {
// Filter by ip for connection limits // Filter by ip for connection limits
let ip_addr = flow.remote_address().ip_addr(); let ip_addr = flow.remote_address().ip_addr();
match inner.address_filter.add_connection(ip_addr) { if let Err(e) = self
Ok(()) => {} .network_manager()
Err(e) => { .address_filter()
// Return the connection in the error to be disposed of .add_connection(ip_addr)
return Err(ConnectionTableAddError::address_filter( {
network_connection, // Return the connection in the error to be disposed of
e, return Err(ConnectionTableAddError::address_filter(
)); network_connection,
} e,
}; ));
}
// if we have reached the maximum number of connections per protocol type // if we have reached the maximum number of connections per protocol type
// then drop the least recently used connection that is not protected or referenced // then drop the least recently used connection that is not protected or referenced
let out_conn = match Self::lru_out_connection_inner(&mut inner, protocol_index) { let out_conn = match self.lru_out_connection_inner(&mut inner, protocol_index) {
Ok(v) => v, Ok(v) => v,
Err(()) => { Err(()) => {
return Err(ConnectionTableAddError::table_full(network_connection)); return Err(ConnectionTableAddError::table_full(network_connection));
@ -437,7 +442,8 @@ impl ConnectionTable {
} }
#[instrument(level = "trace", skip(inner), ret)] #[instrument(level = "trace", skip(inner), ret)]
fn remove_connection_records( fn remove_connection_records_inner(
&self,
inner: &mut ConnectionTableInner, inner: &mut ConnectionTableInner,
id: NetworkConnectionId, id: NetworkConnectionId,
) -> NetworkConnection { ) -> NetworkConnection {
@ -462,8 +468,8 @@ impl ConnectionTable {
} }
// address_filter // address_filter
let ip_addr = remote.socket_addr().ip(); let ip_addr = remote.socket_addr().ip();
inner self.network_manager()
.address_filter .address_filter()
.remove_connection(ip_addr) .remove_connection(ip_addr)
.expect("Inconsistency in connection table"); .expect("Inconsistency in connection table");
conn conn
@ -477,7 +483,7 @@ impl ConnectionTable {
if !inner.conn_by_id[protocol_index].contains_key(&id) { if !inner.conn_by_id[protocol_index].contains_key(&id) {
return None; return None;
} }
let conn = Self::remove_connection_records(&mut inner, id); let conn = self.remove_connection_records_inner(&mut inner, id);
Some(conn) Some(conn)
} }

View File

@ -35,7 +35,7 @@ impl NetworkManager {
// Direct bootstrap request // Direct bootstrap request
#[instrument(level = "trace", target = "net", err, skip(self))] #[instrument(level = "trace", target = "net", err, skip(self))]
pub async fn boot_request(&self, dial_info: DialInfo) -> EyreResult<Vec<Arc<PeerInfo>>> { pub async fn boot_request(&self, dial_info: DialInfo) -> EyreResult<Vec<Arc<PeerInfo>>> {
let timeout_ms = self.with_config(|c| c.network.rpc.timeout_ms); let timeout_ms = self.config().with(|c| c.network.rpc.timeout_ms);
// Send boot magic to requested peer address // Send boot magic to requested peer address
let data = BOOT_MAGIC.to_vec(); let data = BOOT_MAGIC.to_vec();

View File

@ -1,8 +1,8 @@
use crate::*; use super::*;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod native; mod native;
#[cfg(target_arch = "wasm32")] #[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
mod wasm; mod wasm;
mod address_check; mod address_check;
@ -36,16 +36,15 @@ use connection_handle::*;
use crypto::*; use crypto::*;
use futures_util::stream::FuturesUnordered; use futures_util::stream::FuturesUnordered;
use hashlink::LruCache; use hashlink::LruCache;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
use native::*; use native::*;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
pub use native::{MAX_CAPABILITIES, PUBLIC_INTERNET_CAPABILITIES}; pub use native::{MAX_CAPABILITIES, PUBLIC_INTERNET_CAPABILITIES};
use routing_table::*; use routing_table::*;
use rpc_processor::*; use rpc_processor::*;
use storage_manager::*; #[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
#[cfg(target_arch = "wasm32")]
use wasm::*; use wasm::*;
#[cfg(target_arch = "wasm32")] #[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
pub use wasm::{/* LOCAL_NETWORK_CAPABILITIES, */ MAX_CAPABILITIES, PUBLIC_INTERNET_CAPABILITIES,}; pub use wasm::{/* LOCAL_NETWORK_CAPABILITIES, */ MAX_CAPABILITIES, PUBLIC_INTERNET_CAPABILITIES,};
//////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////
@ -65,7 +64,6 @@ pub const HOLE_PUNCH_DELAY_MS: u32 = 100;
struct NetworkComponents { struct NetworkComponents {
net: Network, net: Network,
connection_manager: ConnectionManager, connection_manager: ConnectionManager,
rpc_processor: RPCProcessor,
receipt_manager: ReceiptManager, receipt_manager: ReceiptManager,
} }
@ -119,45 +117,74 @@ enum SendDataToExistingFlowResult {
#[derive(Copy, Clone, Debug, Eq, PartialEq)] #[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum StartupDisposition { pub enum StartupDisposition {
Success, Success,
#[cfg_attr(target_arch = "wasm32", expect(dead_code))] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), expect(dead_code))]
BindRetry, BindRetry,
} }
#[derive(Debug, Clone)]
pub struct NetworkManagerStartupContext {
pub startup_lock: Arc<StartupLock>,
}
impl NetworkManagerStartupContext {
pub fn new() -> Self {
Self {
startup_lock: Arc::new(StartupLock::new()),
}
}
}
impl Default for NetworkManagerStartupContext {
fn default() -> Self {
Self::new()
}
}
// The mutable state of the network manager // The mutable state of the network manager
#[derive(Debug)]
struct NetworkManagerInner { struct NetworkManagerInner {
stats: NetworkManagerStats, stats: NetworkManagerStats,
client_allowlist: LruCache<TypedKey, ClientAllowlistEntry>, client_allowlist: LruCache<TypedKey, ClientAllowlistEntry>,
node_contact_method_cache: LruCache<NodeContactMethodCacheKey, NodeContactMethod>, node_contact_method_cache: LruCache<NodeContactMethodCacheKey, NodeContactMethod>,
address_check: Option<AddressCheck>, address_check: Option<AddressCheck>,
peer_info_change_subscription: Option<EventBusSubscription>,
socket_address_change_subscription: Option<EventBusSubscription>,
} }
struct NetworkManagerUnlockedInner { pub(crate) struct NetworkManager {
// Handles registry: VeilidComponentRegistry,
event_bus: EventBus, inner: Mutex<NetworkManagerInner>,
config: VeilidConfig,
storage_manager: StorageManager, // Address filter
table_store: TableStore, address_filter: AddressFilter,
#[cfg(feature = "unstable-blockstore")]
block_store: BlockStore,
crypto: Crypto,
// Accessors // Accessors
routing_table: RwLock<Option<RoutingTable>>,
address_filter: RwLock<Option<AddressFilter>>,
components: RwLock<Option<NetworkComponents>>, components: RwLock<Option<NetworkComponents>>,
update_callback: RwLock<Option<UpdateCallback>>,
// Background processes // Background processes
rolling_transfers_task: TickTask<EyreReport>, rolling_transfers_task: TickTask<EyreReport>,
address_filter_task: TickTask<EyreReport>, address_filter_task: TickTask<EyreReport>,
// Network Key
// Network key
network_key: Option<SharedSecret>, network_key: Option<SharedSecret>,
// Startup Lock
startup_lock: StartupLock, // Startup context
startup_context: NetworkManagerStartupContext,
} }
#[derive(Clone)] impl_veilid_component!(NetworkManager);
pub(crate) struct NetworkManager {
inner: Arc<Mutex<NetworkManagerInner>>, impl fmt::Debug for NetworkManager {
unlocked_inner: Arc<NetworkManagerUnlockedInner>, fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NetworkManager")
//.field("registry", &self.registry)
.field("inner", &self.inner)
.field("address_filter", &self.address_filter)
// .field("components", &self.components)
// .field("rolling_transfers_task", &self.rolling_transfers_task)
// .field("address_filter_task", &self.address_filter_task)
.field("network_key", &self.network_key)
.field("startup_context", &self.startup_context)
.finish()
}
} }
impl NetworkManager { impl NetworkManager {
@ -167,52 +194,20 @@ impl NetworkManager {
client_allowlist: LruCache::new_unbounded(), client_allowlist: LruCache::new_unbounded(),
node_contact_method_cache: LruCache::new(NODE_CONTACT_METHOD_CACHE_SIZE), node_contact_method_cache: LruCache::new(NODE_CONTACT_METHOD_CACHE_SIZE),
address_check: None, address_check: None,
} peer_info_change_subscription: None,
} socket_address_change_subscription: None,
fn new_unlocked_inner(
event_bus: EventBus,
config: VeilidConfig,
storage_manager: StorageManager,
table_store: TableStore,
#[cfg(feature = "unstable-blockstore")] block_store: BlockStore,
crypto: Crypto,
network_key: Option<SharedSecret>,
) -> NetworkManagerUnlockedInner {
NetworkManagerUnlockedInner {
event_bus,
config: config.clone(),
storage_manager,
table_store,
#[cfg(feature = "unstable-blockstore")]
block_store,
crypto,
address_filter: RwLock::new(None),
routing_table: RwLock::new(None),
components: RwLock::new(None),
update_callback: RwLock::new(None),
rolling_transfers_task: TickTask::new(
"rolling_transfers_task",
ROLLING_TRANSFERS_INTERVAL_SECS,
),
address_filter_task: TickTask::new(
"address_filter_task",
ADDRESS_FILTER_TASK_INTERVAL_SECS,
),
network_key,
startup_lock: StartupLock::new(),
} }
} }
pub fn new( pub fn new(
event_bus: EventBus, registry: VeilidComponentRegistry,
config: VeilidConfig, startup_context: NetworkManagerStartupContext,
storage_manager: StorageManager,
table_store: TableStore,
#[cfg(feature = "unstable-blockstore")] block_store: BlockStore,
crypto: Crypto,
) -> Self { ) -> Self {
// Make the network key // Make the network key
let network_key = { let network_key = {
let config = registry.config();
let crypto = registry.crypto();
let c = config.get(); let c = config.get();
let network_key_password = c.network.network_key_password.clone(); let network_key_password = c.network.network_key_password.clone();
let network_key = if let Some(network_key_password) = network_key_password { let network_key = if let Some(network_key_password) = network_key_password {
@ -238,110 +233,52 @@ impl NetworkManager {
network_key network_key
}; };
let inner = Self::new_inner();
let address_filter = AddressFilter::new(registry.clone());
let this = Self { let this = Self {
inner: Arc::new(Mutex::new(Self::new_inner())), registry,
unlocked_inner: Arc::new(Self::new_unlocked_inner( inner: Mutex::new(inner),
event_bus, address_filter,
config, components: RwLock::new(None),
storage_manager, rolling_transfers_task: TickTask::new(
table_store, "rolling_transfers_task",
#[cfg(feature = "unstable-blockstore")] ROLLING_TRANSFERS_INTERVAL_SECS,
block_store, ),
crypto, address_filter_task: TickTask::new(
network_key, "address_filter_task",
)), ADDRESS_FILTER_TASK_INTERVAL_SECS,
),
network_key,
startup_context,
}; };
this.setup_tasks(); this.setup_tasks();
this this
} }
pub fn event_bus(&self) -> EventBus {
self.unlocked_inner.event_bus.clone() pub fn address_filter(&self) -> &AddressFilter {
} &self.address_filter
pub fn config(&self) -> VeilidConfig {
self.unlocked_inner.config.clone()
}
pub fn with_config<F, R>(&self, f: F) -> R
where
F: FnOnce(&VeilidConfigInner) -> R,
{
f(&self.unlocked_inner.config.get())
}
pub fn storage_manager(&self) -> StorageManager {
self.unlocked_inner.storage_manager.clone()
}
pub fn table_store(&self) -> TableStore {
self.unlocked_inner.table_store.clone()
}
#[cfg(feature = "unstable-blockstore")]
pub fn block_store(&self) -> BlockStore {
self.unlocked_inner.block_store.clone()
}
pub fn crypto(&self) -> Crypto {
self.unlocked_inner.crypto.clone()
}
pub fn address_filter(&self) -> AddressFilter {
self.unlocked_inner
.address_filter
.read()
.as_ref()
.unwrap()
.clone()
}
pub fn routing_table(&self) -> RoutingTable {
self.unlocked_inner
.routing_table
.read()
.as_ref()
.unwrap()
.clone()
} }
fn net(&self) -> Network { fn net(&self) -> Network {
self.unlocked_inner self.components.read().as_ref().unwrap().net.clone()
.components
.read()
.as_ref()
.unwrap()
.net
.clone()
} }
fn opt_net(&self) -> Option<Network> { fn opt_net(&self) -> Option<Network> {
self.unlocked_inner self.components.read().as_ref().map(|x| x.net.clone())
.components
.read()
.as_ref()
.map(|x| x.net.clone())
} }
fn receipt_manager(&self) -> ReceiptManager { fn receipt_manager(&self) -> ReceiptManager {
self.unlocked_inner self.components
.components
.read() .read()
.as_ref() .as_ref()
.unwrap() .unwrap()
.receipt_manager .receipt_manager
.clone() .clone()
} }
pub fn rpc_processor(&self) -> RPCProcessor {
self.unlocked_inner
.components
.read()
.as_ref()
.unwrap()
.rpc_processor
.clone()
}
pub fn opt_rpc_processor(&self) -> Option<RPCProcessor> {
self.unlocked_inner
.components
.read()
.as_ref()
.map(|x| x.rpc_processor.clone())
}
pub fn connection_manager(&self) -> ConnectionManager { pub fn connection_manager(&self) -> ConnectionManager {
self.unlocked_inner self.components
.components
.read() .read()
.as_ref() .as_ref()
.unwrap() .unwrap()
@ -349,103 +286,48 @@ impl NetworkManager {
.clone() .clone()
} }
pub fn opt_connection_manager(&self) -> Option<ConnectionManager> { pub fn opt_connection_manager(&self) -> Option<ConnectionManager> {
self.unlocked_inner self.components
.components
.read() .read()
.as_ref() .as_ref()
.map(|x| x.connection_manager.clone()) .map(|x| x.connection_manager.clone())
} }
pub fn update_callback(&self) -> UpdateCallback {
self.unlocked_inner
.update_callback
.read()
.as_ref()
.unwrap()
.clone()
}
#[instrument(level = "debug", skip_all, err)] #[instrument(level = "debug", skip_all, err)]
pub async fn init(&self, update_callback: UpdateCallback) -> EyreResult<()> { async fn init_async(&self) -> EyreResult<()> {
let routing_table = RoutingTable::new(self.clone());
routing_table.init().await?;
let address_filter = AddressFilter::new(self.config(), routing_table.clone());
*self.unlocked_inner.routing_table.write() = Some(routing_table.clone());
*self.unlocked_inner.address_filter.write() = Some(address_filter);
*self.unlocked_inner.update_callback.write() = Some(update_callback);
// Register event handlers
let this = self.clone();
self.event_bus().subscribe(move |evt| {
let this = this.clone();
Box::pin(async move {
this.peer_info_change_event_handler(evt);
})
});
let this = self.clone();
self.event_bus().subscribe(move |evt| {
let this = this.clone();
Box::pin(async move {
this.socket_address_change_event_handler(evt);
})
});
Ok(()) Ok(())
} }
#[instrument(level = "debug", skip_all)] async fn post_init_async(&self) -> EyreResult<()> {
pub async fn terminate(&self) { Ok(())
let routing_table = self.unlocked_inner.routing_table.write().take();
if let Some(routing_table) = routing_table {
routing_table.terminate().await;
}
*self.unlocked_inner.update_callback.write() = None;
} }
async fn pre_terminate_async(&self) {}
#[instrument(level = "debug", skip_all)]
async fn terminate_async(&self) {}
#[instrument(level = "debug", skip_all, err)] #[instrument(level = "debug", skip_all, err)]
pub async fn internal_startup(&self) -> EyreResult<StartupDisposition> { pub async fn internal_startup(&self) -> EyreResult<StartupDisposition> {
if self.unlocked_inner.components.read().is_some() { if self.components.read().is_some() {
log_net!(debug "NetworkManager::internal_startup already started"); log_net!(debug "NetworkManager::internal_startup already started");
return Ok(StartupDisposition::Success); return Ok(StartupDisposition::Success);
} }
// Clean address filter for things that should not be persistent // Clean address filter for things that should not be persistent
self.address_filter().restart(); self.address_filter.restart();
// Create network components // Create network components
let connection_manager = ConnectionManager::new(self.clone()); let connection_manager = ConnectionManager::new(self.registry());
let net = Network::new( let net = Network::new(self.registry());
self.clone(),
self.routing_table(),
connection_manager.clone(),
);
let rpc_processor = RPCProcessor::new(
self.clone(),
self.unlocked_inner
.update_callback
.read()
.as_ref()
.unwrap()
.clone(),
);
let receipt_manager = ReceiptManager::new(); let receipt_manager = ReceiptManager::new();
*self.unlocked_inner.components.write() = Some(NetworkComponents {
*self.components.write() = Some(NetworkComponents {
net: net.clone(), net: net.clone(),
connection_manager: connection_manager.clone(), connection_manager: connection_manager.clone(),
rpc_processor: rpc_processor.clone(),
receipt_manager: receipt_manager.clone(), receipt_manager: receipt_manager.clone(),
}); });
// Start network components let (detect_address_changes, ip6_prefix_size) = self.config().with(|c| {
connection_manager.startup().await?;
match net.startup().await? {
StartupDisposition::Success => {}
StartupDisposition::BindRetry => {
return Ok(StartupDisposition::BindRetry);
}
}
let (detect_address_changes, ip6_prefix_size) = self.with_config(|c| {
( (
c.network.detect_address_changes, c.network.detect_address_changes,
c.network.max_connections_per_ip6_prefix_size as usize, c.network.max_connections_per_ip6_prefix_size as usize,
@ -456,9 +338,30 @@ impl NetworkManager {
ip6_prefix_size, ip6_prefix_size,
}; };
let address_check = AddressCheck::new(address_check_config, net.clone()); let address_check = AddressCheck::new(address_check_config, net.clone());
self.inner.lock().address_check = Some(address_check);
rpc_processor.startup().await?; // Register event handlers
let peer_info_change_subscription =
impl_subscribe_event_bus!(self, Self, peer_info_change_event_handler);
let socket_address_change_subscription =
impl_subscribe_event_bus!(self, Self, socket_address_change_event_handler);
{
let mut inner = self.inner.lock();
inner.address_check = Some(address_check);
inner.peer_info_change_subscription = Some(peer_info_change_subscription);
inner.socket_address_change_subscription = Some(socket_address_change_subscription);
}
// Start network components
connection_manager.startup().await?;
match net.startup().await? {
StartupDisposition::Success => {}
StartupDisposition::BindRetry => {
return Ok(StartupDisposition::BindRetry);
}
}
receipt_manager.startup().await?; receipt_manager.startup().await?;
log_net!("NetworkManager::internal_startup end"); log_net!("NetworkManager::internal_startup end");
@ -468,15 +371,11 @@ impl NetworkManager {
#[instrument(level = "debug", skip_all, err)] #[instrument(level = "debug", skip_all, err)]
pub async fn startup(&self) -> EyreResult<StartupDisposition> { pub async fn startup(&self) -> EyreResult<StartupDisposition> {
let guard = self.unlocked_inner.startup_lock.startup()?; let guard = self.startup_context.startup_lock.startup()?;
match self.internal_startup().await { match self.internal_startup().await {
Ok(StartupDisposition::Success) => { Ok(StartupDisposition::Success) => {
guard.success(); guard.success();
// Inform api clients that things have changed
self.send_network_update();
Ok(StartupDisposition::Success) Ok(StartupDisposition::Success)
} }
Ok(StartupDisposition::BindRetry) => { Ok(StartupDisposition::BindRetry) => {
@ -492,25 +391,30 @@ impl NetworkManager {
#[instrument(level = "debug", skip_all)] #[instrument(level = "debug", skip_all)]
async fn shutdown_internal(&self) { async fn shutdown_internal(&self) {
// Cancel all tasks // Shutdown event bus subscriptions and address check
self.cancel_tasks().await; {
let mut inner = self.inner.lock();
// Shutdown address check if let Some(sub) = inner.socket_address_change_subscription.take() {
self.inner.lock().address_check = Option::<AddressCheck>::None; self.event_bus().unsubscribe(sub);
}
if let Some(sub) = inner.peer_info_change_subscription.take() {
self.event_bus().unsubscribe(sub);
}
inner.address_check = None;
}
// Shutdown network components if they started up // Shutdown network components if they started up
log_net!(debug "shutting down network components"); log_net!(debug "shutting down network components");
{ {
let components = self.unlocked_inner.components.read().clone(); let components = self.components.read().clone();
if let Some(components) = components { if let Some(components) = components {
components.net.shutdown().await; components.net.shutdown().await;
components.rpc_processor.shutdown().await;
components.receipt_manager.shutdown().await; components.receipt_manager.shutdown().await;
components.connection_manager.shutdown().await; components.connection_manager.shutdown().await;
} }
} }
*self.unlocked_inner.components.write() = None; *self.components.write() = None;
// reset the state // reset the state
log_net!(debug "resetting network manager state"); log_net!(debug "resetting network manager state");
@ -521,21 +425,22 @@ impl NetworkManager {
#[instrument(level = "debug", skip_all)] #[instrument(level = "debug", skip_all)]
pub async fn shutdown(&self) { pub async fn shutdown(&self) {
log_net!(debug "starting network manager shutdown"); // Cancel all tasks
log_net!(debug "stopping network manager tasks");
self.cancel_tasks().await;
let Ok(guard) = self.unlocked_inner.startup_lock.shutdown().await else { // Proceed with shutdown
log_net!(debug "network manager is already shut down"); log_net!(debug "starting network manager shutdown");
return; let guard = self
}; .startup_context
.startup_lock
.shutdown()
.await
.expect("should be started up");
self.shutdown_internal().await; self.shutdown_internal().await;
guard.success(); guard.success();
// send update
log_net!(debug "sending network state update to api clients");
self.send_network_update();
log_net!(debug "finished network manager shutdown"); log_net!(debug "finished network manager shutdown");
} }
@ -568,7 +473,9 @@ impl NetworkManager {
} }
pub fn purge_client_allowlist(&self) { pub fn purge_client_allowlist(&self) {
let timeout_ms = self.with_config(|c| c.network.client_allowlist_timeout_ms); let timeout_ms = self
.config()
.with(|c| c.network.client_allowlist_timeout_ms);
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
let cutoff_timestamp = let cutoff_timestamp =
Timestamp::now() - TimestampDuration::new((timeout_ms as u64) * 1000u64); Timestamp::now() - TimestampDuration::new((timeout_ms as u64) * 1000u64);
@ -607,14 +514,15 @@ impl NetworkManager {
extra_data: D, extra_data: D,
callback: impl ReceiptCallback, callback: impl ReceiptCallback,
) -> EyreResult<Vec<u8>> { ) -> EyreResult<Vec<u8>> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_context.startup_lock.enter() else {
bail!("network is not started"); bail!("network is not started");
}; };
let receipt_manager = self.receipt_manager(); let receipt_manager = self.receipt_manager();
let routing_table = self.routing_table(); let routing_table = self.routing_table();
let crypto = self.crypto();
// Generate receipt and serialized form to return // Generate receipt and serialized form to return
let vcrypto = self.crypto().best(); let vcrypto = crypto.best();
let nonce = vcrypto.random_nonce(); let nonce = vcrypto.random_nonce();
let node_id = routing_table.node_id(vcrypto.kind()); let node_id = routing_table.node_id(vcrypto.kind());
@ -628,7 +536,7 @@ impl NetworkManager {
extra_data, extra_data,
)?; )?;
let out = receipt let out = receipt
.to_signed_data(self.crypto(), &node_id_secret) .to_signed_data(&crypto, &node_id_secret)
.wrap_err("failed to generate signed receipt")?; .wrap_err("failed to generate signed receipt")?;
// Record the receipt for later // Record the receipt for later
@ -645,15 +553,16 @@ impl NetworkManager {
expiration_us: TimestampDuration, expiration_us: TimestampDuration,
extra_data: D, extra_data: D,
) -> EyreResult<(Vec<u8>, EventualValueFuture<ReceiptEvent>)> { ) -> EyreResult<(Vec<u8>, EventualValueFuture<ReceiptEvent>)> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_context.startup_lock.enter() else {
bail!("network is not started"); bail!("network is not started");
}; };
let receipt_manager = self.receipt_manager(); let receipt_manager = self.receipt_manager();
let routing_table = self.routing_table(); let routing_table = self.routing_table();
let crypto = self.crypto();
// Generate receipt and serialized form to return // Generate receipt and serialized form to return
let vcrypto = self.crypto().best(); let vcrypto = crypto.best();
let nonce = vcrypto.random_nonce(); let nonce = vcrypto.random_nonce();
let node_id = routing_table.node_id(vcrypto.kind()); let node_id = routing_table.node_id(vcrypto.kind());
@ -667,7 +576,7 @@ impl NetworkManager {
extra_data, extra_data,
)?; )?;
let out = receipt let out = receipt
.to_signed_data(self.crypto(), &node_id_secret) .to_signed_data(&crypto, &node_id_secret)
.wrap_err("failed to generate signed receipt")?; .wrap_err("failed to generate signed receipt")?;
// Record the receipt for later // Record the receipt for later
@ -685,13 +594,14 @@ impl NetworkManager {
&self, &self,
receipt_data: R, receipt_data: R,
) -> NetworkResult<()> { ) -> NetworkResult<()> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_context.startup_lock.enter() else {
return NetworkResult::service_unavailable("network is not started"); return NetworkResult::service_unavailable("network is not started");
}; };
let receipt_manager = self.receipt_manager(); let receipt_manager = self.receipt_manager();
let crypto = self.crypto();
let receipt = match Receipt::from_signed_data(self.crypto(), receipt_data.as_ref()) { let receipt = match Receipt::from_signed_data(&crypto, receipt_data.as_ref()) {
Err(e) => { Err(e) => {
return NetworkResult::invalid_message(e.to_string()); return NetworkResult::invalid_message(e.to_string());
} }
@ -710,13 +620,14 @@ impl NetworkManager {
receipt_data: R, receipt_data: R,
inbound_noderef: FilteredNodeRef, inbound_noderef: FilteredNodeRef,
) -> NetworkResult<()> { ) -> NetworkResult<()> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_context.startup_lock.enter() else {
return NetworkResult::service_unavailable("network is not started"); return NetworkResult::service_unavailable("network is not started");
}; };
let receipt_manager = self.receipt_manager(); let receipt_manager = self.receipt_manager();
let crypto = self.crypto();
let receipt = match Receipt::from_signed_data(self.crypto(), receipt_data.as_ref()) { let receipt = match Receipt::from_signed_data(&crypto, receipt_data.as_ref()) {
Err(e) => { Err(e) => {
return NetworkResult::invalid_message(e.to_string()); return NetworkResult::invalid_message(e.to_string());
} }
@ -734,13 +645,14 @@ impl NetworkManager {
&self, &self,
receipt_data: R, receipt_data: R,
) -> NetworkResult<()> { ) -> NetworkResult<()> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_context.startup_lock.enter() else {
return NetworkResult::service_unavailable("network is not started"); return NetworkResult::service_unavailable("network is not started");
}; };
let receipt_manager = self.receipt_manager(); let receipt_manager = self.receipt_manager();
let crypto = self.crypto();
let receipt = match Receipt::from_signed_data(self.crypto(), receipt_data.as_ref()) { let receipt = match Receipt::from_signed_data(&crypto, receipt_data.as_ref()) {
Err(e) => { Err(e) => {
return NetworkResult::invalid_message(e.to_string()); return NetworkResult::invalid_message(e.to_string());
} }
@ -759,13 +671,14 @@ impl NetworkManager {
receipt_data: R, receipt_data: R,
private_route: PublicKey, private_route: PublicKey,
) -> NetworkResult<()> { ) -> NetworkResult<()> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_context.startup_lock.enter() else {
return NetworkResult::service_unavailable("network is not started"); return NetworkResult::service_unavailable("network is not started");
}; };
let receipt_manager = self.receipt_manager(); let receipt_manager = self.receipt_manager();
let crypto = self.crypto();
let receipt = match Receipt::from_signed_data(self.crypto(), receipt_data.as_ref()) { let receipt = match Receipt::from_signed_data(&crypto, receipt_data.as_ref()) {
Err(e) => { Err(e) => {
return NetworkResult::invalid_message(e.to_string()); return NetworkResult::invalid_message(e.to_string());
} }
@ -784,7 +697,7 @@ impl NetworkManager {
signal_flow: Flow, signal_flow: Flow,
signal_info: SignalInfo, signal_info: SignalInfo,
) -> EyreResult<NetworkResult<()>> { ) -> EyreResult<NetworkResult<()>> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_context.startup_lock.enter() else {
return Ok(NetworkResult::service_unavailable("network is not started")); return Ok(NetworkResult::service_unavailable("network is not started"));
}; };
@ -884,7 +797,8 @@ impl NetworkManager {
) -> EyreResult<Vec<u8>> { ) -> EyreResult<Vec<u8>> {
// DH to get encryption key // DH to get encryption key
let routing_table = self.routing_table(); let routing_table = self.routing_table();
let Some(vcrypto) = self.crypto().get(dest_node_id.kind) else { let crypto = self.crypto();
let Some(vcrypto) = crypto.get(dest_node_id.kind) else {
bail!("should not have a destination with incompatible crypto here"); bail!("should not have a destination with incompatible crypto here");
}; };
@ -905,12 +819,7 @@ impl NetworkManager {
dest_node_id.value, dest_node_id.value,
); );
envelope envelope
.to_encrypted_data( .to_encrypted_data(&crypto, body.as_ref(), &node_id_secret, &self.network_key)
self.crypto(),
body.as_ref(),
&node_id_secret,
&self.unlocked_inner.network_key,
)
.wrap_err("envelope failed to encode") .wrap_err("envelope failed to encode")
} }
@ -925,7 +834,7 @@ impl NetworkManager {
destination_node_ref: Option<NodeRef>, destination_node_ref: Option<NodeRef>,
body: B, body: B,
) -> EyreResult<NetworkResult<SendDataMethod>> { ) -> EyreResult<NetworkResult<SendDataMethod>> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_context.startup_lock.enter() else {
return Ok(NetworkResult::no_connection_other("network is not started")); return Ok(NetworkResult::no_connection_other("network is not started"));
}; };
@ -966,7 +875,7 @@ impl NetworkManager {
dial_info: DialInfo, dial_info: DialInfo,
rcpt_data: Vec<u8>, rcpt_data: Vec<u8>,
) -> EyreResult<()> { ) -> EyreResult<()> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_context.startup_lock.enter() else {
log_net!(debug "not sending out-of-band receipt to {} because network is stopped", dial_info); log_net!(debug "not sending out-of-band receipt to {} because network is stopped", dial_info);
return Ok(()); return Ok(());
}; };
@ -993,7 +902,7 @@ impl NetworkManager {
// and passes it to the RPC handler // and passes it to the RPC handler
#[instrument(level = "trace", target = "net", skip_all)] #[instrument(level = "trace", target = "net", skip_all)]
async fn on_recv_envelope(&self, data: &mut [u8], flow: Flow) -> EyreResult<bool> { async fn on_recv_envelope(&self, data: &mut [u8], flow: Flow) -> EyreResult<bool> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_context.startup_lock.enter() else {
return Ok(false); return Ok(false);
}; };
@ -1043,21 +952,20 @@ impl NetworkManager {
} }
// Decode envelope header (may fail signature validation) // Decode envelope header (may fail signature validation)
let envelope = let crypto = self.crypto();
match Envelope::from_signed_data(self.crypto(), data, &self.unlocked_inner.network_key) let envelope = match Envelope::from_signed_data(&crypto, data, &self.network_key) {
{ Ok(v) => v,
Ok(v) => v, Err(e) => {
Err(e) => { log_net!(debug "envelope failed to decode: {}", e);
log_net!(debug "envelope failed to decode: {}", e); // safe to punish here because relays also check here to ensure they arent forwarding things that don't decode
// safe to punish here because relays also check here to ensure they arent forwarding things that don't decode self.address_filter()
self.address_filter() .punish_ip_addr(remote_addr, PunishmentReason::FailedToDecodeEnvelope);
.punish_ip_addr(remote_addr, PunishmentReason::FailedToDecodeEnvelope); return Ok(false);
return Ok(false); }
} };
};
// Get timestamp range // Get timestamp range
let (tsbehind, tsahead) = self.with_config(|c| { let (tsbehind, tsahead) = self.config().with(|c| {
( (
c.network c.network
.rpc .rpc
@ -1136,7 +1044,10 @@ impl NetworkManager {
// which only performs a lightweight lookup before passing the packet back out // which only performs a lightweight lookup before passing the packet back out
// If our node has the relay capability disabled, we should not be asked to relay // If our node has the relay capability disabled, we should not be asked to relay
if self.with_config(|c| c.capabilities.disable.contains(&CAP_RELAY)) { if self
.config()
.with(|c| c.capabilities.disable.contains(&CAP_RELAY))
{
log_net!(debug "node has relay capability disabled, dropping relayed envelope from {} to {}", sender_id, recipient_id); log_net!(debug "node has relay capability disabled, dropping relayed envelope from {} to {}", sender_id, recipient_id);
return Ok(false); return Ok(false);
} }
@ -1191,12 +1102,8 @@ impl NetworkManager {
let node_id_secret = routing_table.node_id_secret_key(envelope.get_crypto_kind()); let node_id_secret = routing_table.node_id_secret_key(envelope.get_crypto_kind());
// Decrypt the envelope body // Decrypt the envelope body
let body = match envelope.decrypt_body( let crypto = self.crypto();
self.crypto(), let body = match envelope.decrypt_body(&crypto, data, &node_id_secret, &self.network_key) {
data,
&node_id_secret,
&self.unlocked_inner.network_key,
) {
Ok(v) => v, Ok(v) => v,
Err(e) => { Err(e) => {
log_net!(debug "failed to decrypt envelope body: {}", e); log_net!(debug "failed to decrypt envelope body: {}", e);

View File

@ -2,6 +2,7 @@
/// Also performs UPNP/IGD mapping if enabled and possible /// Also performs UPNP/IGD mapping if enabled and possible
use super::*; use super::*;
use futures_util::stream::FuturesUnordered; use futures_util::stream::FuturesUnordered;
use igd_manager::{IGDAddressType, IGDProtocolType};
const PORT_MAP_VALIDATE_TRY_COUNT: usize = 3; const PORT_MAP_VALIDATE_TRY_COUNT: usize = 3;
const PORT_MAP_VALIDATE_DELAY_MS: u32 = 500; const PORT_MAP_VALIDATE_DELAY_MS: u32 = 500;
@ -42,9 +43,7 @@ struct DiscoveryContextInner {
external_info: Vec<ExternalInfo>, external_info: Vec<ExternalInfo>,
} }
struct DiscoveryContextUnlockedInner { pub(super) struct DiscoveryContextUnlockedInner {
routing_table: RoutingTable,
net: Network,
config: DiscoveryContextConfig, config: DiscoveryContextConfig,
// per-protocol // per-protocol
@ -53,25 +52,30 @@ struct DiscoveryContextUnlockedInner {
#[derive(Clone)] #[derive(Clone)]
pub(super) struct DiscoveryContext { pub(super) struct DiscoveryContext {
registry: VeilidComponentRegistry,
unlocked_inner: Arc<DiscoveryContextUnlockedInner>, unlocked_inner: Arc<DiscoveryContextUnlockedInner>,
inner: Arc<Mutex<DiscoveryContextInner>>, inner: Arc<Mutex<DiscoveryContextInner>>,
} }
impl_veilid_component_registry_accessor!(DiscoveryContext);
impl core::ops::Deref for DiscoveryContext {
type Target = DiscoveryContextUnlockedInner;
fn deref(&self) -> &Self::Target {
&self.unlocked_inner
}
}
impl DiscoveryContext { impl DiscoveryContext {
pub fn new(routing_table: RoutingTable, net: Network, config: DiscoveryContextConfig) -> Self { pub fn new(registry: VeilidComponentRegistry, config: DiscoveryContextConfig) -> Self {
let intf_addrs = Self::get_local_addresses( let routing_table = registry.routing_table();
routing_table.clone(), let intf_addrs =
config.protocol_type, Self::get_local_addresses(&routing_table, config.protocol_type, config.address_type);
config.address_type,
);
Self { Self {
unlocked_inner: Arc::new(DiscoveryContextUnlockedInner { registry,
routing_table, unlocked_inner: Arc::new(DiscoveryContextUnlockedInner { config, intf_addrs }),
net,
config,
intf_addrs,
}),
inner: Arc::new(Mutex::new(DiscoveryContextInner { inner: Arc::new(Mutex::new(DiscoveryContextInner {
external_info: Vec::new(), external_info: Vec::new(),
})), })),
@ -84,7 +88,7 @@ impl DiscoveryContext {
// This pulls the already-detected local interface dial info from the routing table // This pulls the already-detected local interface dial info from the routing table
#[instrument(level = "trace", skip(routing_table), ret)] #[instrument(level = "trace", skip(routing_table), ret)]
fn get_local_addresses( fn get_local_addresses(
routing_table: RoutingTable, routing_table: &RoutingTable,
protocol_type: ProtocolType, protocol_type: ProtocolType,
address_type: AddressType, address_type: AddressType,
) -> Vec<SocketAddress> { ) -> Vec<SocketAddress> {
@ -108,7 +112,7 @@ impl DiscoveryContext {
// This is done over the normal port using RPC // This is done over the normal port using RPC
#[instrument(level = "trace", skip(self), ret)] #[instrument(level = "trace", skip(self), ret)]
async fn request_public_address(&self, node_ref: FilteredNodeRef) -> Option<SocketAddress> { async fn request_public_address(&self, node_ref: FilteredNodeRef) -> Option<SocketAddress> {
let rpc = self.unlocked_inner.routing_table.rpc_processor(); let rpc = self.rpc_processor();
let res = network_result_value_or_log!(match rpc.rpc_call_status(Destination::direct(node_ref.clone())).await { let res = network_result_value_or_log!(match rpc.rpc_call_status(Destination::direct(node_ref.clone())).await {
Ok(v) => v, Ok(v) => v,
@ -136,16 +140,14 @@ impl DiscoveryContext {
// This is done over the normal port using RPC // This is done over the normal port using RPC
#[instrument(level = "trace", skip(self), ret)] #[instrument(level = "trace", skip(self), ret)]
async fn discover_external_addresses(&self) -> bool { async fn discover_external_addresses(&self) -> bool {
let node_count = { let node_count = self
let config = self.unlocked_inner.routing_table.network_manager().config(); .config()
let c = config.get(); .with(|c| c.network.dht.max_find_node_count as usize);
c.network.dht.max_find_node_count as usize
};
let routing_domain = RoutingDomain::PublicInternet; let routing_domain = RoutingDomain::PublicInternet;
let protocol_type = self.unlocked_inner.config.protocol_type; let protocol_type = self.config.protocol_type;
let address_type = self.unlocked_inner.config.address_type; let address_type = self.config.address_type;
let port = self.unlocked_inner.config.port; let port = self.config.port;
// Build an filter that matches our protocol and address type // Build an filter that matches our protocol and address type
// and excludes relayed nodes so we can get an accurate external address // and excludes relayed nodes so we can get an accurate external address
@ -187,10 +189,11 @@ impl DiscoveryContext {
]); ]);
// Find public nodes matching this filter // Find public nodes matching this filter
let nodes = self let nodes = self.routing_table().find_fast_non_local_nodes_filtered(
.unlocked_inner routing_domain,
.routing_table node_count,
.find_fast_non_local_nodes_filtered(routing_domain, node_count, filters); filters,
);
if nodes.is_empty() { if nodes.is_empty() {
log_net!(debug log_net!(debug
"no external address detection peers of type {:?}:{:?}", "no external address detection peers of type {:?}:{:?}",
@ -212,8 +215,8 @@ impl DiscoveryContext {
async move { async move {
if let Some(address) = this.request_public_address(node.clone()).await { if let Some(address) = this.request_public_address(node.clone()).await {
let dial_info = this let dial_info = this
.unlocked_inner .network_manager()
.net .net()
.make_dial_info(address, protocol_type); .make_dial_info(address, protocol_type);
return Some(ExternalInfo { return Some(ExternalInfo {
dial_info, dial_info,
@ -297,10 +300,9 @@ impl DiscoveryContext {
dial_info: DialInfo, dial_info: DialInfo,
redirect: bool, redirect: bool,
) -> bool { ) -> bool {
let rpc_processor = self.unlocked_inner.routing_table.rpc_processor();
// ask the node to send us a dial info validation receipt // ask the node to send us a dial info validation receipt
match rpc_processor match self
.rpc_processor()
.rpc_call_validate_dial_info(node_ref.clone(), dial_info, redirect) .rpc_call_validate_dial_info(node_ref.clone(), dial_info, redirect)
.await .await
{ {
@ -314,14 +316,22 @@ impl DiscoveryContext {
#[instrument(level = "trace", skip(self), ret)] #[instrument(level = "trace", skip(self), ret)]
async fn try_upnp_port_mapping(&self) -> Option<DialInfo> { async fn try_upnp_port_mapping(&self) -> Option<DialInfo> {
let protocol_type = self.unlocked_inner.config.protocol_type; let protocol_type = self.config.protocol_type;
let address_type = self.unlocked_inner.config.address_type; let address_type = self.config.address_type;
let local_port = self.unlocked_inner.config.port; let local_port = self.config.port;
let igd_protocol_type = match protocol_type.low_level_protocol_type() {
LowLevelProtocolType::UDP => IGDProtocolType::UDP,
LowLevelProtocolType::TCP => IGDProtocolType::TCP,
};
let igd_address_type = match address_type {
AddressType::IPV6 => IGDAddressType::IPV6,
AddressType::IPV4 => IGDAddressType::IPV4,
};
let low_level_protocol_type = protocol_type.low_level_protocol_type();
let external_1 = self.inner.lock().external_info.first().unwrap().clone(); let external_1 = self.inner.lock().external_info.first().unwrap().clone();
let igd_manager = self.unlocked_inner.net.unlocked_inner.igd_manager.clone(); let igd_manager = self.network_manager().net().igd_manager.clone();
let mut tries = 0; let mut tries = 0;
loop { loop {
tries += 1; tries += 1;
@ -329,15 +339,15 @@ impl DiscoveryContext {
// Attempt a port mapping. If this doesn't succeed, it's not going to // Attempt a port mapping. If this doesn't succeed, it's not going to
let mapped_external_address = igd_manager let mapped_external_address = igd_manager
.map_any_port( .map_any_port(
low_level_protocol_type, igd_protocol_type,
address_type, igd_address_type,
local_port, local_port,
Some(external_1.address.ip_addr()), Some(external_1.address.ip_addr()),
) )
.await?; .await?;
// Make dial info from the port mapping // Make dial info from the port mapping
let external_mapped_dial_info = self.unlocked_inner.net.make_dial_info( let external_mapped_dial_info = self.network_manager().net().make_dial_info(
SocketAddress::from_socket_addr(mapped_external_address), SocketAddress::from_socket_addr(mapped_external_address),
protocol_type, protocol_type,
); );
@ -361,10 +371,7 @@ impl DiscoveryContext {
if validate_tries != PORT_MAP_VALIDATE_TRY_COUNT { if validate_tries != PORT_MAP_VALIDATE_TRY_COUNT {
log_net!(debug "UPNP port mapping succeeded but port {}/{} is still unreachable.\nretrying\n", log_net!(debug "UPNP port mapping succeeded but port {}/{} is still unreachable.\nretrying\n",
local_port, match low_level_protocol_type { local_port, igd_protocol_type);
LowLevelProtocolType::UDP => "udp",
LowLevelProtocolType::TCP => "tcp",
});
sleep(PORT_MAP_VALIDATE_DELAY_MS).await sleep(PORT_MAP_VALIDATE_DELAY_MS).await
} else { } else {
break; break;
@ -374,18 +381,15 @@ impl DiscoveryContext {
// Release the mapping if we're still unreachable // Release the mapping if we're still unreachable
let _ = igd_manager let _ = igd_manager
.unmap_port( .unmap_port(
low_level_protocol_type, igd_protocol_type,
address_type, igd_address_type,
external_1.address.port(), external_1.address.port(),
) )
.await; .await;
if tries == PORT_MAP_TRY_COUNT { if tries == PORT_MAP_TRY_COUNT {
warn!("UPNP port mapping succeeded but port {}/{} is still unreachable.\nYou may need to add a local firewall allowed port on this machine.\n", warn!("UPNP port mapping succeeded but port {}/{} is still unreachable.\nYou may need to add a local firewall allowed port on this machine.\n",
local_port, match low_level_protocol_type { local_port, igd_protocol_type
LowLevelProtocolType::UDP => "udp",
LowLevelProtocolType::TCP => "tcp",
}
); );
break; break;
} }
@ -413,7 +417,7 @@ impl DiscoveryContext {
{ {
// Add public dial info with Direct dialinfo class // Add public dial info with Direct dialinfo class
Some(DetectionResult { Some(DetectionResult {
config: this.unlocked_inner.config, config: this.config,
ddi: DetectedDialInfo::Detected(DialInfoDetail { ddi: DetectedDialInfo::Detected(DialInfoDetail {
dial_info: external_1.dial_info.clone(), dial_info: external_1.dial_info.clone(),
class: DialInfoClass::Direct, class: DialInfoClass::Direct,
@ -423,7 +427,7 @@ impl DiscoveryContext {
} else { } else {
// Add public dial info with Blocked dialinfo class // Add public dial info with Blocked dialinfo class
Some(DetectionResult { Some(DetectionResult {
config: this.unlocked_inner.config, config: this.config,
ddi: DetectedDialInfo::Detected(DialInfoDetail { ddi: DetectedDialInfo::Detected(DialInfoDetail {
dial_info: external_1.dial_info.clone(), dial_info: external_1.dial_info.clone(),
class: DialInfoClass::Blocked, class: DialInfoClass::Blocked,
@ -445,7 +449,7 @@ impl DiscoveryContext {
let inner = self.inner.lock(); let inner = self.inner.lock();
inner.external_info.clone() inner.external_info.clone()
}; };
let local_port = self.unlocked_inner.config.port; let local_port = self.config.port;
// Get the external dial info histogram for our use here // Get the external dial info histogram for our use here
let mut external_info_addr_port_hist = HashMap::<SocketAddress, usize>::new(); let mut external_info_addr_port_hist = HashMap::<SocketAddress, usize>::new();
@ -501,7 +505,7 @@ impl DiscoveryContext {
let do_symmetric_nat_fut: SendPinBoxFuture<Option<DetectionResult>> = let do_symmetric_nat_fut: SendPinBoxFuture<Option<DetectionResult>> =
Box::pin(async move { Box::pin(async move {
Some(DetectionResult { Some(DetectionResult {
config: this.unlocked_inner.config, config: this.config,
ddi: DetectedDialInfo::SymmetricNAT, ddi: DetectedDialInfo::SymmetricNAT,
external_address_types, external_address_types,
}) })
@ -535,7 +539,7 @@ impl DiscoveryContext {
{ {
// Add public dial info with Direct dialinfo class // Add public dial info with Direct dialinfo class
return Some(DetectionResult { return Some(DetectionResult {
config: c_this.unlocked_inner.config, config: c_this.config,
ddi: DetectedDialInfo::Detected(DialInfoDetail { ddi: DetectedDialInfo::Detected(DialInfoDetail {
dial_info: external_1_dial_info_with_local_port, dial_info: external_1_dial_info_with_local_port,
class: DialInfoClass::Direct, class: DialInfoClass::Direct,
@ -558,10 +562,7 @@ impl DiscoveryContext {
/////////// ///////////
let this = self.clone(); let this = self.clone();
let do_nat_detect_fut: SendPinBoxFuture<Option<DetectionResult>> = Box::pin(async move { let do_nat_detect_fut: SendPinBoxFuture<Option<DetectionResult>> = Box::pin(async move {
let mut retry_count = { let mut retry_count = this.config().with(|c| c.network.restricted_nat_retries);
let c = this.unlocked_inner.net.config.get();
c.network.restricted_nat_retries
};
// Loop for restricted NAT retries // Loop for restricted NAT retries
loop { loop {
@ -585,7 +586,7 @@ impl DiscoveryContext {
// Add public dial info with full cone NAT network class // Add public dial info with full cone NAT network class
return Some(DetectionResult { return Some(DetectionResult {
config: c_this.unlocked_inner.config, config: c_this.config,
ddi: DetectedDialInfo::Detected(DialInfoDetail { ddi: DetectedDialInfo::Detected(DialInfoDetail {
dial_info: c_external_1.dial_info, dial_info: c_external_1.dial_info,
class: DialInfoClass::FullConeNAT, class: DialInfoClass::FullConeNAT,
@ -620,7 +621,7 @@ impl DiscoveryContext {
{ {
// Got a reply from a non-default port, which means we're only address restricted // Got a reply from a non-default port, which means we're only address restricted
return Some(DetectionResult { return Some(DetectionResult {
config: c_this.unlocked_inner.config, config: c_this.config,
ddi: DetectedDialInfo::Detected(DialInfoDetail { ddi: DetectedDialInfo::Detected(DialInfoDetail {
dial_info: c_external_1.dial_info.clone(), dial_info: c_external_1.dial_info.clone(),
class: DialInfoClass::AddressRestrictedNAT, class: DialInfoClass::AddressRestrictedNAT,
@ -632,7 +633,7 @@ impl DiscoveryContext {
} }
// Didn't get a reply from a non-default port, which means we are also port restricted // Didn't get a reply from a non-default port, which means we are also port restricted
Some(DetectionResult { Some(DetectionResult {
config: c_this.unlocked_inner.config, config: c_this.config,
ddi: DetectedDialInfo::Detected(DialInfoDetail { ddi: DetectedDialInfo::Detected(DialInfoDetail {
dial_info: c_external_1.dial_info.clone(), dial_info: c_external_1.dial_info.clone(),
class: DialInfoClass::PortRestrictedNAT, class: DialInfoClass::PortRestrictedNAT,
@ -678,10 +679,7 @@ impl DiscoveryContext {
&self, &self,
unord: &mut FuturesUnordered<SendPinBoxFuture<Option<DetectionResult>>>, unord: &mut FuturesUnordered<SendPinBoxFuture<Option<DetectionResult>>>,
) { ) {
let enable_upnp = { let enable_upnp = self.config().with(|c| c.network.upnp);
let c = self.unlocked_inner.net.config.get();
c.network.upnp
};
// Do this right away because it's fast and every detection is going to need it // Do this right away because it's fast and every detection is going to need it
// Get our external addresses from two fast nodes // Get our external addresses from two fast nodes
@ -701,7 +699,7 @@ impl DiscoveryContext {
if let Some(external_mapped_dial_info) = this.try_upnp_port_mapping().await { if let Some(external_mapped_dial_info) = this.try_upnp_port_mapping().await {
// Got a port mapping, let's use it // Got a port mapping, let's use it
return Some(DetectionResult { return Some(DetectionResult {
config: this.unlocked_inner.config, config: this.config,
ddi: DetectedDialInfo::Detected(DialInfoDetail { ddi: DetectedDialInfo::Detected(DialInfoDetail {
dial_info: external_mapped_dial_info.clone(), dial_info: external_mapped_dial_info.clone(),
class: DialInfoClass::Mapped, class: DialInfoClass::Mapped,
@ -725,12 +723,7 @@ impl DiscoveryContext {
.lock() .lock()
.external_info .external_info
.iter() .iter()
.find_map(|ei| { .find_map(|ei| self.intf_addrs.contains(&ei.address).then_some(true))
self.unlocked_inner
.intf_addrs
.contains(&ei.address)
.then_some(true)
})
.unwrap_or_default(); .unwrap_or_default();
if local_address_in_external_info { if local_address_in_external_info {

View File

@ -5,13 +5,12 @@ use std::net::UdpSocket;
const UPNP_GATEWAY_DETECT_TIMEOUT_MS: u32 = 5_000; const UPNP_GATEWAY_DETECT_TIMEOUT_MS: u32 = 5_000;
const UPNP_MAPPING_LIFETIME_MS: u32 = 120_000; const UPNP_MAPPING_LIFETIME_MS: u32 = 120_000;
const UPNP_MAPPING_ATTEMPTS: u32 = 3; const UPNP_MAPPING_ATTEMPTS: u32 = 3;
const UPNP_MAPPING_LIFETIME_US: TimestampDuration = const UPNP_MAPPING_LIFETIME_US: u64 = UPNP_MAPPING_LIFETIME_MS as u64 * 1000u64;
TimestampDuration::new(UPNP_MAPPING_LIFETIME_MS as u64 * 1000u64);
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct PortMapKey { struct PortMapKey {
llpt: LowLevelProtocolType, protocol_type: IGDProtocolType,
at: AddressType, address_type: IGDAddressType,
local_port: u16, local_port: u16,
} }
@ -19,36 +18,67 @@ struct PortMapKey {
struct PortMapValue { struct PortMapValue {
ext_ip: IpAddr, ext_ip: IpAddr,
mapped_port: u16, mapped_port: u16,
timestamp: Timestamp, timestamp: u64,
renewal_lifetime: TimestampDuration, renewal_lifetime: u64,
renewal_attempts: u32, renewal_attempts: u32,
} }
struct IGDManagerInner { struct IGDManagerInner {
local_ip_addrs: BTreeMap<AddressType, IpAddr>, local_ip_addrs: BTreeMap<IGDAddressType, IpAddr>,
gateways: BTreeMap<IpAddr, Arc<Gateway>>, gateways: BTreeMap<IpAddr, Arc<Gateway>>,
port_maps: BTreeMap<PortMapKey, PortMapValue>, port_maps: BTreeMap<PortMapKey, PortMapValue>,
} }
#[derive(Clone)] #[derive(Clone)]
pub struct IGDManager { pub struct IGDManager {
config: VeilidConfig, program_name: String,
inner: Arc<Mutex<IGDManagerInner>>, inner: Arc<Mutex<IGDManagerInner>>,
} }
fn convert_llpt(llpt: LowLevelProtocolType) -> PortMappingProtocol { fn convert_protocol_type(igdpt: IGDProtocolType) -> PortMappingProtocol {
match llpt { match igdpt {
LowLevelProtocolType::UDP => PortMappingProtocol::UDP, IGDProtocolType::UDP => PortMappingProtocol::UDP,
LowLevelProtocolType::TCP => PortMappingProtocol::TCP, IGDProtocolType::TCP => PortMappingProtocol::TCP,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum IGDAddressType {
IPV6,
IPV4,
}
impl fmt::Display for IGDAddressType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
IGDAddressType::IPV6 => write!(f, "IPV6"),
IGDAddressType::IPV4 => write!(f, "IPV4"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum IGDProtocolType {
UDP,
TCP,
}
impl fmt::Display for IGDProtocolType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
IGDProtocolType::UDP => write!(f, "UDP"),
IGDProtocolType::TCP => write!(f, "TCP"),
}
} }
} }
impl IGDManager { impl IGDManager {
// /////////////////////////////////////////////////////////////////////
// Public Interface
pub fn new(config: VeilidConfig) -> Self { pub fn new(program_name: String) -> Self {
Self { Self {
config, program_name,
inner: Arc::new(Mutex::new(IGDManagerInner { inner: Arc::new(Mutex::new(IGDManagerInner {
local_ip_addrs: BTreeMap::new(), local_ip_addrs: BTreeMap::new(),
gateways: BTreeMap::new(), gateways: BTreeMap::new(),
@ -58,10 +88,306 @@ impl IGDManager {
} }
#[instrument(level = "trace", target = "net", skip_all)] #[instrument(level = "trace", target = "net", skip_all)]
fn get_routed_local_ip_address(address_type: AddressType) -> Option<IpAddr> { pub async fn unmap_port(
&self,
protocol_type: IGDProtocolType,
address_type: IGDAddressType,
mapped_port: u16,
) -> Option<()> {
let this = self.clone();
blocking_wrapper(
"igd unmap_port",
move || {
let mut inner = this.inner.lock();
// If we already have this port mapped, just return the existing portmap
let mut found = None;
for (pmk, pmv) in &inner.port_maps {
if pmk.protocol_type == protocol_type
&& pmk.address_type == address_type
&& pmv.mapped_port == mapped_port
{
found = Some(*pmk);
break;
}
}
let pmk = found?;
let _pmv = inner
.port_maps
.remove(&pmk)
.expect("key found but remove failed");
// Get local ip address
let local_ip = Self::find_local_ip(&mut inner, address_type)?;
// Find gateway
let gw = Self::find_gateway(&mut inner, local_ip)?;
// Unmap port
match gw.remove_port(convert_protocol_type(protocol_type), mapped_port) {
Ok(()) => (),
Err(e) => {
// Failed to map external port
log_net!(debug "upnp failed to remove external port: {}", e);
return None;
}
};
Some(())
},
None,
)
.await
}
#[instrument(level = "trace", target = "net", skip_all)]
pub async fn map_any_port(
&self,
protocol_type: IGDProtocolType,
address_type: IGDAddressType,
local_port: u16,
expected_external_address: Option<IpAddr>,
) -> Option<SocketAddr> {
let this = self.clone();
blocking_wrapper("igd map_any_port", move || {
let mut inner = this.inner.lock();
// If we already have this port mapped, just return the existing portmap
let pmkey = PortMapKey {
protocol_type,
address_type,
local_port,
};
if let Some(pmval) = inner.port_maps.get(&pmkey) {
return Some(SocketAddr::new(pmval.ext_ip, pmval.mapped_port));
}
// Get local ip address
let local_ip = Self::find_local_ip(&mut inner, address_type)?;
// Find gateway
let gw = Self::find_gateway(&mut inner, local_ip)?;
// Get external address
let ext_ip = match gw.get_external_ip() {
Ok(ip) => ip,
Err(e) => {
log_net!(debug "couldn't get external ip from igd: {}", e);
return None;
}
};
// Ensure external IP matches address type
if ext_ip.is_ipv4() && address_type != IGDAddressType::IPV4 {
log_net!(debug "mismatched ip address type from igd, wanted v4, got v6");
return None;
} else if ext_ip.is_ipv6() && address_type != IGDAddressType::IPV6 {
log_net!(debug "mismatched ip address type from igd, wanted v6, got v4");
return None;
}
if let Some(expected_external_address) = expected_external_address {
if ext_ip != expected_external_address {
log_net!(debug "gateway external address does not match calculated external address: expected={} vs gateway={}", expected_external_address, ext_ip);
return None;
}
}
// Map any port
let desc = this.get_description(protocol_type, local_port);
let mapped_port = match gw.add_any_port(convert_protocol_type(protocol_type), SocketAddr::new(local_ip, local_port), (UPNP_MAPPING_LIFETIME_MS + 999) / 1000, &desc) {
Ok(mapped_port) => mapped_port,
Err(e) => {
// Failed to map external port
log_net!(debug "upnp failed to map external port: {}", e);
return None;
}
};
// Add to mapping list to keep alive
let timestamp = get_timestamp();
inner.port_maps.insert(PortMapKey {
protocol_type,
address_type,
local_port,
}, PortMapValue {
ext_ip,
mapped_port,
timestamp,
renewal_lifetime: (UPNP_MAPPING_LIFETIME_MS / 2) as u64 * 1000u64,
renewal_attempts: 0,
});
// Succeeded, return the externally mapped port
Some(SocketAddr::new(ext_ip, mapped_port))
}, None)
.await
}
#[instrument(
level = "trace",
target = "net",
name = "IGDManager::tick",
skip_all,
err
)]
pub async fn tick(&self) -> EyreResult<bool> {
// Refresh mappings if we have them
// If an error is received, then return false to restart the local network
let mut full_renews: Vec<(PortMapKey, PortMapValue)> = Vec::new();
let mut renews: Vec<(PortMapKey, PortMapValue)> = Vec::new();
{
let inner = self.inner.lock();
let now = get_timestamp();
for (k, v) in &inner.port_maps {
let mapping_lifetime = now.saturating_sub(v.timestamp);
if mapping_lifetime >= UPNP_MAPPING_LIFETIME_US
|| v.renewal_attempts >= UPNP_MAPPING_ATTEMPTS
{
// Past expiration time or tried N times, do a full renew and fail out if we can't
full_renews.push((*k, *v));
} else if mapping_lifetime >= v.renewal_lifetime {
// Attempt a normal renewal
renews.push((*k, *v));
}
}
// See if we need to do some blocking operations
if full_renews.is_empty() && renews.is_empty() {
// Just return now since there's nothing to renew
return Ok(true);
}
}
let this = self.clone();
blocking_wrapper(
"igd tick",
move || {
let mut inner = this.inner.lock();
// Process full renewals
for (k, v) in full_renews {
// Get local ip for address type
let local_ip = match Self::get_local_ip(&mut inner, k.address_type) {
Some(ip) => ip,
None => {
return Err(eyre!("local ip missing for address type"));
}
};
// Get gateway for interface
let gw = match Self::get_gateway(&mut inner, local_ip) {
Some(gw) => gw,
None => {
return Err(eyre!("gateway missing for interface"));
}
};
// Delete the mapping if it exists, ignore any errors here
let _ = gw.remove_port(convert_protocol_type(k.protocol_type), v.mapped_port);
inner.port_maps.remove(&k);
let desc = this.get_description(k.protocol_type, k.local_port);
match gw.add_any_port(
convert_protocol_type(k.protocol_type),
SocketAddr::new(local_ip, k.local_port),
(UPNP_MAPPING_LIFETIME_MS + 999) / 1000,
&desc,
) {
Ok(mapped_port) => {
log_net!(debug "full-renewed mapped port {:?} -> {:?}", v, k);
inner.port_maps.insert(
k,
PortMapValue {
ext_ip: v.ext_ip,
mapped_port,
timestamp: get_timestamp(),
renewal_lifetime: (UPNP_MAPPING_LIFETIME_MS / 2) as u64
* 1000u64,
renewal_attempts: 0,
},
);
}
Err(e) => {
info!("failed to full-renew mapped port {:?} -> {:?}: {}", v, k, e);
// Must restart network now :(
return Ok(false);
}
};
}
// Process normal renewals
for (k, mut v) in renews {
// Get local ip for address type
let local_ip = match Self::get_local_ip(&mut inner, k.address_type) {
Some(ip) => ip,
None => {
return Err(eyre!("local ip missing for address type"));
}
};
// Get gateway for interface
let gw = match Self::get_gateway(&mut inner, local_ip) {
Some(gw) => gw,
None => {
return Err(eyre!("gateway missing for address type"));
}
};
let desc = this.get_description(k.protocol_type, k.local_port);
match gw.add_port(
convert_protocol_type(k.protocol_type),
v.mapped_port,
SocketAddr::new(local_ip, k.local_port),
(UPNP_MAPPING_LIFETIME_MS + 999) / 1000,
&desc,
) {
Ok(()) => {
log_net!("renewed mapped port {:?} -> {:?}", v, k);
inner.port_maps.insert(
k,
PortMapValue {
ext_ip: v.ext_ip,
mapped_port: v.mapped_port,
timestamp: get_timestamp(),
renewal_lifetime: (UPNP_MAPPING_LIFETIME_MS / 2) as u64
* 1000u64,
renewal_attempts: 0,
},
);
}
Err(e) => {
log_net!(debug "failed to renew mapped port {:?} -> {:?}: {}", v, k, e);
// Get closer to the maximum renewal timeline by a factor of two each time
v.renewal_lifetime =
(v.renewal_lifetime + UPNP_MAPPING_LIFETIME_US) / 2u64;
v.renewal_attempts += 1;
// Store new value to try again
inner.port_maps.insert(k, v);
}
};
}
// Normal exit, no restart
Ok(true)
},
Err(eyre!("failed to process blocking task")),
)
.instrument(tracing::trace_span!("igd tick fut"))
.await
}
/////////////////////////////////////////////////////////////////////
// Private Implementation
#[instrument(level = "trace", target = "net", skip_all)]
fn get_routed_local_ip_address(address_type: IGDAddressType) -> Option<IpAddr> {
let socket = match UdpSocket::bind(match address_type { let socket = match UdpSocket::bind(match address_type {
AddressType::IPV4 => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), IGDAddressType::IPV4 => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
AddressType::IPV6 => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0), IGDAddressType::IPV6 => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
}) { }) {
Ok(s) => s, Ok(s) => s,
Err(e) => { Err(e) => {
@ -75,8 +401,8 @@ impl IGDManager {
// using google's dns, but it wont actually send any packets to it // using google's dns, but it wont actually send any packets to it
socket socket
.connect(match address_type { .connect(match address_type {
AddressType::IPV4 => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 80), IGDAddressType::IPV4 => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 80),
AddressType::IPV6 => SocketAddr::new( IGDAddressType::IPV6 => SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888)), IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888)),
80, 80,
), ),
@ -91,7 +417,7 @@ impl IGDManager {
} }
#[instrument(level = "trace", target = "net", skip_all)] #[instrument(level = "trace", target = "net", skip_all)]
fn find_local_ip(inner: &mut IGDManagerInner, address_type: AddressType) -> Option<IpAddr> { fn find_local_ip(inner: &mut IGDManagerInner, address_type: IGDAddressType) -> Option<IpAddr> {
if let Some(ip) = inner.local_ip_addrs.get(&address_type) { if let Some(ip) = inner.local_ip_addrs.get(&address_type) {
return Some(*ip); return Some(*ip);
} }
@ -109,7 +435,7 @@ impl IGDManager {
} }
#[instrument(level = "trace", target = "net", skip_all)] #[instrument(level = "trace", target = "net", skip_all)]
fn get_local_ip(inner: &mut IGDManagerInner, address_type: AddressType) -> Option<IpAddr> { fn get_local_ip(inner: &mut IGDManagerInner, address_type: IGDAddressType) -> Option<IpAddr> {
if let Some(ip) = inner.local_ip_addrs.get(&address_type) { if let Some(ip) = inner.local_ip_addrs.get(&address_type) {
return Some(*ip); return Some(*ip);
} }
@ -164,304 +490,10 @@ impl IGDManager {
None None
} }
fn get_description(&self, llpt: LowLevelProtocolType, local_port: u16) -> String { fn get_description(&self, protocol_type: IGDProtocolType, local_port: u16) -> String {
format!( format!(
"{} map {} for port {}", "{} map {} for port {}",
self.config.get().program_name, self.program_name, protocol_type, local_port
convert_llpt(llpt),
local_port
) )
} }
#[instrument(level = "trace", target = "net", skip_all)]
pub async fn unmap_port(
&self,
llpt: LowLevelProtocolType,
at: AddressType,
mapped_port: u16,
) -> Option<()> {
let this = self.clone();
blocking_wrapper(
"igd unmap_port",
move || {
let mut inner = this.inner.lock();
// If we already have this port mapped, just return the existing portmap
let mut found = None;
for (pmk, pmv) in &inner.port_maps {
if pmk.llpt == llpt && pmk.at == at && pmv.mapped_port == mapped_port {
found = Some(*pmk);
break;
}
}
let pmk = found?;
let _pmv = inner
.port_maps
.remove(&pmk)
.expect("key found but remove failed");
// Get local ip address
let local_ip = Self::find_local_ip(&mut inner, at)?;
// Find gateway
let gw = Self::find_gateway(&mut inner, local_ip)?;
// Unmap port
match gw.remove_port(convert_llpt(llpt), mapped_port) {
Ok(()) => (),
Err(e) => {
// Failed to map external port
log_net!(debug "upnp failed to remove external port: {}", e);
return None;
}
};
Some(())
},
None,
)
.await
}
#[instrument(level = "trace", target = "net", skip_all)]
pub async fn map_any_port(
&self,
llpt: LowLevelProtocolType,
at: AddressType,
local_port: u16,
expected_external_address: Option<IpAddr>,
) -> Option<SocketAddr> {
let this = self.clone();
blocking_wrapper("igd map_any_port", move || {
let mut inner = this.inner.lock();
// If we already have this port mapped, just return the existing portmap
let pmkey = PortMapKey {
llpt,
at,
local_port,
};
if let Some(pmval) = inner.port_maps.get(&pmkey) {
return Some(SocketAddr::new(pmval.ext_ip, pmval.mapped_port));
}
// Get local ip address
let local_ip = Self::find_local_ip(&mut inner, at)?;
// Find gateway
let gw = Self::find_gateway(&mut inner, local_ip)?;
// Get external address
let ext_ip = match gw.get_external_ip() {
Ok(ip) => ip,
Err(e) => {
log_net!(debug "couldn't get external ip from igd: {}", e);
return None;
}
};
// Ensure external IP matches address type
if ext_ip.is_ipv4() && at != AddressType::IPV4 {
log_net!(debug "mismatched ip address type from igd, wanted v4, got v6");
return None;
} else if ext_ip.is_ipv6() && at != AddressType::IPV6 {
log_net!(debug "mismatched ip address type from igd, wanted v6, got v4");
return None;
}
if let Some(expected_external_address) = expected_external_address {
if ext_ip != expected_external_address {
log_net!(debug "gateway external address does not match calculated external address: expected={} vs gateway={}", expected_external_address, ext_ip);
return None;
}
}
// Map any port
let desc = this.get_description(llpt, local_port);
let mapped_port = match gw.add_any_port(convert_llpt(llpt), SocketAddr::new(local_ip, local_port), (UPNP_MAPPING_LIFETIME_MS + 999) / 1000, &desc) {
Ok(mapped_port) => mapped_port,
Err(e) => {
// Failed to map external port
log_net!(debug "upnp failed to map external port: {}", e);
return None;
}
};
// Add to mapping list to keep alive
let timestamp = Timestamp::now();
inner.port_maps.insert(PortMapKey {
llpt,
at,
local_port,
}, PortMapValue {
ext_ip,
mapped_port,
timestamp,
renewal_lifetime: ((UPNP_MAPPING_LIFETIME_MS / 2) as u64 * 1000u64).into(),
renewal_attempts: 0,
});
// Succeeded, return the externally mapped port
Some(SocketAddr::new(ext_ip, mapped_port))
}, None)
.await
}
#[instrument(
level = "trace",
target = "net",
name = "IGDManager::tick",
skip_all,
err
)]
pub async fn tick(&self) -> EyreResult<bool> {
// Refresh mappings if we have them
// If an error is received, then return false to restart the local network
let mut full_renews: Vec<(PortMapKey, PortMapValue)> = Vec::new();
let mut renews: Vec<(PortMapKey, PortMapValue)> = Vec::new();
{
let inner = self.inner.lock();
let now = Timestamp::now();
for (k, v) in &inner.port_maps {
let mapping_lifetime = now.saturating_sub(v.timestamp);
if mapping_lifetime >= UPNP_MAPPING_LIFETIME_US
|| v.renewal_attempts >= UPNP_MAPPING_ATTEMPTS
{
// Past expiration time or tried N times, do a full renew and fail out if we can't
full_renews.push((*k, *v));
} else if mapping_lifetime >= v.renewal_lifetime {
// Attempt a normal renewal
renews.push((*k, *v));
}
}
// See if we need to do some blocking operations
if full_renews.is_empty() && renews.is_empty() {
// Just return now since there's nothing to renew
return Ok(true);
}
}
let this = self.clone();
blocking_wrapper(
"igd tick",
move || {
let mut inner = this.inner.lock();
// Process full renewals
for (k, v) in full_renews {
// Get local ip for address type
let local_ip = match Self::get_local_ip(&mut inner, k.at) {
Some(ip) => ip,
None => {
return Err(eyre!("local ip missing for address type"));
}
};
// Get gateway for interface
let gw = match Self::get_gateway(&mut inner, local_ip) {
Some(gw) => gw,
None => {
return Err(eyre!("gateway missing for interface"));
}
};
// Delete the mapping if it exists, ignore any errors here
let _ = gw.remove_port(convert_llpt(k.llpt), v.mapped_port);
inner.port_maps.remove(&k);
let desc = this.get_description(k.llpt, k.local_port);
match gw.add_any_port(
convert_llpt(k.llpt),
SocketAddr::new(local_ip, k.local_port),
(UPNP_MAPPING_LIFETIME_MS + 999) / 1000,
&desc,
) {
Ok(mapped_port) => {
log_net!(debug "full-renewed mapped port {:?} -> {:?}", v, k);
inner.port_maps.insert(
k,
PortMapValue {
ext_ip: v.ext_ip,
mapped_port,
timestamp: Timestamp::now(),
renewal_lifetime: TimestampDuration::new(
(UPNP_MAPPING_LIFETIME_MS / 2) as u64 * 1000u64,
),
renewal_attempts: 0,
},
);
}
Err(e) => {
info!("failed to full-renew mapped port {:?} -> {:?}: {}", v, k, e);
// Must restart network now :(
return Ok(false);
}
};
}
// Process normal renewals
for (k, mut v) in renews {
// Get local ip for address type
let local_ip = match Self::get_local_ip(&mut inner, k.at) {
Some(ip) => ip,
None => {
return Err(eyre!("local ip missing for address type"));
}
};
// Get gateway for interface
let gw = match Self::get_gateway(&mut inner, local_ip) {
Some(gw) => gw,
None => {
return Err(eyre!("gateway missing for address type"));
}
};
let desc = this.get_description(k.llpt, k.local_port);
match gw.add_port(
convert_llpt(k.llpt),
v.mapped_port,
SocketAddr::new(local_ip, k.local_port),
(UPNP_MAPPING_LIFETIME_MS + 999) / 1000,
&desc,
) {
Ok(()) => {
log_net!("renewed mapped port {:?} -> {:?}", v, k);
inner.port_maps.insert(
k,
PortMapValue {
ext_ip: v.ext_ip,
mapped_port: v.mapped_port,
timestamp: Timestamp::now(),
renewal_lifetime: ((UPNP_MAPPING_LIFETIME_MS / 2) as u64
* 1000u64)
.into(),
renewal_attempts: 0,
},
);
}
Err(e) => {
log_net!(debug "failed to renew mapped port {:?} -> {:?}: {}", v, k, e);
// Get closer to the maximum renewal timeline by a factor of two each time
v.renewal_lifetime =
(v.renewal_lifetime + UPNP_MAPPING_LIFETIME_US) / 2u64;
v.renewal_attempts += 1;
// Store new value to try again
inner.port_maps.insert(k, v);
}
};
}
// Normal exit, no restart
Ok(true)
},
Err(eyre!("failed to process blocking task")),
)
.instrument(tracing::trace_span!("igd tick fut"))
.await
}
} }

View File

@ -113,16 +113,13 @@ struct NetworkInner {
network_state: Option<NetworkState>, network_state: Option<NetworkState>,
} }
struct NetworkUnlockedInner { pub(super) struct NetworkUnlockedInner {
// Startup lock // Startup lock
startup_lock: StartupLock, startup_lock: StartupLock,
// Accessors
routing_table: RoutingTable,
network_manager: NetworkManager,
connection_manager: ConnectionManager,
// Network // Network
interfaces: NetworkInterfaces, interfaces: NetworkInterfaces,
// Background processes // Background processes
update_network_class_task: TickTask<EyreReport>, update_network_class_task: TickTask<EyreReport>,
network_interfaces_task: TickTask<EyreReport>, network_interfaces_task: TickTask<EyreReport>,
@ -135,11 +132,21 @@ struct NetworkUnlockedInner {
#[derive(Clone)] #[derive(Clone)]
pub(super) struct Network { pub(super) struct Network {
config: VeilidConfig, registry: VeilidComponentRegistry,
inner: Arc<Mutex<NetworkInner>>, inner: Arc<Mutex<NetworkInner>>,
unlocked_inner: Arc<NetworkUnlockedInner>, unlocked_inner: Arc<NetworkUnlockedInner>,
} }
impl_veilid_component_registry_accessor!(Network);
impl core::ops::Deref for Network {
type Target = NetworkUnlockedInner;
fn deref(&self) -> &Self::Target {
&self.unlocked_inner
}
}
impl Network { impl Network {
fn new_inner() -> NetworkInner { fn new_inner() -> NetworkInner {
NetworkInner { NetworkInner {
@ -161,17 +168,11 @@ impl Network {
} }
} }
fn new_unlocked_inner( fn new_unlocked_inner(registry: VeilidComponentRegistry) -> NetworkUnlockedInner {
network_manager: NetworkManager, let config = registry.config();
routing_table: RoutingTable, let program_name = config.get().program_name.clone();
connection_manager: ConnectionManager,
) -> NetworkUnlockedInner {
let config = network_manager.config();
NetworkUnlockedInner { NetworkUnlockedInner {
startup_lock: StartupLock::new(), startup_lock: StartupLock::new(),
network_manager,
routing_table,
connection_manager,
interfaces: NetworkInterfaces::new(), interfaces: NetworkInterfaces::new(),
update_network_class_task: TickTask::new( update_network_class_task: TickTask::new(
"update_network_class_task", "update_network_class_task",
@ -183,23 +184,15 @@ impl Network {
), ),
upnp_task: TickTask::new("upnp_task", UPNP_TASK_TICK_PERIOD_SECS), upnp_task: TickTask::new("upnp_task", UPNP_TASK_TICK_PERIOD_SECS),
network_task_lock: AsyncMutex::new(()), network_task_lock: AsyncMutex::new(()),
igd_manager: igd_manager::IGDManager::new(config.clone()), igd_manager: igd_manager::IGDManager::new(program_name),
} }
} }
pub fn new( pub fn new(registry: VeilidComponentRegistry) -> Self {
network_manager: NetworkManager,
routing_table: RoutingTable,
connection_manager: ConnectionManager,
) -> Self {
let this = Self { let this = Self {
config: network_manager.config(),
inner: Arc::new(Mutex::new(Self::new_inner())), inner: Arc::new(Mutex::new(Self::new_inner())),
unlocked_inner: Arc::new(Self::new_unlocked_inner( unlocked_inner: Arc::new(Self::new_unlocked_inner(registry.clone())),
network_manager, registry,
routing_table,
connection_manager,
)),
}; };
this.setup_tasks(); this.setup_tasks();
@ -207,18 +200,6 @@ impl Network {
this this
} }
fn network_manager(&self) -> NetworkManager {
self.unlocked_inner.network_manager.clone()
}
fn routing_table(&self) -> RoutingTable {
self.unlocked_inner.routing_table.clone()
}
fn connection_manager(&self) -> ConnectionManager {
self.unlocked_inner.connection_manager.clone()
}
fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> { fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> {
let cvec = certs(&mut BufReader::new(File::open(path)?)) let cvec = certs(&mut BufReader::new(File::open(path)?))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid TLS certificate"))?; .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid TLS certificate"))?;
@ -248,7 +229,8 @@ impl Network {
} }
fn load_server_config(&self) -> io::Result<ServerConfig> { fn load_server_config(&self) -> io::Result<ServerConfig> {
let c = self.config.get(); let config = self.config();
let c = config.get();
// //
log_net!( log_net!(
"loading certificate from {}", "loading certificate from {}",
@ -288,7 +270,10 @@ impl Network {
if !from.ip().is_unspecified() { if !from.ip().is_unspecified() {
vec![from] vec![from]
} else { } else {
let addrs = self.last_network_state().stable_interface_addresses; let addrs = self
.last_network_state()
.unwrap()
.stable_interface_addresses;
addrs addrs
.iter() .iter()
.filter_map(|a| { .filter_map(|a| {
@ -346,16 +331,15 @@ impl Network {
dial_info: DialInfo, dial_info: DialInfo,
data: Vec<u8>, data: Vec<u8>,
) -> EyreResult<NetworkResult<()>> { ) -> EyreResult<NetworkResult<()>> {
let _guard = self.unlocked_inner.startup_lock.enter()?; let _guard = self.startup_lock.enter()?;
self.record_dial_info_failure( self.record_dial_info_failure(
dial_info.clone(), dial_info.clone(),
async move { async move {
let data_len = data.len(); let data_len = data.len();
let connect_timeout_ms = { let connect_timeout_ms = self
let c = self.config.get(); .config()
c.network.connection_initial_timeout_ms .with(|c| c.network.connection_initial_timeout_ms);
};
if self if self
.network_manager() .network_manager()
@ -368,10 +352,12 @@ impl Network {
match dial_info.protocol_type() { match dial_info.protocol_type() {
ProtocolType::UDP => { ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr(); let peer_socket_addr = dial_info.to_socket_addr();
let h = let h = RawUdpProtocolHandler::new_unspecified_bound_handler(
RawUdpProtocolHandler::new_unspecified_bound_handler(&peer_socket_addr) self.registry(),
.await &peer_socket_addr,
.wrap_err("create socket failure")?; )
.await
.wrap_err("create socket failure")?;
let _ = network_result_try!(h let _ = network_result_try!(h
.send_message(data, peer_socket_addr) .send_message(data, peer_socket_addr)
.await .await
@ -423,16 +409,15 @@ impl Network {
data: Vec<u8>, data: Vec<u8>,
timeout_ms: u32, timeout_ms: u32,
) -> EyreResult<NetworkResult<Vec<u8>>> { ) -> EyreResult<NetworkResult<Vec<u8>>> {
let _guard = self.unlocked_inner.startup_lock.enter()?; let _guard = self.startup_lock.enter()?;
self.record_dial_info_failure( self.record_dial_info_failure(
dial_info.clone(), dial_info.clone(),
async move { async move {
let data_len = data.len(); let data_len = data.len();
let connect_timeout_ms = { let connect_timeout_ms = self
let c = self.config.get(); .config()
c.network.connection_initial_timeout_ms .with(|c| c.network.connection_initial_timeout_ms);
};
if self if self
.network_manager() .network_manager()
@ -445,10 +430,12 @@ impl Network {
match dial_info.protocol_type() { match dial_info.protocol_type() {
ProtocolType::UDP => { ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr(); let peer_socket_addr = dial_info.to_socket_addr();
let h = let h = RawUdpProtocolHandler::new_unspecified_bound_handler(
RawUdpProtocolHandler::new_unspecified_bound_handler(&peer_socket_addr) self.registry(),
.await &peer_socket_addr,
.wrap_err("create socket failure")?; )
.await
.wrap_err("create socket failure")?;
network_result_try!(h network_result_try!(h
.send_message(data, peer_socket_addr) .send_message(data, peer_socket_addr)
.await .await
@ -539,7 +526,7 @@ impl Network {
flow: Flow, flow: Flow,
data: Vec<u8>, data: Vec<u8>,
) -> EyreResult<SendDataToExistingFlowResult> { ) -> EyreResult<SendDataToExistingFlowResult> {
let _guard = self.unlocked_inner.startup_lock.enter()?; let _guard = self.startup_lock.enter()?;
let data_len = data.len(); let data_len = data.len();
@ -573,7 +560,11 @@ impl Network {
// Handle connection-oriented protocols // Handle connection-oriented protocols
// Try to send to the exact existing connection if one exists // Try to send to the exact existing connection if one exists
if let Some(conn) = self.connection_manager().get_connection(flow) { if let Some(conn) = self
.network_manager()
.connection_manager()
.get_connection(flow)
{
// connection exists, send over it // connection exists, send over it
match conn.send_async(data).await { match conn.send_async(data).await {
ConnectionHandleSendResult::Sent => { ConnectionHandleSendResult::Sent => {
@ -606,7 +597,7 @@ impl Network {
dial_info: DialInfo, dial_info: DialInfo,
data: Vec<u8>, data: Vec<u8>,
) -> EyreResult<NetworkResult<UniqueFlow>> { ) -> EyreResult<NetworkResult<UniqueFlow>> {
let _guard = self.unlocked_inner.startup_lock.enter()?; let _guard = self.startup_lock.enter()?;
self.record_dial_info_failure( self.record_dial_info_failure(
dial_info.clone(), dial_info.clone(),
@ -635,7 +626,8 @@ impl Network {
} else { } else {
// Handle connection-oriented protocols // Handle connection-oriented protocols
let conn = network_result_try!( let conn = network_result_try!(
self.connection_manager() self.network_manager()
.connection_manager()
.get_or_create_connection(dial_info.clone()) .get_or_create_connection(dial_info.clone())
.await? .await?
); );
@ -678,14 +670,9 @@ impl Network {
} }
// Start editing routing table // Start editing routing table
let mut editor_public_internet = self let routing_table = self.routing_table();
.unlocked_inner let mut editor_public_internet = routing_table.edit_public_internet_routing_domain();
.routing_table let mut editor_local_network = routing_table.edit_local_network_routing_domain();
.edit_public_internet_routing_domain();
let mut editor_local_network = self
.unlocked_inner
.routing_table
.edit_local_network_routing_domain();
// Setup network // Setup network
editor_local_network.set_local_networks(network_state.local_networks); editor_local_network.set_local_networks(network_state.local_networks);
@ -763,8 +750,8 @@ impl Network {
#[instrument(level = "debug", err, skip_all)] #[instrument(level = "debug", err, skip_all)]
pub(super) async fn register_all_dial_info( pub(super) async fn register_all_dial_info(
&self, &self,
editor_public_internet: &mut RoutingDomainEditorPublicInternet, editor_public_internet: &mut RoutingDomainEditorPublicInternet<'_>,
editor_local_network: &mut RoutingDomainEditorLocalNetwork, editor_local_network: &mut RoutingDomainEditorLocalNetwork<'_>,
) -> EyreResult<()> { ) -> EyreResult<()> {
let Some(protocol_config) = ({ let Some(protocol_config) = ({
let inner = self.inner.lock(); let inner = self.inner.lock();
@ -798,7 +785,7 @@ impl Network {
#[instrument(level = "debug", err, skip_all)] #[instrument(level = "debug", err, skip_all)]
pub async fn startup(&self) -> EyreResult<StartupDisposition> { pub async fn startup(&self) -> EyreResult<StartupDisposition> {
let guard = self.unlocked_inner.startup_lock.startup()?; let guard = self.startup_lock.startup()?;
match self.startup_internal().await { match self.startup_internal().await {
Ok(StartupDisposition::Success) => { Ok(StartupDisposition::Success) => {
@ -824,7 +811,7 @@ impl Network {
} }
pub fn is_started(&self) -> bool { pub fn is_started(&self) -> bool {
self.unlocked_inner.startup_lock.is_started() self.startup_lock.is_started()
} }
#[instrument(level = "debug", skip_all)] #[instrument(level = "debug", skip_all)]
@ -836,12 +823,6 @@ impl Network {
async fn shutdown_internal(&self) { async fn shutdown_internal(&self) {
let routing_table = self.routing_table(); let routing_table = self.routing_table();
// Stop all tasks
log_net!(debug "stopping update network class task");
if let Err(e) = self.unlocked_inner.update_network_class_task.stop().await {
error!("update_network_class_task not cancelled: {}", e);
}
let mut unord = FuturesUnordered::new(); let mut unord = FuturesUnordered::new();
{ {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
@ -876,7 +857,7 @@ impl Network {
#[instrument(level = "debug", skip_all)] #[instrument(level = "debug", skip_all)]
pub async fn shutdown(&self) { pub async fn shutdown(&self) {
log_net!(debug "starting low level network shutdown"); log_net!(debug "starting low level network shutdown");
let Ok(guard) = self.unlocked_inner.startup_lock.shutdown().await else { let Ok(guard) = self.startup_lock.shutdown().await else {
log_net!(debug "low level network is already shut down"); log_net!(debug "low level network is already shut down");
return; return;
}; };
@ -892,7 +873,7 @@ impl Network {
&self, &self,
punishment: Option<Box<dyn FnOnce() + Send + 'static>>, punishment: Option<Box<dyn FnOnce() + Send + 'static>>,
) { ) {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_lock.enter() else {
log_net!(debug "ignoring due to not started up"); log_net!(debug "ignoring due to not started up");
return; return;
}; };
@ -902,7 +883,7 @@ impl Network {
} }
pub fn needs_public_dial_info_check(&self) -> bool { pub fn needs_public_dial_info_check(&self) -> bool {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_lock.enter() else {
log_net!(debug "ignoring due to not started up"); log_net!(debug "ignoring due to not started up");
return false; return false;
}; };

View File

@ -28,7 +28,7 @@ pub(super) struct NetworkState {
impl Network { impl Network {
fn make_stable_interface_addresses(&self) -> Vec<IpAddr> { fn make_stable_interface_addresses(&self) -> Vec<IpAddr> {
let addrs = self.unlocked_inner.interfaces.stable_addresses(); let addrs = self.interfaces.stable_addresses();
let mut addrs: Vec<IpAddr> = addrs let mut addrs: Vec<IpAddr> = addrs
.into_iter() .into_iter()
.filter(|addr| { .filter(|addr| {
@ -41,8 +41,8 @@ impl Network {
addrs addrs
} }
pub(super) fn last_network_state(&self) -> NetworkState { pub(super) fn last_network_state(&self) -> Option<NetworkState> {
self.inner.lock().network_state.clone().unwrap() self.inner.lock().network_state.clone()
} }
pub(super) fn is_stable_interface_address(&self, addr: IpAddr) -> bool { pub(super) fn is_stable_interface_address(&self, addr: IpAddr) -> bool {
@ -57,8 +57,7 @@ impl Network {
pub(super) async fn make_network_state(&self) -> EyreResult<NetworkState> { pub(super) async fn make_network_state(&self) -> EyreResult<NetworkState> {
// refresh network interfaces // refresh network interfaces
self.unlocked_inner self.interfaces
.interfaces
.refresh() .refresh()
.await .await
.wrap_err("failed to refresh network interfaces")?; .wrap_err("failed to refresh network interfaces")?;
@ -66,22 +65,20 @@ impl Network {
// build the set of networks we should consider for the 'LocalNetwork' routing domain // build the set of networks we should consider for the 'LocalNetwork' routing domain
let mut local_networks: HashSet<(IpAddr, IpAddr)> = HashSet::new(); let mut local_networks: HashSet<(IpAddr, IpAddr)> = HashSet::new();
self.unlocked_inner self.interfaces.with_interfaces(|interfaces| {
.interfaces for intf in interfaces.values() {
.with_interfaces(|interfaces| { // Skip networks that we should never encounter
for intf in interfaces.values() { if intf.is_loopback() || !intf.is_running() {
// Skip networks that we should never encounter continue;
if intf.is_loopback() || !intf.is_running() {
continue;
}
// Add network to local networks table
for addr in &intf.addrs {
let netmask = addr.if_addr().netmask();
let network_ip = ipaddr_apply_netmask(addr.if_addr().ip(), netmask);
local_networks.insert((network_ip, netmask));
}
} }
}); // Add network to local networks table
for addr in &intf.addrs {
let netmask = addr.if_addr().netmask();
let network_ip = ipaddr_apply_netmask(addr.if_addr().ip(), netmask);
local_networks.insert((network_ip, netmask));
}
}
});
let mut local_networks: Vec<(IpAddr, IpAddr)> = local_networks.into_iter().collect(); let mut local_networks: Vec<(IpAddr, IpAddr)> = local_networks.into_iter().collect();
local_networks.sort(); local_networks.sort();
@ -107,7 +104,8 @@ impl Network {
// Get protocol config // Get protocol config
let protocol_config = { let protocol_config = {
let c = self.config.get(); let config = self.config();
let c = config.get();
let mut inbound = ProtocolTypeSet::new(); let mut inbound = ProtocolTypeSet::new();
if c.network.protocol.udp.enabled { if c.network.protocol.udp.enabled {

View File

@ -1,6 +1,5 @@
use super::*; use super::*;
use async_tls::TlsAcceptor; use async_tls::TlsAcceptor;
use sockets::*;
use stop_token::future::FutureExt; use stop_token::future::FutureExt;
///////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////
@ -122,8 +121,11 @@ impl Network {
} }
}; };
// Check to see if it is punished // Check to see if it is punished
let address_filter = self.network_manager().address_filter(); if self
if address_filter.is_ip_addr_punished(peer_addr.ip()) { .network_manager()
.address_filter()
.is_ip_addr_punished(peer_addr.ip())
{
return; return;
} }
@ -135,39 +137,12 @@ impl Network {
} }
}; };
#[cfg(all(feature = "rt-async-std", unix))] if let Err(e) = set_tcp_stream_linger(&tcp_stream, Some(core::time::Duration::from_secs(0)))
{ {
// async-std does not directly support linger on TcpStream yet
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd};
if let Err(e) = unsafe {
let s = socket2::Socket::from_raw_fd(tcp_stream.as_raw_fd());
let res = s.set_linger(Some(core::time::Duration::from_secs(0)));
s.into_raw_fd();
res
} {
log_net!(debug "Couldn't set TCP linger: {}", e);
return;
}
}
#[cfg(all(feature = "rt-async-std", windows))]
{
// async-std does not directly support linger on TcpStream yet
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket};
if let Err(e) = unsafe {
let s = socket2::Socket::from_raw_socket(tcp_stream.as_raw_socket());
let res = s.set_linger(Some(core::time::Duration::from_secs(0)));
s.into_raw_socket();
res
} {
log_net!(debug "Couldn't set TCP linger: {}", e);
return;
}
}
#[cfg(not(feature = "rt-async-std"))]
if let Err(e) = tcp_stream.set_linger(Some(core::time::Duration::from_secs(0))) {
log_net!(debug "Couldn't set TCP linger: {}", e); log_net!(debug "Couldn't set TCP linger: {}", e);
return; return;
} }
if let Err(e) = tcp_stream.set_nodelay(true) { if let Err(e) = tcp_stream.set_nodelay(true) {
log_net!(debug "Couldn't set TCP nodelay: {}", e); log_net!(debug "Couldn't set TCP nodelay: {}", e);
return; return;
@ -249,49 +224,19 @@ impl Network {
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
async fn spawn_socket_listener(&self, addr: SocketAddr) -> EyreResult<bool> { async fn spawn_socket_listener(&self, addr: SocketAddr) -> EyreResult<bool> {
// Get config // Get config
let (connection_initial_timeout_ms, tls_connection_initial_timeout_ms) = { let (connection_initial_timeout_ms, tls_connection_initial_timeout_ms) =
let c = self.config.get(); self.config().with(|c| {
( (
c.network.connection_initial_timeout_ms, c.network.connection_initial_timeout_ms,
c.network.tls.connection_initial_timeout_ms, c.network.tls.connection_initial_timeout_ms,
) )
}; });
// Create a socket and bind it
let Some(socket) = new_bound_default_tcp_socket(addr)
.wrap_err("failed to create default socket listener")?
else {
return Ok(false);
};
// Drop the socket
drop(socket);
// Create a shared socket and bind it once we have determined the port is free // Create a shared socket and bind it once we have determined the port is free
let Some(socket) = new_bound_shared_tcp_socket(addr) let Some(listener) = bind_async_tcp_listener(addr)? else {
.wrap_err("failed to create shared socket listener")?
else {
return Ok(false); return Ok(false);
}; };
// Listen on the socket
if socket.listen(128).is_err() {
return Ok(false);
}
// Make an async tcplistener from the socket2 socket
let std_listener: std::net::TcpListener = socket.into();
cfg_if! {
if #[cfg(feature="rt-async-std")] {
let listener = TcpListener::from(std_listener);
} else if #[cfg(feature="rt-tokio")] {
std_listener.set_nonblocking(true).expect("failed to set nonblocking");
let listener = TcpListener::from_std(std_listener).wrap_err("failed to create tokio tcp listener")?;
} else {
compile_error!("needs executor implementation");
}
}
log_net!(debug "spawn_socket_listener: binding successful to {}", addr); log_net!(debug "spawn_socket_listener: binding successful to {}", addr);
// Create protocol handler records // Create protocol handler records
@ -304,22 +249,14 @@ impl Network {
// Spawn the socket task // Spawn the socket task
let this = self.clone(); let this = self.clone();
let stop_token = self.inner.lock().stop_source.as_ref().unwrap().token(); let stop_token = self.inner.lock().stop_source.as_ref().unwrap().token();
let connection_manager = self.connection_manager(); let connection_manager = self.network_manager().connection_manager();
//////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////
let jh = spawn(&format!("TCP listener {}", addr), async move { let jh = spawn(&format!("TCP listener {}", addr), async move {
// moves listener object in and get incoming iterator // moves listener object in and get incoming iterator
// when this task exists, the listener will close the socket // when this task exists, the listener will close the socket
cfg_if! { let incoming_stream = async_tcp_listener_incoming(listener);
if #[cfg(feature="rt-async-std")] {
let incoming_stream = listener.incoming();
} else if #[cfg(feature="rt-tokio")] {
let incoming_stream = tokio_stream::wrappers::TcpListenerStream::new(listener);
} else {
compile_error!("needs executor implementation");
}
}
let _ = incoming_stream let _ = incoming_stream
.for_each_concurrent(None, |tcp_stream| { .for_each_concurrent(None, |tcp_stream| {

View File

@ -1,15 +1,13 @@
use super::*; use super::*;
use sockets::*;
use stop_token::future::FutureExt; use stop_token::future::FutureExt;
impl Network { impl Network {
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub(super) async fn create_udp_listener_tasks(&self) -> EyreResult<()> { pub(super) async fn create_udp_listener_tasks(&self) -> EyreResult<()> {
// Spawn socket tasks // Spawn socket tasks
let mut task_count = { let mut task_count = self
let c = self.config.get(); .config()
c.network.protocol.udp.socket_pool_size .with(|c| c.network.protocol.udp.socket_pool_size);
};
if task_count == 0 { if task_count == 0 {
task_count = get_concurrency() / 2; task_count = get_concurrency() / 2;
if task_count == 0 { if task_count == 0 {
@ -38,7 +36,6 @@ impl Network {
// Spawn a local async task for each socket // Spawn a local async task for each socket
let mut protocol_handlers_unordered = FuturesUnordered::new(); let mut protocol_handlers_unordered = FuturesUnordered::new();
let network_manager = this.network_manager();
let stop_token = { let stop_token = {
let inner = this.inner.lock(); let inner = this.inner.lock();
if inner.stop_source.is_none() { if inner.stop_source.is_none() {
@ -49,7 +46,7 @@ impl Network {
}; };
for ph in protocol_handlers { for ph in protocol_handlers {
let network_manager = network_manager.clone(); let network_manager = this.network_manager();
let stop_token = stop_token.clone(); let stop_token = stop_token.clone();
let ph_future = async move { let ph_future = async move {
let mut data = vec![0u8; 65536]; let mut data = vec![0u8; 65536];
@ -114,28 +111,14 @@ impl Network {
async fn create_udp_protocol_handler(&self, addr: SocketAddr) -> EyreResult<bool> { async fn create_udp_protocol_handler(&self, addr: SocketAddr) -> EyreResult<bool> {
log_net!(debug "create_udp_protocol_handler on {:?}", &addr); log_net!(debug "create_udp_protocol_handler on {:?}", &addr);
// Create a reusable socket // Create a single-address-family UDP socket with default options bound to an address
let Some(socket) = new_bound_default_udp_socket(addr)? else { let Some(udp_socket) = bind_async_udp_socket(addr)? else {
return Ok(false); return Ok(false);
}; };
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
cfg_if! {
if #[cfg(feature="rt-async-std")] {
let udp_socket = UdpSocket::from(std_udp_socket);
} else if #[cfg(feature="rt-tokio")] {
std_udp_socket.set_nonblocking(true).expect("failed to set nonblocking");
let udp_socket = UdpSocket::from_std(std_udp_socket).wrap_err("failed to make inbound tokio udpsocket")?;
} else {
compile_error!("needs executor implementation");
}
}
let socket_arc = Arc::new(udp_socket); let socket_arc = Arc::new(udp_socket);
// Create protocol handler // Create protocol handler
let protocol_handler = let protocol_handler = RawUdpProtocolHandler::new(self.registry(), socket_arc);
RawUdpProtocolHandler::new(socket_arc, Some(self.network_manager().address_filter()));
// Record protocol handler // Record protocol handler
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();

View File

@ -1,4 +1,3 @@
pub mod sockets;
pub mod tcp; pub mod tcp;
pub mod udp; pub mod udp;
pub mod wrtc; pub mod wrtc;
@ -22,7 +21,7 @@ impl ProtocolNetworkConnection {
local_address: Option<SocketAddr>, local_address: Option<SocketAddr>,
dial_info: &DialInfo, dial_info: &DialInfo,
timeout_ms: u32, timeout_ms: u32,
address_filter: AddressFilter, address_filter: &AddressFilter,
) -> io::Result<NetworkResult<ProtocolNetworkConnection>> { ) -> io::Result<NetworkResult<ProtocolNetworkConnection>> {
if address_filter.is_ip_addr_punished(dial_info.address().ip_addr()) { if address_filter.is_ip_addr_punished(dial_info.address().ip_addr()) {
return Ok(NetworkResult::no_connection_other("punished")); return Ok(NetworkResult::no_connection_other("punished"));

View File

@ -1,191 +0,0 @@
use crate::*;
use async_io::Async;
use std::io;
cfg_if! {
if #[cfg(feature="rt-async-std")] {
pub use async_std::net::{TcpStream, TcpListener, UdpSocket};
} else if #[cfg(feature="rt-tokio")] {
pub use tokio::net::{TcpStream, TcpListener, UdpSocket};
pub use tokio_util::compat::*;
} else {
compile_error!("needs executor implementation");
}
}
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
// cfg_if! {
// if #[cfg(windows)] {
// use winapi::shared::ws2def::{ SOL_SOCKET, SO_EXCLUSIVEADDRUSE};
// use winapi::um::winsock2::{SOCKET_ERROR, setsockopt};
// use winapi::ctypes::c_int;
// use std::os::windows::io::AsRawSocket;
// fn set_exclusiveaddruse(socket: &Socket) -> io::Result<()> {
// unsafe {
// let optval:c_int = 1;
// if setsockopt(socket.as_raw_socket().try_into().unwrap(), SOL_SOCKET, SO_EXCLUSIVEADDRUSE, (&optval as *const c_int).cast(),
// std::mem::size_of::<c_int>() as c_int) == SOCKET_ERROR {
// return Err(io::Error::last_os_error());
// }
// Ok(())
// }
// }
// }
// }
#[instrument(level = "trace", ret)]
pub fn new_shared_udp_socket(domain: Domain) -> io::Result<Socket> {
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
socket.set_reuse_address(true)?;
cfg_if! {
if #[cfg(unix)] {
socket.set_reuse_port(true)?;
}
}
Ok(socket)
}
#[instrument(level = "trace", ret)]
pub fn new_default_udp_socket(domain: Domain) -> io::Result<Socket> {
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
Ok(socket)
}
#[instrument(level = "trace", ret)]
pub fn new_bound_default_udp_socket(local_address: SocketAddr) -> io::Result<Option<Socket>> {
let domain = Domain::for_address(local_address);
let socket = new_default_udp_socket(domain)?;
let socket2_addr = SockAddr::from(local_address);
if socket.bind(&socket2_addr).is_err() {
return Ok(None);
}
log_net!("created bound default udp socket on {:?}", &local_address);
Ok(Some(socket))
}
#[instrument(level = "trace", ret)]
pub fn new_default_tcp_socket(domain: Domain) -> io::Result<Socket> {
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) {
log_net!(error "Couldn't set TCP linger: {}", e);
}
if let Err(e) = socket.set_nodelay(true) {
log_net!(error "Couldn't set TCP nodelay: {}", e);
}
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
Ok(socket)
}
#[instrument(level = "trace", ret)]
pub fn new_shared_tcp_socket(domain: Domain) -> io::Result<Socket> {
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) {
log_net!(error "Couldn't set TCP linger: {}", e);
}
if let Err(e) = socket.set_nodelay(true) {
log_net!(error "Couldn't set TCP nodelay: {}", e);
}
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
socket.set_reuse_address(true)?;
cfg_if! {
if #[cfg(unix)] {
socket.set_reuse_port(true)?;
}
}
Ok(socket)
}
#[instrument(level = "trace", ret)]
pub fn new_bound_default_tcp_socket(local_address: SocketAddr) -> io::Result<Option<Socket>> {
let domain = Domain::for_address(local_address);
let socket = new_default_tcp_socket(domain)?;
let socket2_addr = SockAddr::from(local_address);
if socket.bind(&socket2_addr).is_err() {
return Ok(None);
}
log_net!("created bound default tcp socket on {:?}", &local_address);
Ok(Some(socket))
}
#[instrument(level = "trace", ret)]
pub fn new_bound_shared_tcp_socket(local_address: SocketAddr) -> io::Result<Option<Socket>> {
let domain = Domain::for_address(local_address);
let socket = new_shared_tcp_socket(domain)?;
let socket2_addr = SockAddr::from(local_address);
if socket.bind(&socket2_addr).is_err() {
return Ok(None);
}
log_net!("created bound shared tcp socket on {:?}", &local_address);
Ok(Some(socket))
}
// Non-blocking connect is tricky when you want to start with a prepared socket
// Errors should not be logged as they are valid conditions for this function
#[instrument(level = "trace", ret)]
pub async fn nonblocking_connect(
socket: Socket,
addr: SocketAddr,
timeout_ms: u32,
) -> io::Result<TimeoutOr<TcpStream>> {
// Set for non blocking connect
socket.set_nonblocking(true)?;
// Make socket2 SockAddr
let socket2_addr = socket2::SockAddr::from(addr);
// Connect to the remote address
match socket.connect(&socket2_addr) {
Ok(()) => Ok(()),
#[cfg(unix)]
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(()),
Err(e) => Err(e),
}?;
let async_stream = Async::new(std::net::TcpStream::from(socket))?;
// The stream becomes writable when connected
timeout_or_try!(
timeout(timeout_ms, async_stream.writable().in_current_span())
.await
.into_timeout_or()
.into_result()?
);
// Check low level error
let async_stream = match async_stream.get_ref().take_error()? {
None => Ok(async_stream),
Some(err) => Err(err),
}?;
// Convert back to inner and then return async version
cfg_if! {
if #[cfg(feature="rt-async-std")] {
Ok(TimeoutOr::value(TcpStream::from(async_stream.into_inner()?)))
} else if #[cfg(feature="rt-tokio")] {
Ok(TimeoutOr::value(TcpStream::from_std(async_stream.into_inner()?)?))
} else {
compile_error!("needs executor implementation");
}
}
}

View File

@ -1,6 +1,5 @@
use super::*; use super::*;
use futures_util::{AsyncReadExt, AsyncWriteExt}; use futures_util::{AsyncReadExt, AsyncWriteExt};
use sockets::*;
pub struct RawTcpNetworkConnection { pub struct RawTcpNetworkConnection {
flow: Flow, flow: Flow,
@ -157,32 +156,28 @@ impl RawTcpProtocolHandler {
#[instrument(level = "trace", target = "protocol", err)] #[instrument(level = "trace", target = "protocol", err)]
pub async fn connect( pub async fn connect(
local_address: Option<SocketAddr>, local_address: Option<SocketAddr>,
socket_addr: SocketAddr, remote_address: SocketAddr,
timeout_ms: u32, timeout_ms: u32,
) -> io::Result<NetworkResult<ProtocolNetworkConnection>> { ) -> io::Result<NetworkResult<ProtocolNetworkConnection>> {
// Make a shared socket
let socket = match local_address {
Some(a) => {
new_bound_shared_tcp_socket(a)?.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?
}
None => new_default_tcp_socket(socket2::Domain::for_address(socket_addr))?,
};
// Non-blocking connect to remote address // Non-blocking connect to remote address
let ts = network_result_try!(nonblocking_connect(socket, socket_addr, timeout_ms) let tcp_stream = network_result_try!(connect_async_tcp_stream(
.await local_address,
.folded()?); remote_address,
timeout_ms
)
.await
.folded()?);
// See what local address we ended up with and turn this into a stream // See what local address we ended up with and turn this into a stream
let actual_local_address = ts.local_addr()?; let actual_local_address = tcp_stream.local_addr()?;
#[cfg(feature = "rt-tokio")] #[cfg(feature = "rt-tokio")]
let ts = ts.compat(); let tcp_stream = tcp_stream.compat();
let ps = AsyncPeekStream::new(ts); let ps = AsyncPeekStream::new(tcp_stream);
// Wrap the stream in a network connection and return it // Wrap the stream in a network connection and return it
let flow = Flow::new( let flow = Flow::new(
PeerAddress::new( PeerAddress::new(
SocketAddress::from_socket_addr(socket_addr), SocketAddress::from_socket_addr(remote_address),
ProtocolType::TCP, ProtocolType::TCP,
), ),
SocketAddress::from_socket_addr(actual_local_address), SocketAddress::from_socket_addr(actual_local_address),

View File

@ -1,19 +1,20 @@
use super::*; use super::*;
use sockets::*;
#[derive(Clone)] #[derive(Clone)]
pub struct RawUdpProtocolHandler { pub struct RawUdpProtocolHandler {
registry: VeilidComponentRegistry,
socket: Arc<UdpSocket>, socket: Arc<UdpSocket>,
assembly_buffer: AssemblyBuffer, assembly_buffer: AssemblyBuffer,
address_filter: Option<AddressFilter>,
} }
impl_veilid_component_registry_accessor!(RawUdpProtocolHandler);
impl RawUdpProtocolHandler { impl RawUdpProtocolHandler {
pub fn new(socket: Arc<UdpSocket>, address_filter: Option<AddressFilter>) -> Self { pub fn new(registry: VeilidComponentRegistry, socket: Arc<UdpSocket>) -> Self {
Self { Self {
registry,
socket, socket,
assembly_buffer: AssemblyBuffer::new(), assembly_buffer: AssemblyBuffer::new(),
address_filter,
} }
} }
@ -24,10 +25,12 @@ impl RawUdpProtocolHandler {
let (size, remote_addr) = network_result_value_or_log!(self.socket.recv_from(data).await.into_network_result()? => continue); let (size, remote_addr) = network_result_value_or_log!(self.socket.recv_from(data).await.into_network_result()? => continue);
// Check to see if it is punished // Check to see if it is punished
if let Some(af) = self.address_filter.as_ref() { if self
if af.is_ip_addr_punished(remote_addr.ip()) { .network_manager()
continue; .address_filter()
} .is_ip_addr_punished(remote_addr.ip())
{
continue;
} }
// Insert into assembly buffer // Insert into assembly buffer
@ -91,10 +94,12 @@ impl RawUdpProtocolHandler {
} }
// Check to see if it is punished // Check to see if it is punished
if let Some(af) = self.address_filter.as_ref() { if self
if af.is_ip_addr_punished(remote_addr.ip()) { .network_manager()
return Ok(NetworkResult::no_connection_other("punished")); .address_filter()
} .is_ip_addr_punished(remote_addr.ip())
{
return Ok(NetworkResult::no_connection_other("punished"));
} }
// Fragment and send // Fragment and send
@ -137,11 +142,13 @@ impl RawUdpProtocolHandler {
#[instrument(level = "trace", target = "protocol", err)] #[instrument(level = "trace", target = "protocol", err)]
pub async fn new_unspecified_bound_handler( pub async fn new_unspecified_bound_handler(
registry: VeilidComponentRegistry,
socket_addr: &SocketAddr, socket_addr: &SocketAddr,
) -> io::Result<RawUdpProtocolHandler> { ) -> io::Result<RawUdpProtocolHandler> {
// get local wildcard address for bind // get local wildcard address for bind
let local_socket_addr = compatible_unspecified_socket_addr(socket_addr); let local_socket_addr = compatible_unspecified_socket_addr(socket_addr);
let socket = UdpSocket::bind(local_socket_addr).await?; let socket = bind_async_udp_socket(local_socket_addr)?
Ok(RawUdpProtocolHandler::new(Arc::new(socket), None)) .ok_or(io::Error::from(io::ErrorKind::AddrInUse))?;
Ok(RawUdpProtocolHandler::new(registry, Arc::new(socket)))
} }
} }

View File

@ -10,7 +10,6 @@ use async_tungstenite::tungstenite::protocol::{frame::coding::CloseCode, CloseFr
use async_tungstenite::tungstenite::Error; use async_tungstenite::tungstenite::Error;
use async_tungstenite::{accept_hdr_async, client_async, WebSocketStream}; use async_tungstenite::{accept_hdr_async, client_async, WebSocketStream};
use futures_util::{AsyncRead, AsyncWrite, SinkExt}; use futures_util::{AsyncRead, AsyncWrite, SinkExt};
use sockets::*;
// Maximum number of websocket request headers to permit // Maximum number of websocket request headers to permit
const MAX_WS_HEADERS: usize = 24; const MAX_WS_HEADERS: usize = 24;
@ -316,21 +315,16 @@ impl WebsocketProtocolHandler {
let domain = split_url.host.clone(); let domain = split_url.host.clone();
// Resolve remote address // Resolve remote address
let remote_socket_addr = dial_info.to_socket_addr(); let remote_address = dial_info.to_socket_addr();
// Make a shared socket
let socket = match local_address {
Some(a) => {
new_bound_shared_tcp_socket(a)?.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?
}
None => new_default_tcp_socket(socket2::Domain::for_address(remote_socket_addr))?,
};
// Non-blocking connect to remote address // Non-blocking connect to remote address
let tcp_stream = let tcp_stream = network_result_try!(connect_async_tcp_stream(
network_result_try!(nonblocking_connect(socket, remote_socket_addr, timeout_ms) local_address,
.await remote_address,
.folded()?); timeout_ms
)
.await
.folded()?);
// See what local address we ended up with // See what local address we ended up with
let actual_local_addr = tcp_stream.local_addr()?; let actual_local_addr = tcp_stream.local_addr()?;

View File

@ -140,14 +140,13 @@ impl Network {
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub(super) async fn bind_udp_protocol_handlers(&self) -> EyreResult<StartupDisposition> { pub(super) async fn bind_udp_protocol_handlers(&self) -> EyreResult<StartupDisposition> {
log_net!("UDP: binding protocol handlers"); log_net!("UDP: binding protocol handlers");
let (listen_address, public_address, detect_address_changes) = { let (listen_address, public_address, detect_address_changes) = self.config().with(|c| {
let c = self.config.get();
( (
c.network.protocol.udp.listen_address.clone(), c.network.protocol.udp.listen_address.clone(),
c.network.protocol.udp.public_address.clone(), c.network.protocol.udp.public_address.clone(),
c.network.detect_address_changes, c.network.detect_address_changes,
) )
}; });
// Get the binding parameters from the user-specified listen address // Get the binding parameters from the user-specified listen address
let bind_set = self let bind_set = self
@ -187,18 +186,17 @@ impl Network {
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub(super) async fn register_udp_dial_info( pub(super) async fn register_udp_dial_info(
&self, &self,
editor_public_internet: &mut RoutingDomainEditorPublicInternet, editor_public_internet: &mut RoutingDomainEditorPublicInternet<'_>,
editor_local_network: &mut RoutingDomainEditorLocalNetwork, editor_local_network: &mut RoutingDomainEditorLocalNetwork<'_>,
) -> EyreResult<()> { ) -> EyreResult<()> {
log_net!("UDP: registering dial info"); log_net!("UDP: registering dial info");
let (public_address, detect_address_changes) = { let (public_address, detect_address_changes) = self.config().with(|c| {
let c = self.config.get();
( (
c.network.protocol.udp.public_address.clone(), c.network.protocol.udp.public_address.clone(),
c.network.detect_address_changes, c.network.detect_address_changes,
) )
}; });
let local_dial_info_list = { let local_dial_info_list = {
let mut out = vec![]; let mut out = vec![];
@ -263,14 +261,13 @@ impl Network {
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub(super) async fn start_ws_listeners(&self) -> EyreResult<StartupDisposition> { pub(super) async fn start_ws_listeners(&self) -> EyreResult<StartupDisposition> {
log_net!("WS: binding protocol handlers"); log_net!("WS: binding protocol handlers");
let (listen_address, url, detect_address_changes) = { let (listen_address, url, detect_address_changes) = self.config().with(|c| {
let c = self.config.get();
( (
c.network.protocol.ws.listen_address.clone(), c.network.protocol.ws.listen_address.clone(),
c.network.protocol.ws.url.clone(), c.network.protocol.ws.url.clone(),
c.network.detect_address_changes, c.network.detect_address_changes,
) )
}; });
// Get the binding parameters from the user-specified listen address // Get the binding parameters from the user-specified listen address
let bind_set = self let bind_set = self
@ -313,18 +310,17 @@ impl Network {
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub(super) async fn register_ws_dial_info( pub(super) async fn register_ws_dial_info(
&self, &self,
editor_public_internet: &mut RoutingDomainEditorPublicInternet, editor_public_internet: &mut RoutingDomainEditorPublicInternet<'_>,
editor_local_network: &mut RoutingDomainEditorLocalNetwork, editor_local_network: &mut RoutingDomainEditorLocalNetwork<'_>,
) -> EyreResult<()> { ) -> EyreResult<()> {
log_net!("WS: registering dial info"); log_net!("WS: registering dial info");
let (url, path, detect_address_changes) = { let (url, path, detect_address_changes) = self.config().with(|c| {
let c = self.config.get();
( (
c.network.protocol.ws.url.clone(), c.network.protocol.ws.url.clone(),
c.network.protocol.ws.path.clone(), c.network.protocol.ws.path.clone(),
c.network.detect_address_changes, c.network.detect_address_changes,
) )
}; });
let mut registered_addresses: HashSet<IpAddr> = HashSet::new(); let mut registered_addresses: HashSet<IpAddr> = HashSet::new();
@ -409,14 +405,13 @@ impl Network {
pub(super) async fn start_wss_listeners(&self) -> EyreResult<StartupDisposition> { pub(super) async fn start_wss_listeners(&self) -> EyreResult<StartupDisposition> {
log_net!("WSS: binding protocol handlers"); log_net!("WSS: binding protocol handlers");
let (listen_address, url, detect_address_changes) = { let (listen_address, url, detect_address_changes) = self.config().with(|c| {
let c = self.config.get();
( (
c.network.protocol.wss.listen_address.clone(), c.network.protocol.wss.listen_address.clone(),
c.network.protocol.wss.url.clone(), c.network.protocol.wss.url.clone(),
c.network.detect_address_changes, c.network.detect_address_changes,
) )
}; });
// Get the binding parameters from the user-specified listen address // Get the binding parameters from the user-specified listen address
let bind_set = self let bind_set = self
@ -460,18 +455,17 @@ impl Network {
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub(super) async fn register_wss_dial_info( pub(super) async fn register_wss_dial_info(
&self, &self,
editor_public_internet: &mut RoutingDomainEditorPublicInternet, editor_public_internet: &mut RoutingDomainEditorPublicInternet<'_>,
editor_local_network: &mut RoutingDomainEditorLocalNetwork, editor_local_network: &mut RoutingDomainEditorLocalNetwork<'_>,
) -> EyreResult<()> { ) -> EyreResult<()> {
log_net!("WSS: registering dialinfo"); log_net!("WSS: registering dialinfo");
let (url, _detect_address_changes) = { let (url, _detect_address_changes) = self.config().with(|c| {
let c = self.config.get();
( (
c.network.protocol.wss.url.clone(), c.network.protocol.wss.url.clone(),
c.network.detect_address_changes, c.network.detect_address_changes,
) )
}; });
// NOTE: No interface dial info for WSS, as there is no way to connect to a local dialinfo via TLS // NOTE: No interface dial info for WSS, as there is no way to connect to a local dialinfo via TLS
// If the hostname is specified, it is the public dialinfo via the URL. If no hostname // If the hostname is specified, it is the public dialinfo via the URL. If no hostname
@ -520,14 +514,13 @@ impl Network {
pub(super) async fn start_tcp_listeners(&self) -> EyreResult<StartupDisposition> { pub(super) async fn start_tcp_listeners(&self) -> EyreResult<StartupDisposition> {
log_net!("TCP: binding protocol handlers"); log_net!("TCP: binding protocol handlers");
let (listen_address, public_address, detect_address_changes) = { let (listen_address, public_address, detect_address_changes) = self.config().with(|c| {
let c = self.config.get();
( (
c.network.protocol.tcp.listen_address.clone(), c.network.protocol.tcp.listen_address.clone(),
c.network.protocol.tcp.public_address.clone(), c.network.protocol.tcp.public_address.clone(),
c.network.detect_address_changes, c.network.detect_address_changes,
) )
}; });
// Get the binding parameters from the user-specified listen address // Get the binding parameters from the user-specified listen address
let bind_set = self let bind_set = self
@ -570,18 +563,17 @@ impl Network {
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub(super) async fn register_tcp_dial_info( pub(super) async fn register_tcp_dial_info(
&self, &self,
editor_public_internet: &mut RoutingDomainEditorPublicInternet, editor_public_internet: &mut RoutingDomainEditorPublicInternet<'_>,
editor_local_network: &mut RoutingDomainEditorLocalNetwork, editor_local_network: &mut RoutingDomainEditorLocalNetwork<'_>,
) -> EyreResult<()> { ) -> EyreResult<()> {
log_net!("TCP: registering dialinfo"); log_net!("TCP: registering dialinfo");
let (public_address, detect_address_changes) = { let (public_address, detect_address_changes) = self.config().with(|c| {
let c = self.config.get();
( (
c.network.protocol.tcp.public_address.clone(), c.network.protocol.tcp.public_address.clone(),
c.network.detect_address_changes, c.network.detect_address_changes,
) )
}; });
let mut registered_addresses: HashSet<IpAddr> = HashSet::new(); let mut registered_addresses: HashSet<IpAddr> = HashSet::new();

View File

@ -7,46 +7,41 @@ use super::*;
impl Network { impl Network {
pub fn setup_tasks(&self) { pub fn setup_tasks(&self) {
// Set update network class tick task // Set update network class tick task
{ let this = self.clone();
let this = self.clone(); self.update_network_class_task.set_routine(move |s, l, t| {
self.unlocked_inner let this = this.clone();
.update_network_class_task Box::pin(async move {
.set_routine(move |s, l, t| { this.update_network_class_task_routine(s, Timestamp::new(l), Timestamp::new(t))
Box::pin(this.clone().update_network_class_task_routine( .await
s, })
Timestamp::new(l), });
Timestamp::new(t),
))
});
}
// Set network interfaces tick task // Set network interfaces tick task
{ let this = self.clone();
let this = self.clone(); self.network_interfaces_task.set_routine(move |s, l, t| {
self.unlocked_inner let this = this.clone();
.network_interfaces_task Box::pin(async move {
.set_routine(move |s, l, t| { this.network_interfaces_task_routine(s, Timestamp::new(l), Timestamp::new(t))
Box::pin(this.clone().network_interfaces_task_routine( .await
s, })
Timestamp::new(l), });
Timestamp::new(t),
))
});
}
// Set upnp tick task // Set upnp tick task
{ {
let this = self.clone(); let this = self.clone();
self.unlocked_inner.upnp_task.set_routine(move |s, l, t| { self.upnp_task.set_routine(move |s, l, t| {
Box::pin( let this = this.clone();
this.clone() Box::pin(async move {
.upnp_task_routine(s, Timestamp::new(l), Timestamp::new(t)), this.upnp_task_routine(s, Timestamp::new(l), Timestamp::new(t))
) .await
})
}); });
} }
} }
#[instrument(level = "trace", target = "net", name = "Network::tick", skip_all, err)] #[instrument(level = "trace", target = "net", name = "Network::tick", skip_all, err)]
pub async fn tick(&self) -> EyreResult<()> { pub async fn tick(&self) -> EyreResult<()> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_lock.enter() else {
log_net!(debug "ignoring due to not started up"); log_net!(debug "ignoring due to not started up");
return Ok(()); return Ok(());
}; };
@ -65,7 +60,7 @@ impl Network {
// If we need to figure out our network class, tick the task for it // If we need to figure out our network class, tick the task for it
if detect_address_changes { if detect_address_changes {
// Check our network interfaces to see if they have changed // Check our network interfaces to see if they have changed
self.unlocked_inner.network_interfaces_task.tick().await?; self.network_interfaces_task.tick().await?;
// Check our public dial info to see if it has changed // Check our public dial info to see if it has changed
let public_internet_network_class = self let public_internet_network_class = self
@ -95,16 +90,31 @@ impl Network {
} }
if has_at_least_two { if has_at_least_two {
self.unlocked_inner.update_network_class_task.tick().await?; self.update_network_class_task.tick().await?;
} }
} }
} }
// If we need to tick upnp, do it // If we need to tick upnp, do it
if upnp { if upnp {
self.unlocked_inner.upnp_task.tick().await?; self.upnp_task.tick().await?;
} }
Ok(()) Ok(())
} }
pub async fn cancel_tasks(&self) {
log_net!(debug "stopping upnp task");
if let Err(e) = self.upnp_task.stop().await {
warn!("upnp_task not stopped: {}", e);
}
log_net!(debug "stopping network interfaces task");
if let Err(e) = self.network_interfaces_task.stop().await {
warn!("network_interfaces_task not stopped: {}", e);
}
log_net!(debug "stopping update network class task");
if let Err(e) = self.update_network_class_task.stop().await {
warn!("update_network_class_task not stopped: {}", e);
}
}
} }

View File

@ -3,20 +3,28 @@ use super::*;
impl Network { impl Network {
#[instrument(level = "trace", target = "net", skip_all, err)] #[instrument(level = "trace", target = "net", skip_all, err)]
pub(super) async fn network_interfaces_task_routine( pub(super) async fn network_interfaces_task_routine(
self, &self,
_stop_token: StopToken, stop_token: StopToken,
_l: Timestamp, _l: Timestamp,
_t: Timestamp, _t: Timestamp,
) -> EyreResult<()> { ) -> EyreResult<()> {
let _guard = self.unlocked_inner.network_task_lock.lock().await; // Network lock ensures only one task operating on the low level network state
// can happen at the same time.
let _guard = match self.network_task_lock.try_lock() {
Ok(v) => v,
Err(_) => {
// If we can't get the lock right now, then
return Ok(());
}
};
self.update_network_state().await?; self.update_network_state(stop_token).await?;
Ok(()) Ok(())
} }
// See if our interface addresses have changed, if so redo public dial info if necessary // See if our interface addresses have changed, if so redo public dial info if necessary
async fn update_network_state(&self) -> EyreResult<bool> { async fn update_network_state(&self, _stop_token: StopToken) -> EyreResult<bool> {
let mut local_network_changed = false; let mut local_network_changed = false;
let mut public_internet_changed = false; let mut public_internet_changed = false;
@ -29,7 +37,7 @@ impl Network {
} }
}; };
if new_network_state != last_network_state { if last_network_state.is_none() || new_network_state != last_network_state.unwrap() {
// Save new network state // Save new network state
{ {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
@ -37,17 +45,13 @@ impl Network {
} }
// network state has changed // network state has changed
let mut editor_local_network = self let routing_table = self.routing_table();
.unlocked_inner
.routing_table let mut editor_local_network = routing_table.edit_local_network_routing_domain();
.edit_local_network_routing_domain();
editor_local_network.set_local_networks(new_network_state.local_networks); editor_local_network.set_local_networks(new_network_state.local_networks);
editor_local_network.clear_dial_info_details(None, None); editor_local_network.clear_dial_info_details(None, None);
let mut editor_public_internet = self let mut editor_public_internet = routing_table.edit_public_internet_routing_domain();
.unlocked_inner
.routing_table
.edit_public_internet_routing_domain();
// Update protocols // Update protocols
self.register_all_dial_info(&mut editor_public_internet, &mut editor_local_network) self.register_all_dial_info(&mut editor_public_internet, &mut editor_local_network)

View File

@ -8,12 +8,20 @@ type InboundProtocolMap = HashMap<(AddressType, LowLevelProtocolType, u16), Vec<
impl Network { impl Network {
#[instrument(parent = None, level = "trace", skip(self), err)] #[instrument(parent = None, level = "trace", skip(self), err)]
pub async fn update_network_class_task_routine( pub async fn update_network_class_task_routine(
self, &self,
stop_token: StopToken, stop_token: StopToken,
l: Timestamp, l: Timestamp,
t: Timestamp, t: Timestamp,
) -> EyreResult<()> { ) -> EyreResult<()> {
let _guard = self.unlocked_inner.network_task_lock.lock().await; // Network lock ensures only one task operating on the low level network state
// can happen at the same time.
let _guard = match self.network_task_lock.try_lock() {
Ok(v) => v,
Err(_) => {
// If we can't get the lock right now, then
return Ok(());
}
};
// Do the public dial info check // Do the public dial info check
let finished = self.do_public_dial_info_check(stop_token, l, t).await?; let finished = self.do_public_dial_info_check(stop_token, l, t).await?;
@ -125,8 +133,9 @@ impl Network {
}; };
// Save off existing public dial info for change detection later // Save off existing public dial info for change detection later
let existing_public_dial_info: HashSet<DialInfoDetail> = self let routing_table = self.routing_table();
.routing_table()
let existing_public_dial_info: HashSet<DialInfoDetail> = routing_table
.all_filtered_dial_info_details( .all_filtered_dial_info_details(
RoutingDomain::PublicInternet.into(), RoutingDomain::PublicInternet.into(),
&DialInfoFilter::all(), &DialInfoFilter::all(),
@ -135,7 +144,7 @@ impl Network {
.collect(); .collect();
// Set most permissive network config and start from scratch // Set most permissive network config and start from scratch
let mut editor = self.routing_table().edit_public_internet_routing_domain(); let mut editor = routing_table.edit_public_internet_routing_domain();
editor.setup_network( editor.setup_network(
protocol_config.outbound, protocol_config.outbound,
protocol_config.inbound, protocol_config.inbound,
@ -156,7 +165,7 @@ impl Network {
port, port,
}; };
context_configs.insert(dcc); context_configs.insert(dcc);
let discovery_context = DiscoveryContext::new(self.routing_table(), self.clone(), dcc); let discovery_context = DiscoveryContext::new(self.registry(), dcc);
discovery_context.discover(&mut unord).await; discovery_context.discover(&mut unord).await;
} }
@ -247,22 +256,18 @@ impl Network {
match protocol_type { match protocol_type {
ProtocolType::UDP => DialInfo::udp(addr), ProtocolType::UDP => DialInfo::udp(addr),
ProtocolType::TCP => DialInfo::tcp(addr), ProtocolType::TCP => DialInfo::tcp(addr),
ProtocolType::WS => { ProtocolType::WS => DialInfo::try_ws(
let c = self.config.get(); addr,
DialInfo::try_ws( self.config()
addr, .with(|c| format!("ws://{}/{}", addr, c.network.protocol.ws.path)),
format!("ws://{}/{}", addr, c.network.protocol.ws.path), )
) .unwrap(),
.unwrap() ProtocolType::WSS => DialInfo::try_wss(
} addr,
ProtocolType::WSS => { self.config()
let c = self.config.get(); .with(|c| format!("wss://{}/{}", addr, c.network.protocol.wss.path)),
DialInfo::try_wss( )
addr, .unwrap(),
format!("wss://{}/{}", addr, c.network.protocol.wss.path),
)
.unwrap()
}
} }
} }
} }

View File

@ -3,12 +3,12 @@ use super::*;
impl Network { impl Network {
#[instrument(parent = None, level = "trace", target = "net", skip_all, err)] #[instrument(parent = None, level = "trace", target = "net", skip_all, err)]
pub(super) async fn upnp_task_routine( pub(super) async fn upnp_task_routine(
self, &self,
_stop_token: StopToken, _stop_token: StopToken,
_l: Timestamp, _l: Timestamp,
_t: Timestamp, _t: Timestamp,
) -> EyreResult<()> { ) -> EyreResult<()> {
if !self.unlocked_inner.igd_manager.tick().await? { if !self.igd_manager.tick().await? {
info!("upnp failed, restarting local network"); info!("upnp failed, restarting local network");
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
inner.network_needs_restart = true; inner.network_needs_restart = true;

View File

@ -4,7 +4,7 @@ use std::{io, sync::Arc};
use stop_token::prelude::*; use stop_token::prelude::*;
cfg_if::cfg_if! { cfg_if::cfg_if! {
if #[cfg(target_arch = "wasm32")] { if #[cfg(all(target_arch = "wasm32", target_os = "unknown"))] {
// No accept support for WASM // No accept support for WASM
} else { } else {
@ -307,8 +307,7 @@ impl NetworkConnection {
flow flow
); );
let network_manager = connection_manager.network_manager(); let registry = connection_manager.registry();
let address_filter = network_manager.address_filter();
let mut unord = FuturesUnordered::new(); let mut unord = FuturesUnordered::new();
let mut need_receiver = true; let mut need_receiver = true;
let mut need_sender = true; let mut need_sender = true;
@ -364,14 +363,17 @@ impl NetworkConnection {
// Add another message receiver future if necessary // Add another message receiver future if necessary
if need_receiver { if need_receiver {
need_receiver = false; need_receiver = false;
let registry = registry.clone();
let receiver_fut = Self::recv_internal(&protocol_connection, stats.clone()) let receiver_fut = Self::recv_internal(&protocol_connection, stats.clone())
.then(|res| async { .then(|res| async {
let registry = registry;
let network_manager = registry.network_manager();
match res { match res {
Ok(v) => { Ok(v) => {
let peer_address = protocol_connection.flow().remote(); let peer_address = protocol_connection.flow().remote();
// Check to see if it is punished // Check to see if it is punished
if address_filter.is_ip_addr_punished(peer_address.socket_addr().ip()) { if network_manager.address_filter().is_ip_addr_punished(peer_address.socket_addr().ip()) {
return RecvLoopAction::Finish; return RecvLoopAction::Finish;
} }
@ -383,7 +385,7 @@ impl NetworkConnection {
// Punish invalid framing (tcp framing or websocket framing) // Punish invalid framing (tcp framing or websocket framing)
if v.is_invalid_message() { if v.is_invalid_message() {
address_filter.punish_ip_addr(peer_address.socket_addr().ip(), PunishmentReason::InvalidFraming); network_manager.address_filter().punish_ip_addr(peer_address.socket_addr().ip(), PunishmentReason::InvalidFraming);
return RecvLoopAction::Finish; return RecvLoopAction::Finish;
} }

View File

@ -309,13 +309,7 @@ impl ReceiptManager {
Ok(()) Ok(())
} }
pub async fn shutdown(&self) { pub async fn cancel_tasks(&self) {
log_net!(debug "starting receipt manager shutdown");
let Ok(guard) = self.unlocked_inner.startup_lock.shutdown().await else {
log_net!(debug "receipt manager is already shut down");
return;
};
// Stop all tasks // Stop all tasks
let timeout_task = { let timeout_task = {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
@ -329,6 +323,14 @@ impl ReceiptManager {
if timeout_task.join().await.is_err() { if timeout_task.join().await.is_err() {
panic!("joining timeout task failed"); panic!("joining timeout task failed");
} }
}
pub async fn shutdown(&self) {
log_net!(debug "starting receipt manager shutdown");
let Ok(guard) = self.unlocked_inner.startup_lock.shutdown().await else {
log_net!(debug "receipt manager is already shut down");
return;
};
*self.inner.lock() = Self::new_inner(); *self.inner.lock() = Self::new_inner();

View File

@ -40,9 +40,10 @@ impl NetworkManager {
destination_node_ref: FilteredNodeRef, destination_node_ref: FilteredNodeRef,
data: Vec<u8>, data: Vec<u8>,
) -> SendPinBoxFuture<EyreResult<NetworkResult<SendDataMethod>>> { ) -> SendPinBoxFuture<EyreResult<NetworkResult<SendDataMethod>>> {
let this = self.clone(); let registry = self.registry();
Box::pin( Box::pin(
async move { async move {
let this = registry.network_manager();
// If we need to relay, do it // If we need to relay, do it
let (contact_method, target_node_ref, opt_relayed_contact_method) = match possibly_relayed_contact_method.clone() { let (contact_method, target_node_ref, opt_relayed_contact_method) = match possibly_relayed_contact_method.clone() {
@ -652,17 +653,14 @@ impl NetworkManager {
data: Vec<u8>, data: Vec<u8>,
) -> EyreResult<NetworkResult<UniqueFlow>> { ) -> EyreResult<NetworkResult<UniqueFlow>> {
// Detect if network is stopping so we can break out of this // Detect if network is stopping so we can break out of this
let Some(stop_token) = self.unlocked_inner.startup_lock.stop_token() else { let Some(stop_token) = self.startup_context.startup_lock.stop_token() else {
return Ok(NetworkResult::service_unavailable("network is stopping")); return Ok(NetworkResult::service_unavailable("network is stopping"));
}; };
// Build a return receipt for the signal // Build a return receipt for the signal
let receipt_timeout = TimestampDuration::new_ms( let receipt_timeout = TimestampDuration::new_ms(
self.unlocked_inner self.config()
.config .with(|c| c.network.reverse_connection_receipt_time_ms as u64),
.get()
.network
.reverse_connection_receipt_time_ms as u64,
); );
let (receipt, eventual_value) = self.generate_single_shot_receipt(receipt_timeout, [])?; let (receipt, eventual_value) = self.generate_single_shot_receipt(receipt_timeout, [])?;
@ -763,7 +761,7 @@ impl NetworkManager {
data: Vec<u8>, data: Vec<u8>,
) -> EyreResult<NetworkResult<UniqueFlow>> { ) -> EyreResult<NetworkResult<UniqueFlow>> {
// Detect if network is stopping so we can break out of this // Detect if network is stopping so we can break out of this
let Some(stop_token) = self.unlocked_inner.startup_lock.stop_token() else { let Some(stop_token) = self.startup_context.startup_lock.stop_token() else {
return Ok(NetworkResult::service_unavailable("network is stopping")); return Ok(NetworkResult::service_unavailable("network is stopping"));
}; };
@ -776,11 +774,8 @@ impl NetworkManager {
// Build a return receipt for the signal // Build a return receipt for the signal
let receipt_timeout = TimestampDuration::new_ms( let receipt_timeout = TimestampDuration::new_ms(
self.unlocked_inner self.config()
.config .with(|c| c.network.hole_punch_receipt_time_ms as u64),
.get()
.network
.hole_punch_receipt_time_ms as u64,
); );
let (receipt, eventual_value) = self.generate_single_shot_receipt(receipt_timeout, [])?; let (receipt, eventual_value) = self.generate_single_shot_receipt(receipt_timeout, [])?;

View File

@ -1,7 +1,7 @@
use super::*; use super::*;
// Statistics per address // Statistics per address
#[derive(Clone, Default)] #[derive(Clone, Debug, Default)]
pub struct PerAddressStats { pub struct PerAddressStats {
pub last_seen_ts: Timestamp, pub last_seen_ts: Timestamp,
pub transfer_stats_accounting: TransferStatsAccounting, pub transfer_stats_accounting: TransferStatsAccounting,
@ -18,7 +18,7 @@ impl Default for PerAddressStatsKey {
} }
// Statistics about the low-level network // Statistics about the low-level network
#[derive(Clone)] #[derive(Debug, Clone)]
pub struct NetworkManagerStats { pub struct NetworkManagerStats {
pub self_stats: PerAddressStats, pub self_stats: PerAddressStats,
pub per_address_stats: LruCache<PerAddressStatsKey, PerAddressStats>, pub per_address_stats: LruCache<PerAddressStatsKey, PerAddressStats>,
@ -116,12 +116,10 @@ impl NetworkManager {
}) })
} }
pub(super) fn send_network_update(&self) { pub fn send_network_update(&self) {
let update_cb = self.unlocked_inner.update_callback.read().clone(); let update_cb = self.update_callback();
if update_cb.is_none() {
return;
}
let state = self.get_veilid_state(); let state = self.get_veilid_state();
(update_cb.unwrap())(VeilidUpdate::Network(state)); update_cb(VeilidUpdate::Network(state));
} }
} }

View File

@ -5,48 +5,39 @@ use super::*;
impl NetworkManager { impl NetworkManager {
pub fn setup_tasks(&self) { pub fn setup_tasks(&self) {
// Set rolling transfers tick task // Set rolling transfers tick task
{ impl_setup_task!(
let this = self.clone(); self,
self.unlocked_inner Self,
.rolling_transfers_task rolling_transfers_task,
.set_routine(move |s, l, t| { rolling_transfers_task_routine
Box::pin(this.clone().rolling_transfers_task_routine( );
s,
Timestamp::new(l),
Timestamp::new(t),
))
});
}
// Set address filter task // Set address filter task
{ {
let this = self.clone(); let registry = self.registry();
self.unlocked_inner self.address_filter_task.set_routine(move |s, l, t| {
.address_filter_task let registry = registry.clone();
.set_routine(move |s, l, t| { Box::pin(async move {
Box::pin(this.address_filter().address_filter_task_routine( registry
s, .network_manager()
Timestamp::new(l), .address_filter()
Timestamp::new(t), .address_filter_task_routine(s, Timestamp::new(l), Timestamp::new(t))
)) .await
}); })
});
} }
} }
#[instrument(level = "trace", name = "NetworkManager::tick", skip_all, err)] #[instrument(level = "trace", name = "NetworkManager::tick", skip_all, err)]
pub async fn tick(&self) -> EyreResult<()> { pub async fn tick(&self) -> EyreResult<()> {
let routing_table = self.routing_table();
let net = self.net(); let net = self.net();
let receipt_manager = self.receipt_manager(); let receipt_manager = self.receipt_manager();
// Run the rolling transfers task // Run the rolling transfers task
self.unlocked_inner.rolling_transfers_task.tick().await?; self.rolling_transfers_task.tick().await?;
// Run the address filter task // Run the address filter task
self.unlocked_inner.address_filter_task.tick().await?; self.address_filter_task.tick().await?;
// Run the routing table tick
routing_table.tick().await?;
// Run the low level network tick // Run the low level network tick
net.tick().await?; net.tick().await?;
@ -61,15 +52,21 @@ impl NetworkManager {
} }
pub async fn cancel_tasks(&self) { pub async fn cancel_tasks(&self) {
log_net!(debug "stopping receipt manager tasks");
let receipt_manager = self.receipt_manager();
receipt_manager.cancel_tasks().await;
let net = self.net();
net.cancel_tasks().await;
log_net!(debug "stopping rolling transfers task"); log_net!(debug "stopping rolling transfers task");
if let Err(e) = self.unlocked_inner.rolling_transfers_task.stop().await { if let Err(e) = self.rolling_transfers_task.stop().await {
warn!("rolling_transfers_task not stopped: {}", e); warn!("rolling_transfers_task not stopped: {}", e);
} }
log_net!(debug "stopping routing table tasks"); log_net!(debug "stopping address filter task");
let routing_table = self.routing_table(); if let Err(e) = self.address_filter_task.stop().await {
routing_table.cancel_tasks().await; warn!("address_filter_task not stopped: {}", e);
}
// other tasks will get cancelled via the 'shutdown' mechanism
} }
} }

View File

@ -4,7 +4,7 @@ impl NetworkManager {
// Compute transfer statistics for the low level network // Compute transfer statistics for the low level network
#[instrument(level = "trace", skip(self), err)] #[instrument(level = "trace", skip(self), err)]
pub async fn rolling_transfers_task_routine( pub async fn rolling_transfers_task_routine(
self, &self,
_stop_token: StopToken, _stop_token: StopToken,
last_ts: Timestamp, last_ts: Timestamp,
cur_ts: Timestamp, cur_ts: Timestamp,

View File

@ -1,13 +1,12 @@
use super::*; use super::*;
use super::connection_table::*; use super::connection_table::*;
use crate::tests::common::test_veilid_config::*; use crate::tests::mock_registry;
use crate::tests::mock_routing_table;
pub async fn test_add_get_remove() { pub async fn test_add_get_remove() {
let config = get_config(); let registry = mock_registry::init("").await;
let address_filter = AddressFilter::new(config.clone(), mock_routing_table());
let table = ConnectionTable::new(config, address_filter); let table = ConnectionTable::new(registry.clone());
let a1 = Flow::new_no_local(PeerAddress::new( let a1 = Flow::new_no_local(PeerAddress::new(
SocketAddress::new(Address::IPV4(Ipv4Addr::new(192, 168, 0, 1)), 8080), SocketAddress::new(Address::IPV4(Ipv4Addr::new(192, 168, 0, 1)), 8080),
@ -122,6 +121,8 @@ pub async fn test_add_get_remove() {
a4 a4
); );
assert_eq!(table.connection_count(), 0); assert_eq!(table.connection_count(), 0);
mock_registry::terminate(registry).await;
} }
pub async fn test_all() { pub async fn test_all() {

View File

@ -30,7 +30,7 @@ pub async fn test_signed_node_info() {
// Test correct validation // Test correct validation
let keypair = vcrypto.generate_keypair(); let keypair = vcrypto.generate_keypair();
let sni = SignedDirectNodeInfo::make_signatures( let sni = SignedDirectNodeInfo::make_signatures(
crypto.clone(), &crypto,
vec![TypedKeyPair::new(ck, keypair)], vec![TypedKeyPair::new(ck, keypair)],
node_info.clone(), node_info.clone(),
) )
@ -42,7 +42,7 @@ pub async fn test_signed_node_info() {
sni.timestamp(), sni.timestamp(),
sni.signatures().to_vec(), sni.signatures().to_vec(),
); );
let tks_validated = sdni.validate(&tks, crypto.clone()).unwrap(); let tks_validated = sdni.validate(&tks, &crypto).unwrap();
assert_eq!(tks_validated.len(), oldtkslen); assert_eq!(tks_validated.len(), oldtkslen);
assert_eq!(tks_validated.len(), sni.signatures().len()); assert_eq!(tks_validated.len(), sni.signatures().len());
@ -54,7 +54,7 @@ pub async fn test_signed_node_info() {
sni.timestamp(), sni.timestamp(),
sni.signatures().to_vec(), sni.signatures().to_vec(),
); );
sdni.validate(&tks1, crypto.clone()).unwrap_err(); sdni.validate(&tks1, &crypto).unwrap_err();
// Test unsupported cryptosystem validation // Test unsupported cryptosystem validation
let fake_crypto_kind: CryptoKind = FourCC::from([0, 1, 2, 3]); let fake_crypto_kind: CryptoKind = FourCC::from([0, 1, 2, 3]);
@ -65,7 +65,7 @@ pub async fn test_signed_node_info() {
tksfake.add(TypedKey::new(ck, keypair.key)); tksfake.add(TypedKey::new(ck, keypair.key));
let sdnifake = let sdnifake =
SignedDirectNodeInfo::new(node_info.clone(), sni.timestamp(), sigsfake.clone()); SignedDirectNodeInfo::new(node_info.clone(), sni.timestamp(), sigsfake.clone());
let tksfake_validated = sdnifake.validate(&tksfake, crypto.clone()).unwrap(); let tksfake_validated = sdnifake.validate(&tksfake, &crypto).unwrap();
assert_eq!(tksfake_validated.len(), 1); assert_eq!(tksfake_validated.len(), 1);
assert_eq!(sdnifake.signatures().len(), sigsfake.len()); assert_eq!(sdnifake.signatures().len(), sigsfake.len());
@ -89,7 +89,7 @@ pub async fn test_signed_node_info() {
let oldtks2len = tks2.len(); let oldtks2len = tks2.len();
let sni2 = SignedRelayedNodeInfo::make_signatures( let sni2 = SignedRelayedNodeInfo::make_signatures(
crypto.clone(), &crypto,
vec![TypedKeyPair::new(ck, keypair2)], vec![TypedKeyPair::new(ck, keypair2)],
node_info2.clone(), node_info2.clone(),
tks.clone(), tks.clone(),
@ -103,7 +103,7 @@ pub async fn test_signed_node_info() {
sni2.timestamp(), sni2.timestamp(),
sni2.signatures().to_vec(), sni2.signatures().to_vec(),
); );
let tks2_validated = srni.validate(&tks2, crypto.clone()).unwrap(); let tks2_validated = srni.validate(&tks2, &crypto).unwrap();
assert_eq!(tks2_validated.len(), oldtks2len); assert_eq!(tks2_validated.len(), oldtks2len);
assert_eq!(tks2_validated.len(), sni2.signatures().len()); assert_eq!(tks2_validated.len(), sni2.signatures().len());
@ -119,7 +119,7 @@ pub async fn test_signed_node_info() {
sni2.timestamp(), sni2.timestamp(),
sni2.signatures().to_vec(), sni2.signatures().to_vec(),
); );
srni.validate(&tks3, crypto.clone()).unwrap_err(); srni.validate(&tks3, &crypto).unwrap_err();
// Test unsupported cryptosystem validation // Test unsupported cryptosystem validation
let fake_crypto_kind: CryptoKind = FourCC::from([0, 1, 2, 3]); let fake_crypto_kind: CryptoKind = FourCC::from([0, 1, 2, 3]);
@ -135,7 +135,7 @@ pub async fn test_signed_node_info() {
sni2.timestamp(), sni2.timestamp(),
sigsfake3.clone(), sigsfake3.clone(),
); );
let tksfake3_validated = srnifake.validate(&tksfake3, crypto.clone()).unwrap(); let tksfake3_validated = srnifake.validate(&tksfake3, &crypto).unwrap();
assert_eq!(tksfake3_validated.len(), 1); assert_eq!(tksfake3_validated.len(), 1);
assert_eq!(srnifake.signatures().len(), sigsfake3.len()); assert_eq!(srnifake.signatures().len(), sigsfake3.len());
} }

View File

@ -20,7 +20,7 @@ impl Address {
SocketAddr::V6(v6) => Address::IPV6(*v6.ip()), SocketAddr::V6(v6) => Address::IPV6(*v6.ip()),
} }
} }
#[cfg_attr(target_arch = "wasm32", expect(dead_code))] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), expect(dead_code))]
pub fn from_ip_addr(addr: IpAddr) -> Address { pub fn from_ip_addr(addr: IpAddr) -> Address {
match addr { match addr {
IpAddr::V4(v4) => Address::IPV4(v4), IpAddr::V4(v4) => Address::IPV4(v4),

View File

@ -268,7 +268,7 @@ impl DialInfo {
Self::WSS(di) => di.socket_address.port(), Self::WSS(di) => di.socket_address.port(),
} }
} }
#[cfg_attr(target_arch = "wasm32", expect(dead_code))] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), expect(dead_code))]
pub fn set_port(&mut self, port: u16) { pub fn set_port(&mut self, port: u16) {
match self { match self {
Self::UDP(di) => di.socket_address.set_port(port), Self::UDP(di) => di.socket_address.set_port(port),
@ -366,7 +366,7 @@ impl DialInfo {
// This will not be used on signed dialinfo, only for bootstrapping, so we don't need to worry about // This will not be used on signed dialinfo, only for bootstrapping, so we don't need to worry about
// the '0.0.0.0' address being propagated across the routing table // the '0.0.0.0' address being propagated across the routing table
cfg_if::cfg_if! { cfg_if::cfg_if! {
if #[cfg(target_arch = "wasm32")] { if #[cfg(all(target_arch = "wasm32", target_os = "unknown"))] {
vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0,0,0,0)), port)] vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0,0,0,0)), port)]
} else { } else {
match split_url.host { match split_url.host {

View File

@ -21,7 +21,7 @@ pub(crate) enum SignalInfo {
} }
impl SignalInfo { impl SignalInfo {
pub fn validate(&self, crypto: Crypto) -> Result<(), RPCError> { pub fn validate(&self, crypto: &Crypto) -> Result<(), RPCError> {
match self { match self {
SignalInfo::HolePunch { receipt, peer_info } => { SignalInfo::HolePunch { receipt, peer_info } => {
if receipt.len() < MIN_RECEIPT_SIZE { if receipt.len() < MIN_RECEIPT_SIZE {

View File

@ -1,2 +1,2 @@
[build] [build]
target = "wasm32-unknown-unknown" all(target_arch = "wasm32", target_os = "unknown")

View File

@ -3,8 +3,6 @@ mod protocol;
use super::*; use super::*;
use crate::routing_table::*; use crate::routing_table::*;
use connection_manager::*;
use protocol::ws::WebsocketProtocolHandler;
pub use protocol::*; pub use protocol::*;
use std::io; use std::io;
@ -64,23 +62,28 @@ struct NetworkInner {
protocol_config: ProtocolConfig, protocol_config: ProtocolConfig,
} }
struct NetworkUnlockedInner { pub(super) struct NetworkUnlockedInner {
// Startup lock // Startup lock
startup_lock: StartupLock, startup_lock: StartupLock,
// Accessors
routing_table: RoutingTable,
network_manager: NetworkManager,
connection_manager: ConnectionManager,
} }
#[derive(Clone)] #[derive(Clone)]
pub(super) struct Network { pub(super) struct Network {
config: VeilidConfig, registry: VeilidComponentRegistry,
inner: Arc<Mutex<NetworkInner>>, inner: Arc<Mutex<NetworkInner>>,
unlocked_inner: Arc<NetworkUnlockedInner>, unlocked_inner: Arc<NetworkUnlockedInner>,
} }
impl_veilid_component_registry_accessor!(Network);
impl core::ops::Deref for Network {
type Target = NetworkUnlockedInner;
fn deref(&self) -> &Self::Target {
&self.unlocked_inner
}
}
impl Network { impl Network {
fn new_inner() -> NetworkInner { fn new_inner() -> NetworkInner {
NetworkInner { NetworkInner {
@ -89,45 +92,20 @@ impl Network {
} }
} }
fn new_unlocked_inner( fn new_unlocked_inner() -> NetworkUnlockedInner {
network_manager: NetworkManager,
routing_table: RoutingTable,
connection_manager: ConnectionManager,
) -> NetworkUnlockedInner {
NetworkUnlockedInner { NetworkUnlockedInner {
startup_lock: StartupLock::new(), startup_lock: StartupLock::new(),
network_manager,
routing_table,
connection_manager,
} }
} }
pub fn new( pub fn new(registry: VeilidComponentRegistry) -> Self {
network_manager: NetworkManager,
routing_table: RoutingTable,
connection_manager: ConnectionManager,
) -> Self {
Self { Self {
config: network_manager.config(), registry,
inner: Arc::new(Mutex::new(Self::new_inner())), inner: Arc::new(Mutex::new(Self::new_inner())),
unlocked_inner: Arc::new(Self::new_unlocked_inner( unlocked_inner: Arc::new(Self::new_unlocked_inner()),
network_manager,
routing_table,
connection_manager,
)),
} }
} }
fn network_manager(&self) -> NetworkManager {
self.unlocked_inner.network_manager.clone()
}
fn routing_table(&self) -> RoutingTable {
self.unlocked_inner.routing_table.clone()
}
fn connection_manager(&self) -> ConnectionManager {
self.unlocked_inner.connection_manager.clone()
}
///////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////
// Record DialInfo failures // Record DialInfo failures
@ -159,10 +137,9 @@ impl Network {
self.record_dial_info_failure(dial_info.clone(), async move { self.record_dial_info_failure(dial_info.clone(), async move {
let data_len = data.len(); let data_len = data.len();
let timeout_ms = { let timeout_ms = self
let c = self.config.get(); .config()
c.network.connection_initial_timeout_ms .with(|c| c.network.connection_initial_timeout_ms);
};
if self if self
.network_manager() .network_manager()
@ -180,7 +157,7 @@ impl Network {
bail!("no support for TCP protocol") bail!("no support for TCP protocol")
} }
ProtocolType::WS | ProtocolType::WSS => { ProtocolType::WS | ProtocolType::WSS => {
let pnc = network_result_try!(WebsocketProtocolHandler::connect( let pnc = network_result_try!(ws::WebsocketProtocolHandler::connect(
&dial_info, timeout_ms &dial_info, timeout_ms
) )
.await .await
@ -210,14 +187,13 @@ impl Network {
data: Vec<u8>, data: Vec<u8>,
timeout_ms: u32, timeout_ms: u32,
) -> EyreResult<NetworkResult<Vec<u8>>> { ) -> EyreResult<NetworkResult<Vec<u8>>> {
let _guard = self.unlocked_inner.startup_lock.enter()?; let _guard = self.startup_lock.enter()?;
self.record_dial_info_failure(dial_info.clone(), async move { self.record_dial_info_failure(dial_info.clone(), async move {
let data_len = data.len(); let data_len = data.len();
let connect_timeout_ms = { let connect_timeout_ms = self
let c = self.config.get(); .config()
c.network.connection_initial_timeout_ms .with(|c| c.network.connection_initial_timeout_ms);
};
if self if self
.network_manager() .network_manager()
@ -239,7 +215,7 @@ impl Network {
ProtocolType::UDP => unreachable!(), ProtocolType::UDP => unreachable!(),
ProtocolType::TCP => unreachable!(), ProtocolType::TCP => unreachable!(),
ProtocolType::WS | ProtocolType::WSS => { ProtocolType::WS | ProtocolType::WSS => {
WebsocketProtocolHandler::connect(&dial_info, connect_timeout_ms) ws::WebsocketProtocolHandler::connect(&dial_info, connect_timeout_ms)
.await .await
.wrap_err("connect failure")? .wrap_err("connect failure")?
} }
@ -271,7 +247,7 @@ impl Network {
flow: Flow, flow: Flow,
data: Vec<u8>, data: Vec<u8>,
) -> EyreResult<SendDataToExistingFlowResult> { ) -> EyreResult<SendDataToExistingFlowResult> {
let _guard = self.unlocked_inner.startup_lock.enter()?; let _guard = self.startup_lock.enter()?;
let data_len = data.len(); let data_len = data.len();
match flow.protocol_type() { match flow.protocol_type() {
@ -287,7 +263,11 @@ impl Network {
// Handle connection-oriented protocols // Handle connection-oriented protocols
// Try to send to the exact existing connection if one exists // Try to send to the exact existing connection if one exists
if let Some(conn) = self.connection_manager().get_connection(flow) { if let Some(conn) = self
.network_manager()
.connection_manager()
.get_connection(flow)
{
// connection exists, send over it // connection exists, send over it
match conn.send_async(data).await { match conn.send_async(data).await {
ConnectionHandleSendResult::Sent => { ConnectionHandleSendResult::Sent => {
@ -320,7 +300,7 @@ impl Network {
dial_info: DialInfo, dial_info: DialInfo,
data: Vec<u8>, data: Vec<u8>,
) -> EyreResult<NetworkResult<UniqueFlow>> { ) -> EyreResult<NetworkResult<UniqueFlow>> {
let _guard = self.unlocked_inner.startup_lock.enter()?; let _guard = self.startup_lock.enter()?;
self.record_dial_info_failure(dial_info.clone(), async move { self.record_dial_info_failure(dial_info.clone(), async move {
let data_len = data.len(); let data_len = data.len();
@ -333,7 +313,8 @@ impl Network {
// Handle connection-oriented protocols // Handle connection-oriented protocols
let conn = network_result_try!( let conn = network_result_try!(
self.connection_manager() self.network_manager()
.connection_manager()
.get_or_create_connection(dial_info.clone()) .get_or_create_connection(dial_info.clone())
.await? .await?
); );
@ -361,7 +342,8 @@ impl Network {
log_net!(debug "starting network"); log_net!(debug "starting network");
// get protocol config // get protocol config
let protocol_config = { let protocol_config = {
let c = self.config.get(); let config = self.config();
let c = config.get();
let inbound = ProtocolTypeSet::new(); let inbound = ProtocolTypeSet::new();
let mut outbound = ProtocolTypeSet::new(); let mut outbound = ProtocolTypeSet::new();
@ -398,10 +380,8 @@ impl Network {
self.inner.lock().protocol_config = protocol_config.clone(); self.inner.lock().protocol_config = protocol_config.clone();
// Start editing routing table // Start editing routing table
let mut editor_public_internet = self let routing_table = self.routing_table();
.unlocked_inner let mut editor_public_internet = routing_table.edit_public_internet_routing_domain();
.routing_table
.edit_public_internet_routing_domain();
// set up the routing table's network config // set up the routing table's network config
editor_public_internet.setup_network( editor_public_internet.setup_network(
@ -421,7 +401,7 @@ impl Network {
#[instrument(level = "debug", err, skip_all)] #[instrument(level = "debug", err, skip_all)]
pub async fn startup(&self) -> EyreResult<StartupDisposition> { pub async fn startup(&self) -> EyreResult<StartupDisposition> {
let guard = self.unlocked_inner.startup_lock.startup()?; let guard = self.startup_lock.startup()?;
match self.startup_internal().await { match self.startup_internal().await {
Ok(StartupDisposition::Success) => { Ok(StartupDisposition::Success) => {
@ -445,7 +425,7 @@ impl Network {
} }
pub fn is_started(&self) -> bool { pub fn is_started(&self) -> bool {
self.unlocked_inner.startup_lock.is_started() self.startup_lock.is_started()
} }
#[instrument(level = "debug", skip_all)] #[instrument(level = "debug", skip_all)]
@ -456,7 +436,7 @@ impl Network {
#[instrument(level = "debug", skip_all)] #[instrument(level = "debug", skip_all)]
pub async fn shutdown(&self) { pub async fn shutdown(&self) {
log_net!(debug "starting low level network shutdown"); log_net!(debug "starting low level network shutdown");
let Ok(guard) = self.unlocked_inner.startup_lock.shutdown().await else { let Ok(guard) = self.startup_lock.shutdown().await else {
log_net!(debug "low level network is already shut down"); log_net!(debug "low level network is already shut down");
return; return;
}; };
@ -493,14 +473,14 @@ impl Network {
&self, &self,
_punishment: Option<Box<dyn FnOnce() + Send + 'static>>, _punishment: Option<Box<dyn FnOnce() + Send + 'static>>,
) { ) {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_lock.enter() else {
log_net!(debug "ignoring due to not started up"); log_net!(debug "ignoring due to not started up");
return; return;
}; };
} }
pub fn needs_public_dial_info_check(&self) -> bool { pub fn needs_public_dial_info_check(&self) -> bool {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_lock.enter() else {
log_net!(debug "ignoring due to not started up"); log_net!(debug "ignoring due to not started up");
return false; return false;
}; };
@ -511,11 +491,12 @@ impl Network {
////////////////////////////////////////// //////////////////////////////////////////
#[instrument(level = "trace", target = "net", name = "Network::tick", skip_all, err)] #[instrument(level = "trace", target = "net", name = "Network::tick", skip_all, err)]
pub async fn tick(&self) -> EyreResult<()> { pub async fn tick(&self) -> EyreResult<()> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_lock.enter() else {
log_net!(debug "ignoring due to not started up"); log_net!(debug "ignoring due to not started up");
return Ok(()); return Ok(());
}; };
Ok(()) Ok(())
} }
pub async fn cancel_tasks(&self) {}
} }

View File

@ -16,7 +16,7 @@ impl ProtocolNetworkConnection {
_local_address: Option<SocketAddr>, _local_address: Option<SocketAddr>,
dial_info: &DialInfo, dial_info: &DialInfo,
timeout_ms: u32, timeout_ms: u32,
address_filter: AddressFilter, address_filter: &AddressFilter,
) -> io::Result<NetworkResult<ProtocolNetworkConnection>> { ) -> io::Result<NetworkResult<ProtocolNetworkConnection>> {
if address_filter.is_ip_addr_punished(dial_info.address().ip_addr()) { if address_filter.is_ip_addr_punished(dial_info.address().ip_addr()) {
return Ok(NetworkResult::no_connection_other("punished")); return Ok(NetworkResult::no_connection_other("punished"));

View File

@ -9,29 +9,6 @@ struct WebsocketNetworkConnectionInner {
ws_stream: CloneStream<WsStream>, ws_stream: CloneStream<WsStream>,
} }
fn to_io(err: WsErr) -> io::Error {
match err {
WsErr::InvalidWsState { supplied: _ } => {
io::Error::new(io::ErrorKind::InvalidInput, err.to_string())
}
WsErr::ConnectionNotOpen => io::Error::new(io::ErrorKind::NotConnected, err.to_string()),
WsErr::InvalidUrl { supplied: _ } => {
io::Error::new(io::ErrorKind::InvalidInput, err.to_string())
}
WsErr::InvalidCloseCode { supplied: _ } => {
io::Error::new(io::ErrorKind::InvalidInput, err.to_string())
}
WsErr::ReasonStringToLong => io::Error::new(io::ErrorKind::InvalidInput, err.to_string()),
WsErr::ConnectionFailed { event: _ } => {
io::Error::new(io::ErrorKind::ConnectionRefused, err.to_string())
}
WsErr::InvalidEncoding => io::Error::new(io::ErrorKind::InvalidInput, err.to_string()),
WsErr::CantDecodeBlob => io::Error::new(io::ErrorKind::InvalidInput, err.to_string()),
WsErr::UnknownDataType => io::Error::new(io::ErrorKind::InvalidInput, err.to_string()),
_ => io::Error::new(io::ErrorKind::Other, err.to_string()),
}
}
#[derive(Clone)] #[derive(Clone)]
pub struct WebsocketNetworkConnection { pub struct WebsocketNetworkConnection {
flow: Flow, flow: Flow,
@ -65,7 +42,7 @@ impl WebsocketNetworkConnection {
)] )]
pub async fn close(&self) -> io::Result<NetworkResult<()>> { pub async fn close(&self) -> io::Result<NetworkResult<()>> {
#[allow(unused_variables)] #[allow(unused_variables)]
let x = self.inner.ws_meta.close().await.map_err(to_io); let x = self.inner.ws_meta.close().await.map_err(ws_err_to_io_error);
#[cfg(feature = "verbose-tracing")] #[cfg(feature = "verbose-tracing")]
log_net!(debug "close result: {:?}", x); log_net!(debug "close result: {:?}", x);
Ok(NetworkResult::value(())) Ok(NetworkResult::value(()))
@ -83,7 +60,7 @@ impl WebsocketNetworkConnection {
.send(WsMessage::Binary(message)), .send(WsMessage::Binary(message)),
) )
.await .await
.map_err(to_io) .map_err(ws_err_to_io_error)
.into_network_result()?; .into_network_result()?;
#[cfg(feature = "verbose-tracing")] #[cfg(feature = "verbose-tracing")]
@ -140,7 +117,9 @@ impl WebsocketProtocolHandler {
} }
let fut = SendWrapper::new(timeout(timeout_ms, async move { let fut = SendWrapper::new(timeout(timeout_ms, async move {
WsMeta::connect(request, None).await.map_err(to_io) WsMeta::connect(request, None)
.await
.map_err(ws_err_to_io_error)
})); }));
let (wsmeta, wsio) = network_result_try!(network_result_try!(fut let (wsmeta, wsio) = network_result_try!(network_result_try!(fut

View File

@ -640,7 +640,7 @@ impl BucketEntryInner {
only_live: bool, only_live: bool,
filter: NodeRefFilter, filter: NodeRefFilter,
) -> Vec<(Flow, Timestamp)> { ) -> Vec<(Flow, Timestamp)> {
let opt_connection_manager = rti.unlocked_inner.network_manager.opt_connection_manager(); let opt_connection_manager = rti.network_manager().opt_connection_manager();
let mut out: Vec<(Flow, Timestamp)> = self let mut out: Vec<(Flow, Timestamp)> = self
.last_flows .last_flows

View File

@ -35,7 +35,6 @@ impl RoutingTable {
let valid_envelope_versions = VALID_ENVELOPE_VERSIONS.map(|x| x.to_string()).join(","); let valid_envelope_versions = VALID_ENVELOPE_VERSIONS.map(|x| x.to_string()).join(",");
let node_ids = self let node_ids = self
.unlocked_inner
.node_ids() .node_ids()
.iter() .iter()
.map(|x| x.to_string()) .map(|x| x.to_string())
@ -57,7 +56,7 @@ impl RoutingTable {
pub fn debug_info_nodeid(&self) -> String { pub fn debug_info_nodeid(&self) -> String {
let mut out = String::new(); let mut out = String::new();
for nid in self.unlocked_inner.node_ids().iter() { for nid in self.node_ids().iter() {
out += &format!("{}\n", nid); out += &format!("{}\n", nid);
} }
out out
@ -66,7 +65,7 @@ impl RoutingTable {
pub fn debug_info_nodeinfo(&self) -> String { pub fn debug_info_nodeinfo(&self) -> String {
let mut out = String::new(); let mut out = String::new();
let inner = self.inner.read(); let inner = self.inner.read();
out += &format!("Node Ids: {}\n", self.unlocked_inner.node_ids()); out += &format!("Node Ids: {}\n", self.node_ids());
out += &format!( out += &format!(
"Self Transfer Stats:\n{}", "Self Transfer Stats:\n{}",
indent_all_string(&inner.self_transfer_stats) indent_all_string(&inner.self_transfer_stats)
@ -250,7 +249,7 @@ impl RoutingTable {
out += &format!("{:?}: {}: {}\n", routing_domain, crypto_kind, count); out += &format!("{:?}: {}: {}\n", routing_domain, crypto_kind, count);
} }
for ck in &VALID_CRYPTO_KINDS { for ck in &VALID_CRYPTO_KINDS {
let our_node_id = self.unlocked_inner.node_id(*ck); let our_node_id = self.node_id(*ck);
let mut filtered_total = 0; let mut filtered_total = 0;
let mut b = 0; let mut b = 0;
@ -319,7 +318,7 @@ impl RoutingTable {
) -> String { ) -> String {
let cur_ts = Timestamp::now(); let cur_ts = Timestamp::now();
let relay_node_filter = self.make_public_internet_relay_node_filter(); let relay_node_filter = self.make_public_internet_relay_node_filter();
let our_node_ids = self.unlocked_inner.node_ids(); let our_node_ids = self.node_ids();
let mut relay_count = 0usize; let mut relay_count = 0usize;
let mut relaying_count = 0usize; let mut relaying_count = 0usize;
@ -340,7 +339,7 @@ impl RoutingTable {
node_count, node_count,
filters, filters,
|_rti, entry: Option<Arc<BucketEntry>>| { |_rti, entry: Option<Arc<BucketEntry>>| {
NodeRef::new(self.clone(), entry.unwrap().clone()) NodeRef::new(self.registry(), entry.unwrap().clone())
}, },
); );
let mut out = String::new(); let mut out = String::new();
@ -376,9 +375,10 @@ impl RoutingTable {
relaying_count += 1; relaying_count += 1;
} }
let best_node_id = node.best_node_id();
out += " "; out += " ";
out += &node out += &node.operate(|_rti, e| Self::format_entry(cur_ts, best_node_id, e, &relay_tag));
.operate(|_rti, e| Self::format_entry(cur_ts, node.best_node_id(), e, &relay_tag));
out += "\n"; out += "\n";
} }

View File

@ -42,10 +42,9 @@ impl RoutingTable {
) as RoutingTableEntryFilter; ) as RoutingTableEntryFilter;
let filters = VecDeque::from([filter]); let filters = VecDeque::from([filter]);
let node_count = { let node_count = self
let c = self.config.get(); .config()
c.network.dht.max_find_node_count as usize .with(|c| c.network.dht.max_find_node_count as usize);
};
let closest_nodes = match self.find_preferred_closest_nodes( let closest_nodes = match self.find_preferred_closest_nodes(
node_count, node_count,
@ -82,11 +81,13 @@ impl RoutingTable {
// find N nodes closest to the target node in our routing table // find N nodes closest to the target node in our routing table
// ensure the nodes returned are only the ones closer to the target node than ourself // ensure the nodes returned are only the ones closer to the target node than ourself
let Some(vcrypto) = self.crypto().get(crypto_kind) else { let crypto = self.crypto();
let Some(vcrypto) = crypto.get(crypto_kind) else {
return NetworkResult::invalid_message("unsupported cryptosystem"); return NetworkResult::invalid_message("unsupported cryptosystem");
}; };
let vcrypto = &vcrypto;
let own_distance = vcrypto.distance(&own_node_id.value, &key.value); let own_distance = vcrypto.distance(&own_node_id.value, &key.value);
let vcrypto2 = vcrypto.clone();
let filter = Box::new( let filter = Box::new(
move |rti: &RoutingTableInner, opt_entry: Option<Arc<BucketEntry>>| { move |rti: &RoutingTableInner, opt_entry: Option<Arc<BucketEntry>>| {
@ -121,10 +122,9 @@ impl RoutingTable {
) as RoutingTableEntryFilter; ) as RoutingTableEntryFilter;
let filters = VecDeque::from([filter]); let filters = VecDeque::from([filter]);
let node_count = { let node_count = self
let c = self.config.get(); .config()
c.network.dht.max_find_node_count as usize .with(|c| c.network.dht.max_find_node_count as usize);
};
// //
let closest_nodes = match self.find_preferred_closest_nodes( let closest_nodes = match self.find_preferred_closest_nodes(
@ -147,7 +147,7 @@ impl RoutingTable {
// Validate peers returned are, in fact, closer to the key than the node we sent this to // Validate peers returned are, in fact, closer to the key than the node we sent this to
// This same test is used on the other side so we vet things here // This same test is used on the other side so we vet things here
let valid = match Self::verify_peers_closer(vcrypto2, own_node_id, key, &closest_nodes) { let valid = match Self::verify_peers_closer(vcrypto, own_node_id, key, &closest_nodes) {
Ok(v) => v, Ok(v) => v,
Err(e) => { Err(e) => {
panic!("missing cryptosystem in peers node ids: {}", e); panic!("missing cryptosystem in peers node ids: {}", e);
@ -166,7 +166,7 @@ impl RoutingTable {
/// Determine if set of peers is closer to key_near than key_far is to key_near /// Determine if set of peers is closer to key_near than key_far is to key_near
#[instrument(level = "trace", target = "rtab", skip_all, err)] #[instrument(level = "trace", target = "rtab", skip_all, err)]
pub fn verify_peers_closer( pub fn verify_peers_closer(
vcrypto: CryptoSystemVersion, vcrypto: &crypto::CryptoSystemGuard<'_>,
key_far: TypedKey, key_far: TypedKey,
key_near: TypedKey, key_near: TypedKey,
peers: &[Arc<PeerInfo>], peers: &[Arc<PeerInfo>],

View File

@ -91,16 +91,12 @@ pub struct RecentPeersEntry {
pub last_connection: Flow, pub last_connection: Flow,
} }
pub(crate) struct RoutingTableUnlockedInner { pub(crate) struct RoutingTable {
// Accessors registry: VeilidComponentRegistry,
event_bus: EventBus, inner: RwLock<RoutingTableInner>,
config: VeilidConfig,
network_manager: NetworkManager,
/// The current node's public DHT keys /// Route spec store
node_id: TypedKeyGroup, route_spec_store: RouteSpecStore,
/// The current node's public DHT secrets
node_id_secret: TypedSecretGroup,
/// Buckets to kick on our next kick task /// Buckets to kick on our next kick task
kick_queue: Mutex<BTreeSet<BucketIndex>>, kick_queue: Mutex<BTreeSet<BucketIndex>>,
/// Background process for computing statistics /// Background process for computing statistics
@ -131,103 +127,27 @@ pub(crate) struct RoutingTableUnlockedInner {
private_route_management_task: TickTask<EyreReport>, private_route_management_task: TickTask<EyreReport>,
} }
impl RoutingTableUnlockedInner { impl fmt::Debug for RoutingTable {
pub fn network_manager(&self) -> NetworkManager { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.network_manager.clone() f.debug_struct("RoutingTable")
} // .field("inner", &self.inner)
pub fn crypto(&self) -> Crypto { // .field("unlocked_inner", &self.unlocked_inner)
self.network_manager().crypto() .finish()
}
pub fn rpc_processor(&self) -> RPCProcessor {
self.network_manager().rpc_processor()
}
pub fn update_callback(&self) -> UpdateCallback {
self.network_manager().update_callback()
}
pub fn with_config<F, R>(&self, f: F) -> R
where
F: FnOnce(&VeilidConfigInner) -> R,
{
f(&self.config.get())
}
pub fn node_id(&self, kind: CryptoKind) -> TypedKey {
self.node_id.get(kind).unwrap()
}
pub fn node_id_secret_key(&self, kind: CryptoKind) -> SecretKey {
self.node_id_secret.get(kind).unwrap().value
}
pub fn node_ids(&self) -> TypedKeyGroup {
self.node_id.clone()
}
pub fn node_id_typed_key_pairs(&self) -> Vec<TypedKeyPair> {
let mut tkps = Vec::new();
for ck in VALID_CRYPTO_KINDS {
tkps.push(TypedKeyPair::new(
ck,
KeyPair::new(self.node_id(ck).value, self.node_id_secret_key(ck)),
));
}
tkps
}
pub fn matches_own_node_id(&self, node_ids: &[TypedKey]) -> bool {
for ni in node_ids {
if let Some(v) = self.node_id.get(ni.kind) {
if v.value == ni.value {
return true;
}
}
}
false
}
pub fn matches_own_node_id_key(&self, node_id_key: &PublicKey) -> bool {
for tk in self.node_id.iter() {
if tk.value == *node_id_key {
return true;
}
}
false
}
pub fn calculate_bucket_index(&self, node_id: &TypedKey) -> BucketIndex {
let crypto = self.crypto();
let self_node_id_key = self.node_id(node_id.kind).value;
let vcrypto = crypto.get(node_id.kind).unwrap();
(
node_id.kind,
vcrypto
.distance(&node_id.value, &self_node_id_key)
.first_nonzero_bit()
.unwrap(),
)
} }
} }
#[derive(Clone)] impl_veilid_component!(RoutingTable);
pub(crate) struct RoutingTable {
inner: Arc<RwLock<RoutingTableInner>>,
unlocked_inner: Arc<RoutingTableUnlockedInner>,
}
impl RoutingTable { impl RoutingTable {
fn new_unlocked_inner( pub fn new(registry: VeilidComponentRegistry) -> Self {
event_bus: EventBus, let config = registry.config();
config: VeilidConfig,
network_manager: NetworkManager,
) -> RoutingTableUnlockedInner {
let c = config.get(); let c = config.get();
let inner = RwLock::new(RoutingTableInner::new(registry.clone()));
RoutingTableUnlockedInner { let route_spec_store = RouteSpecStore::new(registry.clone());
event_bus, let this = Self {
config: config.clone(), registry,
network_manager, inner,
node_id: c.network.routing_table.node_id.clone(), route_spec_store,
node_id_secret: c.network.routing_table.node_id_secret.clone(),
kick_queue: Mutex::new(BTreeSet::default()), kick_queue: Mutex::new(BTreeSet::default()),
rolling_transfers_task: TickTask::new( rolling_transfers_task: TickTask::new(
"rolling_transfers_task", "rolling_transfers_task",
@ -269,16 +189,6 @@ impl RoutingTable {
"private_route_management_task", "private_route_management_task",
PRIVATE_ROUTE_MANAGEMENT_INTERVAL_SECS, PRIVATE_ROUTE_MANAGEMENT_INTERVAL_SECS,
), ),
}
}
pub fn new(network_manager: NetworkManager) -> Self {
let event_bus = network_manager.event_bus();
let config = network_manager.config();
let unlocked_inner = Arc::new(Self::new_unlocked_inner(event_bus, config, network_manager));
let inner = Arc::new(RwLock::new(RoutingTableInner::new(unlocked_inner.clone())));
let this = Self {
inner,
unlocked_inner,
}; };
this.setup_tasks(); this.setup_tasks();
@ -290,7 +200,7 @@ impl RoutingTable {
/// Initialization /// Initialization
/// Called to initialize the routing table after it is created /// Called to initialize the routing table after it is created
pub async fn init(&self) -> EyreResult<()> { async fn init_async(&self) -> EyreResult<()> {
log_rtab!(debug "starting routing table init"); log_rtab!(debug "starting routing table init");
// Set up routing buckets // Set up routing buckets
@ -309,42 +219,35 @@ impl RoutingTable {
// Set up routespecstore // Set up routespecstore
log_rtab!(debug "starting route spec store init"); log_rtab!(debug "starting route spec store init");
let route_spec_store = match RouteSpecStore::load(self.clone()).await { if let Err(e) = self.route_spec_store().load().await {
Ok(v) => v, log_rtab!(debug "Error loading route spec store: {:#?}. Resetting.", e);
Err(e) => { self.route_spec_store().reset();
log_rtab!(debug "Error loading route spec store: {:#?}. Resetting.", e);
RouteSpecStore::new(self.clone())
}
}; };
log_rtab!(debug "finished route spec store init"); log_rtab!(debug "finished route spec store init");
{
let mut inner = self.inner.write();
inner.route_spec_store = Some(route_spec_store);
}
// Inform storage manager we are up
self.network_manager
.storage_manager()
.set_routing_table(Some(self.clone()))
.await;
log_rtab!(debug "finished routing table init"); log_rtab!(debug "finished routing table init");
Ok(()) Ok(())
} }
/// Called to shut down the routing table async fn post_init_async(&self) -> EyreResult<()> {
pub async fn terminate(&self) { Ok(())
log_rtab!(debug "starting routing table terminate"); }
// Stop storage manager from using us pub(crate) async fn startup(&self) -> EyreResult<()> {
self.network_manager Ok(())
.storage_manager() }
.set_routing_table(None)
.await;
pub(crate) async fn shutdown(&self) {
// Stop tasks // Stop tasks
log_net!(debug "stopping routing table tasks");
self.cancel_tasks().await; self.cancel_tasks().await;
}
async fn pre_terminate_async(&self) {}
/// Called to shut down the routing table
async fn terminate_async(&self) {
log_rtab!(debug "starting routing table terminate");
// Load bucket entries from table db if possible // Load bucket entries from table db if possible
log_rtab!(debug "saving routing table entries"); log_rtab!(debug "saving routing table entries");
@ -365,11 +268,73 @@ impl RoutingTable {
log_rtab!(debug "shutting down routing table"); log_rtab!(debug "shutting down routing table");
let mut inner = self.inner.write(); let mut inner = self.inner.write();
*inner = RoutingTableInner::new(self.unlocked_inner.clone()); *inner = RoutingTableInner::new(self.registry());
log_rtab!(debug "finished routing table terminate"); log_rtab!(debug "finished routing table terminate");
} }
///////////////////////////////////////////////////////////////////
pub fn node_id(&self, kind: CryptoKind) -> TypedKey {
self.config()
.with(|c| c.network.routing_table.node_id.get(kind).unwrap())
}
pub fn node_id_secret_key(&self, kind: CryptoKind) -> SecretKey {
self.config()
.with(|c| c.network.routing_table.node_id_secret.get(kind).unwrap())
.value
}
pub fn node_ids(&self) -> TypedKeyGroup {
self.config()
.with(|c| c.network.routing_table.node_id.clone())
}
pub fn node_id_typed_key_pairs(&self) -> Vec<TypedKeyPair> {
let mut tkps = Vec::new();
for ck in VALID_CRYPTO_KINDS {
tkps.push(TypedKeyPair::new(
ck,
KeyPair::new(self.node_id(ck).value, self.node_id_secret_key(ck)),
));
}
tkps
}
pub fn matches_own_node_id(&self, node_ids: &[TypedKey]) -> bool {
for ni in node_ids {
if let Some(v) = self.node_ids().get(ni.kind) {
if v.value == ni.value {
return true;
}
}
}
false
}
pub fn matches_own_node_id_key(&self, node_id_key: &PublicKey) -> bool {
for tk in self.node_ids().iter() {
if tk.value == *node_id_key {
return true;
}
}
false
}
pub fn calculate_bucket_index(&self, node_id: &TypedKey) -> BucketIndex {
let crypto = self.crypto();
let self_node_id_key = self.node_id(node_id.kind).value;
let vcrypto = crypto.get(node_id.kind).unwrap();
(
node_id.kind,
vcrypto
.distance(&node_id.value, &self_node_id_key)
.first_nonzero_bit()
.unwrap(),
)
}
/// Serialize the routing table. /// Serialize the routing table.
fn serialized_buckets(&self) -> (SerializedBucketMap, SerializedBuckets) { fn serialized_buckets(&self) -> (SerializedBucketMap, SerializedBuckets) {
// Since entries are shared by multiple buckets per cryptokind // Since entries are shared by multiple buckets per cryptokind
@ -406,7 +371,7 @@ impl RoutingTable {
async fn save_buckets(&self) -> EyreResult<()> { async fn save_buckets(&self) -> EyreResult<()> {
let (serialized_bucket_map, all_entry_bytes) = self.serialized_buckets(); let (serialized_bucket_map, all_entry_bytes) = self.serialized_buckets();
let table_store = self.unlocked_inner.network_manager().table_store(); let table_store = self.table_store();
let tdb = table_store.open(ROUTING_TABLE, 1).await?; let tdb = table_store.open(ROUTING_TABLE, 1).await?;
let dbx = tdb.transact(); let dbx = tdb.transact();
if let Err(e) = dbx.store_json(0, SERIALIZED_BUCKET_MAP, &serialized_bucket_map) { if let Err(e) = dbx.store_json(0, SERIALIZED_BUCKET_MAP, &serialized_bucket_map) {
@ -420,12 +385,14 @@ impl RoutingTable {
dbx.commit().await?; dbx.commit().await?;
Ok(()) Ok(())
} }
/// Deserialize routing table from table store /// Deserialize routing table from table store
async fn load_buckets(&self) -> EyreResult<()> { async fn load_buckets(&self) -> EyreResult<()> {
// Make a cache validity key of all our node ids and our bootstrap choice // Make a cache validity key of all our node ids and our bootstrap choice
let mut cache_validity_key: Vec<u8> = Vec::new(); let mut cache_validity_key: Vec<u8> = Vec::new();
{ {
let c = self.unlocked_inner.config.get(); let config = self.config();
let c = config.get();
for ck in VALID_CRYPTO_KINDS { for ck in VALID_CRYPTO_KINDS {
if let Some(nid) = c.network.routing_table.node_id.get(ck) { if let Some(nid) = c.network.routing_table.node_id.get(ck) {
cache_validity_key.append(&mut nid.value.bytes.to_vec()); cache_validity_key.append(&mut nid.value.bytes.to_vec());
@ -446,7 +413,7 @@ impl RoutingTable {
}; };
// Deserialize bucket map and all entries from the table store // Deserialize bucket map and all entries from the table store
let table_store = self.unlocked_inner.network_manager().table_store(); let table_store = self.table_store();
let db = table_store.open(ROUTING_TABLE, 1).await?; let db = table_store.open(ROUTING_TABLE, 1).await?;
let caches_valid = match db.load(0, CACHE_VALIDITY_KEY).await? { let caches_valid = match db.load(0, CACHE_VALIDITY_KEY).await? {
@ -479,14 +446,13 @@ impl RoutingTable {
// Reconstruct all entries // Reconstruct all entries
let inner = &mut *self.inner.write(); let inner = &mut *self.inner.write();
self.populate_routing_table(inner, serialized_bucket_map, all_entry_bytes)?; Self::populate_routing_table_inner(inner, serialized_bucket_map, all_entry_bytes)?;
Ok(()) Ok(())
} }
/// Write the deserialized table store data to the routing table. /// Write the deserialized table store data to the routing table.
pub fn populate_routing_table( pub fn populate_routing_table_inner(
&self,
inner: &mut RoutingTableInner, inner: &mut RoutingTableInner,
serialized_bucket_map: SerializedBucketMap, serialized_bucket_map: SerializedBucketMap,
all_entry_bytes: SerializedBuckets, all_entry_bytes: SerializedBuckets,
@ -542,8 +508,8 @@ impl RoutingTable {
self.inner.read().routing_domain_for_address(address) self.inner.read().routing_domain_for_address(address)
} }
pub fn route_spec_store(&self) -> RouteSpecStore { pub fn route_spec_store(&self) -> &RouteSpecStore {
self.inner.read().route_spec_store.as_ref().unwrap().clone() &self.route_spec_store
} }
pub fn relay_node(&self, domain: RoutingDomain) -> Option<FilteredNodeRef> { pub fn relay_node(&self, domain: RoutingDomain) -> Option<FilteredNodeRef> {
@ -600,12 +566,12 @@ impl RoutingTable {
/// Edit the PublicInternet RoutingDomain /// Edit the PublicInternet RoutingDomain
pub fn edit_public_internet_routing_domain(&self) -> RoutingDomainEditorPublicInternet { pub fn edit_public_internet_routing_domain(&self) -> RoutingDomainEditorPublicInternet {
RoutingDomainEditorPublicInternet::new(self.clone()) RoutingDomainEditorPublicInternet::new(self)
} }
/// Edit the LocalNetwork RoutingDomain /// Edit the LocalNetwork RoutingDomain
pub fn edit_local_network_routing_domain(&self) -> RoutingDomainEditorLocalNetwork { pub fn edit_local_network_routing_domain(&self) -> RoutingDomainEditorLocalNetwork {
RoutingDomainEditorLocalNetwork::new(self.clone()) RoutingDomainEditorLocalNetwork::new(self)
} }
/// Return a copy of our node's peerinfo (may not yet be published) /// Return a copy of our node's peerinfo (may not yet be published)
@ -619,7 +585,7 @@ impl RoutingTable {
} }
/// Return the domain's currently registered network class /// Return the domain's currently registered network class
#[cfg_attr(target_arch = "wasm32", expect(dead_code))] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), expect(dead_code))]
pub fn get_network_class(&self, routing_domain: RoutingDomain) -> Option<NetworkClass> { pub fn get_network_class(&self, routing_domain: RoutingDomain) -> Option<NetworkClass> {
self.inner.read().get_network_class(routing_domain) self.inner.read().get_network_class(routing_domain)
} }
@ -656,7 +622,7 @@ impl RoutingTable {
) -> Vec<FilteredNodeRef> { ) -> Vec<FilteredNodeRef> {
self.inner self.inner
.read() .read()
.get_nodes_needing_ping(self.clone(), routing_domain, cur_ts) .get_nodes_needing_ping(routing_domain, cur_ts)
} }
fn queue_bucket_kicks(&self, node_ids: TypedKeyGroup) { fn queue_bucket_kicks(&self, node_ids: TypedKeyGroup) {
@ -667,21 +633,19 @@ impl RoutingTable {
} }
// Put it in the kick queue // Put it in the kick queue
let x = self.unlocked_inner.calculate_bucket_index(node_id); let x = self.calculate_bucket_index(node_id);
self.unlocked_inner.kick_queue.lock().insert(x); self.kick_queue.lock().insert(x);
} }
} }
/// Resolve an existing routing table entry using any crypto kind and return a reference to it /// Resolve an existing routing table entry using any crypto kind and return a reference to it
pub fn lookup_any_node_ref(&self, node_id_key: PublicKey) -> EyreResult<Option<NodeRef>> { pub fn lookup_any_node_ref(&self, node_id_key: PublicKey) -> EyreResult<Option<NodeRef>> {
self.inner self.inner.read().lookup_any_node_ref(node_id_key)
.read()
.lookup_any_node_ref(self.clone(), node_id_key)
} }
/// Resolve an existing routing table entry and return a reference to it /// Resolve an existing routing table entry and return a reference to it
pub fn lookup_node_ref(&self, node_id: TypedKey) -> EyreResult<Option<NodeRef>> { pub fn lookup_node_ref(&self, node_id: TypedKey) -> EyreResult<Option<NodeRef>> {
self.inner.read().lookup_node_ref(self.clone(), node_id) self.inner.read().lookup_node_ref(node_id)
} }
/// Resolve an existing routing table entry and return a filtered reference to it /// Resolve an existing routing table entry and return a filtered reference to it
@ -692,12 +656,9 @@ impl RoutingTable {
routing_domain_set: RoutingDomainSet, routing_domain_set: RoutingDomainSet,
dial_info_filter: DialInfoFilter, dial_info_filter: DialInfoFilter,
) -> EyreResult<Option<FilteredNodeRef>> { ) -> EyreResult<Option<FilteredNodeRef>> {
self.inner.read().lookup_and_filter_noderef( self.inner
self.clone(), .read()
node_id, .lookup_and_filter_noderef(node_id, routing_domain_set, dial_info_filter)
routing_domain_set,
dial_info_filter,
)
} }
/// Shortcut function to add a node to our routing table if it doesn't exist /// Shortcut function to add a node to our routing table if it doesn't exist
@ -711,7 +672,7 @@ impl RoutingTable {
) -> EyreResult<FilteredNodeRef> { ) -> EyreResult<FilteredNodeRef> {
self.inner self.inner
.write() .write()
.register_node_with_peer_info(self.clone(), peer_info, allow_invalid) .register_node_with_peer_info(peer_info, allow_invalid)
} }
/// Shortcut function to add a node to our routing table if it doesn't exist /// Shortcut function to add a node to our routing table if it doesn't exist
@ -726,7 +687,7 @@ impl RoutingTable {
) -> EyreResult<FilteredNodeRef> { ) -> EyreResult<FilteredNodeRef> {
self.inner self.inner
.write() .write()
.register_node_with_id(self.clone(), routing_domain, node_id, timestamp) .register_node_with_id(routing_domain, node_id, timestamp)
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -824,7 +785,7 @@ impl RoutingTable {
} }
/// Makes a filter that finds nodes with a matching inbound dialinfo /// Makes a filter that finds nodes with a matching inbound dialinfo
#[cfg_attr(target_arch = "wasm32", expect(dead_code))] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), expect(dead_code))]
pub fn make_inbound_dial_info_entry_filter<'a>( pub fn make_inbound_dial_info_entry_filter<'a>(
routing_domain: RoutingDomain, routing_domain: RoutingDomain,
dial_info_filter: DialInfoFilter, dial_info_filter: DialInfoFilter,
@ -885,7 +846,7 @@ impl RoutingTable {
filters: VecDeque<RoutingTableEntryFilter>, filters: VecDeque<RoutingTableEntryFilter>,
) -> Vec<NodeRef> { ) -> Vec<NodeRef> {
self.inner.read().find_fast_non_local_nodes_filtered( self.inner.read().find_fast_non_local_nodes_filtered(
self.clone(), self.registry(),
routing_domain, routing_domain,
node_count, node_count,
filters, filters,
@ -971,7 +932,7 @@ impl RoutingTable {
protocol_types_len * 2 * max_per_type, protocol_types_len * 2 * max_per_type,
filters, filters,
|_rti, entry: Option<Arc<BucketEntry>>| { |_rti, entry: Option<Arc<BucketEntry>>| {
NodeRef::new(self.clone(), entry.unwrap().clone()) NodeRef::new(self.registry(), entry.unwrap().clone())
}, },
) )
} }
@ -1073,7 +1034,6 @@ impl RoutingTable {
let res = network_result_try!( let res = network_result_try!(
rpc_processor rpc_processor
.clone()
.rpc_call_find_node( .rpc_call_find_node(
Destination::direct(node_ref.default_filtered()), Destination::direct(node_ref.default_filtered()),
node_id, node_id,
@ -1162,11 +1122,3 @@ impl RoutingTable {
} }
} }
} }
impl core::ops::Deref for RoutingTable {
type Target = RoutingTableUnlockedInner;
fn deref(&self) -> &Self::Target {
&self.unlocked_inner
}
}

View File

@ -1,7 +1,7 @@
use super::*; use super::*;
pub(crate) struct FilteredNodeRef { pub(crate) struct FilteredNodeRef {
routing_table: RoutingTable, registry: VeilidComponentRegistry,
entry: Arc<BucketEntry>, entry: Arc<BucketEntry>,
filter: NodeRefFilter, filter: NodeRefFilter,
sequencing: Sequencing, sequencing: Sequencing,
@ -9,9 +9,11 @@ pub(crate) struct FilteredNodeRef {
track_id: usize, track_id: usize,
} }
impl_veilid_component_registry_accessor!(FilteredNodeRef);
impl FilteredNodeRef { impl FilteredNodeRef {
pub fn new( pub fn new(
routing_table: RoutingTable, registry: VeilidComponentRegistry,
entry: Arc<BucketEntry>, entry: Arc<BucketEntry>,
filter: NodeRefFilter, filter: NodeRefFilter,
sequencing: Sequencing, sequencing: Sequencing,
@ -19,7 +21,7 @@ impl FilteredNodeRef {
entry.ref_count.fetch_add(1u32, Ordering::AcqRel); entry.ref_count.fetch_add(1u32, Ordering::AcqRel);
Self { Self {
routing_table, registry,
entry, entry,
filter, filter,
sequencing, sequencing,
@ -29,7 +31,7 @@ impl FilteredNodeRef {
} }
pub fn unfiltered(&self) -> NodeRef { pub fn unfiltered(&self) -> NodeRef {
NodeRef::new(self.routing_table.clone(), self.entry.clone()) NodeRef::new(self.registry(), self.entry.clone())
} }
pub fn filtered_clone(&self, filter: NodeRefFilter) -> FilteredNodeRef { pub fn filtered_clone(&self, filter: NodeRefFilter) -> FilteredNodeRef {
@ -40,7 +42,7 @@ impl FilteredNodeRef {
pub fn sequencing_clone(&self, sequencing: Sequencing) -> FilteredNodeRef { pub fn sequencing_clone(&self, sequencing: Sequencing) -> FilteredNodeRef {
FilteredNodeRef::new( FilteredNodeRef::new(
self.routing_table.clone(), self.registry.clone(),
self.entry.clone(), self.entry.clone(),
self.filter(), self.filter(),
sequencing, sequencing,
@ -70,9 +72,6 @@ impl FilteredNodeRef {
} }
impl NodeRefAccessorsTrait for FilteredNodeRef { impl NodeRefAccessorsTrait for FilteredNodeRef {
fn routing_table(&self) -> RoutingTable {
self.routing_table.clone()
}
fn entry(&self) -> Arc<BucketEntry> { fn entry(&self) -> Arc<BucketEntry> {
self.entry.clone() self.entry.clone()
} }
@ -105,7 +104,8 @@ impl NodeRefOperateTrait for FilteredNodeRef {
where where
F: FnOnce(&RoutingTableInner, &BucketEntryInner) -> T, F: FnOnce(&RoutingTableInner, &BucketEntryInner) -> T,
{ {
let inner = &*self.routing_table.inner.read(); let routing_table = self.registry.routing_table();
let inner = &*routing_table.inner.read();
self.entry.with(inner, f) self.entry.with(inner, f)
} }
@ -113,7 +113,8 @@ impl NodeRefOperateTrait for FilteredNodeRef {
where where
F: FnOnce(&mut RoutingTableInner, &mut BucketEntryInner) -> T, F: FnOnce(&mut RoutingTableInner, &mut BucketEntryInner) -> T,
{ {
let inner = &mut *self.routing_table.inner.write(); let routing_table = self.registry.routing_table();
let inner = &mut *routing_table.inner.write();
self.entry.with_mut(inner, f) self.entry.with_mut(inner, f)
} }
} }
@ -125,7 +126,7 @@ impl Clone for FilteredNodeRef {
self.entry.ref_count.fetch_add(1u32, Ordering::AcqRel); self.entry.ref_count.fetch_add(1u32, Ordering::AcqRel);
Self { Self {
routing_table: self.routing_table.clone(), registry: self.registry.clone(),
entry: self.entry.clone(), entry: self.entry.clone(),
filter: self.filter, filter: self.filter,
sequencing: self.sequencing, sequencing: self.sequencing,
@ -162,7 +163,7 @@ impl Drop for FilteredNodeRef {
// get node ids with inner unlocked because nothing could be referencing this entry now // get node ids with inner unlocked because nothing could be referencing this entry now
// and we don't know when it will get dropped, possibly inside a lock // and we don't know when it will get dropped, possibly inside a lock
let node_ids = self.entry.with_inner(|e| e.node_ids()); let node_ids = self.entry.with_inner(|e| e.node_ids());
self.routing_table.queue_bucket_kicks(node_ids); self.routing_table().queue_bucket_kicks(node_ids);
} }
} }
} }

View File

@ -16,18 +16,20 @@ pub(crate) use traits::*;
// Default NodeRef // Default NodeRef
pub(crate) struct NodeRef { pub(crate) struct NodeRef {
routing_table: RoutingTable, registry: VeilidComponentRegistry,
entry: Arc<BucketEntry>, entry: Arc<BucketEntry>,
#[cfg(feature = "tracking")] #[cfg(feature = "tracking")]
track_id: usize, track_id: usize,
} }
impl_veilid_component_registry_accessor!(NodeRef);
impl NodeRef { impl NodeRef {
pub fn new(routing_table: RoutingTable, entry: Arc<BucketEntry>) -> Self { pub fn new(registry: VeilidComponentRegistry, entry: Arc<BucketEntry>) -> Self {
entry.ref_count.fetch_add(1u32, Ordering::AcqRel); entry.ref_count.fetch_add(1u32, Ordering::AcqRel);
Self { Self {
routing_table, registry,
entry, entry,
#[cfg(feature = "tracking")] #[cfg(feature = "tracking")]
track_id: entry.track(), track_id: entry.track(),
@ -36,7 +38,7 @@ impl NodeRef {
pub fn default_filtered(&self) -> FilteredNodeRef { pub fn default_filtered(&self) -> FilteredNodeRef {
FilteredNodeRef::new( FilteredNodeRef::new(
self.routing_table.clone(), self.registry.clone(),
self.entry.clone(), self.entry.clone(),
NodeRefFilter::new(), NodeRefFilter::new(),
Sequencing::default(), Sequencing::default(),
@ -45,7 +47,7 @@ impl NodeRef {
pub fn sequencing_filtered(&self, sequencing: Sequencing) -> FilteredNodeRef { pub fn sequencing_filtered(&self, sequencing: Sequencing) -> FilteredNodeRef {
FilteredNodeRef::new( FilteredNodeRef::new(
self.routing_table.clone(), self.registry.clone(),
self.entry.clone(), self.entry.clone(),
NodeRefFilter::new(), NodeRefFilter::new(),
sequencing, sequencing,
@ -57,7 +59,7 @@ impl NodeRef {
routing_domain_set: R, routing_domain_set: R,
) -> FilteredNodeRef { ) -> FilteredNodeRef {
FilteredNodeRef::new( FilteredNodeRef::new(
self.routing_table.clone(), self.registry.clone(),
self.entry.clone(), self.entry.clone(),
NodeRefFilter::new().with_routing_domain_set(routing_domain_set.into()), NodeRefFilter::new().with_routing_domain_set(routing_domain_set.into()),
Sequencing::default(), Sequencing::default(),
@ -66,7 +68,7 @@ impl NodeRef {
pub fn custom_filtered(&self, filter: NodeRefFilter) -> FilteredNodeRef { pub fn custom_filtered(&self, filter: NodeRefFilter) -> FilteredNodeRef {
FilteredNodeRef::new( FilteredNodeRef::new(
self.routing_table.clone(), self.registry.clone(),
self.entry.clone(), self.entry.clone(),
filter, filter,
Sequencing::default(), Sequencing::default(),
@ -76,7 +78,7 @@ impl NodeRef {
#[expect(dead_code)] #[expect(dead_code)]
pub fn dial_info_filtered(&self, filter: DialInfoFilter) -> FilteredNodeRef { pub fn dial_info_filtered(&self, filter: DialInfoFilter) -> FilteredNodeRef {
FilteredNodeRef::new( FilteredNodeRef::new(
self.routing_table.clone(), self.registry.clone(),
self.entry.clone(), self.entry.clone(),
NodeRefFilter::new().with_dial_info_filter(filter), NodeRefFilter::new().with_dial_info_filter(filter),
Sequencing::default(), Sequencing::default(),
@ -92,9 +94,6 @@ impl NodeRef {
} }
impl NodeRefAccessorsTrait for NodeRef { impl NodeRefAccessorsTrait for NodeRef {
fn routing_table(&self) -> RoutingTable {
self.routing_table.clone()
}
fn entry(&self) -> Arc<BucketEntry> { fn entry(&self) -> Arc<BucketEntry> {
self.entry.clone() self.entry.clone()
} }
@ -125,7 +124,8 @@ impl NodeRefOperateTrait for NodeRef {
where where
F: FnOnce(&RoutingTableInner, &BucketEntryInner) -> T, F: FnOnce(&RoutingTableInner, &BucketEntryInner) -> T,
{ {
let inner = &*self.routing_table.inner.read(); let routing_table = self.routing_table();
let inner = &*routing_table.inner.read();
self.entry.with(inner, f) self.entry.with(inner, f)
} }
@ -133,7 +133,8 @@ impl NodeRefOperateTrait for NodeRef {
where where
F: FnOnce(&mut RoutingTableInner, &mut BucketEntryInner) -> T, F: FnOnce(&mut RoutingTableInner, &mut BucketEntryInner) -> T,
{ {
let inner = &mut *self.routing_table.inner.write(); let routing_table = self.routing_table();
let inner = &mut *routing_table.inner.write();
self.entry.with_mut(inner, f) self.entry.with_mut(inner, f)
} }
} }
@ -145,7 +146,7 @@ impl Clone for NodeRef {
self.entry.ref_count.fetch_add(1u32, Ordering::AcqRel); self.entry.ref_count.fetch_add(1u32, Ordering::AcqRel);
Self { Self {
routing_table: self.routing_table.clone(), registry: self.registry.clone(),
entry: self.entry.clone(), entry: self.entry.clone(),
#[cfg(feature = "tracking")] #[cfg(feature = "tracking")]
track_id: self.entry.write().track(), track_id: self.entry.write().track(),
@ -178,7 +179,7 @@ impl Drop for NodeRef {
// get node ids with inner unlocked because nothing could be referencing this entry now // get node ids with inner unlocked because nothing could be referencing this entry now
// and we don't know when it will get dropped, possibly inside a lock // and we don't know when it will get dropped, possibly inside a lock
let node_ids = self.entry.with_inner(|e| e.node_ids()); let node_ids = self.entry.with_inner(|e| e.node_ids());
self.routing_table.queue_bucket_kicks(node_ids); self.routing_table().queue_bucket_kicks(node_ids);
} }
} }
} }

View File

@ -15,6 +15,21 @@ pub(crate) struct NodeRefLock<
nr: N, nr: N,
} }
impl<
'a,
N: NodeRefAccessorsTrait
+ NodeRefOperateTrait
+ VeilidComponentRegistryAccessor
+ fmt::Debug
+ fmt::Display
+ Clone,
> VeilidComponentRegistryAccessor for NodeRefLock<'a, N>
{
fn registry(&self) -> VeilidComponentRegistry {
self.nr.registry()
}
}
impl<'a, N: NodeRefAccessorsTrait + NodeRefOperateTrait + fmt::Debug + fmt::Display + Clone> impl<'a, N: NodeRefAccessorsTrait + NodeRefOperateTrait + fmt::Debug + fmt::Display + Clone>
NodeRefLock<'a, N> NodeRefLock<'a, N>
{ {
@ -33,9 +48,6 @@ impl<'a, N: NodeRefAccessorsTrait + NodeRefOperateTrait + fmt::Debug + fmt::Disp
impl<'a, N: NodeRefAccessorsTrait + NodeRefOperateTrait + fmt::Debug + fmt::Display + Clone> impl<'a, N: NodeRefAccessorsTrait + NodeRefOperateTrait + fmt::Debug + fmt::Display + Clone>
NodeRefAccessorsTrait for NodeRefLock<'a, N> NodeRefAccessorsTrait for NodeRefLock<'a, N>
{ {
fn routing_table(&self) -> RoutingTable {
self.nr.routing_table()
}
fn entry(&self) -> Arc<BucketEntry> { fn entry(&self) -> Arc<BucketEntry> {
self.nr.entry() self.nr.entry()
} }

View File

@ -15,6 +15,21 @@ pub(crate) struct NodeRefLockMut<
nr: N, nr: N,
} }
impl<
'a,
N: NodeRefAccessorsTrait
+ NodeRefOperateTrait
+ VeilidComponentRegistryAccessor
+ fmt::Debug
+ fmt::Display
+ Clone,
> VeilidComponentRegistryAccessor for NodeRefLockMut<'a, N>
{
fn registry(&self) -> VeilidComponentRegistry {
self.nr.registry()
}
}
impl<'a, N: NodeRefAccessorsTrait + NodeRefOperateTrait + fmt::Debug + fmt::Display + Clone> impl<'a, N: NodeRefAccessorsTrait + NodeRefOperateTrait + fmt::Debug + fmt::Display + Clone>
NodeRefLockMut<'a, N> NodeRefLockMut<'a, N>
{ {
@ -34,9 +49,6 @@ impl<'a, N: NodeRefAccessorsTrait + NodeRefOperateTrait + fmt::Debug + fmt::Disp
impl<'a, N: NodeRefAccessorsTrait + NodeRefOperateTrait + fmt::Debug + fmt::Display + Clone> impl<'a, N: NodeRefAccessorsTrait + NodeRefOperateTrait + fmt::Debug + fmt::Display + Clone>
NodeRefAccessorsTrait for NodeRefLockMut<'a, N> NodeRefAccessorsTrait for NodeRefLockMut<'a, N>
{ {
fn routing_table(&self) -> RoutingTable {
self.nr.routing_table()
}
fn entry(&self) -> Arc<BucketEntry> { fn entry(&self) -> Arc<BucketEntry> {
self.nr.entry() self.nr.entry()
} }

View File

@ -2,7 +2,6 @@ use super::*;
// Field accessors // Field accessors
pub(crate) trait NodeRefAccessorsTrait { pub(crate) trait NodeRefAccessorsTrait {
fn routing_table(&self) -> RoutingTable;
fn entry(&self) -> Arc<BucketEntry>; fn entry(&self) -> Arc<BucketEntry>;
fn sequencing(&self) -> Sequencing; fn sequencing(&self) -> Sequencing;
fn routing_domain_set(&self) -> RoutingDomainSet; fn routing_domain_set(&self) -> RoutingDomainSet;
@ -125,12 +124,12 @@ pub(crate) trait NodeRefCommonTrait: NodeRefAccessorsTrait + NodeRefOperateTrait
}; };
// If relay is ourselves, then return None, because we can't relay through ourselves // If relay is ourselves, then return None, because we can't relay through ourselves
// and to contact this node we should have had an existing inbound connection // and to contact this node we should have had an existing inbound connection
if rti.unlocked_inner.matches_own_node_id(rpi.node_ids()) { if rti.routing_table().matches_own_node_id(rpi.node_ids()) {
bail!("Can't relay though ourselves"); bail!("Can't relay though ourselves");
} }
// Register relay node and return noderef // Register relay node and return noderef
let nr = rti.register_node_with_peer_info(self.routing_table(), rpi, false)?; let nr = rti.register_node_with_peer_info(rpi, false)?;
Ok(Some(nr)) Ok(Some(nr))
}) })
} }
@ -253,7 +252,7 @@ pub(crate) trait NodeRefCommonTrait: NodeRefAccessorsTrait + NodeRefOperateTrait
else { else {
return false; return false;
}; };
let our_node_ids = rti.unlocked_inner.node_ids(); let our_node_ids = rti.routing_table().node_ids();
our_node_ids.contains_any(&relay_ids) our_node_ids.contains_any(&relay_ids)
}) })
} }

View File

@ -31,7 +31,7 @@ pub(crate) enum RouteNode {
} }
impl RouteNode { impl RouteNode {
pub fn validate(&self, crypto: Crypto) -> VeilidAPIResult<()> { pub fn validate(&self, crypto: &Crypto) -> VeilidAPIResult<()> {
match self { match self {
RouteNode::NodeId(_) => Ok(()), RouteNode::NodeId(_) => Ok(()),
RouteNode::PeerInfo(pi) => pi.validate(crypto), RouteNode::PeerInfo(pi) => pi.validate(crypto),
@ -40,7 +40,7 @@ impl RouteNode {
pub fn node_ref( pub fn node_ref(
&self, &self,
routing_table: RoutingTable, routing_table: &RoutingTable,
crypto_kind: CryptoKind, crypto_kind: CryptoKind,
) -> Option<NodeRef> { ) -> Option<NodeRef> {
match self { match self {
@ -91,7 +91,7 @@ pub(crate) struct RouteHop {
pub next_hop: Option<RouteHopData>, pub next_hop: Option<RouteHopData>,
} }
impl RouteHop { impl RouteHop {
pub fn validate(&self, crypto: Crypto) -> VeilidAPIResult<()> { pub fn validate(&self, crypto: &Crypto) -> VeilidAPIResult<()> {
self.node.validate(crypto) self.node.validate(crypto)
} }
} }
@ -108,7 +108,7 @@ pub(crate) enum PrivateRouteHops {
} }
impl PrivateRouteHops { impl PrivateRouteHops {
pub fn validate(&self, crypto: Crypto) -> VeilidAPIResult<()> { pub fn validate(&self, crypto: &Crypto) -> VeilidAPIResult<()> {
match self { match self {
PrivateRouteHops::FirstHop(rh) => rh.validate(crypto), PrivateRouteHops::FirstHop(rh) => rh.validate(crypto),
PrivateRouteHops::Data(_) => Ok(()), PrivateRouteHops::Data(_) => Ok(()),
@ -138,7 +138,7 @@ impl PrivateRoute {
} }
} }
pub fn validate(&self, crypto: Crypto) -> VeilidAPIResult<()> { pub fn validate(&self, crypto: &Crypto) -> VeilidAPIResult<()> {
self.hops.validate(crypto) self.hops.validate(crypto)
} }

View File

@ -34,85 +34,71 @@ struct RouteSpecStoreInner {
cache: RouteSpecStoreCache, cache: RouteSpecStoreCache,
} }
struct RouteSpecStoreUnlockedInner { /// The routing table's storage for private/safety routes
/// Handle to routing table #[derive(Debug)]
routing_table: RoutingTable, pub(crate) struct RouteSpecStore {
registry: VeilidComponentRegistry,
inner: Mutex<RouteSpecStoreInner>,
/// Maximum number of hops in a route /// Maximum number of hops in a route
max_route_hop_count: usize, max_route_hop_count: usize,
/// Default number of hops in a route /// Default number of hops in a route
default_route_hop_count: usize, default_route_hop_count: usize,
} }
impl fmt::Debug for RouteSpecStoreUnlockedInner { impl_veilid_component_registry_accessor!(RouteSpecStore);
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RouteSpecStoreUnlockedInner")
.field("max_route_hop_count", &self.max_route_hop_count)
.field("default_route_hop_count", &self.default_route_hop_count)
.finish()
}
}
/// The routing table's storage for private/safety routes
#[derive(Clone, Debug)]
pub(crate) struct RouteSpecStore {
inner: Arc<Mutex<RouteSpecStoreInner>>,
unlocked_inner: Arc<RouteSpecStoreUnlockedInner>,
}
impl RouteSpecStore { impl RouteSpecStore {
pub fn new(routing_table: RoutingTable) -> Self { pub fn new(registry: VeilidComponentRegistry) -> Self {
let config = routing_table.network_manager().config(); let config = registry.config();
let c = config.get(); let c = config.get();
Self { Self {
unlocked_inner: Arc::new(RouteSpecStoreUnlockedInner { registry,
max_route_hop_count: c.network.rpc.max_route_hop_count.into(), inner: Mutex::new(RouteSpecStoreInner {
default_route_hop_count: c.network.rpc.default_route_hop_count.into(),
routing_table,
}),
inner: Arc::new(Mutex::new(RouteSpecStoreInner {
content: RouteSpecStoreContent::new(), content: RouteSpecStoreContent::new(),
cache: Default::default(), cache: Default::default(),
})), }),
max_route_hop_count: c.network.rpc.max_route_hop_count.into(),
default_route_hop_count: c.network.rpc.default_route_hop_count.into(),
} }
} }
#[instrument(level = "trace", target = "route", skip(routing_table), err)] #[instrument(level = "trace", target = "route", skip_all)]
pub async fn load(routing_table: RoutingTable) -> EyreResult<RouteSpecStore> { pub fn reset(&self) {
let (max_route_hop_count, default_route_hop_count) = { *self.inner.lock() = RouteSpecStoreInner {
let config = routing_table.network_manager().config(); content: RouteSpecStoreContent::new(),
let c = config.get();
(
c.network.rpc.max_route_hop_count as usize,
c.network.rpc.default_route_hop_count as usize,
)
};
// Get frozen blob from table store
let content = RouteSpecStoreContent::load(routing_table.clone()).await?;
let mut inner = RouteSpecStoreInner {
content,
cache: Default::default(), cache: Default::default(),
}; };
}
// Rebuild the routespecstore cache #[instrument(level = "trace", target = "route", skip_all, err)]
let rti = &*routing_table.inner.read(); pub async fn load(&self) -> EyreResult<()> {
for (_, rssd) in inner.content.iter_details() { let inner = {
inner.cache.add_to_cache(rti, rssd); let table_store = self.table_store();
} let routing_table = self.routing_table();
// Return the loaded RouteSpecStore // Get frozen blob from table store
let rss = RouteSpecStore { let content = RouteSpecStoreContent::load(&table_store, &routing_table).await?;
unlocked_inner: Arc::new(RouteSpecStoreUnlockedInner {
max_route_hop_count, let mut inner = RouteSpecStoreInner {
default_route_hop_count, content,
routing_table: routing_table.clone(), cache: Default::default(),
}), };
inner: Arc::new(Mutex::new(inner)),
// Rebuild the routespecstore cache
let rti = &*routing_table.inner.read();
for (_, rssd) in inner.content.iter_details() {
inner.cache.add_to_cache(rti, rssd);
}
inner
}; };
Ok(rss) // Return the loaded RouteSpecStore
*self.inner.lock() = inner;
Ok(())
} }
#[instrument(level = "trace", target = "route", skip(self), err)] #[instrument(level = "trace", target = "route", skip(self), err)]
@ -123,9 +109,8 @@ impl RouteSpecStore {
}; };
// Save our content // Save our content
content let table_store = self.table_store();
.save(self.unlocked_inner.routing_table.clone()) content.save(&table_store).await?;
.await?;
Ok(()) Ok(())
} }
@ -146,16 +131,17 @@ impl RouteSpecStore {
dead_remote_routes, dead_remote_routes,
})); }));
let update_callback = self.unlocked_inner.routing_table.update_callback(); let update_callback = self.registry.update_callback();
update_callback(update); update_callback(update);
} }
/// Purge the route spec store /// Purge the route spec store
pub async fn purge(&self) -> VeilidAPIResult<()> { pub async fn purge(&self) -> VeilidAPIResult<()> {
// Briefly pause routing table ticker while changes are made // Briefly pause routing table ticker while changes are made
let _tick_guard = self.unlocked_inner.routing_table.pause_tasks().await; let routing_table = self.routing_table();
self.unlocked_inner.routing_table.cancel_tasks().await;
let _tick_guard = routing_table.pause_tasks().await;
routing_table.cancel_tasks().await;
{ {
let inner = &mut *self.inner.lock(); let inner = &mut *self.inner.lock();
inner.content = Default::default(); inner.content = Default::default();
@ -181,7 +167,7 @@ impl RouteSpecStore {
automatic: bool, automatic: bool,
) -> VeilidAPIResult<RouteId> { ) -> VeilidAPIResult<RouteId> {
let inner = &mut *self.inner.lock(); let inner = &mut *self.inner.lock();
let routing_table = self.unlocked_inner.routing_table.clone(); let routing_table = self.routing_table();
let rti = &mut *routing_table.inner.write(); let rti = &mut *routing_table.inner.write();
self.allocate_route_inner( self.allocate_route_inner(
@ -213,12 +199,10 @@ impl RouteSpecStore {
apibail_generic!("safety_spec.preferred_route must be empty when allocating new route"); apibail_generic!("safety_spec.preferred_route must be empty when allocating new route");
} }
let ip6_prefix_size = rti let ip6_prefix_size = self
.unlocked_inner .registry()
.config .config()
.get() .with(|c| c.network.max_connections_per_ip6_prefix_size as usize);
.network
.max_connections_per_ip6_prefix_size as usize;
if safety_spec.hop_count < 1 { if safety_spec.hop_count < 1 {
apibail_invalid_argument!( apibail_invalid_argument!(
@ -228,7 +212,7 @@ impl RouteSpecStore {
); );
} }
if safety_spec.hop_count > self.unlocked_inner.max_route_hop_count { if safety_spec.hop_count > self.max_route_hop_count {
apibail_invalid_argument!( apibail_invalid_argument!(
"Not allocating route longer than max route hop count", "Not allocating route longer than max route hop count",
"hop_count", "hop_count",
@ -492,9 +476,8 @@ impl RouteSpecStore {
}) })
}; };
let routing_table = self.unlocked_inner.routing_table.clone();
let transform = |_rti: &RoutingTableInner, entry: Option<Arc<BucketEntry>>| -> NodeRef { let transform = |_rti: &RoutingTableInner, entry: Option<Arc<BucketEntry>>| -> NodeRef {
NodeRef::new(routing_table.clone(), entry.unwrap()) NodeRef::new(self.registry(), entry.unwrap())
}; };
// Pull the whole routing table in sorted order // Pull the whole routing table in sorted order
@ -667,13 +650,9 @@ impl RouteSpecStore {
// Got a unique route, lets build the details, register it, and return it // Got a unique route, lets build the details, register it, and return it
let hop_node_refs: Vec<NodeRef> = route_nodes.iter().map(|k| nodes[*k].clone()).collect(); let hop_node_refs: Vec<NodeRef> = route_nodes.iter().map(|k| nodes[*k].clone()).collect();
let mut route_set = BTreeMap::<PublicKey, RouteSpecDetail>::new(); let mut route_set = BTreeMap::<PublicKey, RouteSpecDetail>::new();
let crypto = self.crypto();
for crypto_kind in crypto_kinds.iter().copied() { for crypto_kind in crypto_kinds.iter().copied() {
let vcrypto = self let vcrypto = crypto.get(crypto_kind).unwrap();
.unlocked_inner
.routing_table
.crypto()
.get(crypto_kind)
.unwrap();
let keypair = vcrypto.generate_keypair(); let keypair = vcrypto.generate_keypair();
let hops: Vec<PublicKey> = route_nodes let hops: Vec<PublicKey> = route_nodes
.iter() .iter()
@ -734,7 +713,7 @@ impl RouteSpecStore {
R: fmt::Debug, R: fmt::Debug,
{ {
let inner = &*self.inner.lock(); let inner = &*self.inner.lock();
let crypto = self.unlocked_inner.routing_table.crypto(); let crypto = self.crypto();
let Some(vcrypto) = crypto.get(public_key.kind) else { let Some(vcrypto) = crypto.get(public_key.kind) else {
log_rpc!(debug "can't handle route with public key: {:?}", public_key); log_rpc!(debug "can't handle route with public key: {:?}", public_key);
return None; return None;
@ -852,7 +831,7 @@ impl RouteSpecStore {
}; };
// Test with double-round trip ping to self // Test with double-round trip ping to self
let rpc_processor = self.unlocked_inner.routing_table.rpc_processor(); let rpc_processor = self.rpc_processor();
let _res = match rpc_processor.rpc_call_status(dest).await? { let _res = match rpc_processor.rpc_call_status(dest).await? {
NetworkResult::Value(v) => v, NetworkResult::Value(v) => v,
_ => { _ => {
@ -886,7 +865,7 @@ impl RouteSpecStore {
// Get a safety route that is good enough // Get a safety route that is good enough
let safety_spec = SafetySpec { let safety_spec = SafetySpec {
preferred_route: None, preferred_route: None,
hop_count: self.unlocked_inner.default_route_hop_count, hop_count: self.default_route_hop_count,
stability, stability,
sequencing, sequencing,
}; };
@ -900,8 +879,7 @@ impl RouteSpecStore {
}; };
// Test with double-round trip ping to self // Test with double-round trip ping to self
let rpc_processor = self.unlocked_inner.routing_table.rpc_processor(); let _res = match self.rpc_processor().rpc_call_status(dest).await? {
let _res = match rpc_processor.rpc_call_status(dest).await? {
NetworkResult::Value(v) => v, NetworkResult::Value(v) => v,
_ => { _ => {
// Did not error, but did not come back, just return false // Did not error, but did not come back, just return false
@ -921,7 +899,8 @@ impl RouteSpecStore {
}; };
// Remove from hop cache // Remove from hop cache
let rti = &*self.unlocked_inner.routing_table.inner.read(); let routing_table = self.routing_table();
let rti = &*routing_table.inner.read();
if !inner.cache.remove_from_cache(rti, id, &rssd) { if !inner.cache.remove_from_cache(rti, id, &rssd) {
panic!("hop cache should have contained cache key"); panic!("hop cache should have contained cache key");
} }
@ -1097,7 +1076,7 @@ impl RouteSpecStore {
) -> VeilidAPIResult<CompiledRoute> { ) -> VeilidAPIResult<CompiledRoute> {
// let profile_start_ts = get_timestamp(); // let profile_start_ts = get_timestamp();
let inner = &mut *self.inner.lock(); let inner = &mut *self.inner.lock();
let routing_table = self.unlocked_inner.routing_table.clone(); let routing_table = self.routing_table();
let rti = &mut *routing_table.inner.write(); let rti = &mut *routing_table.inner.write();
// Get useful private route properties // Get useful private route properties
@ -1108,7 +1087,7 @@ impl RouteSpecStore {
}; };
let pr_pubkey = private_route.public_key.value; let pr_pubkey = private_route.public_key.value;
let pr_hopcount = private_route.hop_count as usize; let pr_hopcount = private_route.hop_count as usize;
let max_route_hop_count = self.unlocked_inner.max_route_hop_count; let max_route_hop_count = self.max_route_hop_count;
// Check private route hop count isn't larger than the max route hop count plus one for the 'first hop' header // Check private route hop count isn't larger than the max route hop count plus one for the 'first hop' header
if pr_hopcount > (max_route_hop_count + 1) { if pr_hopcount > (max_route_hop_count + 1) {
@ -1130,10 +1109,10 @@ impl RouteSpecStore {
let opt_first_hop = match pr_first_hop_node { let opt_first_hop = match pr_first_hop_node {
RouteNode::NodeId(id) => rti RouteNode::NodeId(id) => rti
.lookup_node_ref(routing_table.clone(), TypedKey::new(crypto_kind, id)) .lookup_node_ref(TypedKey::new(crypto_kind, id))
.map_err(VeilidAPIError::internal)?, .map_err(VeilidAPIError::internal)?,
RouteNode::PeerInfo(pi) => Some( RouteNode::PeerInfo(pi) => Some(
rti.register_node_with_peer_info(routing_table.clone(), pi, false) rti.register_node_with_peer_info(pi, false)
.map_err(VeilidAPIError::internal)? .map_err(VeilidAPIError::internal)?
.unfiltered(), .unfiltered(),
), ),
@ -1362,7 +1341,7 @@ impl RouteSpecStore {
avoid_nodes: &[TypedKey], avoid_nodes: &[TypedKey],
) -> VeilidAPIResult<PublicKey> { ) -> VeilidAPIResult<PublicKey> {
// Ensure the total hop count isn't too long for our config // Ensure the total hop count isn't too long for our config
let max_route_hop_count = self.unlocked_inner.max_route_hop_count; let max_route_hop_count = self.max_route_hop_count;
if safety_spec.hop_count == 0 { if safety_spec.hop_count == 0 {
apibail_invalid_argument!( apibail_invalid_argument!(
"safety route hop count is zero", "safety route hop count is zero",
@ -1438,7 +1417,7 @@ impl RouteSpecStore {
avoid_nodes: &[TypedKey], avoid_nodes: &[TypedKey],
) -> VeilidAPIResult<PublicKey> { ) -> VeilidAPIResult<PublicKey> {
let inner = &mut *self.inner.lock(); let inner = &mut *self.inner.lock();
let routing_table = self.unlocked_inner.routing_table.clone(); let routing_table = self.routing_table();
let rti = &mut *routing_table.inner.write(); let rti = &mut *routing_table.inner.write();
self.get_route_for_safety_spec_inner( self.get_route_for_safety_spec_inner(
@ -1457,7 +1436,7 @@ impl RouteSpecStore {
rsd: &RouteSpecDetail, rsd: &RouteSpecDetail,
optimized: bool, optimized: bool,
) -> VeilidAPIResult<PrivateRoute> { ) -> VeilidAPIResult<PrivateRoute> {
let routing_table = self.unlocked_inner.routing_table.clone(); let routing_table = self.routing_table();
let rti = &*routing_table.inner.read(); let rti = &*routing_table.inner.read();
// Ensure we get the crypto for it // Ensure we get the crypto for it
@ -1732,8 +1711,7 @@ impl RouteSpecStore {
cur_ts: Timestamp, cur_ts: Timestamp,
) -> VeilidAPIResult<()> { ) -> VeilidAPIResult<()> {
let Some(our_node_info_ts) = self let Some(our_node_info_ts) = self
.unlocked_inner .routing_table()
.routing_table
.get_published_peer_info(RoutingDomain::PublicInternet) .get_published_peer_info(RoutingDomain::PublicInternet)
.map(|pi| pi.signed_node_info().timestamp()) .map(|pi| pi.signed_node_info().timestamp())
else { else {
@ -1767,11 +1745,7 @@ impl RouteSpecStore {
let inner = &mut *self.inner.lock(); let inner = &mut *self.inner.lock();
// Check for stub route // Check for stub route
if self if self.routing_table().matches_own_node_id_key(key) {
.unlocked_inner
.routing_table
.matches_own_node_id_key(key)
{
return None; return None;
} }
@ -1869,7 +1843,7 @@ impl RouteSpecStore {
/// Convert binary blob to private route vector /// Convert binary blob to private route vector
pub fn blob_to_private_routes(&self, blob: Vec<u8>) -> VeilidAPIResult<Vec<PrivateRoute>> { pub fn blob_to_private_routes(&self, blob: Vec<u8>) -> VeilidAPIResult<Vec<PrivateRoute>> {
// Get crypto // Get crypto
let crypto = self.unlocked_inner.routing_table.crypto(); let crypto = self.crypto();
// Deserialize count // Deserialize count
if blob.is_empty() { if blob.is_empty() {
@ -1904,7 +1878,7 @@ impl RouteSpecStore {
let private_route = decode_private_route(&decode_context, &pr_reader).map_err(|e| { let private_route = decode_private_route(&decode_context, &pr_reader).map_err(|e| {
VeilidAPIError::invalid_argument("failed to decode private route", "e", e) VeilidAPIError::invalid_argument("failed to decode private route", "e", e)
})?; })?;
private_route.validate(crypto.clone()).map_err(|e| { private_route.validate(&crypto).map_err(|e| {
VeilidAPIError::invalid_argument("failed to validate private route", "e", e) VeilidAPIError::invalid_argument("failed to validate private route", "e", e)
})?; })?;
@ -1920,7 +1894,7 @@ impl RouteSpecStore {
/// Generate RouteId from typed key set of route public keys /// Generate RouteId from typed key set of route public keys
fn generate_allocated_route_id(&self, rssd: &RouteSetSpecDetail) -> VeilidAPIResult<RouteId> { fn generate_allocated_route_id(&self, rssd: &RouteSetSpecDetail) -> VeilidAPIResult<RouteId> {
let route_set_keys = rssd.get_route_set_keys(); let route_set_keys = rssd.get_route_set_keys();
let crypto = self.unlocked_inner.routing_table.crypto(); let crypto = self.crypto();
let mut idbytes = Vec::with_capacity(PUBLIC_KEY_LENGTH * route_set_keys.len()); let mut idbytes = Vec::with_capacity(PUBLIC_KEY_LENGTH * route_set_keys.len());
let mut best_kind: Option<CryptoKind> = None; let mut best_kind: Option<CryptoKind> = None;
@ -1945,7 +1919,7 @@ impl RouteSpecStore {
&self, &self,
private_routes: &[PrivateRoute], private_routes: &[PrivateRoute],
) -> VeilidAPIResult<RouteId> { ) -> VeilidAPIResult<RouteId> {
let crypto = self.unlocked_inner.routing_table.crypto(); let crypto = self.crypto();
let mut idbytes = Vec::with_capacity(PUBLIC_KEY_LENGTH * private_routes.len()); let mut idbytes = Vec::with_capacity(PUBLIC_KEY_LENGTH * private_routes.len());
let mut best_kind: Option<CryptoKind> = None; let mut best_kind: Option<CryptoKind> = None;

View File

@ -17,9 +17,11 @@ impl RouteSpecStoreContent {
} }
} }
pub async fn load(routing_table: RoutingTable) -> EyreResult<RouteSpecStoreContent> { pub async fn load(
table_store: &TableStore,
routing_table: &RoutingTable,
) -> EyreResult<RouteSpecStoreContent> {
// Deserialize what we can // Deserialize what we can
let table_store = routing_table.network_manager().table_store();
let rsstdb = table_store.open("RouteSpecStore", 1).await?; let rsstdb = table_store.open("RouteSpecStore", 1).await?;
let mut content: RouteSpecStoreContent = let mut content: RouteSpecStoreContent =
rsstdb.load_json(0, b"content").await?.unwrap_or_default(); rsstdb.load_json(0, b"content").await?.unwrap_or_default();
@ -59,10 +61,9 @@ impl RouteSpecStoreContent {
Ok(content) Ok(content)
} }
pub async fn save(&self, routing_table: RoutingTable) -> EyreResult<()> { pub async fn save(&self, table_store: &TableStore) -> EyreResult<()> {
// Save all the fields we care about to the frozen blob in table storage // Save all the fields we care about to the frozen blob in table storage
// This skips #[with(Skip)] saving the secret keys, we save them in the protected store instead // This skips #[with(Skip)] saving the secret keys, we save them in the protected store instead
let table_store = routing_table.network_manager().table_store();
let rsstdb = table_store.open("RouteSpecStore", 1).await?; let rsstdb = table_store.open("RouteSpecStore", 1).await?;
rsstdb.store_json(0, b"content", self).await?; rsstdb.store_json(0, b"content", self).await?;

View File

@ -15,8 +15,8 @@ pub type EntryCounts = BTreeMap<(RoutingDomain, CryptoKind), usize>;
/// RoutingTable rwlock-internal data /// RoutingTable rwlock-internal data
pub struct RoutingTableInner { pub struct RoutingTableInner {
/// Extra pointer to unlocked members to simplify access /// Convenience accessor for the global component registry
pub(super) unlocked_inner: Arc<RoutingTableUnlockedInner>, pub(super) registry: VeilidComponentRegistry,
/// Routing table buckets that hold references to entries, per crypto kind /// Routing table buckets that hold references to entries, per crypto kind
pub(super) buckets: BTreeMap<CryptoKind, Vec<Bucket>>, pub(super) buckets: BTreeMap<CryptoKind, Vec<Bucket>>,
/// A weak set of all the entries we have in the buckets for faster iteration /// A weak set of all the entries we have in the buckets for faster iteration
@ -44,10 +44,12 @@ pub struct RoutingTableInner {
pub(super) opt_active_watch_keepalive_ts: Option<Timestamp>, pub(super) opt_active_watch_keepalive_ts: Option<Timestamp>,
} }
impl_veilid_component_registry_accessor!(RoutingTableInner);
impl RoutingTableInner { impl RoutingTableInner {
pub(super) fn new(unlocked_inner: Arc<RoutingTableUnlockedInner>) -> RoutingTableInner { pub(super) fn new(registry: VeilidComponentRegistry) -> RoutingTableInner {
RoutingTableInner { RoutingTableInner {
unlocked_inner, registry,
buckets: BTreeMap::new(), buckets: BTreeMap::new(),
public_internet_routing_domain: PublicInternetRoutingDomainDetail::default(), public_internet_routing_domain: PublicInternetRoutingDomainDetail::default(),
local_network_routing_domain: LocalNetworkRoutingDomainDetail::default(), local_network_routing_domain: LocalNetworkRoutingDomainDetail::default(),
@ -458,7 +460,6 @@ impl RoutingTableInner {
// Collect all entries that are 'needs_ping' and have some node info making them reachable somehow // Collect all entries that are 'needs_ping' and have some node info making them reachable somehow
pub(super) fn get_nodes_needing_ping( pub(super) fn get_nodes_needing_ping(
&self, &self,
outer_self: RoutingTable,
routing_domain: RoutingDomain, routing_domain: RoutingDomain,
cur_ts: Timestamp, cur_ts: Timestamp,
) -> Vec<FilteredNodeRef> { ) -> Vec<FilteredNodeRef> {
@ -559,7 +560,7 @@ impl RoutingTableInner {
let transform = |_rti: &RoutingTableInner, v: Option<Arc<BucketEntry>>| { let transform = |_rti: &RoutingTableInner, v: Option<Arc<BucketEntry>>| {
FilteredNodeRef::new( FilteredNodeRef::new(
outer_self.clone(), self.registry.clone(),
v.unwrap().clone(), v.unwrap().clone(),
NodeRefFilter::new().with_routing_domain(routing_domain), NodeRefFilter::new().with_routing_domain(routing_domain),
Sequencing::default(), Sequencing::default(),
@ -570,10 +571,10 @@ impl RoutingTableInner {
} }
#[expect(dead_code)] #[expect(dead_code)]
pub fn get_all_alive_nodes(&self, outer_self: RoutingTable, cur_ts: Timestamp) -> Vec<NodeRef> { pub fn get_all_alive_nodes(&self, cur_ts: Timestamp) -> Vec<NodeRef> {
let mut node_refs = Vec::<NodeRef>::with_capacity(self.bucket_entry_count()); let mut node_refs = Vec::<NodeRef>::with_capacity(self.bucket_entry_count());
self.with_entries(cur_ts, BucketEntryState::Unreliable, |_rti, entry| { self.with_entries(cur_ts, BucketEntryState::Unreliable, |_rti, entry| {
node_refs.push(NodeRef::new(outer_self.clone(), entry)); node_refs.push(NodeRef::new(self.registry(), entry));
Option::<()>::None Option::<()>::None
}); });
node_refs node_refs
@ -601,6 +602,8 @@ impl RoutingTableInner {
entry: Arc<BucketEntry>, entry: Arc<BucketEntry>,
node_ids: &[TypedKey], node_ids: &[TypedKey],
) -> EyreResult<()> { ) -> EyreResult<()> {
let routing_table = self.routing_table();
entry.with_mut_inner(|e| { entry.with_mut_inner(|e| {
let mut existing_node_ids = e.node_ids(); let mut existing_node_ids = e.node_ids();
@ -631,21 +634,21 @@ impl RoutingTableInner {
if let Some(old_node_id) = e.add_node_id(*node_id)? { if let Some(old_node_id) = e.add_node_id(*node_id)? {
// Remove any old node id for this crypto kind // Remove any old node id for this crypto kind
if VALID_CRYPTO_KINDS.contains(&ck) { if VALID_CRYPTO_KINDS.contains(&ck) {
let bucket_index = self.unlocked_inner.calculate_bucket_index(&old_node_id); let bucket_index = routing_table.calculate_bucket_index(&old_node_id);
let bucket = self.get_bucket_mut(bucket_index); let bucket = self.get_bucket_mut(bucket_index);
bucket.remove_entry(&old_node_id.value); bucket.remove_entry(&old_node_id.value);
self.unlocked_inner.kick_queue.lock().insert(bucket_index); routing_table.kick_queue.lock().insert(bucket_index);
} }
} }
// Bucket the entry appropriately // Bucket the entry appropriately
if VALID_CRYPTO_KINDS.contains(&ck) { if VALID_CRYPTO_KINDS.contains(&ck) {
let bucket_index = self.unlocked_inner.calculate_bucket_index(node_id); let bucket_index = routing_table.calculate_bucket_index(node_id);
let bucket = self.get_bucket_mut(bucket_index); let bucket = self.get_bucket_mut(bucket_index);
bucket.add_existing_entry(node_id.value, entry.clone()); bucket.add_existing_entry(node_id.value, entry.clone());
// Kick bucket // Kick bucket
self.unlocked_inner.kick_queue.lock().insert(bucket_index); routing_table.kick_queue.lock().insert(bucket_index);
} }
} }
@ -653,7 +656,7 @@ impl RoutingTableInner {
for node_id in existing_node_ids.iter() { for node_id in existing_node_ids.iter() {
let ck = node_id.kind; let ck = node_id.kind;
if VALID_CRYPTO_KINDS.contains(&ck) { if VALID_CRYPTO_KINDS.contains(&ck) {
let bucket_index = self.unlocked_inner.calculate_bucket_index(node_id); let bucket_index = routing_table.calculate_bucket_index(node_id);
let bucket = self.get_bucket_mut(bucket_index); let bucket = self.get_bucket_mut(bucket_index);
bucket.remove_entry(&node_id.value); bucket.remove_entry(&node_id.value);
entry.with_mut_inner(|e| e.remove_node_id(ck)); entry.with_mut_inner(|e| e.remove_node_id(ck));
@ -687,15 +690,16 @@ impl RoutingTableInner {
#[instrument(level = "trace", skip_all, err)] #[instrument(level = "trace", skip_all, err)]
fn create_node_ref<F>( fn create_node_ref<F>(
&mut self, &mut self,
outer_self: RoutingTable,
node_ids: &TypedKeyGroup, node_ids: &TypedKeyGroup,
update_func: F, update_func: F,
) -> EyreResult<NodeRef> ) -> EyreResult<NodeRef>
where where
F: FnOnce(&mut RoutingTableInner, &mut BucketEntryInner), F: FnOnce(&mut RoutingTableInner, &mut BucketEntryInner),
{ {
let routing_table = self.routing_table();
// Ensure someone isn't trying register this node itself // Ensure someone isn't trying register this node itself
if self.unlocked_inner.matches_own_node_id(node_ids) { if routing_table.matches_own_node_id(node_ids) {
bail!("can't register own node"); bail!("can't register own node");
} }
@ -708,7 +712,7 @@ impl RoutingTableInner {
continue; continue;
} }
// Find the first in crypto sort order // Find the first in crypto sort order
let bucket_index = self.unlocked_inner.calculate_bucket_index(node_id); let bucket_index = routing_table.calculate_bucket_index(node_id);
let bucket = self.get_bucket(bucket_index); let bucket = self.get_bucket(bucket_index);
if let Some(entry) = bucket.entry(&node_id.value) { if let Some(entry) = bucket.entry(&node_id.value) {
// Best entry is the first one in sorted order that exists from the node id list // Best entry is the first one in sorted order that exists from the node id list
@ -730,7 +734,7 @@ impl RoutingTableInner {
} }
// Make a noderef to return // Make a noderef to return
let nr = NodeRef::new(outer_self.clone(), best_entry.clone()); let nr = NodeRef::new(self.registry(), best_entry.clone());
// Update the entry with the update func // Update the entry with the update func
best_entry.with_mut_inner(|e| update_func(self, e)); best_entry.with_mut_inner(|e| update_func(self, e));
@ -741,11 +745,11 @@ impl RoutingTableInner {
// If no entry exists yet, add the first entry to a bucket, possibly evicting a bucket member // If no entry exists yet, add the first entry to a bucket, possibly evicting a bucket member
let first_node_id = node_ids[0]; let first_node_id = node_ids[0];
let bucket_entry = self.unlocked_inner.calculate_bucket_index(&first_node_id); let bucket_entry = routing_table.calculate_bucket_index(&first_node_id);
let bucket = self.get_bucket_mut(bucket_entry); let bucket = self.get_bucket_mut(bucket_entry);
let new_entry = bucket.add_new_entry(first_node_id.value); let new_entry = bucket.add_new_entry(first_node_id.value);
self.all_entries.insert(new_entry.clone()); self.all_entries.insert(new_entry.clone());
self.unlocked_inner.kick_queue.lock().insert(bucket_entry); routing_table.kick_queue.lock().insert(bucket_entry);
// Update the other bucket entries with the remaining node ids // Update the other bucket entries with the remaining node ids
if let Err(e) = self.update_bucket_entry_node_ids(new_entry.clone(), node_ids) { if let Err(e) = self.update_bucket_entry_node_ids(new_entry.clone(), node_ids) {
@ -753,7 +757,7 @@ impl RoutingTableInner {
} }
// Make node ref to return // Make node ref to return
let nr = NodeRef::new(outer_self.clone(), new_entry.clone()); let nr = NodeRef::new(self.registry(), new_entry.clone());
// Update the entry with the update func // Update the entry with the update func
new_entry.with_mut_inner(|e| update_func(self, e)); new_entry.with_mut_inner(|e| update_func(self, e));
@ -766,15 +770,9 @@ impl RoutingTableInner {
/// Resolve an existing routing table entry using any crypto kind and return a reference to it /// Resolve an existing routing table entry using any crypto kind and return a reference to it
#[instrument(level = "trace", skip_all, err)] #[instrument(level = "trace", skip_all, err)]
pub fn lookup_any_node_ref( pub fn lookup_any_node_ref(&self, node_id_key: PublicKey) -> EyreResult<Option<NodeRef>> {
&self,
outer_self: RoutingTable,
node_id_key: PublicKey,
) -> EyreResult<Option<NodeRef>> {
for ck in VALID_CRYPTO_KINDS { for ck in VALID_CRYPTO_KINDS {
if let Some(nr) = if let Some(nr) = self.lookup_node_ref(TypedKey::new(ck, node_id_key))? {
self.lookup_node_ref(outer_self.clone(), TypedKey::new(ck, node_id_key))?
{
return Ok(Some(nr)); return Ok(Some(nr));
} }
} }
@ -783,35 +781,30 @@ impl RoutingTableInner {
/// Resolve an existing routing table entry and return a reference to it /// Resolve an existing routing table entry and return a reference to it
#[instrument(level = "trace", skip_all, err)] #[instrument(level = "trace", skip_all, err)]
pub fn lookup_node_ref( pub fn lookup_node_ref(&self, node_id: TypedKey) -> EyreResult<Option<NodeRef>> {
&self, if self.routing_table().matches_own_node_id(&[node_id]) {
outer_self: RoutingTable,
node_id: TypedKey,
) -> EyreResult<Option<NodeRef>> {
if self.unlocked_inner.matches_own_node_id(&[node_id]) {
bail!("can't look up own node id in routing table"); bail!("can't look up own node id in routing table");
} }
if !VALID_CRYPTO_KINDS.contains(&node_id.kind) { if !VALID_CRYPTO_KINDS.contains(&node_id.kind) {
bail!("can't look up node id with invalid crypto kind"); bail!("can't look up node id with invalid crypto kind");
} }
let bucket_index = self.unlocked_inner.calculate_bucket_index(&node_id); let bucket_index = self.routing_table().calculate_bucket_index(&node_id);
let bucket = self.get_bucket(bucket_index); let bucket = self.get_bucket(bucket_index);
Ok(bucket Ok(bucket
.entry(&node_id.value) .entry(&node_id.value)
.map(|e| NodeRef::new(outer_self, e))) .map(|e| NodeRef::new(self.registry(), e)))
} }
/// Resolve an existing routing table entry and return a filtered reference to it /// Resolve an existing routing table entry and return a filtered reference to it
#[instrument(level = "trace", skip_all, err)] #[instrument(level = "trace", skip_all, err)]
pub fn lookup_and_filter_noderef( pub fn lookup_and_filter_noderef(
&self, &self,
outer_self: RoutingTable,
node_id: TypedKey, node_id: TypedKey,
routing_domain_set: RoutingDomainSet, routing_domain_set: RoutingDomainSet,
dial_info_filter: DialInfoFilter, dial_info_filter: DialInfoFilter,
) -> EyreResult<Option<FilteredNodeRef>> { ) -> EyreResult<Option<FilteredNodeRef>> {
let nr = self.lookup_node_ref(outer_self, node_id)?; let nr = self.lookup_node_ref(node_id)?;
Ok(nr.map(|nr| { Ok(nr.map(|nr| {
nr.custom_filtered( nr.custom_filtered(
NodeRefFilter::new() NodeRefFilter::new()
@ -826,7 +819,7 @@ impl RoutingTableInner {
where where
F: FnOnce(Arc<BucketEntry>) -> R, F: FnOnce(Arc<BucketEntry>) -> R,
{ {
if self.unlocked_inner.matches_own_node_id(&[node_id]) { if self.routing_table().matches_own_node_id(&[node_id]) {
log_rtab!(error "can't look up own node id in routing table"); log_rtab!(error "can't look up own node id in routing table");
return None; return None;
} }
@ -834,7 +827,7 @@ impl RoutingTableInner {
log_rtab!(error "can't look up node id with invalid crypto kind"); log_rtab!(error "can't look up node id with invalid crypto kind");
return None; return None;
} }
let bucket_entry = self.unlocked_inner.calculate_bucket_index(&node_id); let bucket_entry = self.routing_table().calculate_bucket_index(&node_id);
let bucket = self.get_bucket(bucket_entry); let bucket = self.get_bucket(bucket_entry);
bucket.entry(&node_id.value).map(f) bucket.entry(&node_id.value).map(f)
} }
@ -845,7 +838,6 @@ impl RoutingTableInner {
#[instrument(level = "trace", skip_all, err)] #[instrument(level = "trace", skip_all, err)]
pub fn register_node_with_peer_info( pub fn register_node_with_peer_info(
&mut self, &mut self,
outer_self: RoutingTable,
peer_info: Arc<PeerInfo>, peer_info: Arc<PeerInfo>,
allow_invalid: bool, allow_invalid: bool,
) -> EyreResult<FilteredNodeRef> { ) -> EyreResult<FilteredNodeRef> {
@ -853,7 +845,7 @@ impl RoutingTableInner {
// if our own node is in the list, then ignore it as we don't add ourselves to our own routing table // if our own node is in the list, then ignore it as we don't add ourselves to our own routing table
if self if self
.unlocked_inner .routing_table()
.matches_own_node_id(peer_info.node_ids()) .matches_own_node_id(peer_info.node_ids())
{ {
bail!("can't register own node id in routing table"); bail!("can't register own node id in routing table");
@ -891,10 +883,10 @@ impl RoutingTableInner {
if let Some(relay_peer_info) = peer_info.signed_node_info().relay_peer_info(routing_domain) if let Some(relay_peer_info) = peer_info.signed_node_info().relay_peer_info(routing_domain)
{ {
if !self if !self
.unlocked_inner .routing_table()
.matches_own_node_id(relay_peer_info.node_ids()) .matches_own_node_id(relay_peer_info.node_ids())
{ {
self.register_node_with_peer_info(outer_self.clone(), relay_peer_info, false)?; self.register_node_with_peer_info(relay_peer_info, false)?;
} }
} }
@ -902,7 +894,7 @@ impl RoutingTableInner {
Arc::unwrap_or_clone(peer_info).destructure(); Arc::unwrap_or_clone(peer_info).destructure();
let mut updated = false; let mut updated = false;
let mut old_peer_info = None; let mut old_peer_info = None;
let nr = self.create_node_ref(outer_self, &node_ids, |_rti, e| { let nr = self.create_node_ref(&node_ids, |_rti, e| {
old_peer_info = e.make_peer_info(routing_domain); old_peer_info = e.make_peer_info(routing_domain);
updated = e.update_signed_node_info(routing_domain, &signed_node_info); updated = e.update_signed_node_info(routing_domain, &signed_node_info);
})?; })?;
@ -922,12 +914,11 @@ impl RoutingTableInner {
#[instrument(level = "trace", skip_all, err)] #[instrument(level = "trace", skip_all, err)]
pub fn register_node_with_id( pub fn register_node_with_id(
&mut self, &mut self,
outer_self: RoutingTable,
routing_domain: RoutingDomain, routing_domain: RoutingDomain,
node_id: TypedKey, node_id: TypedKey,
timestamp: Timestamp, timestamp: Timestamp,
) -> EyreResult<FilteredNodeRef> { ) -> EyreResult<FilteredNodeRef> {
let nr = self.create_node_ref(outer_self, &TypedKeyGroup::from(node_id), |_rti, e| { let nr = self.create_node_ref(&TypedKeyGroup::from(node_id), |_rti, e| {
//e.make_not_dead(timestamp); //e.make_not_dead(timestamp);
e.touch_last_seen(timestamp); e.touch_last_seen(timestamp);
})?; })?;
@ -1057,7 +1048,7 @@ impl RoutingTableInner {
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub fn find_fast_non_local_nodes_filtered( pub fn find_fast_non_local_nodes_filtered(
&self, &self,
outer_self: RoutingTable, registry: VeilidComponentRegistry,
routing_domain: RoutingDomain, routing_domain: RoutingDomain,
node_count: usize, node_count: usize,
mut filters: VecDeque<RoutingTableEntryFilter>, mut filters: VecDeque<RoutingTableEntryFilter>,
@ -1089,7 +1080,7 @@ impl RoutingTableInner {
node_count, node_count,
filters, filters,
|_rti: &RoutingTableInner, v: Option<Arc<BucketEntry>>| { |_rti: &RoutingTableInner, v: Option<Arc<BucketEntry>>| {
NodeRef::new(outer_self.clone(), v.unwrap().clone()) NodeRef::new(registry.clone(), v.unwrap().clone())
}, },
) )
} }
@ -1283,10 +1274,12 @@ impl RoutingTableInner {
T: for<'r> FnMut(&'r RoutingTableInner, Option<Arc<BucketEntry>>) -> O, T: for<'r> FnMut(&'r RoutingTableInner, Option<Arc<BucketEntry>>) -> O,
{ {
let cur_ts = Timestamp::now(); let cur_ts = Timestamp::now();
let routing_table = self.routing_table();
// Get the crypto kind // Get the crypto kind
let crypto_kind = node_id.kind; let crypto_kind = node_id.kind;
let Some(vcrypto) = self.unlocked_inner.crypto().get(crypto_kind) else { let crypto = self.crypto();
let Some(vcrypto) = crypto.get(crypto_kind) else {
apibail_generic!("invalid crypto kind"); apibail_generic!("invalid crypto kind");
}; };
@ -1338,12 +1331,12 @@ impl RoutingTableInner {
let a_key = if let Some(a_entry) = a_entry { let a_key = if let Some(a_entry) = a_entry {
a_entry.with_inner(|e| e.node_ids().get(crypto_kind).unwrap()) a_entry.with_inner(|e| e.node_ids().get(crypto_kind).unwrap())
} else { } else {
self.unlocked_inner.node_id(crypto_kind) routing_table.node_id(crypto_kind)
}; };
let b_key = if let Some(b_entry) = b_entry { let b_key = if let Some(b_entry) = b_entry {
b_entry.with_inner(|e| e.node_ids().get(crypto_kind).unwrap()) b_entry.with_inner(|e| e.node_ids().get(crypto_kind).unwrap())
} else { } else {
self.unlocked_inner.node_id(crypto_kind) routing_table.node_id(crypto_kind)
}; };
// distance is the next metric, closer nodes first // distance is the next metric, closer nodes first
@ -1379,7 +1372,8 @@ impl RoutingTableInner {
.collect(); .collect();
// Sort closest // Sort closest
let sort = make_closest_noderef_sort(self.unlocked_inner.crypto(), node_id); let crypto = self.crypto();
let sort = make_closest_noderef_sort(&crypto, node_id);
closest_nodes_locked.sort_by(sort); closest_nodes_locked.sort_by(sort);
// Unlock noderefs // Unlock noderefs
@ -1388,10 +1382,10 @@ impl RoutingTableInner {
} }
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub fn make_closest_noderef_sort( pub fn make_closest_noderef_sort<'a>(
crypto: Crypto, crypto: &'a Crypto,
node_id: TypedKey, node_id: TypedKey,
) -> impl Fn(&LockedNodeRef, &LockedNodeRef) -> core::cmp::Ordering { ) -> impl Fn(&LockedNodeRef, &LockedNodeRef) -> core::cmp::Ordering + 'a {
let kind = node_id.kind; let kind = node_id.kind;
// Get cryptoversion to check distance with // Get cryptoversion to check distance with
let vcrypto = crypto.get(kind).unwrap(); let vcrypto = crypto.get(kind).unwrap();
@ -1418,9 +1412,9 @@ pub fn make_closest_noderef_sort(
} }
pub fn make_closest_node_id_sort( pub fn make_closest_node_id_sort(
crypto: Crypto, crypto: &Crypto,
node_id: TypedKey, node_id: TypedKey,
) -> impl Fn(&CryptoKey, &CryptoKey) -> core::cmp::Ordering { ) -> impl Fn(&CryptoKey, &CryptoKey) -> core::cmp::Ordering + '_ {
let kind = node_id.kind; let kind = node_id.kind;
// Get cryptoversion to check distance with // Get cryptoversion to check distance with
let vcrypto = crypto.get(kind).unwrap(); let vcrypto = crypto.get(kind).unwrap();

View File

@ -7,7 +7,7 @@ pub trait RoutingDomainEditorCommonTrait {
protocol_type: Option<ProtocolType>, protocol_type: Option<ProtocolType>,
) -> &mut Self; ) -> &mut Self;
fn set_relay_node(&mut self, relay_node: Option<NodeRef>) -> &mut Self; fn set_relay_node(&mut self, relay_node: Option<NodeRef>) -> &mut Self;
#[cfg_attr(target_arch = "wasm32", expect(dead_code))] #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), expect(dead_code))]
fn add_dial_info(&mut self, dial_info: DialInfo, class: DialInfoClass) -> &mut Self; fn add_dial_info(&mut self, dial_info: DialInfo, class: DialInfoClass) -> &mut Self;
fn setup_network( fn setup_network(
&mut self, &mut self,
@ -83,7 +83,7 @@ pub(super) enum RoutingDomainChangeCommon {
AddDialInfo { AddDialInfo {
dial_info_detail: DialInfoDetail, dial_info_detail: DialInfoDetail,
}, },
// #[cfg_attr(target_arch = "wasm32", expect(dead_code))] // #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), expect(dead_code))]
// RemoveDialInfoDetail { // RemoveDialInfoDetail {
// dial_info_detail: DialInfoDetail, // dial_info_detail: DialInfoDetail,
// }, // },

View File

@ -1,4 +1,4 @@
#![cfg_attr(target_arch = "wasm32", expect(dead_code))] #![cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), expect(dead_code))]
use super::*; use super::*;
@ -10,15 +10,15 @@ enum RoutingDomainChangeLocalNetwork {
Common(RoutingDomainChangeCommon), Common(RoutingDomainChangeCommon),
} }
pub struct RoutingDomainEditorLocalNetwork { pub struct RoutingDomainEditorLocalNetwork<'a> {
routing_table: RoutingTable, routing_table: &'a RoutingTable,
changes: Vec<RoutingDomainChangeLocalNetwork>, changes: Vec<RoutingDomainChangeLocalNetwork>,
} }
impl RoutingDomainEditorLocalNetwork { impl<'a> RoutingDomainEditorLocalNetwork<'a> {
pub(in crate::routing_table) fn new(routing_table: RoutingTable) -> Self { pub(in crate::routing_table) fn new(routing_table: &'a RoutingTable) -> Self {
Self { Self {
routing_table: routing_table.clone(), routing_table,
changes: Vec::new(), changes: Vec::new(),
} }
} }
@ -30,7 +30,7 @@ impl RoutingDomainEditorLocalNetwork {
} }
} }
impl RoutingDomainEditorCommonTrait for RoutingDomainEditorLocalNetwork { impl<'a> RoutingDomainEditorCommonTrait for RoutingDomainEditorLocalNetwork<'a> {
#[instrument(level = "debug", skip(self))] #[instrument(level = "debug", skip(self))]
fn clear_dial_info_details( fn clear_dial_info_details(
&mut self, &mut self,

View File

@ -144,11 +144,7 @@ impl RoutingDomainDetail for LocalNetworkRoutingDomainDetail {
pi pi
}; };
if let Err(e) = rti if let Err(e) = rti.event_bus().post(PeerInfoChangeEvent { peer_info }) {
.unlocked_inner
.event_bus
.post(PeerInfoChangeEvent { peer_info })
{
log_rtab!(debug "Failed to post event: {}", e); log_rtab!(debug "Failed to post event: {}", e);
} }

View File

@ -143,7 +143,7 @@ impl RoutingDomainDetailCommon {
pub fn network_class(&self) -> Option<NetworkClass> { pub fn network_class(&self) -> Option<NetworkClass> {
cfg_if! { cfg_if! {
if #[cfg(target_arch = "wasm32")] { if #[cfg(all(target_arch = "wasm32", target_os = "unknown"))] {
Some(NetworkClass::WebApp) Some(NetworkClass::WebApp)
} else { } else {
if self.address_types.is_empty() { if self.address_types.is_empty() {
@ -312,6 +312,9 @@ impl RoutingDomainDetailCommon {
// Internal functions // Internal functions
fn make_peer_info(&self, rti: &RoutingTableInner) -> PeerInfo { fn make_peer_info(&self, rti: &RoutingTableInner) -> PeerInfo {
let crypto = rti.crypto();
let routing_table = rti.routing_table();
let node_info = NodeInfo::new( let node_info = NodeInfo::new(
self.network_class().unwrap_or(NetworkClass::Invalid), self.network_class().unwrap_or(NetworkClass::Invalid),
self.outbound_protocols, self.outbound_protocols,
@ -343,8 +346,8 @@ impl RoutingDomainDetailCommon {
let signed_node_info = match relay_info { let signed_node_info = match relay_info {
Some((relay_ids, relay_sdni)) => SignedNodeInfo::Relayed( Some((relay_ids, relay_sdni)) => SignedNodeInfo::Relayed(
SignedRelayedNodeInfo::make_signatures( SignedRelayedNodeInfo::make_signatures(
rti.unlocked_inner.crypto(), &crypto,
rti.unlocked_inner.node_id_typed_key_pairs(), routing_table.node_id_typed_key_pairs(),
node_info, node_info,
relay_ids, relay_ids,
relay_sdni, relay_sdni,
@ -353,8 +356,8 @@ impl RoutingDomainDetailCommon {
), ),
None => SignedNodeInfo::Direct( None => SignedNodeInfo::Direct(
SignedDirectNodeInfo::make_signatures( SignedDirectNodeInfo::make_signatures(
rti.unlocked_inner.crypto(), &crypto,
rti.unlocked_inner.node_id_typed_key_pairs(), routing_table.node_id_typed_key_pairs(),
node_info, node_info,
) )
.unwrap(), .unwrap(),
@ -363,7 +366,7 @@ impl RoutingDomainDetailCommon {
PeerInfo::new( PeerInfo::new(
self.routing_domain, self.routing_domain,
rti.unlocked_inner.node_ids(), routing_table.node_ids(),
signed_node_info, signed_node_info,
) )
} }

View File

@ -5,13 +5,13 @@ enum RoutingDomainChangePublicInternet {
Common(RoutingDomainChangeCommon), Common(RoutingDomainChangeCommon),
} }
pub struct RoutingDomainEditorPublicInternet { pub struct RoutingDomainEditorPublicInternet<'a> {
routing_table: RoutingTable, routing_table: &'a RoutingTable,
changes: Vec<RoutingDomainChangePublicInternet>, changes: Vec<RoutingDomainChangePublicInternet>,
} }
impl RoutingDomainEditorPublicInternet { impl<'a> RoutingDomainEditorPublicInternet<'a> {
pub(in crate::routing_table) fn new(routing_table: RoutingTable) -> Self { pub(in crate::routing_table) fn new(routing_table: &'a RoutingTable) -> Self {
Self { Self {
routing_table, routing_table,
changes: Vec::new(), changes: Vec::new(),
@ -41,7 +41,7 @@ impl RoutingDomainEditorPublicInternet {
} }
} }
impl RoutingDomainEditorCommonTrait for RoutingDomainEditorPublicInternet { impl<'a> RoutingDomainEditorCommonTrait for RoutingDomainEditorPublicInternet<'a> {
#[instrument(level = "debug", skip(self))] #[instrument(level = "debug", skip(self))]
fn clear_dial_info_details( fn clear_dial_info_details(
&mut self, &mut self,
@ -263,8 +263,7 @@ impl RoutingDomainEditorCommonTrait for RoutingDomainEditorPublicInternet {
if changed { if changed {
// Clear the routespecstore cache if our PublicInternet dial info has changed // Clear the routespecstore cache if our PublicInternet dial info has changed
let rss = self.routing_table.route_spec_store(); self.routing_table.route_spec_store().reset_cache();
rss.reset_cache();
} }
} }

View File

@ -122,11 +122,7 @@ impl RoutingDomainDetail for PublicInternetRoutingDomainDetail {
pi pi
}; };
if let Err(e) = rti if let Err(e) = rti.event_bus().post(PeerInfoChangeEvent { peer_info }) {
.unlocked_inner
.event_bus
.post(PeerInfoChangeEvent { peer_info })
{
log_rtab!(debug "Failed to post event: {}", e); log_rtab!(debug "Failed to post event: {}", e);
} }
@ -167,11 +163,8 @@ impl RoutingDomainDetail for PublicInternetRoutingDomainDetail {
dif_sort: Option<Arc<DialInfoDetailSort>>, dif_sort: Option<Arc<DialInfoDetailSort>>,
) -> ContactMethod { ) -> ContactMethod {
let ip6_prefix_size = rti let ip6_prefix_size = rti
.unlocked_inner .config()
.config .with(|c| c.network.max_connections_per_ip6_prefix_size as usize);
.get()
.network
.max_connections_per_ip6_prefix_size as usize;
// Get the nodeinfos for convenience // Get the nodeinfos for convenience
let node_a = peer_a.signed_node_info().node_info(); let node_a = peer_a.signed_node_info().node_info();

View File

@ -81,7 +81,7 @@ impl RoutingTable {
} }
// If this is our own node id, then we skip it for bootstrap, in case we are a bootstrap node // If this is our own node id, then we skip it for bootstrap, in case we are a bootstrap node
if self.unlocked_inner.matches_own_node_id(&node_ids) { if self.matches_own_node_id(&node_ids) {
return Ok(None); return Ok(None);
} }
@ -255,7 +255,7 @@ impl RoutingTable {
//#[instrument(level = "trace", skip(self), err)] //#[instrument(level = "trace", skip(self), err)]
pub fn bootstrap_with_peer( pub fn bootstrap_with_peer(
self, &self,
crypto_kinds: Vec<CryptoKind>, crypto_kinds: Vec<CryptoKind>,
pi: Arc<PeerInfo>, pi: Arc<PeerInfo>,
unord: &FuturesUnordered<SendPinBoxFuture<()>>, unord: &FuturesUnordered<SendPinBoxFuture<()>>,
@ -280,19 +280,20 @@ impl RoutingTable {
for crypto_kind in crypto_kinds { for crypto_kind in crypto_kinds {
// Bootstrap this crypto kind // Bootstrap this crypto kind
let nr = nr.unfiltered(); let nr = nr.unfiltered();
let routing_table = self.clone();
unord.push(Box::pin( unord.push(Box::pin(
async move { async move {
let network_manager = nr.network_manager();
let routing_table = nr.routing_table();
// Get what contact method would be used for contacting the bootstrap // Get what contact method would be used for contacting the bootstrap
let bsdi = match routing_table let bsdi = match network_manager
.network_manager()
.get_node_contact_method(nr.default_filtered()) .get_node_contact_method(nr.default_filtered())
{ {
Ok(NodeContactMethod::Direct(v)) => v, Ok(NodeContactMethod::Direct(v)) => v,
Ok(v) => { Ok(v) => {
log_rtab!(debug "invalid contact method for bootstrap, ignoring peer: {:?}", v); log_rtab!(debug "invalid contact method for bootstrap, ignoring peer: {:?}", v);
// let _ = routing_table // let _ =
// .network_manager() // network_manager
// .get_node_contact_method(nr.clone()); // .get_node_contact_method(nr.clone());
return; return;
} }
@ -312,7 +313,7 @@ impl RoutingTable {
log_rtab!(debug "bootstrap server is not responding for dialinfo: {}", bsdi); log_rtab!(debug "bootstrap server is not responding for dialinfo: {}", bsdi);
// Try a different dialinfo next time // Try a different dialinfo next time
routing_table.network_manager().address_filter().set_dial_info_failed(bsdi); network_manager.address_filter().set_dial_info_failed(bsdi);
} else { } else {
// otherwise this bootstrap is valid, lets ask it to find ourselves now // otherwise this bootstrap is valid, lets ask it to find ourselves now
routing_table.reverse_find_node(crypto_kind, nr, true, vec![]).await routing_table.reverse_find_node(crypto_kind, nr, true, vec![]).await
@ -325,7 +326,7 @@ impl RoutingTable {
#[instrument(level = "trace", skip(self), err)] #[instrument(level = "trace", skip(self), err)]
pub async fn bootstrap_with_peer_list( pub async fn bootstrap_with_peer_list(
self, &self,
peers: Vec<Arc<PeerInfo>>, peers: Vec<Arc<PeerInfo>>,
stop_token: StopToken, stop_token: StopToken,
) -> EyreResult<()> { ) -> EyreResult<()> {
@ -339,8 +340,7 @@ impl RoutingTable {
// Run all bootstrap operations concurrently // Run all bootstrap operations concurrently
let mut unord = FuturesUnordered::<SendPinBoxFuture<()>>::new(); let mut unord = FuturesUnordered::<SendPinBoxFuture<()>>::new();
for peer in peers { for peer in peers {
self.clone() self.bootstrap_with_peer(crypto_kinds.clone(), peer, &unord);
.bootstrap_with_peer(crypto_kinds.clone(), peer, &unord);
} }
// Wait for all bootstrap operations to complete before we complete the singlefuture // Wait for all bootstrap operations to complete before we complete the singlefuture
@ -364,10 +364,15 @@ impl RoutingTable {
} }
#[instrument(level = "trace", skip(self), err)] #[instrument(level = "trace", skip(self), err)]
pub async fn bootstrap_task_routine(self, stop_token: StopToken) -> EyreResult<()> { pub async fn bootstrap_task_routine(
&self,
stop_token: StopToken,
_last_ts: Timestamp,
_cur_ts: Timestamp,
) -> EyreResult<()> {
let bootstrap = self let bootstrap = self
.unlocked_inner .config()
.with_config(|c| c.network.routing_table.bootstrap.clone()); .with(|c| c.network.routing_table.bootstrap.clone());
// Don't bother if bootstraps aren't configured // Don't bother if bootstraps aren't configured
if bootstrap.is_empty() { if bootstrap.is_empty() {
@ -445,8 +450,6 @@ impl RoutingTable {
peers peers
}; };
self.clone() self.bootstrap_with_peer_list(peers, stop_token).await
.bootstrap_with_peer_list(peers, stop_token)
.await
} }
} }

Some files were not shown because too many files have changed in this diff Show More