checkpoint merge of network-shim branch

This commit is contained in:
Christien Rioux 2025-02-10 03:06:41 +00:00
parent 079b665230
commit a2b0214b8e
276 changed files with 17493 additions and 7193 deletions

1912
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,13 @@ members = [
]
resolver = "2"
[workspace.package]
repository = "https://gitlab.com/veilid/veilid"
authors = ["Veilid Team <contact@veilid.com>"]
license = "MPL-2.0"
edition = "2021"
rust-version = "1.81.0"
[patch.crates-io]
cursive = { git = "https://gitlab.com/veilid/cursive.git" }
cursive_core = { git = "https://gitlab.com/veilid/cursive.git" }
@ -26,3 +33,31 @@ lto = true
[profile.dev.package.backtrace]
opt-level = 3
[profile.dev.package.argon2]
opt-level = 3
debug-assertions = false
[profile.dev.package.ed25519-dalek]
opt-level = 3
debug-assertions = false
[profile.dev.package.x25519-dalek]
opt-level = 3
debug-assertions = false
[profile.dev.package.curve25519-dalek]
opt-level = 3
debug-assertions = false
[profile.dev.package.chacha20poly1305]
opt-level = 3
debug-assertions = false
[profile.dev.package.blake3]
opt-level = 3
debug-assertions = false
[profile.dev.package.chacha20]
opt-level = 3
debug-assertions = false

View File

@ -36,7 +36,7 @@ command line without it. If you do so, you may skip to
[Run Veilid setup script](#Run Veilid setup script).
- build-tools;34.0.0
- ndk;26.3.11579264
- ndk;27.0.12077973
- cmake;3.22.1
- platform-tools
- platforms;android-34
@ -58,7 +58,7 @@ the command line to install the requisite package versions:
sdkmanager --install "platform-tools"
sdkmanager --install "platforms;android-34"
sdkmanager --install "build-tools;34.0.0"
sdkmanager --install "ndk;26.3.11579264"
sdkmanager --install "ndk;27.0.12077973"
sdkmanager --install "cmake;3.22.1"
```
@ -110,7 +110,7 @@ You will need to use Android Studio [here](https://developer.android.com/studio)
to maintain your Android dependencies. Use the SDK Manager in the IDE to install the following packages (use package details view to select version):
- Android SDK Build Tools (34.0.0)
- NDK (Side-by-side) (26.3.11579264)
- NDK (Side-by-side) (27.0.12077973)
- Cmake (3.22.1)
- Android SDK 34
- Android SDK Command Line Tools (latest) (7.0/latest)

View File

@ -18,7 +18,7 @@ FROM ubuntu:18.04
ENV ZIG_VERSION=0.13.0
ENV CMAKE_VERSION_MINOR=3.30
ENV CMAKE_VERSION_PATCH=3.30.1
ENV WASM_BINDGEN_CLI_VERSION=0.2.93
ENV WASM_BINDGEN_CLI_VERSION=0.2.100
ENV RUST_VERSION=1.81.0
ENV RUSTUP_HOME=/usr/local/rustup
ENV RUSTUP_DIST_SERVER=https://static.rust-lang.org
@ -82,7 +82,7 @@ deps-android:
RUN mkdir /Android; mkdir /Android/Sdk
RUN curl -o /Android/cmdline-tools.zip https://dl.google.com/android/repository/commandlinetools-linux-9123335_latest.zip
RUN cd /Android; unzip /Android/cmdline-tools.zip
RUN yes | /Android/cmdline-tools/bin/sdkmanager --sdk_root=/Android/Sdk build-tools\;34.0.0 ndk\;26.3.11579264 cmake\;3.22.1 platform-tools platforms\;android-34 cmdline-tools\;latest
RUN yes | /Android/cmdline-tools/bin/sdkmanager --sdk_root=/Android/Sdk build-tools\;34.0.0 ndk\;27.0.12077973 cmake\;3.22.1 platform-tools platforms\;android-34 cmdline-tools\;latest
RUN rm -rf /Android/cmdline-tools
RUN apt-get clean
@ -170,7 +170,7 @@ build-linux-arm64:
build-android:
FROM +code-android
WORKDIR /veilid/veilid-core
ENV PATH=$PATH:/Android/Sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/bin/
ENV PATH=$PATH:/Android/Sdk/ndk/27.0.12077973/toolchains/llvm/prebuilt/linux-x86_64/bin/
RUN cargo build --target aarch64-linux-android --release
RUN cargo build --target armv7-linux-androideabi --release
RUN cargo build --target i686-linux-android --release

View File

@ -43,14 +43,14 @@ while true; do
curl -o $HOME/Android/cmdline-tools.zip https://dl.google.com/android/repository/commandlinetools-linux-9123335_latest.zip
cd $HOME/Android
unzip $HOME/Android/cmdline-tools.zip
$HOME/Android/cmdline-tools/bin/sdkmanager --sdk_root=$HOME/Android/Sdk build-tools\;34.0.0 ndk\;26.3.11579264 cmake\;3.22.1 platform-tools platforms\;android-34 cmdline-tools\;latest emulator
$HOME/Android/cmdline-tools/bin/sdkmanager --sdk_root=$HOME/Android/Sdk build-tools\;34.0.0 ndk\;27.0.12077973 cmake\;3.22.1 platform-tools platforms\;android-34 cmdline-tools\;latest emulator
cd $HOME
rm -rf $HOME/Android/cmdline-tools $HOME/Android/cmdline-tools.zip
# Add environment variables
cat >>$HOME/.profile <<END
source "\$HOME/.cargo/env"
export PATH=\$PATH:\$HOME/Android/Sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/bin:\$HOME/Android/Sdk/platform-tools:\$HOME/Android/Sdk/cmdline-tools/latest/bin
export PATH=\$PATH:\$HOME/Android/Sdk/ndk/27.0.12077973/toolchains/llvm/prebuilt/linux-x86_64/bin:\$HOME/Android/Sdk/platform-tools:\$HOME/Android/Sdk/cmdline-tools/latest/bin
export ANDROID_HOME=\$HOME/Android/Sdk
END
break

View File

@ -42,7 +42,7 @@ while true; do
fi
# ensure ndk is installed
ANDROID_NDK_HOME="$ANDROID_HOME/ndk/26.3.11579264"
ANDROID_NDK_HOME="$ANDROID_HOME/ndk/27.0.12077973"
if [ -f "$ANDROID_NDK_HOME/ndk-build" ]; then
echo '[X] Android NDK is installed at the location $ANDROID_NDK_HOME'
else

View File

@ -31,10 +31,10 @@ while true; do
fi
# ensure Android SDK packages are installed
$ANDROID_HOME/cmdline-tools/latest/bin/sdkmanager build-tools\;34.0.0 ndk\;26.3.11579264 cmake\;3.22.1 platform-tools platforms\;android-34
$ANDROID_HOME/cmdline-tools/latest/bin/sdkmanager build-tools\;34.0.0 ndk\;27.0.12077973 cmake\;3.22.1 platform-tools platforms\;android-34
# ensure ndk is installed
ANDROID_NDK_HOME="$ANDROID_HOME/ndk/26.3.11579264"
ANDROID_NDK_HOME="$ANDROID_HOME/ndk/27.0.12077973"
if [ -f "$ANDROID_NDK_HOME/ndk-build" ]; then
echo '[X] Android NDK is installed at the location $ANDROID_NDK_HOME'
else

View File

@ -27,6 +27,7 @@ logging:
enabled: false
testing:
subnode_index: 0
subnode_count: 1
core:
protected_store:
allow_insecure_fallback: true

View File

@ -138,6 +138,7 @@ otlp:
```yaml
testing:
subnode_index: 0
subnode_count: 1
```
### core

View File

@ -1,8 +1,8 @@
[target.aarch64-linux-android]
linker = "/Android/Sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android34-clang"
linker = "/Android/Sdk/ndk/27.0.12077973/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android34-clang"
[target.armv7-linux-androideabi]
linker = "/Android/Sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/bin/armv7a-linux-androideabi33-clang"
linker = "/Android/Sdk/ndk/27.0.12077973/toolchains/llvm/prebuilt/linux-x86_64/bin/armv7a-linux-androideabi33-clang"
[target.x86_64-linux-android]
linker = "/Android/Sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/bin/x86_64-linux-android34-clang"
linker = "/Android/Sdk/ndk/27.0.12077973/toolchains/llvm/prebuilt/linux-x86_64/bin/x86_64-linux-android34-clang"
[target.i686-linux-android]
linker = "/Android/Sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/bin/i686-linux-android34-clang"
linker = "/Android/Sdk/ndk/27.0.12077973/toolchains/llvm/prebuilt/linux-x86_64/bin/i686-linux-android34-clang"

View File

@ -4,12 +4,12 @@ name = "veilid-cli"
version = "0.4.1"
# ---
description = "Client application for connecting to a Veilid headless node"
repository = "https://gitlab.com/veilid/veilid"
authors = ["Veilid Team <contact@veilid.com>"]
edition = "2021"
license = "MPL-2.0"
resolver = "2"
rust-version = "1.81.0"
repository.workspace = true
authors.workspace = true
license.workspace = true
edition.workspace = true
rust-version.workspace = true
[[bin]]
name = "veilid-cli"
@ -17,6 +17,8 @@ path = "src/main.rs"
[features]
default = ["rt-tokio"]
default-async-std = ["rt-async-std"]
rt-async-std = [
"async-std",
"veilid-tools/rt-async-std",

View File

@ -223,13 +223,12 @@ impl ClientApiConnection {
trace!("ClientApiConnection::handle_tcp_connection");
// Connect the TCP socket
let stream = TcpStream::connect(connect_addr)
let stream = connect_async_tcp_stream(None, connect_addr, 10_000)
.await
.map_err(map_to_string)?
.into_timeout_error()
.map_err(map_to_string)?;
// If it succeed, disable nagle algorithm
stream.set_nodelay(true).map_err(map_to_string)?;
// State we connected
let comproc = self.inner.lock().comproc.clone();
comproc.set_connection_state(ConnectionState::ConnectedTCP(
@ -239,16 +238,8 @@ impl ClientApiConnection {
// Split into reader and writer halves
// with line buffering on the reader
cfg_if! {
if #[cfg(feature="rt-async-std")] {
use futures::AsyncReadExt;
let (reader, writer) = stream.split();
let reader = BufReader::new(reader);
} else {
let (reader, writer) = stream.into_split();
let reader = BufReader::new(reader);
}
}
let (reader, writer) = split_async_tcp_stream(stream);
let reader = BufReader::new(reader);
self.clone().run_json_api_processor(reader, writer).await
}

View File

@ -57,6 +57,7 @@ struct CommandProcessorInner {
#[derive(Clone)]
pub struct CommandProcessor {
inner: Arc<Mutex<CommandProcessorInner>>,
settings: Arc<Settings>,
}
impl CommandProcessor {
@ -75,6 +76,7 @@ impl CommandProcessor {
last_call_id: None,
enable_app_messages: false,
})),
settings: Arc::new(settings.clone()),
}
}
pub fn set_client_api_connection(&self, capi: ClientApiConnection) {
@ -186,6 +188,54 @@ Core Debug Commands:
Ok(())
}
pub fn cmd_connect(&self, rest: Option<String>, callback: UICallback) -> Result<(), String> {
trace!("CommandProcessor::cmd_connect");
let capi = self.capi();
let ui = self.ui_sender();
let this = self.clone();
spawn_detached_local("cmd connect", async move {
capi.disconnect().await;
if let Some(rest) = rest {
if let Ok(subnode_index) = u16::from_str(&rest) {
let ipc_path = this
.settings
.resolve_ipc_path(this.settings.ipc_path.clone(), subnode_index);
this.set_ipc_path(ipc_path);
this.set_network_address(None);
} else if let Some(ipc_path) =
this.settings.resolve_ipc_path(Some(rest.clone().into()), 0)
{
this.set_ipc_path(Some(ipc_path));
this.set_network_address(None);
} else if let Ok(Some(network_address)) =
this.settings.resolve_network_address(Some(rest.clone()))
{
if let Some(addr) = network_address.first() {
this.set_network_address(Some(*addr));
this.set_ipc_path(None);
} else {
ui.add_node_event(
Level::Error,
&format!("Invalid network address: {}", rest),
);
}
} else {
ui.add_node_event(
Level::Error,
&format!("Invalid connection string: {}", rest),
);
}
}
this.start_connection();
ui.send_callback(callback);
});
Ok(())
}
pub fn cmd_debug(&self, command_line: String, callback: UICallback) -> Result<(), String> {
trace!("CommandProcessor::cmd_debug");
let capi = self.capi();
@ -331,6 +381,7 @@ Core Debug Commands:
"exit" => self.cmd_exit(callback),
"quit" => self.cmd_exit(callback),
"disconnect" => self.cmd_disconnect(callback),
"connect" => self.cmd_connect(rest, callback),
"shutdown" => self.cmd_shutdown(callback),
"change_log_level" => self.cmd_change_log_level(rest, callback),
"change_log_ignore" => self.cmd_change_log_ignore(rest, callback),

View File

@ -28,10 +28,11 @@ pub struct InteractiveUIInner {
#[derive(Clone)]
pub struct InteractiveUI {
inner: Arc<Mutex<InteractiveUIInner>>,
_settings: Arc<Settings>,
}
impl InteractiveUI {
pub fn new(_settings: &Settings) -> (Self, InteractiveUISender) {
pub fn new(settings: &Settings) -> (Self, InteractiveUISender) {
let (cssender, csreceiver) = flume::unbounded::<ConnectionState>();
let term = Term::stdout();
@ -45,9 +46,10 @@ impl InteractiveUI {
error: None,
done: Some(StopSource::new()),
connection_state_receiver: csreceiver,
log_enabled: false,
log_enabled: true,
enable_color,
})),
_settings: Arc::new(settings.clone()),
};
let ui_sender = InteractiveUISender {
@ -169,7 +171,6 @@ impl InteractiveUI {
eprintln!("Error: {:?}", e);
self.inner.lock().done.take();
}
self.inner.lock().log_enabled = true;
}
} else if line == "log warn" {
let opt_cmdproc = self.inner.lock().cmdproc.clone();
@ -181,7 +182,6 @@ impl InteractiveUI {
eprintln!("Error: {:?}", e);
self.inner.lock().done.take();
}
self.inner.lock().log_enabled = true;
}
} else if line == "log info" {
let opt_cmdproc = self.inner.lock().cmdproc.clone();
@ -193,7 +193,6 @@ impl InteractiveUI {
eprintln!("Error: {:?}", e);
self.inner.lock().done.take();
}
self.inner.lock().log_enabled = true;
}
} else if line == "log debug" || line == "log" {
let opt_cmdproc = self.inner.lock().cmdproc.clone();
@ -205,6 +204,8 @@ impl InteractiveUI {
eprintln!("Error: {:?}", e);
self.inner.lock().done.take();
}
}
if line == "log" {
self.inner.lock().log_enabled = true;
}
} else if line == "log trace" {
@ -217,7 +218,6 @@ impl InteractiveUI {
eprintln!("Error: {:?}", e);
self.inner.lock().done.take();
}
self.inner.lock().log_enabled = true;
}
} else if line == "log off" {
let opt_cmdproc = self.inner.lock().cmdproc.clone();
@ -229,9 +229,27 @@ impl InteractiveUI {
eprintln!("Error: {:?}", e);
self.inner.lock().done.take();
}
self.inner.lock().log_enabled = false;
}
} else if line == "log hide" || line == "log disable" {
self.inner.lock().log_enabled = false;
} else if line == "log show" || line == "log enable" {
self.inner.lock().log_enabled = true;
} else if !line.is_empty() {
if line == "help" {
let _ = writeln!(
stdout,
r#"
Interactive Mode Commands:
help - Display this help
clear - Clear the screen
log [level] - Set the client api log level for the node to one of: error,warn,info,debug,trace,off
hide|disable - Turn off viewing the log without changing the log level for the node
show|enable - Turn on viewing the log without changing the log level for the node
- With no option, 'log' turns on viewing the log and sets the level to 'debug'
"#
);
}
let cmdproc = self.inner.lock().cmdproc.clone();
if let Some(cmdproc) = &cmdproc {
if let Err(e) = cmdproc.run_command(

View File

@ -3,7 +3,7 @@
#![deny(unused_must_use)]
#![recursion_limit = "256"]
use crate::{settings::NamedSocketAddrs, tools::*, ui::*};
use crate::{tools::*, ui::*};
use clap::{Parser, ValueEnum};
use flexi_logger::*;
@ -37,7 +37,7 @@ struct CmdlineArgs {
ipc_path: Option<PathBuf>,
/// Subnode index to use when connecting
#[arg(short('n'), long, default_value = "0")]
subnode_index: usize,
subnode_index: u16,
/// Address to connect to
#[arg(long, short = 'a')]
address: Option<String>,
@ -47,9 +47,9 @@ struct CmdlineArgs {
/// Specify a configuration file to use
#[arg(short = 'c', long, value_name = "FILE")]
config_file: Option<PathBuf>,
/// log level
#[arg(value_enum)]
log_level: Option<LogLevel>,
/// Log level for the CLI itself (not for the Veilid node)
#[arg(long, value_enum)]
cli_log_level: Option<LogLevel>,
/// interactive
#[arg(long, short = 'i', group = "execution_mode")]
interactive: bool,
@ -93,11 +93,11 @@ fn main() -> Result<(), String> {
.map_err(|e| format!("configuration is invalid: {}", e))?;
// Set config from command line
if let Some(LogLevel::Debug) = args.log_level {
if let Some(LogLevel::Debug) = args.cli_log_level {
settings.logging.level = settings::LogLevel::Debug;
settings.logging.terminal.enabled = true;
}
if let Some(LogLevel::Trace) = args.log_level {
if let Some(LogLevel::Trace) = args.cli_log_level {
settings.logging.level = settings::LogLevel::Trace;
settings.logging.terminal.enabled = true;
}
@ -248,59 +248,14 @@ fn main() -> Result<(), String> {
// Determine IPC path to try
let mut client_api_ipc_path = None;
if enable_ipc {
cfg_if::cfg_if! {
if #[cfg(windows)] {
if let Some(ipc_path) = args.ipc_path.or(settings.ipc_path.clone()) {
if is_ipc_socket_path(&ipc_path) {
// try direct path
enable_network = false;
client_api_ipc_path = Some(ipc_path);
} else {
// try subnode index inside path
let ipc_path = ipc_path.join(args.subnode_index.to_string());
if is_ipc_socket_path(&ipc_path) {
// subnode indexed path exists
enable_network = false;
client_api_ipc_path = Some(ipc_path);
}
}
}
} else {
if let Some(ipc_path) = args.ipc_path.or(settings.ipc_path.clone()) {
if is_ipc_socket_path(&ipc_path) {
// try direct path
enable_network = false;
client_api_ipc_path = Some(ipc_path);
} else if ipc_path.exists() && ipc_path.is_dir() {
// try subnode index inside path
let ipc_path = ipc_path.join(args.subnode_index.to_string());
if is_ipc_socket_path(&ipc_path) {
// subnode indexed path exists
enable_network = false;
client_api_ipc_path = Some(ipc_path);
}
}
}
}
client_api_ipc_path = settings.resolve_ipc_path(args.ipc_path, args.subnode_index);
if client_api_ipc_path.is_some() {
enable_network = false;
}
}
let mut client_api_network_addresses = None;
if enable_network {
let args_address = if let Some(args_address) = args.address {
match NamedSocketAddrs::try_from(args_address) {
Ok(v) => Some(v),
Err(e) => {
return Err(format!("Invalid server address: {}", e));
}
}
} else {
None
};
if let Some(address_arg) = args_address.or(settings.address.clone()) {
client_api_network_addresses = Some(address_arg.addrs);
} else if let Some(address) = settings.address.clone() {
client_api_network_addresses = Some(address.addrs.clone());
}
client_api_network_addresses = settings.resolve_network_address(args.address)?;
}
// Create command processor

View File

@ -1,5 +1,6 @@
use directories::*;
use crate::tools::*;
use serde_derive::*;
use std::ffi::OsStr;
use std::net::{SocketAddr, ToSocketAddrs};
@ -118,7 +119,7 @@ pub fn convert_loglevel(log_level: LogLevel) -> log::LevelFilter {
}
}
#[derive(Debug, Clone)]
#[derive(Clone, Debug)]
pub struct NamedSocketAddrs {
pub _name: String,
pub addrs: Vec<SocketAddr>,
@ -148,26 +149,26 @@ impl<'de> serde::Deserialize<'de> for NamedSocketAddrs {
}
}
#[derive(Debug, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
pub struct Terminal {
pub enabled: bool,
}
#[derive(Debug, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
pub struct File {
pub enabled: bool,
pub directory: String,
pub append: bool,
}
#[derive(Debug, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
pub struct Logging {
pub terminal: Terminal,
pub file: File,
pub level: LogLevel,
}
#[derive(Debug, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
pub struct Colors {
pub background: String,
pub shadow: String,
@ -182,7 +183,7 @@ pub struct Colors {
pub highlight_text: String,
}
#[derive(Debug, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
pub struct LogColors {
pub trace: String,
pub debug: String,
@ -191,7 +192,7 @@ pub struct LogColors {
pub error: String,
}
#[derive(Debug, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
pub struct Theme {
pub shadow: bool,
pub borders: String,
@ -199,24 +200,24 @@ pub struct Theme {
pub log_colors: LogColors,
}
#[derive(Debug, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
pub struct NodeLog {
pub scrollback: usize,
}
#[derive(Debug, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
pub struct CommandLine {
pub history_size: usize,
}
#[derive(Debug, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
pub struct Interface {
pub theme: Theme,
pub node_log: NodeLog,
pub command_line: CommandLine,
}
#[derive(Debug, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
pub struct Settings {
pub enable_ipc: bool,
pub ipc_path: Option<PathBuf>,
@ -229,6 +230,90 @@ pub struct Settings {
}
impl Settings {
//////////////////////////////////////////////////////////////////////////////////
pub fn new(config_file: Option<&OsStr>) -> Result<Self, config::ConfigError> {
// Load the default config
let mut cfg = load_default_config()?;
// Merge in the config file if we have one
if let Some(config_file) = config_file {
let config_file_path = Path::new(config_file);
// If the user specifies a config file on the command line then it must exist
cfg = load_config(cfg, config_file_path)?;
}
// Generate config
cfg.try_deserialize()
}
pub fn resolve_ipc_path(
&self,
ipc_path: Option<PathBuf>,
subnode_index: u16,
) -> Option<PathBuf> {
let mut client_api_ipc_path = None;
// Determine IPC path to try
cfg_if::cfg_if! {
if #[cfg(windows)] {
if let Some(ipc_path) = ipc_path.or(self.ipc_path.clone()) {
if is_ipc_socket_path(&ipc_path) {
// try direct path
enable_network = false;
client_api_ipc_path = Some(ipc_path);
} else {
// try subnode index inside path
let ipc_path = ipc_path.join(subnode_index.to_string());
if is_ipc_socket_path(&ipc_path) {
// subnode indexed path exists
client_api_ipc_path = Some(ipc_path);
}
}
}
} else {
if let Some(ipc_path) = ipc_path.or(self.ipc_path.clone()) {
if is_ipc_socket_path(&ipc_path) {
// try direct path
client_api_ipc_path = Some(ipc_path);
} else if ipc_path.exists() && ipc_path.is_dir() {
// try subnode index inside path
let ipc_path = ipc_path.join(subnode_index.to_string());
if is_ipc_socket_path(&ipc_path) {
// subnode indexed path exists
client_api_ipc_path = Some(ipc_path);
}
}
}
}
}
client_api_ipc_path
}
pub fn resolve_network_address(
&self,
address: Option<String>,
) -> Result<Option<Vec<SocketAddr>>, String> {
let mut client_api_network_addresses = None;
let args_address = if let Some(args_address) = address {
match NamedSocketAddrs::try_from(args_address) {
Ok(v) => Some(v),
Err(e) => {
return Err(format!("Invalid server address: {}", e));
}
}
} else {
None
};
if let Some(address_arg) = args_address.or(self.address.clone()) {
client_api_network_addresses = Some(address_arg.addrs);
} else if let Some(address) = self.address.clone() {
client_api_network_addresses = Some(address.addrs.clone());
}
Ok(client_api_network_addresses)
}
////////////////////////////////////////////////////////////////////////////
#[cfg_attr(windows, expect(dead_code))]
fn get_server_default_directory(subpath: &str) -> PathBuf {
#[cfg(unix)]
@ -284,21 +369,6 @@ impl Settings {
default_log_directory
}
pub fn new(config_file: Option<&OsStr>) -> Result<Self, config::ConfigError> {
// Load the default config
let mut cfg = load_default_config()?;
// Merge in the config file if we have one
if let Some(config_file) = config_file {
let config_file_path = Path::new(config_file);
// If the user specifies a config file on the command line then it must exist
cfg = load_config(cfg, config_file_path)?;
}
// Generate config
cfg.try_deserialize()
}
}
#[test]

View File

@ -8,12 +8,10 @@ use core::str::FromStr;
cfg_if! {
if #[cfg(feature="rt-async-std")] {
pub use async_std::net::TcpStream;
pub fn block_on<F: Future<Output = T>, T>(f: F) -> T {
async_std::task::block_on(f)
}
} else if #[cfg(feature="rt-tokio")] {
pub use tokio::net::TcpStream;
pub fn block_on<F: Future<Output = T>, T>(f: F) -> T {
let rt = tokio::runtime::Runtime::new().unwrap();
let local = tokio::task::LocalSet::new();

View File

@ -4,13 +4,13 @@ name = "veilid-core"
version = "0.4.1"
# ---
description = "Core library used to create a Veilid node and operate it as part of an application"
repository = "https://gitlab.com/veilid/veilid"
authors = ["Veilid Team <contact@veilid.com>"]
edition = "2021"
build = "build.rs"
license = "MPL-2.0"
resolver = "2"
rust-version = "1.81.0"
repository.workspace = true
authors.workspace = true
license.workspace = true
edition.workspace = true
rust-version.workspace = true
[lib]
crate-type = ["cdylib", "staticlib", "rlib"]
@ -56,6 +56,8 @@ veilid_core_ios_tests = ["dep:tracing-oslog"]
debug-locks = ["veilid-tools/debug-locks"]
unstable-blockstore = []
unstable-tunnels = []
virtual-network = ["veilid-tools/virtual-network"]
virtual-network-server = ["veilid-tools/virtual-network-server"]
# GeoIP
geolocation = ["maxminddb", "reqwest"]
@ -133,8 +135,8 @@ hickory-resolver = { version = "0.24.1", optional = true }
# Serialization
capnp = { version = "0.19.6", default-features = false, features = ["alloc"] }
serde = { version = "1.0.204", features = ["derive", "rc"] }
serde_json = { version = "1.0.120" }
serde = { version = "1.0.214", features = ["derive", "rc"] }
serde_json = { version = "1.0.132" }
serde-big-array = "0.5.1"
json = "0.12.4"
data-encoding = { version = "2.6.0" }
@ -148,7 +150,7 @@ sanitize-filename = "0.5.0"
# Dependencies for native builds only
# Linux, Windows, Mac, iOS, Android
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies]
# Tools
config = { version = "0.13.4", default-features = false, features = ["yaml"] }
@ -164,7 +166,6 @@ sysinfo = { version = "^0.30.13", default-features = false }
tokio = { version = "1.38.1", features = ["full"], optional = true }
tokio-util = { version = "0.7.11", features = ["compat"], optional = true }
tokio-stream = { version = "0.1.15", features = ["net"], optional = true }
async-io = { version = "1.13.0" }
futures-util = { version = "0.3.30", default-features = false, features = [
"async-await",
"sink",
@ -184,10 +185,9 @@ webpki = "0.22.4"
webpki-roots = "0.25.4"
rustls = "0.21.12"
rustls-pemfile = "1.0.4"
socket2 = { version = "0.5.7", features = ["all"] }
# Dependencies for WASM builds only
[target.'cfg(target_arch = "wasm32")'.dependencies]
[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies]
veilid-tools = { version = "0.4.1", path = "../veilid-tools", default-features = false, features = [
"rt-wasm-bindgen",
@ -222,7 +222,7 @@ tracing-wasm = "0.2.1"
keyvaluedb-web = "0.1.2"
### Configuration for WASM32 'web-sys' crate
[target.'cfg(target_arch = "wasm32")'.dependencies.web-sys]
[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies.web-sys]
version = "0.3.69"
features = [
'Document',
@ -260,12 +260,12 @@ tracing-oslog = { version = "0.1.2", optional = true }
### DEV DEPENDENCIES
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dev-dependencies]
simplelog = { version = "0.12.2", features = ["test"] }
serial_test = "2.0.0"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dev-dependencies]
serial_test = { version = "2.0.0", default-features = false, features = [
"async",
] }

View File

@ -178,7 +178,7 @@ fn fix_android_emulator() {
.or(env::var("ANDROID_SDK_ROOT"))
.expect("ANDROID_HOME or ANDROID_SDK_ROOT not set");
let lib_path = glob(&format!(
"{android_home}/ndk/26.3.11579264/**/lib{missing_library}.a"
"{android_home}/ndk/27.0.12077973/**/lib{missing_library}.a"
))
.expect("failed to glob")
.next()

View File

@ -1,54 +1,43 @@
use crate::*;
use crypto::Crypto;
use network_manager::*;
use routing_table::*;
use storage_manager::*;
use crate::{network_manager::StartupDisposition, *};
use routing_table::RoutingTableHealth;
#[derive(Debug, Clone)]
pub struct AttachmentManagerStartupContext {
pub startup_lock: Arc<StartupLock>,
}
impl AttachmentManagerStartupContext {
pub fn new() -> Self {
Self {
startup_lock: Arc::new(StartupLock::new()),
}
}
}
impl Default for AttachmentManagerStartupContext {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
struct AttachmentManagerInner {
last_attachment_state: AttachmentState,
last_routing_table_health: Option<RoutingTableHealth>,
maintain_peers: bool,
started_ts: Timestamp,
attach_ts: Option<Timestamp>,
update_callback: Option<UpdateCallback>,
attachment_maintainer_jh: Option<MustJoinHandle<()>>,
}
struct AttachmentManagerUnlockedInner {
_event_bus: EventBus,
config: VeilidConfig,
network_manager: NetworkManager,
#[derive(Debug)]
pub struct AttachmentManager {
registry: VeilidComponentRegistry,
inner: Mutex<AttachmentManagerInner>,
startup_context: AttachmentManagerStartupContext,
}
#[derive(Clone)]
pub struct AttachmentManager {
inner: Arc<Mutex<AttachmentManagerInner>>,
unlocked_inner: Arc<AttachmentManagerUnlockedInner>,
}
impl_veilid_component!(AttachmentManager);
impl AttachmentManager {
fn new_unlocked_inner(
event_bus: EventBus,
config: VeilidConfig,
storage_manager: StorageManager,
table_store: TableStore,
#[cfg(feature = "unstable-blockstore")] block_store: BlockStore,
crypto: Crypto,
) -> AttachmentManagerUnlockedInner {
AttachmentManagerUnlockedInner {
_event_bus: event_bus.clone(),
config: config.clone(),
network_manager: NetworkManager::new(
event_bus,
config,
storage_manager,
table_store,
#[cfg(feature = "unstable-blockstore")]
block_store,
crypto,
),
}
}
fn new_inner() -> AttachmentManagerInner {
AttachmentManagerInner {
last_attachment_state: AttachmentState::Detached,
@ -56,52 +45,35 @@ impl AttachmentManager {
maintain_peers: false,
started_ts: Timestamp::now(),
attach_ts: None,
update_callback: None,
attachment_maintainer_jh: None,
}
}
pub fn new(
event_bus: EventBus,
config: VeilidConfig,
storage_manager: StorageManager,
table_store: TableStore,
#[cfg(feature = "unstable-blockstore")] block_store: BlockStore,
crypto: Crypto,
registry: VeilidComponentRegistry,
startup_context: AttachmentManagerStartupContext,
) -> Self {
Self {
inner: Arc::new(Mutex::new(Self::new_inner())),
unlocked_inner: Arc::new(Self::new_unlocked_inner(
event_bus,
config,
storage_manager,
table_store,
#[cfg(feature = "unstable-blockstore")]
block_store,
crypto,
)),
registry,
inner: Mutex::new(Self::new_inner()),
startup_context,
}
}
pub fn config(&self) -> VeilidConfig {
self.unlocked_inner.config.clone()
pub fn is_attached(&self) -> bool {
let s = self.inner.lock().last_attachment_state;
!matches!(s, AttachmentState::Detached | AttachmentState::Detaching)
}
pub fn network_manager(&self) -> NetworkManager {
self.unlocked_inner.network_manager.clone()
#[allow(dead_code)]
pub fn is_detached(&self) -> bool {
let s = self.inner.lock().last_attachment_state;
matches!(s, AttachmentState::Detached)
}
// pub fn is_attached(&self) -> bool {
// let s = self.inner.lock().last_attachment_state;
// !matches!(s, AttachmentState::Detached | AttachmentState::Detaching)
// }
// pub fn is_detached(&self) -> bool {
// let s = self.inner.lock().last_attachment_state;
// matches!(s, AttachmentState::Detached)
// }
// pub fn get_attach_timestamp(&self) -> Option<Timestamp> {
// self.inner.lock().attach_ts
// }
#[allow(dead_code)]
pub fn get_attach_timestamp(&self) -> Option<Timestamp> {
self.inner.lock().attach_ts
}
fn translate_routing_table_health(
health: &RoutingTableHealth,
@ -155,11 +127,6 @@ impl AttachmentManager {
inner.last_attachment_state =
AttachmentManager::translate_routing_table_health(&health, routing_table_config);
// If we don't have an update callback yet for some reason, just return now
let Some(update_callback) = inner.update_callback.clone() else {
return;
};
// Send update if one of:
// * the attachment state has changed
// * routing domain readiness has changed
@ -172,7 +139,7 @@ impl AttachmentManager {
})
.unwrap_or(true);
if send_update {
Some((update_callback, Self::get_veilid_state_inner(&inner)))
Some(Self::get_veilid_state_inner(&inner))
} else {
None
}
@ -180,15 +147,14 @@ impl AttachmentManager {
// Send the update outside of the lock
if let Some(update) = opt_update {
(update.0)(VeilidUpdate::Attachment(update.1));
(self.update_callback())(VeilidUpdate::Attachment(update));
}
}
fn update_attaching_detaching_state(&self, state: AttachmentState) {
let uptime;
let attached_uptime;
let update_callback = {
{
let mut inner = self.inner.lock();
// Clear routing table health so when we start measuring it we start from scratch
@ -211,29 +177,98 @@ impl AttachmentManager {
let now = Timestamp::now();
uptime = now - inner.started_ts;
attached_uptime = inner.attach_ts.map(|ts| now - ts);
// Get callback
inner.update_callback.clone()
};
// Send update
if let Some(update_callback) = update_callback {
update_callback(VeilidUpdate::Attachment(Box::new(VeilidStateAttachment {
state,
public_internet_ready: false,
local_network_ready: false,
uptime,
attached_uptime,
})))
(self.update_callback())(VeilidUpdate::Attachment(Box::new(VeilidStateAttachment {
state,
public_internet_ready: false,
local_network_ready: false,
uptime,
attached_uptime,
})))
}
async fn startup(&self) -> EyreResult<StartupDisposition> {
let guard = self.startup_context.startup_lock.startup()?;
let rpc_processor = self.rpc_processor();
let network_manager = self.network_manager();
// Startup network manager
network_manager.startup().await?;
// Startup rpc processor
if let Err(e) = rpc_processor.startup().await {
network_manager.shutdown().await;
return Err(e);
}
// Startup routing table
let routing_table = self.routing_table();
if let Err(e) = routing_table.startup().await {
rpc_processor.shutdown().await;
network_manager.shutdown().await;
return Err(e);
}
// Startup successful
guard.success();
// Inform api clients that things have changed
log_net!(debug "sending network state update to api clients");
network_manager.send_network_update();
Ok(StartupDisposition::Success)
}
async fn shutdown(&self) {
let guard = self
.startup_context
.startup_lock
.shutdown()
.await
.expect("should be started up");
let routing_table = self.routing_table();
let rpc_processor = self.rpc_processor();
let network_manager = self.network_manager();
// Shutdown RoutingTable
routing_table.shutdown().await;
// Shutdown NetworkManager
network_manager.shutdown().await;
// Shutdown RPCProcessor
rpc_processor.shutdown().await;
// Shutdown successful
guard.success();
// send update
log_net!(debug "sending network state update to api clients");
network_manager.send_network_update();
}
async fn tick(&self) -> EyreResult<()> {
// Run the network manager tick
let network_manager = self.network_manager();
network_manager.tick().await?;
// Run the routing table tick
let routing_table = self.routing_table();
routing_table.tick().await?;
Ok(())
}
#[instrument(parent = None, level = "debug", skip_all)]
async fn attachment_maintainer(self) {
async fn attachment_maintainer(&self) {
log_net!(debug "attachment starting");
self.update_attaching_detaching_state(AttachmentState::Attaching);
let netman = self.network_manager();
let network_manager = self.network_manager();
let mut restart;
let mut restart_delay;
@ -241,9 +276,9 @@ impl AttachmentManager {
restart = false;
restart_delay = 1;
match netman.startup().await {
match self.startup().await {
Err(err) => {
error!("network startup failed: {}", err);
error!("attachment startup failed: {}", err);
restart = true;
}
Ok(StartupDisposition::BindRetry) => {
@ -257,15 +292,15 @@ impl AttachmentManager {
while self.inner.lock().maintain_peers {
// tick network manager
let next_tick_ts = get_timestamp() + 1_000_000u64;
if let Err(err) = netman.tick().await {
error!("Error in network manager: {}", err);
if let Err(err) = self.tick().await {
error!("Error in attachment tick: {}", err);
self.inner.lock().maintain_peers = false;
restart = true;
break;
}
// see if we need to restart the network
if netman.network_needs_restart() {
if network_manager.network_needs_restart() {
info!("Restarting network");
restart = true;
break;
@ -288,8 +323,8 @@ impl AttachmentManager {
log_net!(debug "attachment stopping");
}
log_net!(debug "stopping network");
netman.shutdown().await;
log_net!(debug "shutting down attachment");
self.shutdown().await;
}
}
@ -313,25 +348,24 @@ impl AttachmentManager {
}
#[instrument(level = "debug", skip_all, err)]
pub async fn init(&self, update_callback: UpdateCallback) -> EyreResult<()> {
{
let mut inner = self.inner.lock();
inner.update_callback = Some(update_callback.clone());
}
self.network_manager().init(update_callback).await?;
pub async fn init_async(&self) -> EyreResult<()> {
Ok(())
}
#[instrument(level = "debug", skip_all, err)]
pub async fn post_init_async(&self) -> EyreResult<()> {
Ok(())
}
#[instrument(level = "debug", skip_all)]
pub async fn terminate(&self) {
pub async fn pre_terminate_async(&self) {
// Ensure we detached
self.detach().await;
self.network_manager().terminate().await;
self.inner.lock().update_callback = None;
}
#[instrument(level = "debug", skip_all)]
pub async fn terminate_async(&self) {}
#[instrument(level = "trace", skip_all)]
pub async fn attach(&self) -> bool {
// Create long-running connection maintenance routine
@ -340,10 +374,11 @@ impl AttachmentManager {
return false;
}
inner.maintain_peers = true;
inner.attachment_maintainer_jh = Some(spawn(
"attachment maintainer",
self.clone().attachment_maintainer(),
));
let registry = self.registry();
inner.attachment_maintainer_jh = Some(spawn("attachment maintainer", async move {
let this = registry.attachment_manager();
this.attachment_maintainer().await;
}));
true
}

View File

@ -0,0 +1,336 @@
use std::marker::PhantomData;
use super::*;
pub trait AsAnyArcSendSync {
fn as_any_arc_send_sync(self: Arc<Self>) -> Arc<dyn core::any::Any + Send + Sync>;
}
impl<T: Send + Sync + 'static> AsAnyArcSendSync for T {
fn as_any_arc_send_sync(self: Arc<Self>) -> Arc<dyn core::any::Any + Send + Sync> {
self
}
}
pub trait VeilidComponent:
AsAnyArcSendSync + VeilidComponentRegistryAccessor + core::fmt::Debug
{
fn init(&self) -> SendPinBoxFutureLifetime<'_, EyreResult<()>>;
fn post_init(&self) -> SendPinBoxFutureLifetime<'_, EyreResult<()>>;
fn pre_terminate(&self) -> SendPinBoxFutureLifetime<'_, ()>;
fn terminate(&self) -> SendPinBoxFutureLifetime<'_, ()>;
}
pub trait VeilidComponentRegistryAccessor {
fn registry(&self) -> VeilidComponentRegistry;
fn config(&self) -> VeilidConfig {
self.registry().config.clone()
}
fn update_callback(&self) -> UpdateCallback {
self.registry().config.update_callback()
}
fn event_bus(&self) -> EventBus {
self.registry().event_bus.clone()
}
}
pub struct VeilidComponentGuard<'a, T: VeilidComponent + Send + Sync + 'static> {
component: Arc<T>,
_phantom: core::marker::PhantomData<&'a T>,
}
impl<'a, T> core::ops::Deref for VeilidComponentGuard<'a, T>
where
T: VeilidComponent + Send + Sync + 'static,
{
type Target = T;
fn deref(&self) -> &Self::Target {
&self.component
}
}
#[derive(Debug)]
struct VeilidComponentRegistryInner {
type_map: HashMap<core::any::TypeId, Arc<dyn VeilidComponent + Send + Sync>>,
init_order: Vec<core::any::TypeId>,
mock: bool,
}
#[derive(Clone, Debug)]
pub struct VeilidComponentRegistry {
inner: Arc<Mutex<VeilidComponentRegistryInner>>,
config: VeilidConfig,
event_bus: EventBus,
init_lock: Arc<AsyncMutex<bool>>,
}
impl VeilidComponentRegistry {
pub fn new(config: VeilidConfig) -> Self {
Self {
inner: Arc::new(Mutex::new(VeilidComponentRegistryInner {
type_map: HashMap::new(),
init_order: Vec::new(),
mock: false,
})),
config,
event_bus: EventBus::new(),
init_lock: Arc::new(AsyncMutex::new(false)),
}
}
pub fn enable_mock(&self) {
let mut inner = self.inner.lock();
inner.mock = true;
}
pub fn register<
T: VeilidComponent + Send + Sync + 'static,
F: FnOnce(VeilidComponentRegistry) -> T,
>(
&self,
component_constructor: F,
) {
let component = Arc::new(component_constructor(self.clone()));
let component_type_id = core::any::TypeId::of::<T>();
let mut inner = self.inner.lock();
assert!(
inner
.type_map
.insert(component_type_id, component)
.is_none(),
"should not register same component twice"
);
inner.init_order.push(component_type_id);
}
pub fn register_with_context<
C,
T: VeilidComponent + Send + Sync + 'static,
F: FnOnce(VeilidComponentRegistry, C) -> T,
>(
&self,
component_constructor: F,
context: C,
) {
let component = Arc::new(component_constructor(self.clone(), context));
let component_type_id = core::any::TypeId::of::<T>();
let mut inner = self.inner.lock();
assert!(
inner
.type_map
.insert(component_type_id, component)
.is_none(),
"should not register same component twice"
);
inner.init_order.push(component_type_id);
}
pub async fn init(&self) -> EyreResult<()> {
let Some(mut _init_guard) = asyncmutex_try_lock!(self.init_lock) else {
bail!("init should only happen one at a time");
};
if *_init_guard {
bail!("already initialized");
}
// Event bus starts up early
self.event_bus.startup().await?;
// Process components in initialization order
let init_order = self.get_init_order();
let mut initialized = vec![];
for component in init_order {
if let Err(e) = component.init().await {
self.terminate_inner(initialized).await;
self.event_bus.shutdown().await;
return Err(e);
}
initialized.push(component);
}
*_init_guard = true;
Ok(())
}
pub async fn post_init(&self) -> EyreResult<()> {
let Some(mut _init_guard) = asyncmutex_try_lock!(self.init_lock) else {
bail!("init should only happen one at a time");
};
if !*_init_guard {
bail!("not initialized");
}
let init_order = self.get_init_order();
let mut post_initialized = vec![];
for component in init_order {
if let Err(e) = component.post_init().await {
self.pre_terminate_inner(post_initialized).await;
return Err(e);
}
post_initialized.push(component)
}
Ok(())
}
pub async fn pre_terminate(&self) {
let Some(mut _init_guard) = asyncmutex_try_lock!(self.init_lock) else {
panic!("terminate should only happen one at a time");
};
if !*_init_guard {
panic!("not initialized");
}
let init_order = self.get_init_order();
self.pre_terminate_inner(init_order).await;
}
pub async fn terminate(&self) {
let Some(mut _init_guard) = asyncmutex_try_lock!(self.init_lock) else {
panic!("terminate should only happen one at a time");
};
if !*_init_guard {
panic!("not initialized");
}
// Terminate components in reverse initialization order
let init_order = self.get_init_order();
self.terminate_inner(init_order).await;
// Event bus shuts down last
self.event_bus.shutdown().await;
*_init_guard = false;
}
async fn pre_terminate_inner(
&self,
pre_initialized: Vec<Arc<dyn VeilidComponent + Send + Sync>>,
) {
for component in pre_initialized.iter().rev() {
component.pre_terminate().await;
}
}
async fn terminate_inner(&self, initialized: Vec<Arc<dyn VeilidComponent + Send + Sync>>) {
for component in initialized.iter().rev() {
component.terminate().await;
}
}
fn get_init_order(&self) -> Vec<Arc<dyn VeilidComponent + Send + Sync>> {
let inner = self.inner.lock();
inner
.init_order
.iter()
.map(|id| inner.type_map.get(id).unwrap().clone())
.collect::<Vec<_>>()
}
//////////////////////////////////////////////////////////////
pub fn lookup<'a, T: VeilidComponent + Send + Sync + 'static>(
&self,
) -> Option<VeilidComponentGuard<'a, T>> {
let inner = self.inner.lock();
let component_type_id = core::any::TypeId::of::<T>();
let component_dyn = inner.type_map.get(&component_type_id)?.clone();
let component = component_dyn
.as_any_arc_send_sync()
.downcast::<T>()
.unwrap();
Some(VeilidComponentGuard {
component,
_phantom: PhantomData {},
})
}
}
impl VeilidComponentRegistryAccessor for VeilidComponentRegistry {
fn registry(&self) -> VeilidComponentRegistry {
self.clone()
}
}
////////////////////////////////////////////////////////////////////
macro_rules! impl_veilid_component_registry_accessor {
($struct_name:ident) => {
impl VeilidComponentRegistryAccessor for $struct_name {
fn registry(&self) -> VeilidComponentRegistry {
self.registry.clone()
}
}
};
}
pub(crate) use impl_veilid_component_registry_accessor;
/////////////////////////////////////////////////////////////////////
macro_rules! impl_veilid_component {
($component_name:ident) => {
impl_veilid_component_registry_accessor!($component_name);
impl VeilidComponent for $component_name {
fn init(&self) -> SendPinBoxFutureLifetime<'_, EyreResult<()>> {
Box::pin(async { self.init_async().await })
}
fn post_init(&self) -> SendPinBoxFutureLifetime<'_, EyreResult<()>> {
Box::pin(async { self.post_init_async().await })
}
fn pre_terminate(&self) -> SendPinBoxFutureLifetime<'_, ()> {
Box::pin(async { self.pre_terminate_async().await })
}
fn terminate(&self) -> SendPinBoxFutureLifetime<'_, ()> {
Box::pin(async { self.terminate_async().await })
}
}
};
}
pub(crate) use impl_veilid_component;
/////////////////////////////////////////////////////////////////////
// Utility macro for setting up a background TickTask
// Should be called during new/construction of a component with background tasks
// and before any post-init 'tick' operations are started
macro_rules! impl_setup_task {
($this:expr, $this_type:ty, $task_name:ident, $task_routine:ident ) => {{
let registry = $this.registry();
$this.$task_name.set_routine(move |s, l, t| {
let registry = registry.clone();
Box::pin(async move {
let this = registry.lookup::<$this_type>().unwrap();
this.$task_routine(s, Timestamp::new(l), Timestamp::new(t))
.await
})
});
}};
}
pub(crate) use impl_setup_task;
// Utility macro for setting up an event bus handler
// Should be called after init, during post-init or later
// Subscription should be unsubscribed before termination
macro_rules! impl_subscribe_event_bus {
($this:expr, $this_type:ty, $event_handler:ident ) => {{
let registry = $this.registry();
$this.event_bus().subscribe(move |evt| {
let registry = registry.clone();
Box::pin(async move {
let this = registry.lookup::<$this_type>().unwrap();
this.$event_handler(evt);
})
})
}};
}
pub(crate) use impl_subscribe_event_bus;

View File

@ -1,242 +1,26 @@
use crate::attachment_manager::*;
use crate::attachment_manager::{AttachmentManager, AttachmentManagerStartupContext};
use crate::crypto::Crypto;
use crate::logging::*;
use crate::storage_manager::*;
use crate::network_manager::{NetworkManager, NetworkManagerStartupContext};
use crate::routing_table::RoutingTable;
use crate::rpc_processor::{RPCProcessor, RPCProcessorStartupContext};
use crate::storage_manager::StorageManager;
use crate::veilid_api::*;
use crate::veilid_config::*;
use crate::*;
pub type UpdateCallback = Arc<dyn Fn(VeilidUpdate) + Send + Sync>;
/// Internal services startup mechanism.
/// Ensures that everything is started up, and shut down in the right order
/// and provides an atomic state for if the system is properly operational.
struct StartupShutdownContext {
pub config: VeilidConfig,
pub update_callback: UpdateCallback,
pub event_bus: Option<EventBus>,
pub protected_store: Option<ProtectedStore>,
pub table_store: Option<TableStore>,
#[cfg(feature = "unstable-blockstore")]
pub block_store: Option<BlockStore>,
pub crypto: Option<Crypto>,
pub attachment_manager: Option<AttachmentManager>,
pub storage_manager: Option<StorageManager>,
}
impl StartupShutdownContext {
pub fn new_empty(config: VeilidConfig, update_callback: UpdateCallback) -> Self {
Self {
config,
update_callback,
event_bus: None,
protected_store: None,
table_store: None,
#[cfg(feature = "unstable-blockstore")]
block_store: None,
crypto: None,
attachment_manager: None,
storage_manager: None,
}
}
#[allow(clippy::too_many_arguments)]
pub fn new_full(
config: VeilidConfig,
update_callback: UpdateCallback,
event_bus: EventBus,
protected_store: ProtectedStore,
table_store: TableStore,
#[cfg(feature = "unstable-blockstore")] block_store: BlockStore,
crypto: Crypto,
attachment_manager: AttachmentManager,
storage_manager: StorageManager,
) -> Self {
Self {
config,
update_callback,
event_bus: Some(event_bus),
protected_store: Some(protected_store),
table_store: Some(table_store),
#[cfg(feature = "unstable-blockstore")]
block_store: Some(block_store),
crypto: Some(crypto),
attachment_manager: Some(attachment_manager),
storage_manager: Some(storage_manager),
}
}
#[instrument(level = "trace", target = "core_context", err, skip_all)]
pub async fn startup(&mut self) -> EyreResult<()> {
info!("Veilid API starting up");
info!("init api tracing");
let (program_name, namespace) = {
let config = self.config.get();
(config.program_name.clone(), config.namespace.clone())
};
ApiTracingLayer::add_callback(program_name, namespace, self.update_callback.clone())
.await?;
// Add the event bus
let event_bus = EventBus::new();
if let Err(e) = event_bus.startup().await {
error!("failed to start up event bus: {}", e);
self.shutdown().await;
return Err(e.into());
}
self.event_bus = Some(event_bus.clone());
// Set up protected store
let protected_store = ProtectedStore::new(event_bus.clone(), self.config.clone());
if let Err(e) = protected_store.init().await {
error!("failed to init protected store: {}", e);
self.shutdown().await;
return Err(e);
}
self.protected_store = Some(protected_store.clone());
// Set up tablestore and crypto system
let table_store = TableStore::new(
event_bus.clone(),
self.config.clone(),
protected_store.clone(),
);
let crypto = Crypto::new(event_bus.clone(), self.config.clone(), table_store.clone());
table_store.set_crypto(crypto.clone());
// Initialize table store first, so crypto code can load caches
// Tablestore can use crypto during init, just not any cached operations or things
// that require flushing back to the tablestore
if let Err(e) = table_store.init().await {
error!("failed to init table store: {}", e);
self.shutdown().await;
return Err(e);
}
self.table_store = Some(table_store.clone());
// Set up crypto
if let Err(e) = crypto.init().await {
error!("failed to init crypto: {}", e);
self.shutdown().await;
return Err(e);
}
self.crypto = Some(crypto.clone());
// Set up block store
#[cfg(feature = "unstable-blockstore")]
{
let block_store = BlockStore::new(event_bus.clone(), self.config.clone());
if let Err(e) = block_store.init().await {
error!("failed to init block store: {}", e);
self.shutdown().await;
return Err(e);
}
self.block_store = Some(block_store.clone());
}
// Set up storage manager
let update_callback = self.update_callback.clone();
let storage_manager = StorageManager::new(
event_bus.clone(),
self.config.clone(),
self.crypto.clone().unwrap(),
self.table_store.clone().unwrap(),
#[cfg(feature = "unstable-blockstore")]
self.block_store.clone().unwrap(),
);
if let Err(e) = storage_manager.init(update_callback).await {
error!("failed to init storage manager: {}", e);
self.shutdown().await;
return Err(e);
}
self.storage_manager = Some(storage_manager.clone());
// Set up attachment manager
let update_callback = self.update_callback.clone();
let attachment_manager = AttachmentManager::new(
event_bus.clone(),
self.config.clone(),
storage_manager,
table_store,
#[cfg(feature = "unstable-blockstore")]
block_store,
crypto,
);
if let Err(e) = attachment_manager.init(update_callback).await {
error!("failed to init attachment manager: {}", e);
self.shutdown().await;
return Err(e);
}
self.attachment_manager = Some(attachment_manager);
info!("Veilid API startup complete");
Ok(())
}
#[instrument(level = "trace", target = "core_context", skip_all)]
pub async fn shutdown(&mut self) {
info!("Veilid API shutting down");
if let Some(attachment_manager) = &mut self.attachment_manager {
attachment_manager.terminate().await;
}
if let Some(storage_manager) = &mut self.storage_manager {
storage_manager.terminate().await;
}
#[cfg(feature = "unstable-blockstore")]
if let Some(block_store) = &mut self.block_store {
block_store.terminate().await;
}
if let Some(crypto) = &mut self.crypto {
crypto.terminate().await;
}
if let Some(table_store) = &mut self.table_store {
table_store.terminate().await;
}
if let Some(protected_store) = &mut self.protected_store {
protected_store.terminate().await;
}
if let Some(event_bus) = &mut self.event_bus {
event_bus.shutdown().await;
}
info!("Veilid API shutdown complete");
// api logger terminate is idempotent
let (program_name, namespace) = {
let config = self.config.get();
(config.program_name.clone(), config.namespace.clone())
};
if let Err(e) = ApiTracingLayer::remove_callback(program_name, namespace).await {
error!("Error removing callback from ApiTracingLayer: {}", e);
}
// send final shutdown update
(self.update_callback)(VeilidUpdate::Shutdown);
}
}
type InitKey = (String, String);
/////////////////////////////////////////////////////////////////////////////
pub struct VeilidCoreContext {
pub config: VeilidConfig,
pub update_callback: UpdateCallback,
// Event bus
pub event_bus: EventBus,
// Services
pub storage_manager: StorageManager,
pub protected_store: ProtectedStore,
pub table_store: TableStore,
#[cfg(feature = "unstable-blockstore")]
pub block_store: BlockStore,
pub crypto: Crypto,
pub attachment_manager: AttachmentManager,
#[derive(Clone, Debug)]
pub(crate) struct VeilidCoreContext {
registry: VeilidComponentRegistry,
}
impl_veilid_component_registry_accessor!(VeilidCoreContext);
impl VeilidCoreContext {
#[instrument(level = "trace", target = "core_context", err, skip_all)]
async fn new_with_config_callback(
@ -244,10 +28,9 @@ impl VeilidCoreContext {
config_callback: ConfigCallback,
) -> VeilidAPIResult<VeilidCoreContext> {
// Set up config from callback
let mut config = VeilidConfig::new();
config.setup(config_callback, update_callback.clone())?;
let config = VeilidConfig::new_from_callback(config_callback, update_callback)?;
Self::new_common(update_callback, config).await
Self::new_common(config).await
}
#[instrument(level = "trace", target = "core_context", err, skip_all)]
@ -256,16 +39,12 @@ impl VeilidCoreContext {
config_inner: VeilidConfigInner,
) -> VeilidAPIResult<VeilidCoreContext> {
// Set up config from json
let mut config = VeilidConfig::new();
config.setup_from_config(config_inner, update_callback.clone())?;
Self::new_common(update_callback, config).await
let config = VeilidConfig::new_from_config(config_inner, update_callback);
Self::new_common(config).await
}
#[instrument(level = "trace", target = "core_context", err, skip_all)]
async fn new_common(
update_callback: UpdateCallback,
config: VeilidConfig,
) -> VeilidAPIResult<VeilidCoreContext> {
async fn new_common(config: VeilidConfig) -> VeilidAPIResult<VeilidCoreContext> {
cfg_if! {
if #[cfg(target_os = "android")] {
if !crate::intf::android::is_android_ready() {
@ -274,45 +53,134 @@ impl VeilidCoreContext {
}
}
let mut sc = StartupShutdownContext::new_empty(config.clone(), update_callback);
sc.startup().await.map_err(VeilidAPIError::generic)?;
info!("Veilid API starting up");
Ok(VeilidCoreContext {
config: sc.config,
update_callback: sc.update_callback,
event_bus: sc.event_bus.unwrap(),
storage_manager: sc.storage_manager.unwrap(),
protected_store: sc.protected_store.unwrap(),
table_store: sc.table_store.unwrap(),
#[cfg(feature = "unstable-blockstore")]
block_store: sc.block_store.unwrap(),
crypto: sc.crypto.unwrap(),
attachment_manager: sc.attachment_manager.unwrap(),
})
let (program_name, namespace, update_callback) = {
let cfginner = config.get();
(
cfginner.program_name.clone(),
cfginner.namespace.clone(),
config.update_callback(),
)
};
ApiTracingLayer::add_callback(program_name, namespace, update_callback.clone()).await?;
// Create component registry
let registry = VeilidComponentRegistry::new(config);
// Register all components
registry.register(ProtectedStore::new);
registry.register(Crypto::new);
registry.register(TableStore::new);
#[cfg(feature = "unstable-blockstore")]
registry.register(BlockStore::new);
registry.register(StorageManager::new);
registry.register(RoutingTable::new);
registry
.register_with_context(NetworkManager::new, NetworkManagerStartupContext::default());
registry.register_with_context(RPCProcessor::new, RPCProcessorStartupContext::default());
registry.register_with_context(
AttachmentManager::new,
AttachmentManagerStartupContext::default(),
);
// Run initialization
// This should make the majority of subsystems functional
registry.init().await.map_err(VeilidAPIError::internal)?;
// Run post-initialization
// This should resolve any inter-subsystem dependencies
// required for background processes that utilize multiple subsystems
// Background processes also often require registry lookup of the
// current subsystem, which is not available until after init succeeds
if let Err(e) = registry.post_init().await {
registry.terminate().await;
return Err(VeilidAPIError::internal(e));
}
info!("Veilid API startup complete");
Ok(Self { registry })
}
#[instrument(level = "trace", target = "core_context", skip_all)]
async fn shutdown(self) {
let mut sc = StartupShutdownContext::new_full(
self.config.clone(),
self.update_callback.clone(),
self.event_bus,
self.protected_store,
self.table_store,
#[cfg(feature = "unstable-blockstore")]
self.block_store,
self.crypto,
self.attachment_manager,
self.storage_manager,
);
sc.shutdown().await;
info!("Veilid API shutdown complete");
let (program_name, namespace, update_callback) = {
let config = self.registry.config();
let cfginner = config.get();
(
cfginner.program_name.clone(),
cfginner.namespace.clone(),
config.update_callback(),
)
};
// Run pre-termination
// This should shut down background processes that may require the existence of
// other subsystems that may not exist during final termination
self.registry.pre_terminate().await;
// Run termination
// This should finish any shutdown operations for the subsystems
self.registry.terminate().await;
if let Err(e) = ApiTracingLayer::remove_callback(program_name, namespace).await {
error!("Error removing callback from ApiTracingLayer: {}", e);
}
// send final shutdown update
update_callback(VeilidUpdate::Shutdown);
}
}
/////////////////////////////////////////////////////////////////////////////
pub trait RegisteredComponents {
fn protected_store<'a>(&self) -> VeilidComponentGuard<'a, ProtectedStore>;
fn crypto<'a>(&self) -> VeilidComponentGuard<'a, Crypto>;
fn table_store<'a>(&self) -> VeilidComponentGuard<'a, TableStore>;
fn storage_manager<'a>(&self) -> VeilidComponentGuard<'a, StorageManager>;
fn routing_table<'a>(&self) -> VeilidComponentGuard<'a, RoutingTable>;
fn network_manager<'a>(&self) -> VeilidComponentGuard<'a, NetworkManager>;
fn rpc_processor<'a>(&self) -> VeilidComponentGuard<'a, RPCProcessor>;
fn attachment_manager<'a>(&self) -> VeilidComponentGuard<'a, AttachmentManager>;
}
impl<T: VeilidComponentRegistryAccessor> RegisteredComponents for T {
fn protected_store<'a>(&self) -> VeilidComponentGuard<'a, ProtectedStore> {
self.registry().lookup::<ProtectedStore>().unwrap()
}
fn crypto<'a>(&self) -> VeilidComponentGuard<'a, Crypto> {
self.registry().lookup::<Crypto>().unwrap()
}
fn table_store<'a>(&self) -> VeilidComponentGuard<'a, TableStore> {
self.registry().lookup::<TableStore>().unwrap()
}
fn storage_manager<'a>(&self) -> VeilidComponentGuard<'a, StorageManager> {
self.registry().lookup::<StorageManager>().unwrap()
}
fn routing_table<'a>(&self) -> VeilidComponentGuard<'a, RoutingTable> {
self.registry().lookup::<RoutingTable>().unwrap()
}
fn network_manager<'a>(&self) -> VeilidComponentGuard<'a, NetworkManager> {
self.registry().lookup::<NetworkManager>().unwrap()
}
fn rpc_processor<'a>(&self) -> VeilidComponentGuard<'a, RPCProcessor> {
self.registry().lookup::<RPCProcessor>().unwrap()
}
fn attachment_manager<'a>(&self) -> VeilidComponentGuard<'a, AttachmentManager> {
self.registry().lookup::<AttachmentManager>().unwrap()
}
}
/////////////////////////////////////////////////////////////////////////////
lazy_static::lazy_static! {
static ref INITIALIZED: AsyncMutex<HashSet<(String,String)>> = AsyncMutex::new(HashSet::new());
static ref INITIALIZED: Mutex<HashSet<InitKey>> = Mutex::new(HashSet::new());
static ref STARTUP_TABLE: AsyncTagLockTable<InitKey> = AsyncTagLockTable::new();
}
/// Initialize a Veilid node.
@ -345,9 +213,11 @@ pub async fn api_startup(
})?;
let init_key = (program_name, namespace);
// Only allow one startup/shutdown per program_name+namespace combination simultaneously
let _tag_guard = STARTUP_TABLE.lock_tag(init_key.clone()).await;
// See if we have an API started up already
let mut initialized_lock = INITIALIZED.lock().await;
if initialized_lock.contains(&init_key) {
if INITIALIZED.lock().contains(&init_key) {
apibail_already_initialized!();
}
@ -358,7 +228,8 @@ pub async fn api_startup(
// Return an API object around our context
let veilid_api = VeilidAPI::new(context);
initialized_lock.insert(init_key);
// Add to the initialized set
INITIALIZED.lock().insert(init_key);
Ok(veilid_api)
}
@ -403,12 +274,13 @@ pub async fn api_startup_config(
// Get the program_name and namespace we're starting up in
let program_name = config.program_name.clone();
let namespace = config.namespace.clone();
let init_key = (program_name, namespace);
// Only allow one startup/shutdown per program_name+namespace combination simultaneously
let _tag_guard = STARTUP_TABLE.lock_tag(init_key.clone()).await;
// See if we have an API started up already
let mut initialized_lock = INITIALIZED.lock().await;
if initialized_lock.contains(&init_key) {
if INITIALIZED.lock().contains(&init_key) {
apibail_already_initialized!();
}
@ -418,20 +290,32 @@ pub async fn api_startup_config(
// Return an API object around our context
let veilid_api = VeilidAPI::new(context);
initialized_lock.insert(init_key);
// Add to the initialized set
INITIALIZED.lock().insert(init_key);
Ok(veilid_api)
}
#[instrument(level = "trace", target = "core_context", skip_all)]
pub async fn api_shutdown(context: VeilidCoreContext) {
let mut initialized_lock = INITIALIZED.lock().await;
pub(crate) async fn api_shutdown(context: VeilidCoreContext) {
let init_key = {
let config = context.config.get();
(config.program_name.clone(), config.namespace.clone())
let registry = context.registry();
let config = registry.config();
let cfginner = config.get();
(cfginner.program_name.clone(), cfginner.namespace.clone())
};
// Only allow one startup/shutdown per program_name+namespace combination simultaneously
let _tag_guard = STARTUP_TABLE.lock_tag(init_key.clone()).await;
// See if we have an API started up already
if !INITIALIZED.lock().contains(&init_key) {
return;
}
// Shutdown the context
context.shutdown().await;
initialized_lock.remove(&init_key);
// Remove from the initialized set
INITIALIZED.lock().remove(&init_key);
}

View File

@ -1,11 +1,11 @@
use super::*;
const VEILID_DOMAIN_API: &[u8] = b"VEILID_API";
pub(crate) const VEILID_DOMAIN_API: &[u8] = b"VEILID_API";
pub trait CryptoSystem {
// Accessors
fn kind(&self) -> CryptoKind;
fn crypto(&self) -> Crypto;
fn crypto(&self) -> VeilidComponentGuard<'_, Crypto>;
// Cached Operations
fn cached_dh(&self, key: &PublicKey, secret: &SecretKey) -> VeilidAPIResult<SharedSecret>;

View File

@ -67,7 +67,7 @@ impl Envelope {
#[instrument(level = "trace", target = "envelope", skip_all)]
pub fn from_signed_data(
crypto: Crypto,
crypto: &Crypto,
data: &[u8],
network_key: &Option<SharedSecret>,
) -> VeilidAPIResult<Envelope> {
@ -193,7 +193,7 @@ impl Envelope {
#[instrument(level = "trace", target = "envelope", skip_all)]
pub fn decrypt_body(
&self,
crypto: Crypto,
crypto: &Crypto,
data: &[u8],
node_id_secret: &SecretKey,
network_key: &Option<SharedSecret>,
@ -226,7 +226,7 @@ impl Envelope {
#[instrument(level = "trace", target = "envelope", skip_all, err)]
pub fn to_encrypted_data(
&self,
crypto: Crypto,
crypto: &Crypto,
body: &[u8],
node_id_secret: &SecretKey,
network_key: &Option<SharedSecret>,

View File

@ -0,0 +1,276 @@
use super::*;
/// Guard to access a particular cryptosystem
pub struct CryptoSystemGuard<'a> {
crypto_system: Arc<dyn CryptoSystem + Send + Sync>,
_phantom: core::marker::PhantomData<&'a (dyn CryptoSystem + Send + Sync)>,
}
impl<'a> CryptoSystemGuard<'a> {
pub(super) fn new(crypto_system: Arc<dyn CryptoSystem + Send + Sync>) -> Self {
Self {
crypto_system,
_phantom: PhantomData,
}
}
pub fn as_async(self) -> AsyncCryptoSystemGuard<'a> {
AsyncCryptoSystemGuard { guard: self }
}
}
impl<'a> core::ops::Deref for CryptoSystemGuard<'a> {
type Target = dyn CryptoSystem + Send + Sync;
fn deref(&self) -> &Self::Target {
self.crypto_system.as_ref()
}
}
/// Async cryptosystem guard to help break up heavy blocking operations
pub struct AsyncCryptoSystemGuard<'a> {
guard: CryptoSystemGuard<'a>,
}
async fn yielding<R, T: FnOnce() -> R>(x: T) -> R {
let out = x();
sleep(0).await;
out
}
impl<'a> AsyncCryptoSystemGuard<'a> {
// Accessors
pub fn kind(&self) -> CryptoKind {
self.guard.kind()
}
pub fn crypto(&self) -> VeilidComponentGuard<'_, Crypto> {
self.guard.crypto()
}
// Cached Operations
pub async fn cached_dh(
&self,
key: &PublicKey,
secret: &SecretKey,
) -> VeilidAPIResult<SharedSecret> {
yielding(|| self.guard.cached_dh(key, secret)).await
}
// Generation
pub async fn random_bytes(&self, len: u32) -> Vec<u8> {
yielding(|| self.guard.random_bytes(len)).await
}
pub fn default_salt_length(&self) -> u32 {
self.guard.default_salt_length()
}
pub async fn hash_password(&self, password: &[u8], salt: &[u8]) -> VeilidAPIResult<String> {
yielding(|| self.guard.hash_password(password, salt)).await
}
pub async fn verify_password(
&self,
password: &[u8],
password_hash: &str,
) -> VeilidAPIResult<bool> {
yielding(|| self.guard.verify_password(password, password_hash)).await
}
pub async fn derive_shared_secret(
&self,
password: &[u8],
salt: &[u8],
) -> VeilidAPIResult<SharedSecret> {
yielding(|| self.guard.derive_shared_secret(password, salt)).await
}
pub async fn random_nonce(&self) -> Nonce {
yielding(|| self.guard.random_nonce()).await
}
pub async fn random_shared_secret(&self) -> SharedSecret {
yielding(|| self.guard.random_shared_secret()).await
}
pub async fn compute_dh(
&self,
key: &PublicKey,
secret: &SecretKey,
) -> VeilidAPIResult<SharedSecret> {
yielding(|| self.guard.compute_dh(key, secret)).await
}
pub async fn generate_shared_secret(
&self,
key: &PublicKey,
secret: &SecretKey,
domain: &[u8],
) -> VeilidAPIResult<SharedSecret> {
let dh = self.compute_dh(key, secret).await?;
Ok(self
.generate_hash(&[&dh.bytes, domain, VEILID_DOMAIN_API].concat())
.await)
}
pub async fn generate_keypair(&self) -> KeyPair {
yielding(|| self.guard.generate_keypair()).await
}
pub async fn generate_hash(&self, data: &[u8]) -> HashDigest {
yielding(|| self.guard.generate_hash(data)).await
}
pub async fn generate_hash_reader(
&self,
reader: &mut dyn std::io::Read,
) -> VeilidAPIResult<HashDigest> {
yielding(|| self.guard.generate_hash_reader(reader)).await
}
// Validation
pub async fn validate_keypair(&self, key: &PublicKey, secret: &SecretKey) -> bool {
yielding(|| self.guard.validate_keypair(key, secret)).await
}
pub async fn validate_hash(&self, data: &[u8], hash: &HashDigest) -> bool {
yielding(|| self.guard.validate_hash(data, hash)).await
}
pub async fn validate_hash_reader(
&self,
reader: &mut dyn std::io::Read,
hash: &HashDigest,
) -> VeilidAPIResult<bool> {
yielding(|| self.guard.validate_hash_reader(reader, hash)).await
}
// Distance Metric
pub async fn distance(&self, key1: &CryptoKey, key2: &CryptoKey) -> CryptoKeyDistance {
yielding(|| self.guard.distance(key1, key2)).await
}
// Authentication
pub async fn sign(
&self,
key: &PublicKey,
secret: &SecretKey,
data: &[u8],
) -> VeilidAPIResult<Signature> {
yielding(|| self.guard.sign(key, secret, data)).await
}
pub async fn verify(
&self,
key: &PublicKey,
data: &[u8],
signature: &Signature,
) -> VeilidAPIResult<bool> {
yielding(|| self.guard.verify(key, data, signature)).await
}
// AEAD Encrypt/Decrypt
pub fn aead_overhead(&self) -> usize {
self.guard.aead_overhead()
}
pub async fn decrypt_in_place_aead(
&self,
body: &mut Vec<u8>,
nonce: &Nonce,
shared_secret: &SharedSecret,
associated_data: Option<&[u8]>,
) -> VeilidAPIResult<()> {
yielding(|| {
self.guard
.decrypt_in_place_aead(body, nonce, shared_secret, associated_data)
})
.await
}
pub async fn decrypt_aead(
&self,
body: &[u8],
nonce: &Nonce,
shared_secret: &SharedSecret,
associated_data: Option<&[u8]>,
) -> VeilidAPIResult<Vec<u8>> {
yielding(|| {
self.guard
.decrypt_aead(body, nonce, shared_secret, associated_data)
})
.await
}
pub async fn encrypt_in_place_aead(
&self,
body: &mut Vec<u8>,
nonce: &Nonce,
shared_secret: &SharedSecret,
associated_data: Option<&[u8]>,
) -> VeilidAPIResult<()> {
yielding(|| {
self.guard
.encrypt_in_place_aead(body, nonce, shared_secret, associated_data)
})
.await
}
pub async fn encrypt_aead(
&self,
body: &[u8],
nonce: &Nonce,
shared_secret: &SharedSecret,
associated_data: Option<&[u8]>,
) -> VeilidAPIResult<Vec<u8>> {
yielding(|| {
self.guard
.encrypt_aead(body, nonce, shared_secret, associated_data)
})
.await
}
// NoAuth Encrypt/Decrypt
pub async fn crypt_in_place_no_auth(
&self,
body: &mut [u8],
nonce: &[u8; NONCE_LENGTH],
shared_secret: &SharedSecret,
) {
yielding(|| {
self.guard
.crypt_in_place_no_auth(body, nonce, shared_secret)
})
.await
}
pub async fn crypt_b2b_no_auth(
&self,
in_buf: &[u8],
out_buf: &mut [u8],
nonce: &[u8; NONCE_LENGTH],
shared_secret: &SharedSecret,
) {
yielding(|| {
self.guard
.crypt_b2b_no_auth(in_buf, out_buf, nonce, shared_secret)
})
.await
}
pub async fn crypt_no_auth_aligned_8(
&self,
body: &[u8],
nonce: &[u8; NONCE_LENGTH],
shared_secret: &SharedSecret,
) -> Vec<u8> {
yielding(|| {
self.guard
.crypt_no_auth_aligned_8(body, nonce, shared_secret)
})
.await
}
pub async fn crypt_no_auth_unaligned(
&self,
body: &[u8],
nonce: &[u8; NONCE_LENGTH],
shared_secret: &SharedSecret,
) -> Vec<u8> {
yielding(|| {
self.guard
.crypt_no_auth_unaligned(body, nonce, shared_secret)
})
.await
}
}

View File

@ -1,6 +1,7 @@
mod blake3digest512;
mod dh_cache;
mod envelope;
mod guard;
mod receipt;
mod types;
@ -16,6 +17,7 @@ pub use blake3digest512::*;
pub use crypto_system::*;
pub use envelope::*;
pub use guard::*;
pub use receipt::*;
pub use types::*;
@ -29,9 +31,7 @@ use core::convert::TryInto;
use dh_cache::*;
use hashlink::linked_hash_map::Entry;
use hashlink::LruCache;
/// Handle to a particular cryptosystem
pub type CryptoSystemVersion = Arc<dyn CryptoSystem + Send + Sync>;
use std::marker::PhantomData;
cfg_if! {
if #[cfg(all(feature = "enable-crypto-none", feature = "enable-crypto-vld0"))] {
@ -72,23 +72,40 @@ pub fn best_envelope_version() -> EnvelopeVersion {
struct CryptoInner {
dh_cache: DHCache,
flush_future: Option<SendPinBoxFuture<()>>,
#[cfg(feature = "enable-crypto-vld0")]
crypto_vld0: Option<Arc<dyn CryptoSystem + Send + Sync>>,
#[cfg(feature = "enable-crypto-none")]
crypto_none: Option<Arc<dyn CryptoSystem + Send + Sync>>,
}
struct CryptoUnlockedInner {
_event_bus: EventBus,
config: VeilidConfig,
table_store: TableStore,
impl fmt::Debug for CryptoInner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CryptoInner")
//.field("dh_cache", &self.dh_cache)
// .field("flush_future", &self.flush_future)
// .field("crypto_vld0", &self.crypto_vld0)
// .field("crypto_none", &self.crypto_none)
.finish()
}
}
/// Crypto factory implementation
#[derive(Clone)]
pub struct Crypto {
unlocked_inner: Arc<CryptoUnlockedInner>,
inner: Arc<Mutex<CryptoInner>>,
registry: VeilidComponentRegistry,
inner: Mutex<CryptoInner>,
#[cfg(feature = "enable-crypto-vld0")]
crypto_vld0: Arc<dyn CryptoSystem + Send + Sync>,
#[cfg(feature = "enable-crypto-none")]
crypto_none: Arc<dyn CryptoSystem + Send + Sync>,
}
impl_veilid_component!(Crypto);
impl fmt::Debug for Crypto {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Crypto")
//.field("registry", &self.registry)
.field("inner", &self.inner)
// .field("crypto_vld0", &self.crypto_vld0)
// .field("crypto_none", &self.crypto_none)
.finish()
}
}
impl Crypto {
@ -96,63 +113,43 @@ impl Crypto {
CryptoInner {
dh_cache: DHCache::new(DH_CACHE_SIZE),
flush_future: None,
}
}
pub fn new(registry: VeilidComponentRegistry) -> Self {
Self {
registry: registry.clone(),
inner: Mutex::new(Self::new_inner()),
#[cfg(feature = "enable-crypto-vld0")]
crypto_vld0: None,
crypto_vld0: Arc::new(vld0::CryptoSystemVLD0::new(registry.clone())),
#[cfg(feature = "enable-crypto-none")]
crypto_none: None,
crypto_none: Arc::new(none::CryptoSystemNONE::new(registry.clone())),
}
}
pub fn new(event_bus: EventBus, config: VeilidConfig, table_store: TableStore) -> Self {
let out = Self {
unlocked_inner: Arc::new(CryptoUnlockedInner {
_event_bus: event_bus,
config,
table_store,
}),
inner: Arc::new(Mutex::new(Self::new_inner())),
};
#[cfg(feature = "enable-crypto-vld0")]
{
out.inner.lock().crypto_vld0 = Some(Arc::new(vld0::CryptoSystemVLD0::new(out.clone())));
}
#[cfg(feature = "enable-crypto-none")]
{
out.inner.lock().crypto_none = Some(Arc::new(none::CryptoSystemNONE::new(out.clone())));
}
out
}
pub fn config(&self) -> VeilidConfig {
self.unlocked_inner.config.clone()
}
#[instrument(level = "trace", target = "crypto", skip_all, err)]
pub async fn init(&self) -> EyreResult<()> {
let table_store = self.unlocked_inner.table_store.clone();
async fn init_async(&self) -> EyreResult<()> {
// Nothing to initialize at this time
Ok(())
}
// Setup called by table store after it get initialized
#[instrument(level = "trace", target = "crypto", skip_all, err)]
pub(crate) async fn table_store_setup(&self, table_store: &TableStore) -> EyreResult<()> {
// Init node id from config
if let Err(e) = self
.unlocked_inner
.config
.init_node_ids(self.clone(), table_store.clone())
.await
{
if let Err(e) = self.setup_node_ids(table_store).await {
return Err(e).wrap_err("init node id failed");
}
// make local copy of node id for easy access
let mut cache_validity_key: Vec<u8> = Vec::new();
{
let c = self.unlocked_inner.config.get();
self.config().with(|c| {
for ck in VALID_CRYPTO_KINDS {
if let Some(nid) = c.network.routing_table.node_id.get(ck) {
cache_validity_key.append(&mut nid.value.bytes.to_vec());
}
}
};
});
// load caches if they are valid for this node id
let mut db = table_store
@ -175,13 +172,17 @@ impl Crypto {
db.store(0, b"cache_validity_key", &cache_validity_key)
.await?;
}
Ok(())
}
#[instrument(level = "trace", target = "crypto", skip_all, err)]
async fn post_init_async(&self) -> EyreResult<()> {
// Schedule flushing
let this = self.clone();
let registry = self.registry();
let flush_future = interval("crypto flush", 60000, move || {
let this = this.clone();
let crypto = registry.crypto();
async move {
if let Err(e) = this.flush().await {
if let Err(e) = crypto.flush().await {
warn!("flush failed: {}", e);
}
}
@ -197,16 +198,12 @@ impl Crypto {
cache_to_bytes(&inner.dh_cache)
};
let db = self
.unlocked_inner
.table_store
.open("crypto_caches", 1)
.await?;
let db = self.table_store().open("crypto_caches", 1).await?;
db.store(0, b"dh_cache", &cache_bytes).await?;
Ok(())
}
pub async fn terminate(&self) {
async fn pre_terminate_async(&self) {
let flush_future = self.inner.lock().flush_future.take();
if let Some(f) = flush_future {
f.await;
@ -222,23 +219,36 @@ impl Crypto {
};
}
async fn terminate_async(&self) {
// Nothing to terminate at this time
}
/// Factory method to get a specific crypto version
pub fn get(&self, kind: CryptoKind) -> Option<CryptoSystemVersion> {
let inner = self.inner.lock();
pub fn get(&self, kind: CryptoKind) -> Option<CryptoSystemGuard<'_>> {
match kind {
#[cfg(feature = "enable-crypto-vld0")]
CRYPTO_KIND_VLD0 => Some(inner.crypto_vld0.clone().unwrap()),
CRYPTO_KIND_VLD0 => Some(CryptoSystemGuard::new(self.crypto_vld0.clone())),
#[cfg(feature = "enable-crypto-none")]
CRYPTO_KIND_NONE => Some(inner.crypto_none.clone().unwrap()),
CRYPTO_KIND_NONE => Some(CryptoSystemGuard::new(self.crypto_none.clone())),
_ => None,
}
}
/// Factory method to get a specific crypto version for async use
pub fn get_async(&self, kind: CryptoKind) -> Option<AsyncCryptoSystemGuard<'_>> {
self.get(kind).map(|x| x.as_async())
}
// Factory method to get the best crypto version
pub fn best(&self) -> CryptoSystemVersion {
pub fn best(&self) -> CryptoSystemGuard<'_> {
self.get(best_crypto_kind()).unwrap()
}
// Factory method to get the best crypto version for async use
pub fn best_async(&self) -> AsyncCryptoSystemGuard<'_> {
self.get_async(best_crypto_kind()).unwrap()
}
/// Signature set verification
/// Returns Some() the set of signature cryptokinds that validate and are supported
/// Returns None if any cryptokinds are supported and do not validate
@ -331,4 +341,120 @@ impl Crypto {
}
Ok(())
}
#[cfg(not(test))]
async fn setup_node_id(
&self,
vcrypto: AsyncCryptoSystemGuard<'_>,
table_store: &TableStore,
) -> VeilidAPIResult<(TypedKey, TypedSecret)> {
let config = self.config();
let ck = vcrypto.kind();
let (mut node_id, mut node_id_secret) = config.with(|c| {
(
c.network.routing_table.node_id.get(ck),
c.network.routing_table.node_id_secret.get(ck),
)
});
// See if node id was previously stored in the table store
let config_table = table_store.open("__veilid_config", 1).await?;
let table_key_node_id = format!("node_id_{}", ck);
let table_key_node_id_secret = format!("node_id_secret_{}", ck);
if node_id.is_none() {
log_crypto!(debug "pulling {} from storage", table_key_node_id);
if let Ok(Some(stored_node_id)) = config_table
.load_json::<TypedKey>(0, table_key_node_id.as_bytes())
.await
{
log_crypto!(debug "{} found in storage", table_key_node_id);
node_id = Some(stored_node_id);
} else {
log_crypto!(debug "{} not found in storage", table_key_node_id);
}
}
// See if node id secret was previously stored in the protected store
if node_id_secret.is_none() {
log_crypto!(debug "pulling {} from storage", table_key_node_id_secret);
if let Ok(Some(stored_node_id_secret)) = config_table
.load_json::<TypedSecret>(0, table_key_node_id_secret.as_bytes())
.await
{
log_crypto!(debug "{} found in storage", table_key_node_id_secret);
node_id_secret = Some(stored_node_id_secret);
} else {
log_crypto!(debug "{} not found in storage", table_key_node_id_secret);
}
}
// If we have a node id from storage, check it
let (node_id, node_id_secret) =
if let (Some(node_id), Some(node_id_secret)) = (node_id, node_id_secret) {
// Validate node id
if !vcrypto
.validate_keypair(&node_id.value, &node_id_secret.value)
.await
{
apibail_generic!(format!(
"node_id_secret_{} and node_id_key_{} don't match",
ck, ck
));
}
(node_id, node_id_secret)
} else {
// If we still don't have a valid node id, generate one
log_crypto!(debug "generating new node_id_{}", ck);
let kp = vcrypto.generate_keypair().await;
(TypedKey::new(ck, kp.key), TypedSecret::new(ck, kp.secret))
};
info!("Node Id: {}", node_id);
// Save the node id / secret in storage
config_table
.store_json(0, table_key_node_id.as_bytes(), &node_id)
.await?;
config_table
.store_json(0, table_key_node_id_secret.as_bytes(), &node_id_secret)
.await?;
Ok((node_id, node_id_secret))
}
/// Get the node id from config if one is specified.
/// Must be done -after- protected store is initialized, during table store init
#[cfg_attr(test, allow(unused_variables))]
async fn setup_node_ids(&self, table_store: &TableStore) -> VeilidAPIResult<()> {
let mut out_node_id = TypedKeyGroup::new();
let mut out_node_id_secret = TypedSecretGroup::new();
for ck in VALID_CRYPTO_KINDS {
let vcrypto = self
.get_async(ck)
.expect("Valid crypto kind is not actually valid.");
#[cfg(test)]
let (node_id, node_id_secret) = {
let kp = vcrypto.generate_keypair().await;
(TypedKey::new(ck, kp.key), TypedSecret::new(ck, kp.secret))
};
#[cfg(not(test))]
let (node_id, node_id_secret) = self.setup_node_id(vcrypto, table_store).await?;
// Save for config
out_node_id.add(node_id);
out_node_id_secret.add(node_id_secret);
}
// Commit back to config
self.config().try_with_mut(|c| {
c.network.routing_table.node_id = out_node_id;
c.network.routing_table.node_id_secret = out_node_id_secret;
Ok(())
})?;
Ok(())
}
}

View File

@ -49,14 +49,13 @@ fn is_bytes_eq_32(a: &[u8], v: u8) -> bool {
}
/// None CryptoSystem
#[derive(Clone)]
pub struct CryptoSystemNONE {
crypto: Crypto,
registry: VeilidComponentRegistry,
}
impl CryptoSystemNONE {
pub fn new(crypto: Crypto) -> Self {
Self { crypto }
pub fn new(registry: VeilidComponentRegistry) -> Self {
Self { registry }
}
}
@ -66,13 +65,13 @@ impl CryptoSystem for CryptoSystemNONE {
CRYPTO_KIND_NONE
}
fn crypto(&self) -> Crypto {
self.crypto.clone()
fn crypto(&self) -> VeilidComponentGuard<'_, Crypto> {
self.registry().lookup::<Crypto>().unwrap()
}
// Cached Operations
fn cached_dh(&self, key: &PublicKey, secret: &SecretKey) -> VeilidAPIResult<SharedSecret> {
self.crypto
self.crypto()
.cached_dh_internal::<CryptoSystemNONE>(self, key, secret)
}

View File

@ -68,7 +68,7 @@ impl Receipt {
}
#[instrument(level = "trace", target = "receipt", skip_all, err)]
pub fn from_signed_data(crypto: Crypto, data: &[u8]) -> VeilidAPIResult<Receipt> {
pub fn from_signed_data(crypto: &Crypto, data: &[u8]) -> VeilidAPIResult<Receipt> {
// Ensure we are at least the length of the envelope
if data.len() < MIN_RECEIPT_SIZE {
apibail_parse_error!("receipt too small", data.len());
@ -157,7 +157,7 @@ impl Receipt {
}
#[instrument(level = "trace", target = "receipt", skip_all, err)]
pub fn to_signed_data(&self, crypto: Crypto, secret: &SecretKey) -> VeilidAPIResult<Vec<u8>> {
pub fn to_signed_data(&self, crypto: &Crypto, secret: &SecretKey) -> VeilidAPIResult<Vec<u8>> {
// Ensure extra data isn't too long
let receipt_size: usize = self.extra_data.len() + MIN_RECEIPT_SIZE;
if receipt_size > MAX_RECEIPT_SIZE {

View File

@ -2,20 +2,20 @@ use super::*;
static LOREM_IPSUM:&[u8] = b"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. ";
pub async fn test_aead(vcrypto: CryptoSystemVersion) {
pub async fn test_aead(vcrypto: &AsyncCryptoSystemGuard<'_>) {
trace!("test_aead");
let n1 = vcrypto.random_nonce();
let n1 = vcrypto.random_nonce().await;
let n2 = loop {
let n = vcrypto.random_nonce();
let n = vcrypto.random_nonce().await;
if n != n1 {
break n;
}
};
let ss1 = vcrypto.random_shared_secret();
let ss1 = vcrypto.random_shared_secret().await;
let ss2 = loop {
let ss = vcrypto.random_shared_secret();
let ss = vcrypto.random_shared_secret().await;
if ss != ss1 {
break ss;
}
@ -27,6 +27,7 @@ pub async fn test_aead(vcrypto: CryptoSystemVersion) {
assert!(
vcrypto
.encrypt_in_place_aead(&mut body, &n1, &ss1, None)
.await
.is_ok(),
"encrypt should succeed"
);
@ -41,6 +42,7 @@ pub async fn test_aead(vcrypto: CryptoSystemVersion) {
assert!(
vcrypto
.decrypt_in_place_aead(&mut body, &n1, &ss1, None)
.await
.is_ok(),
"decrypt should succeed"
);
@ -49,6 +51,7 @@ pub async fn test_aead(vcrypto: CryptoSystemVersion) {
assert!(
vcrypto
.decrypt_in_place_aead(&mut body3, &n2, &ss1, None)
.await
.is_err(),
"decrypt with wrong nonce should fail"
);
@ -57,6 +60,7 @@ pub async fn test_aead(vcrypto: CryptoSystemVersion) {
assert!(
vcrypto
.decrypt_in_place_aead(&mut body4, &n1, &ss2, None)
.await
.is_err(),
"decrypt with wrong secret should fail"
);
@ -65,37 +69,47 @@ pub async fn test_aead(vcrypto: CryptoSystemVersion) {
assert!(
vcrypto
.decrypt_in_place_aead(&mut body5, &n1, &ss2, Some(b"foobar"))
.await
.is_err(),
"decrypt with wrong associated data should fail"
);
assert_ne!(body5, body, "failure changes data");
assert!(
vcrypto.decrypt_aead(LOREM_IPSUM, &n1, &ss1, None).is_err(),
vcrypto
.decrypt_aead(LOREM_IPSUM, &n1, &ss1, None)
.await
.is_err(),
"should fail authentication"
);
let body5 = vcrypto.encrypt_aead(LOREM_IPSUM, &n1, &ss1, None).unwrap();
let body6 = vcrypto.decrypt_aead(&body5, &n1, &ss1, None).unwrap();
let body7 = vcrypto.encrypt_aead(LOREM_IPSUM, &n1, &ss1, None).unwrap();
let body5 = vcrypto
.encrypt_aead(LOREM_IPSUM, &n1, &ss1, None)
.await
.unwrap();
let body6 = vcrypto.decrypt_aead(&body5, &n1, &ss1, None).await.unwrap();
let body7 = vcrypto
.encrypt_aead(LOREM_IPSUM, &n1, &ss1, None)
.await
.unwrap();
assert_eq!(body6, LOREM_IPSUM);
assert_eq!(body5, body7);
}
pub async fn test_no_auth(vcrypto: CryptoSystemVersion) {
pub async fn test_no_auth(vcrypto: &AsyncCryptoSystemGuard<'_>) {
trace!("test_no_auth");
let n1 = vcrypto.random_nonce();
let n1 = vcrypto.random_nonce().await;
let n2 = loop {
let n = vcrypto.random_nonce();
let n = vcrypto.random_nonce().await;
if n != n1 {
break n;
}
};
let ss1 = vcrypto.random_shared_secret();
let ss1 = vcrypto.random_shared_secret().await;
let ss2 = loop {
let ss = vcrypto.random_shared_secret();
let ss = vcrypto.random_shared_secret().await;
if ss != ss1 {
break ss;
}
@ -104,7 +118,7 @@ pub async fn test_no_auth(vcrypto: CryptoSystemVersion) {
let mut body = LOREM_IPSUM.to_vec();
let body2 = body.clone();
let size_before_encrypt = body.len();
vcrypto.crypt_in_place_no_auth(&mut body, &n1, &ss1);
vcrypto.crypt_in_place_no_auth(&mut body, &n1, &ss1).await;
let size_after_encrypt = body.len();
assert_eq!(
@ -114,49 +128,69 @@ pub async fn test_no_auth(vcrypto: CryptoSystemVersion) {
let mut body3 = body.clone();
let mut body4 = body.clone();
vcrypto.crypt_in_place_no_auth(&mut body, &n1, &ss1);
vcrypto.crypt_in_place_no_auth(&mut body, &n1, &ss1).await;
assert_eq!(body, body2, "result after decrypt should be the same");
vcrypto.crypt_in_place_no_auth(&mut body3, &n2, &ss1);
vcrypto.crypt_in_place_no_auth(&mut body3, &n2, &ss1).await;
assert_ne!(body3, body, "decrypt should not be equal with wrong nonce");
vcrypto.crypt_in_place_no_auth(&mut body4, &n1, &ss2);
vcrypto.crypt_in_place_no_auth(&mut body4, &n1, &ss2).await;
assert_ne!(body4, body, "decrypt should not be equal with wrong secret");
let body5 = vcrypto.crypt_no_auth_unaligned(LOREM_IPSUM, &n1, &ss1);
let body6 = vcrypto.crypt_no_auth_unaligned(&body5, &n1, &ss1);
let body7 = vcrypto.crypt_no_auth_unaligned(LOREM_IPSUM, &n1, &ss1);
let body5 = vcrypto
.crypt_no_auth_unaligned(LOREM_IPSUM, &n1, &ss1)
.await;
let body6 = vcrypto.crypt_no_auth_unaligned(&body5, &n1, &ss1).await;
let body7 = vcrypto
.crypt_no_auth_unaligned(LOREM_IPSUM, &n1, &ss1)
.await;
assert_eq!(body6, LOREM_IPSUM);
assert_eq!(body5, body7);
let body5 = vcrypto.crypt_no_auth_aligned_8(LOREM_IPSUM, &n1, &ss1);
let body6 = vcrypto.crypt_no_auth_aligned_8(&body5, &n1, &ss1);
let body7 = vcrypto.crypt_no_auth_aligned_8(LOREM_IPSUM, &n1, &ss1);
let body5 = vcrypto
.crypt_no_auth_aligned_8(LOREM_IPSUM, &n1, &ss1)
.await;
let body6 = vcrypto.crypt_no_auth_aligned_8(&body5, &n1, &ss1).await;
let body7 = vcrypto
.crypt_no_auth_aligned_8(LOREM_IPSUM, &n1, &ss1)
.await;
assert_eq!(body6, LOREM_IPSUM);
assert_eq!(body5, body7);
}
pub async fn test_dh(vcrypto: CryptoSystemVersion) {
pub async fn test_dh(vcrypto: &AsyncCryptoSystemGuard<'_>) {
trace!("test_dh");
let (dht_key, dht_key_secret) = vcrypto.generate_keypair().into_split();
assert!(vcrypto.validate_keypair(&dht_key, &dht_key_secret));
let (dht_key2, dht_key_secret2) = vcrypto.generate_keypair().into_split();
assert!(vcrypto.validate_keypair(&dht_key2, &dht_key_secret2));
let (dht_key, dht_key_secret) = vcrypto.generate_keypair().await.into_split();
assert!(vcrypto.validate_keypair(&dht_key, &dht_key_secret).await);
let (dht_key2, dht_key_secret2) = vcrypto.generate_keypair().await.into_split();
assert!(vcrypto.validate_keypair(&dht_key2, &dht_key_secret2).await);
let r1 = vcrypto.compute_dh(&dht_key, &dht_key_secret2).unwrap();
let r2 = vcrypto.compute_dh(&dht_key2, &dht_key_secret).unwrap();
let r3 = vcrypto.compute_dh(&dht_key, &dht_key_secret2).unwrap();
let r4 = vcrypto.compute_dh(&dht_key2, &dht_key_secret).unwrap();
let r1 = vcrypto
.compute_dh(&dht_key, &dht_key_secret2)
.await
.unwrap();
let r2 = vcrypto
.compute_dh(&dht_key2, &dht_key_secret)
.await
.unwrap();
let r3 = vcrypto
.compute_dh(&dht_key, &dht_key_secret2)
.await
.unwrap();
let r4 = vcrypto
.compute_dh(&dht_key2, &dht_key_secret)
.await
.unwrap();
assert_eq!(r1, r2);
assert_eq!(r3, r4);
assert_eq!(r2, r3);
trace!("dh: {:?}", r1);
// test cache
let r5 = vcrypto.cached_dh(&dht_key, &dht_key_secret2).unwrap();
let r6 = vcrypto.cached_dh(&dht_key2, &dht_key_secret).unwrap();
let r7 = vcrypto.cached_dh(&dht_key, &dht_key_secret2).unwrap();
let r8 = vcrypto.cached_dh(&dht_key2, &dht_key_secret).unwrap();
let r5 = vcrypto.cached_dh(&dht_key, &dht_key_secret2).await.unwrap();
let r6 = vcrypto.cached_dh(&dht_key2, &dht_key_secret).await.unwrap();
let r7 = vcrypto.cached_dh(&dht_key, &dht_key_secret2).await.unwrap();
let r8 = vcrypto.cached_dh(&dht_key2, &dht_key_secret).await.unwrap();
assert_eq!(r1, r5);
assert_eq!(r2, r6);
assert_eq!(r3, r7);
@ -164,63 +198,67 @@ pub async fn test_dh(vcrypto: CryptoSystemVersion) {
trace!("cached_dh: {:?}", r5);
}
pub async fn test_generation(vcrypto: CryptoSystemVersion) {
let b1 = vcrypto.random_bytes(32);
let b2 = vcrypto.random_bytes(32);
pub async fn test_generation(vcrypto: &AsyncCryptoSystemGuard<'_>) {
let b1 = vcrypto.random_bytes(32).await;
let b2 = vcrypto.random_bytes(32).await;
assert_ne!(b1, b2);
assert_eq!(b1.len(), 32);
assert_eq!(b2.len(), 32);
let b3 = vcrypto.random_bytes(0);
let b4 = vcrypto.random_bytes(0);
let b3 = vcrypto.random_bytes(0).await;
let b4 = vcrypto.random_bytes(0).await;
assert_eq!(b3, b4);
assert_eq!(b3.len(), 0);
assert_ne!(vcrypto.default_salt_length(), 0);
let pstr1 = vcrypto.hash_password(b"abc123", b"qwerasdf").unwrap();
let pstr2 = vcrypto.hash_password(b"abc123", b"qwerasdf").unwrap();
let pstr1 = vcrypto.hash_password(b"abc123", b"qwerasdf").await.unwrap();
let pstr2 = vcrypto.hash_password(b"abc123", b"qwerasdf").await.unwrap();
assert_eq!(pstr1, pstr2);
let pstr3 = vcrypto.hash_password(b"abc123", b"qwerasdg").unwrap();
let pstr3 = vcrypto.hash_password(b"abc123", b"qwerasdg").await.unwrap();
assert_ne!(pstr1, pstr3);
let pstr4 = vcrypto.hash_password(b"abc124", b"qwerasdf").unwrap();
let pstr4 = vcrypto.hash_password(b"abc124", b"qwerasdf").await.unwrap();
assert_ne!(pstr1, pstr4);
let pstr5 = vcrypto.hash_password(b"abc124", b"qwerasdg").unwrap();
let pstr5 = vcrypto.hash_password(b"abc124", b"qwerasdg").await.unwrap();
assert_ne!(pstr3, pstr5);
vcrypto
.hash_password(b"abc123", b"qwe")
.await
.expect_err("should reject short salt");
vcrypto
.hash_password(
b"abc123",
b"qwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerz",
)
.await
.expect_err("should reject long salt");
assert!(vcrypto.verify_password(b"abc123", &pstr1).unwrap());
assert!(vcrypto.verify_password(b"abc123", &pstr2).unwrap());
assert!(vcrypto.verify_password(b"abc123", &pstr3).unwrap());
assert!(!vcrypto.verify_password(b"abc123", &pstr4).unwrap());
assert!(!vcrypto.verify_password(b"abc123", &pstr5).unwrap());
assert!(vcrypto.verify_password(b"abc123", &pstr1).await.unwrap());
assert!(vcrypto.verify_password(b"abc123", &pstr2).await.unwrap());
assert!(vcrypto.verify_password(b"abc123", &pstr3).await.unwrap());
assert!(!vcrypto.verify_password(b"abc123", &pstr4).await.unwrap());
assert!(!vcrypto.verify_password(b"abc123", &pstr5).await.unwrap());
let ss1 = vcrypto.derive_shared_secret(b"abc123", b"qwerasdf");
let ss2 = vcrypto.derive_shared_secret(b"abc123", b"qwerasdf");
let ss1 = vcrypto.derive_shared_secret(b"abc123", b"qwerasdf").await;
let ss2 = vcrypto.derive_shared_secret(b"abc123", b"qwerasdf").await;
assert_eq!(ss1, ss2);
let ss3 = vcrypto.derive_shared_secret(b"abc123", b"qwerasdg");
let ss3 = vcrypto.derive_shared_secret(b"abc123", b"qwerasdg").await;
assert_ne!(ss1, ss3);
let ss4 = vcrypto.derive_shared_secret(b"abc124", b"qwerasdf");
let ss4 = vcrypto.derive_shared_secret(b"abc124", b"qwerasdf").await;
assert_ne!(ss1, ss4);
let ss5 = vcrypto.derive_shared_secret(b"abc124", b"qwerasdg");
let ss5 = vcrypto.derive_shared_secret(b"abc124", b"qwerasdg").await;
assert_ne!(ss3, ss5);
vcrypto
.derive_shared_secret(b"abc123", b"qwe")
.await
.expect_err("should reject short salt");
vcrypto
.derive_shared_secret(
b"abc123",
b"qwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerqwerz",
)
.await
.expect_err("should reject long salt");
}
@ -230,11 +268,11 @@ pub async fn test_all() {
// Test versions
for v in VALID_CRYPTO_KINDS {
let vcrypto = crypto.get(v).unwrap();
test_aead(vcrypto.clone()).await;
test_no_auth(vcrypto.clone()).await;
test_dh(vcrypto.clone()).await;
test_generation(vcrypto).await;
let vcrypto = crypto.get_async(v).unwrap();
test_aead(&vcrypto).await;
test_no_auth(&vcrypto).await;
test_dh(&vcrypto).await;
test_generation(&vcrypto).await;
}
crypto_tests_shutdown(api.clone()).await;

View File

@ -2,9 +2,10 @@ use super::*;
pub async fn test_envelope_round_trip(
envelope_version: EnvelopeVersion,
vcrypto: CryptoSystemVersion,
vcrypto: &AsyncCryptoSystemGuard<'_>,
network_key: Option<SharedSecret>,
) {
let crypto = vcrypto.crypto();
if network_key.is_some() {
info!(
"--- test envelope round trip {} w/network key ---",
@ -16,9 +17,9 @@ pub async fn test_envelope_round_trip(
// Create envelope
let ts = Timestamp::from(0x12345678ABCDEF69u64);
let nonce = vcrypto.random_nonce();
let (sender_id, sender_secret) = vcrypto.generate_keypair().into_split();
let (recipient_id, recipient_secret) = vcrypto.generate_keypair().into_split();
let nonce = vcrypto.random_nonce().await;
let (sender_id, sender_secret) = vcrypto.generate_keypair().await.into_split();
let (recipient_id, recipient_secret) = vcrypto.generate_keypair().await.into_split();
let envelope = Envelope::new(
envelope_version,
vcrypto.kind(),
@ -33,15 +34,15 @@ pub async fn test_envelope_round_trip(
// Serialize to bytes
let enc_data = envelope
.to_encrypted_data(vcrypto.crypto(), body, &sender_secret, &network_key)
.to_encrypted_data(&crypto, body, &sender_secret, &network_key)
.expect("failed to encrypt data");
// Deserialize from bytes
let envelope2 = Envelope::from_signed_data(vcrypto.crypto(), &enc_data, &network_key)
let envelope2 = Envelope::from_signed_data(&crypto, &enc_data, &network_key)
.expect("failed to deserialize envelope from data");
let body2 = envelope2
.decrypt_body(vcrypto.crypto(), &enc_data, &recipient_secret, &network_key)
.decrypt_body(&crypto, &enc_data, &recipient_secret, &network_key)
.expect("failed to decrypt envelope body");
// Compare envelope and body
@ -53,43 +54,44 @@ pub async fn test_envelope_round_trip(
let mut mod_enc_data = enc_data.clone();
mod_enc_data[enc_data_len - 1] ^= 0x80u8;
assert!(
Envelope::from_signed_data(vcrypto.crypto(), &mod_enc_data, &network_key).is_err(),
Envelope::from_signed_data(&crypto, &mod_enc_data, &network_key).is_err(),
"should have failed to decode envelope with modified signature"
);
let mut mod_enc_data2 = enc_data.clone();
mod_enc_data2[enc_data_len - 65] ^= 0x80u8;
assert!(
Envelope::from_signed_data(vcrypto.crypto(), &mod_enc_data2, &network_key).is_err(),
Envelope::from_signed_data(&crypto, &mod_enc_data2, &network_key).is_err(),
"should have failed to decode envelope with modified data"
);
}
pub async fn test_receipt_round_trip(
envelope_version: EnvelopeVersion,
vcrypto: CryptoSystemVersion,
vcrypto: &AsyncCryptoSystemGuard<'_>,
) {
let crypto = vcrypto.crypto();
info!("--- test receipt round trip ---");
// Create arbitrary body
let body = b"This is an arbitrary body";
// Create receipt
let nonce = vcrypto.random_nonce();
let (sender_id, sender_secret) = vcrypto.generate_keypair().into_split();
let nonce = vcrypto.random_nonce().await;
let (sender_id, sender_secret) = vcrypto.generate_keypair().await.into_split();
let receipt = Receipt::try_new(envelope_version, vcrypto.kind(), nonce, sender_id, body)
.expect("should not fail");
// Serialize to bytes
let mut enc_data = receipt
.to_signed_data(vcrypto.crypto(), &sender_secret)
.to_signed_data(&crypto, &sender_secret)
.expect("failed to make signed data");
// Deserialize from bytes
let receipt2 = Receipt::from_signed_data(vcrypto.crypto(), &enc_data)
let receipt2 = Receipt::from_signed_data(&crypto, &enc_data)
.expect("failed to deserialize envelope from data");
// Should not validate even when a single bit is changed
enc_data[5] = 0x01;
Receipt::from_signed_data(vcrypto.crypto(), &enc_data)
Receipt::from_signed_data(&crypto, &enc_data)
.expect_err("should have failed to decrypt using wrong secret");
// Compare receipts
@ -103,12 +105,12 @@ pub async fn test_all() {
// Test versions
for ev in VALID_ENVELOPE_VERSIONS {
for v in VALID_CRYPTO_KINDS {
let vcrypto = crypto.get(v).unwrap();
let vcrypto = crypto.get_async(v).unwrap();
test_envelope_round_trip(ev, vcrypto.clone(), None).await;
test_envelope_round_trip(ev, vcrypto.clone(), Some(vcrypto.random_shared_secret()))
test_envelope_round_trip(ev, &vcrypto, None).await;
test_envelope_round_trip(ev, &vcrypto, Some(vcrypto.random_shared_secret().await))
.await;
test_receipt_round_trip(ev, vcrypto).await;
test_receipt_round_trip(ev, &vcrypto).await;
}
}

View File

@ -6,10 +6,10 @@ static CHEEZBURGER: &str = "I can has cheezburger";
static EMPTY_KEY: [u8; PUBLIC_KEY_LENGTH] = [0u8; PUBLIC_KEY_LENGTH];
static EMPTY_KEY_SECRET: [u8; SECRET_KEY_LENGTH] = [0u8; SECRET_KEY_LENGTH];
pub async fn test_generate_secret(vcrypto: CryptoSystemVersion) {
pub async fn test_generate_secret(vcrypto: &AsyncCryptoSystemGuard<'_>) {
// Verify keys generate
let (dht_key, dht_key_secret) = vcrypto.generate_keypair().into_split();
let (dht_key2, dht_key_secret2) = vcrypto.generate_keypair().into_split();
let (dht_key, dht_key_secret) = vcrypto.generate_keypair().await.into_split();
let (dht_key2, dht_key_secret2) = vcrypto.generate_keypair().await.into_split();
// Verify byte patterns are different between public and secret
assert_ne!(dht_key.bytes, dht_key_secret.bytes);
@ -20,21 +20,24 @@ pub async fn test_generate_secret(vcrypto: CryptoSystemVersion) {
assert_ne!(dht_key_secret, dht_key_secret2);
}
pub async fn test_sign_and_verify(vcrypto: CryptoSystemVersion) {
pub async fn test_sign_and_verify(vcrypto: &AsyncCryptoSystemGuard<'_>) {
// Make two keys
let (dht_key, dht_key_secret) = vcrypto.generate_keypair().into_split();
let (dht_key2, dht_key_secret2) = vcrypto.generate_keypair().into_split();
let (dht_key, dht_key_secret) = vcrypto.generate_keypair().await.into_split();
let (dht_key2, dht_key_secret2) = vcrypto.generate_keypair().await.into_split();
// Sign the same message twice
let dht_sig = vcrypto
.sign(&dht_key, &dht_key_secret, LOREM_IPSUM.as_bytes())
.await
.unwrap();
trace!("dht_sig: {:?}", dht_sig);
let dht_sig_b = vcrypto
.sign(&dht_key, &dht_key_secret, LOREM_IPSUM.as_bytes())
.await
.unwrap();
// Sign a second message
let dht_sig_c = vcrypto
.sign(&dht_key, &dht_key_secret, CHEEZBURGER.as_bytes())
.await
.unwrap();
trace!("dht_sig_c: {:?}", dht_sig_c);
// Verify they are the same signature
@ -42,6 +45,7 @@ pub async fn test_sign_and_verify(vcrypto: CryptoSystemVersion) {
// Sign the same message with a different key
let dht_sig2 = vcrypto
.sign(&dht_key2, &dht_key_secret2, LOREM_IPSUM.as_bytes())
.await
.unwrap();
// Verify a different key gives a different signature
assert_ne!(dht_sig2, dht_sig_b);
@ -49,73 +53,93 @@ pub async fn test_sign_and_verify(vcrypto: CryptoSystemVersion) {
// Try using the wrong secret to sign
let a1 = vcrypto
.sign(&dht_key, &dht_key_secret, LOREM_IPSUM.as_bytes())
.await
.unwrap();
let a2 = vcrypto
.sign(&dht_key2, &dht_key_secret2, LOREM_IPSUM.as_bytes())
.await
.unwrap();
let _b1 = vcrypto
.sign(&dht_key, &dht_key_secret2, LOREM_IPSUM.as_bytes())
.await
.unwrap_err();
let _b2 = vcrypto
.sign(&dht_key2, &dht_key_secret, LOREM_IPSUM.as_bytes())
.await
.unwrap_err();
assert_ne!(a1, a2);
assert_eq!(
vcrypto.verify(&dht_key, LOREM_IPSUM.as_bytes(), &a1),
vcrypto.verify(&dht_key, LOREM_IPSUM.as_bytes(), &a1).await,
Ok(true)
);
assert_eq!(
vcrypto.verify(&dht_key2, LOREM_IPSUM.as_bytes(), &a2),
vcrypto.verify(&dht_key2, LOREM_IPSUM.as_bytes(), &a2).await,
Ok(true)
);
assert_eq!(
vcrypto.verify(&dht_key, LOREM_IPSUM.as_bytes(), &a2),
vcrypto.verify(&dht_key, LOREM_IPSUM.as_bytes(), &a2).await,
Ok(false)
);
assert_eq!(
vcrypto.verify(&dht_key2, LOREM_IPSUM.as_bytes(), &a1),
vcrypto.verify(&dht_key2, LOREM_IPSUM.as_bytes(), &a1).await,
Ok(false)
);
// Try verifications that should work
assert_eq!(
vcrypto.verify(&dht_key, LOREM_IPSUM.as_bytes(), &dht_sig),
vcrypto
.verify(&dht_key, LOREM_IPSUM.as_bytes(), &dht_sig)
.await,
Ok(true)
);
assert_eq!(
vcrypto.verify(&dht_key, LOREM_IPSUM.as_bytes(), &dht_sig_b),
vcrypto
.verify(&dht_key, LOREM_IPSUM.as_bytes(), &dht_sig_b)
.await,
Ok(true)
);
assert_eq!(
vcrypto.verify(&dht_key2, LOREM_IPSUM.as_bytes(), &dht_sig2),
vcrypto
.verify(&dht_key2, LOREM_IPSUM.as_bytes(), &dht_sig2)
.await,
Ok(true)
);
assert_eq!(
vcrypto.verify(&dht_key, CHEEZBURGER.as_bytes(), &dht_sig_c),
vcrypto
.verify(&dht_key, CHEEZBURGER.as_bytes(), &dht_sig_c)
.await,
Ok(true)
);
// Try verifications that shouldn't work
assert_eq!(
vcrypto.verify(&dht_key2, LOREM_IPSUM.as_bytes(), &dht_sig),
vcrypto
.verify(&dht_key2, LOREM_IPSUM.as_bytes(), &dht_sig)
.await,
Ok(false)
);
assert_eq!(
vcrypto.verify(&dht_key, LOREM_IPSUM.as_bytes(), &dht_sig2),
vcrypto
.verify(&dht_key, LOREM_IPSUM.as_bytes(), &dht_sig2)
.await,
Ok(false)
);
assert_eq!(
vcrypto.verify(&dht_key2, CHEEZBURGER.as_bytes(), &dht_sig_c),
vcrypto
.verify(&dht_key2, CHEEZBURGER.as_bytes(), &dht_sig_c)
.await,
Ok(false)
);
assert_eq!(
vcrypto.verify(&dht_key, CHEEZBURGER.as_bytes(), &dht_sig),
vcrypto
.verify(&dht_key, CHEEZBURGER.as_bytes(), &dht_sig)
.await,
Ok(false)
);
}
pub async fn test_key_conversions(vcrypto: CryptoSystemVersion) {
pub async fn test_key_conversions(vcrypto: &AsyncCryptoSystemGuard<'_>) {
// Test default key
let (dht_key, dht_key_secret) = (PublicKey::default(), SecretKey::default());
assert_eq!(dht_key.bytes, EMPTY_KEY);
@ -131,10 +155,10 @@ pub async fn test_key_conversions(vcrypto: CryptoSystemVersion) {
assert_eq!(dht_key_secret_string, dht_key_string);
// Make different keys
let (dht_key2, dht_key_secret2) = vcrypto.generate_keypair().into_split();
let (dht_key2, dht_key_secret2) = vcrypto.generate_keypair().await.into_split();
trace!("dht_key2: {:?}", dht_key2);
trace!("dht_key_secret2: {:?}", dht_key_secret2);
let (dht_key3, _dht_key_secret3) = vcrypto.generate_keypair().into_split();
let (dht_key3, _dht_key_secret3) = vcrypto.generate_keypair().await.into_split();
trace!("dht_key3: {:?}", dht_key3);
trace!("_dht_key_secret3: {:?}", _dht_key_secret3);
@ -185,7 +209,7 @@ pub async fn test_key_conversions(vcrypto: CryptoSystemVersion) {
.is_err());
}
pub async fn test_encode_decode(vcrypto: CryptoSystemVersion) {
pub async fn test_encode_decode(vcrypto: &AsyncCryptoSystemGuard<'_>) {
let dht_key = PublicKey::try_decode("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA").unwrap();
let dht_key_secret =
SecretKey::try_decode("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA").unwrap();
@ -194,7 +218,7 @@ pub async fn test_encode_decode(vcrypto: CryptoSystemVersion) {
assert_eq!(dht_key, dht_key_b);
assert_eq!(dht_key_secret, dht_key_secret_b);
let (dht_key2, dht_key_secret2) = vcrypto.generate_keypair().into_split();
let (dht_key2, dht_key_secret2) = vcrypto.generate_keypair().await.into_split();
let e1 = dht_key.encode();
trace!("e1: {:?}", e1);
@ -229,7 +253,7 @@ pub async fn test_encode_decode(vcrypto: CryptoSystemVersion) {
assert!(f2.is_err());
}
pub async fn test_typed_convert(vcrypto: CryptoSystemVersion) {
pub async fn test_typed_convert(vcrypto: &AsyncCryptoSystemGuard<'_>) {
let tks1 = format!(
"{}:7lxDEabK_qgjbe38RtBa3IZLrud84P6NhGP-pRTZzdQ",
vcrypto.kind()
@ -261,15 +285,15 @@ pub async fn test_typed_convert(vcrypto: CryptoSystemVersion) {
assert!(tks6x.ends_with(&tks6));
}
async fn test_hash(vcrypto: CryptoSystemVersion) {
async fn test_hash(vcrypto: &AsyncCryptoSystemGuard<'_>) {
let mut s = BTreeSet::<PublicKey>::new();
let k1 = vcrypto.generate_hash("abc".as_bytes());
let k2 = vcrypto.generate_hash("abcd".as_bytes());
let k3 = vcrypto.generate_hash("".as_bytes());
let k4 = vcrypto.generate_hash(" ".as_bytes());
let k5 = vcrypto.generate_hash(LOREM_IPSUM.as_bytes());
let k6 = vcrypto.generate_hash(CHEEZBURGER.as_bytes());
let k1 = vcrypto.generate_hash("abc".as_bytes()).await;
let k2 = vcrypto.generate_hash("abcd".as_bytes()).await;
let k3 = vcrypto.generate_hash("".as_bytes()).await;
let k4 = vcrypto.generate_hash(" ".as_bytes()).await;
let k5 = vcrypto.generate_hash(LOREM_IPSUM.as_bytes()).await;
let k6 = vcrypto.generate_hash(CHEEZBURGER.as_bytes()).await;
s.insert(k1);
s.insert(k2);
@ -279,12 +303,12 @@ async fn test_hash(vcrypto: CryptoSystemVersion) {
s.insert(k6);
assert_eq!(s.len(), 6);
let v1 = vcrypto.generate_hash("abc".as_bytes());
let v2 = vcrypto.generate_hash("abcd".as_bytes());
let v3 = vcrypto.generate_hash("".as_bytes());
let v4 = vcrypto.generate_hash(" ".as_bytes());
let v5 = vcrypto.generate_hash(LOREM_IPSUM.as_bytes());
let v6 = vcrypto.generate_hash(CHEEZBURGER.as_bytes());
let v1 = vcrypto.generate_hash("abc".as_bytes()).await;
let v2 = vcrypto.generate_hash("abcd".as_bytes()).await;
let v3 = vcrypto.generate_hash("".as_bytes()).await;
let v4 = vcrypto.generate_hash(" ".as_bytes()).await;
let v5 = vcrypto.generate_hash(LOREM_IPSUM.as_bytes()).await;
let v6 = vcrypto.generate_hash(CHEEZBURGER.as_bytes()).await;
assert_eq!(k1, v1);
assert_eq!(k2, v2);
@ -293,24 +317,24 @@ async fn test_hash(vcrypto: CryptoSystemVersion) {
assert_eq!(k5, v5);
assert_eq!(k6, v6);
vcrypto.validate_hash("abc".as_bytes(), &v1);
vcrypto.validate_hash("abcd".as_bytes(), &v2);
vcrypto.validate_hash("".as_bytes(), &v3);
vcrypto.validate_hash(" ".as_bytes(), &v4);
vcrypto.validate_hash(LOREM_IPSUM.as_bytes(), &v5);
vcrypto.validate_hash(CHEEZBURGER.as_bytes(), &v6);
vcrypto.validate_hash("abc".as_bytes(), &v1).await;
vcrypto.validate_hash("abcd".as_bytes(), &v2).await;
vcrypto.validate_hash("".as_bytes(), &v3).await;
vcrypto.validate_hash(" ".as_bytes(), &v4).await;
vcrypto.validate_hash(LOREM_IPSUM.as_bytes(), &v5).await;
vcrypto.validate_hash(CHEEZBURGER.as_bytes(), &v6).await;
}
async fn test_operations(vcrypto: CryptoSystemVersion) {
let k1 = vcrypto.generate_hash(LOREM_IPSUM.as_bytes());
let k2 = vcrypto.generate_hash(CHEEZBURGER.as_bytes());
let k3 = vcrypto.generate_hash("abc".as_bytes());
async fn test_operations(vcrypto: &AsyncCryptoSystemGuard<'_>) {
let k1 = vcrypto.generate_hash(LOREM_IPSUM.as_bytes()).await;
let k2 = vcrypto.generate_hash(CHEEZBURGER.as_bytes()).await;
let k3 = vcrypto.generate_hash("abc".as_bytes()).await;
// Get distance
let d1 = vcrypto.distance(&k1, &k2);
let d2 = vcrypto.distance(&k2, &k1);
let d3 = vcrypto.distance(&k1, &k3);
let d4 = vcrypto.distance(&k2, &k3);
let d1 = vcrypto.distance(&k1, &k2).await;
let d2 = vcrypto.distance(&k2, &k1).await;
let d3 = vcrypto.distance(&k1, &k3).await;
let d4 = vcrypto.distance(&k2, &k3).await;
trace!("d1={:?}", d1);
trace!("d2={:?}", d2);
@ -393,15 +417,15 @@ pub async fn test_all() {
// Test versions
for v in VALID_CRYPTO_KINDS {
let vcrypto = crypto.get(v).unwrap();
let vcrypto = crypto.get_async(v).unwrap();
test_generate_secret(vcrypto.clone()).await;
test_sign_and_verify(vcrypto.clone()).await;
test_key_conversions(vcrypto.clone()).await;
test_encode_decode(vcrypto.clone()).await;
test_typed_convert(vcrypto.clone()).await;
test_hash(vcrypto.clone()).await;
test_operations(vcrypto).await;
test_generate_secret(&vcrypto).await;
test_sign_and_verify(&vcrypto).await;
test_key_conversions(&vcrypto).await;
test_encode_decode(&vcrypto).await;
test_typed_convert(&vcrypto).await;
test_hash(&vcrypto).await;
test_operations(&vcrypto).await;
}
crypto_tests_shutdown(api.clone()).await;

View File

@ -78,7 +78,11 @@ where
macro_rules! byte_array_type {
($name:ident, $size:expr, $encoded_size:expr) => {
#[derive(Clone, Copy, Hash, PartialOrd, Ord, PartialEq, Eq)]
#[cfg_attr(target_arch = "wasm32", derive(Tsify), tsify(into_wasm_abi))]
#[cfg_attr(
all(target_arch = "wasm32", target_os = "unknown"),
derive(Tsify),
tsify(into_wasm_abi)
)]
pub struct $name {
pub bytes: [u8; $size],
}
@ -280,17 +284,17 @@ macro_rules! byte_array_type {
byte_array_type!(CryptoKey, CRYPTO_KEY_LENGTH, CRYPTO_KEY_LENGTH_ENCODED);
#[cfg_attr(target_arch = "wasm32", declare)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type PublicKey = CryptoKey;
#[cfg_attr(target_arch = "wasm32", declare)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type SecretKey = CryptoKey;
#[cfg_attr(target_arch = "wasm32", declare)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type HashDigest = CryptoKey;
#[cfg_attr(target_arch = "wasm32", declare)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type SharedSecret = CryptoKey;
#[cfg_attr(target_arch = "wasm32", declare)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type RouteId = CryptoKey;
#[cfg_attr(target_arch = "wasm32", declare)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type CryptoKeyDistance = CryptoKey;
byte_array_type!(Signature, SIGNATURE_LENGTH, SIGNATURE_LENGTH_ENCODED);

View File

@ -2,7 +2,7 @@ use super::*;
#[derive(Clone, Copy, Default, PartialOrd, Ord, PartialEq, Eq, Hash)]
#[cfg_attr(
target_arch = "wasm32",
all(target_arch = "wasm32", target_os = "unknown"),
derive(Tsify),
tsify(from_wasm_abi, into_wasm_abi)
)]

View File

@ -6,7 +6,7 @@ use core::fmt;
use core::hash::Hash;
/// Cryptography version fourcc code
#[cfg_attr(target_arch = "wasm32", declare)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type CryptoKind = FourCC;
/// Sort best crypto kinds first
@ -52,24 +52,24 @@ pub use crypto_typed::*;
pub use crypto_typed_group::*;
pub use keypair::*;
#[cfg_attr(target_arch = "wasm32", declare)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type TypedKey = CryptoTyped<PublicKey>;
#[cfg_attr(target_arch = "wasm32", declare)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type TypedSecret = CryptoTyped<SecretKey>;
#[cfg_attr(target_arch = "wasm32", declare)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type TypedKeyPair = CryptoTyped<KeyPair>;
#[cfg_attr(target_arch = "wasm32", declare)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type TypedSignature = CryptoTyped<Signature>;
#[cfg_attr(target_arch = "wasm32", declare)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type TypedSharedSecret = CryptoTyped<SharedSecret>;
#[cfg_attr(target_arch = "wasm32", declare)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type TypedKeyGroup = CryptoTypedGroup<PublicKey>;
#[cfg_attr(target_arch = "wasm32", declare)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type TypedSecretGroup = CryptoTypedGroup<SecretKey>;
#[cfg_attr(target_arch = "wasm32", declare)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type TypedKeyPairGroup = CryptoTypedGroup<KeyPair>;
#[cfg_attr(target_arch = "wasm32", declare)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type TypedSignatureGroup = CryptoTypedGroup<Signature>;
#[cfg_attr(target_arch = "wasm32", declare)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), declare)]
pub type TypedSharedSecretGroup = CryptoTypedGroup<SharedSecret>;

View File

@ -47,14 +47,13 @@ pub fn vld0_generate_keypair() -> KeyPair {
}
/// V0 CryptoSystem
#[derive(Clone)]
pub struct CryptoSystemVLD0 {
crypto: Crypto,
registry: VeilidComponentRegistry,
}
impl CryptoSystemVLD0 {
pub fn new(crypto: Crypto) -> Self {
Self { crypto }
pub fn new(registry: VeilidComponentRegistry) -> Self {
Self { registry }
}
}
@ -64,14 +63,14 @@ impl CryptoSystem for CryptoSystemVLD0 {
CRYPTO_KIND_VLD0
}
fn crypto(&self) -> Crypto {
self.crypto.clone()
fn crypto(&self) -> VeilidComponentGuard<'_, Crypto> {
self.registry.lookup::<Crypto>().unwrap()
}
// Cached Operations
#[instrument(level = "trace", skip_all)]
fn cached_dh(&self, key: &PublicKey, secret: &SecretKey) -> VeilidAPIResult<SharedSecret> {
self.crypto
self.crypto()
.cached_dh_internal::<CryptoSystemVLD0>(self, key, secret)
}

View File

@ -1,12 +1,12 @@
use super::*;
#[cfg(target_arch = "wasm32")]
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
mod wasm;
#[cfg(target_arch = "wasm32")]
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
pub use wasm::*;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod native;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
pub use native::*;
pub static KNOWN_PROTECTED_STORE_KEYS: [&str; 2] = ["device_encryption_key", "_test_key"];

View File

@ -4,31 +4,37 @@ struct BlockStoreInner {
//
}
#[derive(Clone)]
pub struct BlockStore {
event_bus: EventBus,
config: VeilidConfig,
inner: Arc<Mutex<BlockStoreInner>>,
impl fmt::Debug for BlockStoreInner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BlockStoreInner").finish()
}
}
#[derive(Debug)]
pub struct BlockStore {
registry: VeilidComponentRegistry,
inner: Mutex<BlockStoreInner>,
}
impl_veilid_component!(BlockStore);
impl BlockStore {
fn new_inner() -> BlockStoreInner {
BlockStoreInner {}
}
pub fn new(event_bus: EventBus, config: VeilidConfig) -> Self {
pub fn new(registry: VeilidComponentRegistry) -> Self {
Self {
event_bus,
config,
inner: Arc::new(Mutex::new(Self::new_inner())),
registry,
inner: Mutex::new(Self::new_inner()),
}
}
pub async fn init(&self) -> EyreResult<()> {
async fn init_async(&self) -> EyreResult<()> {
// Ensure permissions are correct
// ensure_file_private_owner(&dbpath)?;
Ok(())
}
pub async fn terminate(&self) {}
async fn terminate_async(&self) {}
}

View File

@ -6,14 +6,20 @@ use std::path::Path;
pub struct ProtectedStoreInner {
keyring_manager: Option<KeyringManager>,
}
#[derive(Clone)]
pub struct ProtectedStore {
_event_bus: EventBus,
config: VeilidConfig,
inner: Arc<Mutex<ProtectedStoreInner>>,
impl fmt::Debug for ProtectedStoreInner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ProtectedStoreInner").finish()
}
}
#[derive(Debug)]
pub struct ProtectedStore {
registry: VeilidComponentRegistry,
inner: Mutex<ProtectedStoreInner>,
}
impl_veilid_component!(ProtectedStore);
impl ProtectedStore {
fn new_inner() -> ProtectedStoreInner {
ProtectedStoreInner {
@ -21,11 +27,10 @@ impl ProtectedStore {
}
}
pub fn new(event_bus: EventBus, config: VeilidConfig) -> Self {
pub fn new(registry: VeilidComponentRegistry) -> Self {
Self {
_event_bus: event_bus,
config,
inner: Arc::new(Mutex::new(Self::new_inner())),
registry,
inner: Mutex::new(Self::new_inner()),
}
}
@ -42,9 +47,10 @@ impl ProtectedStore {
}
#[instrument(level = "debug", skip(self), err)]
pub async fn init(&self) -> EyreResult<()> {
async fn init_async(&self) -> EyreResult<()> {
let delete = {
let c = self.config.get();
let config = self.config();
let c = config.get();
let mut inner = self.inner.lock();
if !c.protected_store.always_use_insecure_storage {
// Attempt to open the secure keyring
@ -101,13 +107,22 @@ impl ProtectedStore {
Ok(())
}
#[instrument(level = "debug", skip(self), err)]
async fn post_init_async(&self) -> EyreResult<()> {
Ok(())
}
#[instrument(level = "debug", skip(self))]
pub async fn terminate(&self) {
async fn pre_terminate_async(&self) {}
#[instrument(level = "debug", skip(self))]
async fn terminate_async(&self) {
*self.inner.lock() = Self::new_inner();
}
fn service_name(&self) -> String {
let c = self.config.get();
let config = self.config();
let c = config.get();
if c.namespace.is_empty() {
"veilid_protected_store".to_owned()
} else {

View File

@ -4,28 +4,34 @@ struct BlockStoreInner {
//
}
#[derive(Clone)]
pub struct BlockStore {
event_bus: EventBus,
config: VeilidConfig,
inner: Arc<Mutex<BlockStoreInner>>,
impl fmt::Debug for BlockStoreInner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BlockStoreInner").finish()
}
}
#[derive(Debug)]
pub struct BlockStore {
registry: VeilidComponentRegistry,
inner: Mutex<BlockStoreInner>,
}
impl_veilid_component!(BlockStore);
impl BlockStore {
fn new_inner() -> BlockStoreInner {
BlockStoreInner {}
}
pub fn new(event_bus: EventBus, config: VeilidConfig) -> Self {
pub fn new(registry: VeilidComponentRegistry) -> Self {
Self {
event_bus,
config,
inner: Arc::new(Mutex::new(Self::new_inner())),
registry,
inner: Mutex::new(Self::new_inner()),
}
}
pub async fn init(&self) -> EyreResult<()> {
async fn init_async(&self) -> EyreResult<()> {
Ok(())
}
pub async fn terminate(&self) {}
async fn terminate_async(&self) {}
}

View File

@ -3,18 +3,16 @@ use data_encoding::BASE64URL_NOPAD;
use web_sys::*;
#[derive(Clone)]
#[derive(Debug)]
pub struct ProtectedStore {
_event_bus: EventBus,
config: VeilidConfig,
registry: VeilidComponentRegistry,
}
impl_veilid_component!(ProtectedStore);
impl ProtectedStore {
pub fn new(event_bus: EventBus, config: VeilidConfig) -> Self {
Self {
_event_bus: event_bus,
config,
}
pub fn new(registry: VeilidComponentRegistry) -> Self {
Self { registry }
}
#[instrument(level = "trace", skip(self), err)]
@ -30,15 +28,24 @@ impl ProtectedStore {
}
#[instrument(level = "debug", skip(self), err)]
pub async fn init(&self) -> EyreResult<()> {
pub(crate) async fn init_async(&self) -> EyreResult<()> {
Ok(())
}
#[instrument(level = "debug", skip(self), err)]
pub(crate) async fn post_init_async(&self) -> EyreResult<()> {
Ok(())
}
#[instrument(level = "debug", skip(self))]
pub async fn terminate(&self) {}
pub(crate) async fn pre_terminate_async(&self) {}
#[instrument(level = "debug", skip(self))]
pub(crate) async fn terminate_async(&self) {}
fn browser_key_name(&self, key: &str) -> String {
let c = self.config.get();
let config = self.config();
let c = config.get();
if c.namespace.is_empty() {
format!("__veilid_protected_store_{}", key)
} else {

View File

@ -28,7 +28,7 @@
#![recursion_limit = "256"]
cfg_if::cfg_if! {
if #[cfg(target_arch = "wasm32")] {
if #[cfg(all(target_arch = "wasm32", target_os = "unknown"))] {
#[cfg(any(feature = "rt-async-std", feature = "rt-tokio"))]
compile_error!("features \"rt-async-std\" and \"rt-tokio\" can not be specified for WASM");
} else {
@ -45,6 +45,7 @@ cfg_if::cfg_if! {
extern crate alloc;
mod attachment_manager;
mod component;
mod core_context;
mod crypto;
mod intf;
@ -58,6 +59,8 @@ mod veilid_api;
mod veilid_config;
mod wasm_helpers;
pub(crate) use self::component::*;
pub(crate) use self::core_context::RegisteredComponents;
pub use self::core_context::{api_startup, api_startup_config, api_startup_json, UpdateCallback};
pub use self::logging::{
ApiTracingLayer, VeilidLayerFilter, DEFAULT_LOG_FACILITIES_ENABLED_LIST,

View File

@ -2,7 +2,6 @@ use crate::core_context::*;
use crate::veilid_api::*;
use crate::*;
use core::fmt::Write;
use once_cell::sync::OnceCell;
use tracing_subscriber::*;
struct ApiTracingLayerInner {
@ -21,11 +20,10 @@ struct ApiTracingLayerInner {
/// with many copies of Veilid running.
#[derive(Clone)]
pub struct ApiTracingLayer {
inner: Arc<Mutex<Option<ApiTracingLayerInner>>>,
}
pub struct ApiTracingLayer {}
static API_LOGGER: OnceCell<ApiTracingLayer> = OnceCell::new();
static API_LOGGER_INNER: Mutex<Option<ApiTracingLayerInner>> = Mutex::new(None);
static API_LOGGER_ENABLED: AtomicBool = AtomicBool::new(false);
impl ApiTracingLayer {
/// Initialize an ApiTracingLayer singleton
@ -33,11 +31,7 @@ impl ApiTracingLayer {
/// This must be inserted into your tracing subscriber before you
/// call api_startup() or api_startup_json() if you are going to use api tracing.
pub fn init() -> ApiTracingLayer {
API_LOGGER
.get_or_init(|| ApiTracingLayer {
inner: Arc::new(Mutex::new(None)),
})
.clone()
ApiTracingLayer {}
}
fn new_inner() -> ApiTracingLayerInner {
@ -52,12 +46,7 @@ impl ApiTracingLayer {
namespace: String,
update_callback: UpdateCallback,
) -> VeilidAPIResult<()> {
let Some(api_logger) = API_LOGGER.get() else {
// Did not init, so skip this
return Ok(());
};
let mut inner = api_logger.inner.lock();
let mut inner = API_LOGGER_INNER.lock();
if inner.is_none() {
*inner = Some(Self::new_inner());
}
@ -70,6 +59,9 @@ impl ApiTracingLayer {
.unwrap()
.update_callbacks
.insert(key, update_callback);
API_LOGGER_ENABLED.store(true, Ordering::Release);
return Ok(());
}
@ -79,28 +71,29 @@ impl ApiTracingLayer {
namespace: String,
) -> VeilidAPIResult<()> {
let key = (program_name, namespace);
if let Some(api_logger) = API_LOGGER.get() {
let mut inner = api_logger.inner.lock();
if inner.is_none() {
apibail_not_initialized!();
}
if inner
.as_mut()
.unwrap()
.update_callbacks
.remove(&key)
.is_none()
{
apibail_not_initialized!();
}
if inner.as_mut().unwrap().update_callbacks.is_empty() {
*inner = None;
}
let mut inner = API_LOGGER_INNER.lock();
if inner.is_none() {
apibail_not_initialized!();
}
if inner
.as_mut()
.unwrap()
.update_callbacks
.remove(&key)
.is_none()
{
apibail_not_initialized!();
}
if inner.as_mut().unwrap().update_callbacks.is_empty() {
*inner = None;
API_LOGGER_ENABLED.store(false, Ordering::Release);
}
Ok(())
}
fn emit_log(&self, inner: &mut ApiTracingLayerInner, meta: &Metadata<'_>, message: String) {
fn emit_log(&self, meta: &'static Metadata<'static>, message: String) {
let level = *meta.level();
let target = meta.target();
let log_level = VeilidLogLevel::from_tracing_level(level);
@ -148,8 +141,10 @@ impl ApiTracingLayer {
backtrace,
}));
for cb in inner.update_callbacks.values() {
(cb)(log_update.clone());
if let Some(inner) = &mut *API_LOGGER_INNER.lock() {
for cb in inner.update_callbacks.values() {
(cb)(log_update.clone());
}
}
}
}
@ -159,17 +154,23 @@ pub struct SpanDuration {
end: Timestamp,
}
fn simplify_file(file: &str) -> String {
let path = std::path::Path::new(file);
let path_component_count = path.iter().count();
if path.ends_with("mod.rs") && path_component_count >= 2 {
let outpath: std::path::PathBuf = path.iter().skip(path_component_count - 2).collect();
outpath.to_string_lossy().to_string()
} else if let Some(filename) = path.file_name() {
filename.to_string_lossy().to_string()
} else {
file.to_string()
}
fn simplify_file(file: &'static str) -> &'static str {
file.static_transform(|file| {
let out = {
let path = std::path::Path::new(file);
let path_component_count = path.iter().count();
if path.ends_with("mod.rs") && path_component_count >= 2 {
let outpath: std::path::PathBuf =
path.iter().skip(path_component_count - 2).collect();
outpath.to_string_lossy().to_string()
} else if let Some(filename) = path.file_name() {
filename.to_string_lossy().to_string()
} else {
file.to_string()
}
};
out.to_static_str()
})
}
impl<S: Subscriber + for<'a> registry::LookupSpan<'a>> Layer<S> for ApiTracingLayer {
@ -179,47 +180,51 @@ impl<S: Subscriber + for<'a> registry::LookupSpan<'a>> Layer<S> for ApiTracingLa
id: &tracing::Id,
ctx: layer::Context<'_, S>,
) {
if let Some(_inner) = &mut *self.inner.lock() {
let mut new_debug_record = StringRecorder::new();
attrs.record(&mut new_debug_record);
if !API_LOGGER_ENABLED.load(Ordering::Acquire) {
// Optimization if api logger has no callbacks
return;
}
if let Some(span_ref) = ctx.span(id) {
let mut new_debug_record = StringRecorder::new();
attrs.record(&mut new_debug_record);
if let Some(span_ref) = ctx.span(id) {
span_ref
.extensions_mut()
.insert::<StringRecorder>(new_debug_record);
if crate::DURATION_LOG_FACILITIES.contains(&attrs.metadata().target()) {
span_ref
.extensions_mut()
.insert::<StringRecorder>(new_debug_record);
if crate::DURATION_LOG_FACILITIES.contains(&attrs.metadata().target()) {
span_ref
.extensions_mut()
.insert::<SpanDuration>(SpanDuration {
start: Timestamp::now(),
end: Timestamp::default(),
});
}
.insert::<SpanDuration>(SpanDuration {
start: Timestamp::now(),
end: Timestamp::default(),
});
}
}
}
fn on_close(&self, id: span::Id, ctx: layer::Context<'_, S>) {
if let Some(inner) = &mut *self.inner.lock() {
if let Some(span_ref) = ctx.span(&id) {
if let Some(span_duration) = span_ref.extensions_mut().get_mut::<SpanDuration>() {
span_duration.end = Timestamp::now();
let duration = span_duration.end.saturating_sub(span_duration.start);
let meta = span_ref.metadata();
self.emit_log(
inner,
meta,
format!(
" {}{}: duration={}",
span_ref
.parent()
.map(|p| format!("{}::", p.name()))
.unwrap_or_default(),
span_ref.name(),
format_opt_ts(Some(duration))
),
);
}
if !API_LOGGER_ENABLED.load(Ordering::Acquire) {
// Optimization if api logger has no callbacks
return;
}
if let Some(span_ref) = ctx.span(&id) {
if let Some(span_duration) = span_ref.extensions_mut().get_mut::<SpanDuration>() {
span_duration.end = Timestamp::now();
let duration = span_duration.end.saturating_sub(span_duration.start);
let meta = span_ref.metadata();
self.emit_log(
meta,
format!(
" {}{}: duration={}",
span_ref
.parent()
.map(|p| format!("{}::", p.name()))
.unwrap_or_default(),
span_ref.name(),
format_opt_ts(Some(duration))
),
);
}
}
}
@ -230,22 +235,26 @@ impl<S: Subscriber + for<'a> registry::LookupSpan<'a>> Layer<S> for ApiTracingLa
values: &tracing::span::Record<'_>,
ctx: layer::Context<'_, S>,
) {
if let Some(_inner) = &mut *self.inner.lock() {
if let Some(span_ref) = ctx.span(id) {
if let Some(debug_record) = span_ref.extensions_mut().get_mut::<StringRecorder>() {
values.record(debug_record);
}
if !API_LOGGER_ENABLED.load(Ordering::Acquire) {
// Optimization if api logger has no callbacks
return;
}
if let Some(span_ref) = ctx.span(id) {
if let Some(debug_record) = span_ref.extensions_mut().get_mut::<StringRecorder>() {
values.record(debug_record);
}
}
}
fn on_event(&self, event: &tracing::Event<'_>, _ctx: layer::Context<'_, S>) {
if let Some(inner) = &mut *self.inner.lock() {
let mut recorder = StringRecorder::new();
event.record(&mut recorder);
let meta = event.metadata();
self.emit_log(inner, meta, recorder.to_string());
if !API_LOGGER_ENABLED.load(Ordering::Acquire) {
// Optimization if api logger has no callbacks
return;
}
let mut recorder = StringRecorder::new();
event.record(&mut recorder);
let meta = event.metadata();
self.emit_log(meta, recorder.to_string());
}
}

View File

@ -151,42 +151,6 @@ macro_rules! log_client_api {
}
}
#[macro_export]
macro_rules! log_network_result {
(error $text:expr) => {error!(
target: "network_result",
"{}",
$text,
)};
(error $fmt:literal, $($arg:expr),+) => {
error!(target: "network_result", $fmt, $($arg),+);
};
(warn $text:expr) => {warn!(
target: "network_result",
"{}",
$text,
)};
(warn $fmt:literal, $($arg:expr),+) => {
warn!(target:"network_result", $fmt, $($arg),+);
};
(debug $text:expr) => {debug!(
target: "network_result",
"{}",
$text,
)};
(debug $fmt:literal, $($arg:expr),+) => {
debug!(target:"network_result", $fmt, $($arg),+);
};
($text:expr) => {trace!(
target: "network_result",
"{}",
$text,
)};
($fmt:literal, $($arg:expr),+) => {
trace!(target:"network_result", $fmt, $($arg),+);
}
}
#[macro_export]
macro_rules! log_rpc {
(error $text:expr) => { error!(
@ -421,6 +385,14 @@ macro_rules! log_crypto {
(warn $fmt:literal, $($arg:expr),+) => {
warn!(target:"crypto", $fmt, $($arg),+);
};
(debug $text:expr) => { debug!(
target: "crypto",
"{}",
$text,
)};
(debug $fmt:literal, $($arg:expr),+) => {
debug!(target:"crypto", $fmt, $($arg),+);
};
($text:expr) => {trace!(
target: "crypto",
"{}",

View File

@ -23,6 +23,7 @@ pub const ADDRESS_CHECK_CACHE_SIZE: usize = 10;
// TimestampDuration::new(3_600_000_000_u64); // 60 minutes
/// Address checker config
#[derive(Debug)]
pub struct AddressCheckConfig {
pub detect_address_changes: bool,
pub ip6_prefix_size: usize,
@ -44,6 +45,22 @@ pub struct AddressCheck {
address_consistency_table: BTreeMap<AddressCheckCacheKey, LruCache<IpAddr, SocketAddress>>,
}
impl fmt::Debug for AddressCheck {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AddressCheck")
.field("config", &self.config)
//.field("net", &self.net)
.field("current_network_class", &self.current_network_class)
.field("current_addresses", &self.current_addresses)
.field(
"address_inconsistency_table",
&self.address_inconsistency_table,
)
.field("address_consistency_table", &self.address_consistency_table)
.finish()
}
}
impl AddressCheck {
pub fn new(config: AddressCheckConfig, net: Network) -> Self {
Self {

View File

@ -32,63 +32,27 @@ struct AddressFilterInner {
dial_info_failures: BTreeMap<DialInfo, Timestamp>,
}
struct AddressFilterUnlockedInner {
#[derive(Debug)]
pub(crate) struct AddressFilter {
registry: VeilidComponentRegistry,
inner: Mutex<AddressFilterInner>,
max_connections_per_ip4: usize,
max_connections_per_ip6_prefix: usize,
max_connections_per_ip6_prefix_size: usize,
max_connection_frequency_per_min: usize,
punishment_duration_min: usize,
dial_info_failure_duration_min: usize,
routing_table: RoutingTable,
}
impl fmt::Debug for AddressFilterUnlockedInner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AddressFilterUnlockedInner")
.field("max_connections_per_ip4", &self.max_connections_per_ip4)
.field(
"max_connections_per_ip6_prefix",
&self.max_connections_per_ip6_prefix,
)
.field(
"max_connections_per_ip6_prefix_size",
&self.max_connections_per_ip6_prefix_size,
)
.field(
"max_connection_frequency_per_min",
&self.max_connection_frequency_per_min,
)
.field("punishment_duration_min", &self.punishment_duration_min)
.field(
"dial_info_failure_duration_min",
&self.dial_info_failure_duration_min,
)
.finish()
}
}
#[derive(Clone, Debug)]
pub(crate) struct AddressFilter {
unlocked_inner: Arc<AddressFilterUnlockedInner>,
inner: Arc<Mutex<AddressFilterInner>>,
}
impl_veilid_component_registry_accessor!(AddressFilter);
impl AddressFilter {
pub fn new(config: VeilidConfig, routing_table: RoutingTable) -> Self {
pub fn new(registry: VeilidComponentRegistry) -> Self {
let config = registry.config();
let c = config.get();
Self {
unlocked_inner: Arc::new(AddressFilterUnlockedInner {
max_connections_per_ip4: c.network.max_connections_per_ip4 as usize,
max_connections_per_ip6_prefix: c.network.max_connections_per_ip6_prefix as usize,
max_connections_per_ip6_prefix_size: c.network.max_connections_per_ip6_prefix_size
as usize,
max_connection_frequency_per_min: c.network.max_connection_frequency_per_min
as usize,
punishment_duration_min: PUNISHMENT_DURATION_MIN,
dial_info_failure_duration_min: DIAL_INFO_FAILURE_DURATION_MIN,
routing_table,
}),
inner: Arc::new(Mutex::new(AddressFilterInner {
registry,
inner: Mutex::new(AddressFilterInner {
conn_count_by_ip4: BTreeMap::new(),
conn_count_by_ip6_prefix: BTreeMap::new(),
conn_timestamps_by_ip4: BTreeMap::new(),
@ -97,7 +61,14 @@ impl AddressFilter {
punishments_by_ip6_prefix: BTreeMap::new(),
punishments_by_node_id: BTreeMap::new(),
dial_info_failures: BTreeMap::new(),
})),
}),
max_connections_per_ip4: c.network.max_connections_per_ip4 as usize,
max_connections_per_ip6_prefix: c.network.max_connections_per_ip6_prefix as usize,
max_connections_per_ip6_prefix_size: c.network.max_connections_per_ip6_prefix_size
as usize,
max_connection_frequency_per_min: c.network.max_connection_frequency_per_min as usize,
punishment_duration_min: PUNISHMENT_DURATION_MIN,
dial_info_failure_duration_min: DIAL_INFO_FAILURE_DURATION_MIN,
}
}
@ -109,7 +80,7 @@ impl AddressFilter {
inner.dial_info_failures.clear();
}
fn purge_old_timestamps(&self, inner: &mut AddressFilterInner, cur_ts: Timestamp) {
fn purge_old_timestamps_inner(&self, inner: &mut AddressFilterInner, cur_ts: Timestamp) {
// v4
{
let mut dead_keys = Vec::<Ipv4Addr>::new();
@ -151,7 +122,7 @@ impl AddressFilter {
for (key, value) in &mut inner.punishments_by_ip4 {
// Drop punishments older than the punishment duration
if cur_ts.as_u64().saturating_sub(value.timestamp.as_u64())
> self.unlocked_inner.punishment_duration_min as u64 * 60_000_000u64
> self.punishment_duration_min as u64 * 60_000_000u64
{
dead_keys.push(*key);
}
@ -167,7 +138,7 @@ impl AddressFilter {
for (key, value) in &mut inner.punishments_by_ip6_prefix {
// Drop punishments older than the punishment duration
if cur_ts.as_u64().saturating_sub(value.timestamp.as_u64())
> self.unlocked_inner.punishment_duration_min as u64 * 60_000_000u64
> self.punishment_duration_min as u64 * 60_000_000u64
{
dead_keys.push(*key);
}
@ -183,7 +154,7 @@ impl AddressFilter {
for (key, value) in &mut inner.punishments_by_node_id {
// Drop punishments older than the punishment duration
if cur_ts.as_u64().saturating_sub(value.timestamp.as_u64())
> self.unlocked_inner.punishment_duration_min as u64 * 60_000_000u64
> self.punishment_duration_min as u64 * 60_000_000u64
{
dead_keys.push(*key);
}
@ -192,7 +163,7 @@ impl AddressFilter {
warn!("Forgiving: {}", key);
inner.punishments_by_node_id.remove(&key);
// make the entry alive again if it's still here
if let Ok(Some(nr)) = self.unlocked_inner.routing_table.lookup_node_ref(key) {
if let Ok(Some(nr)) = self.routing_table().lookup_node_ref(key) {
nr.operate_mut(|_rti, e| e.set_punished(None));
}
}
@ -203,7 +174,7 @@ impl AddressFilter {
for (key, value) in &mut inner.dial_info_failures {
// Drop failures older than the failure duration
if cur_ts.as_u64().saturating_sub(value.as_u64())
> self.unlocked_inner.dial_info_failure_duration_min as u64 * 60_000_000u64
> self.dial_info_failure_duration_min as u64 * 60_000_000u64
{
dead_keys.push(key.clone());
}
@ -241,10 +212,7 @@ impl AddressFilter {
pub fn is_ip_addr_punished(&self, addr: IpAddr) -> bool {
let inner = self.inner.lock();
let ipblock = ip_to_ipblock(
self.unlocked_inner.max_connections_per_ip6_prefix_size,
addr,
);
let ipblock = ip_to_ipblock(self.max_connections_per_ip6_prefix_size, addr);
self.is_ip_addr_punished_inner(&inner, ipblock)
}
@ -273,8 +241,9 @@ impl AddressFilter {
let mut inner = self.inner.lock();
inner.punishments_by_ip4.clear();
inner.punishments_by_ip6_prefix.clear();
self.unlocked_inner.routing_table.clear_punishments();
inner.punishments_by_node_id.clear();
self.routing_table().clear_punishments();
}
pub fn punish_ip_addr(&self, addr: IpAddr, reason: PunishmentReason) {
@ -282,10 +251,7 @@ impl AddressFilter {
let timestamp = Timestamp::now();
let punishment = Punishment { reason, timestamp };
let ipblock = ip_to_ipblock(
self.unlocked_inner.max_connections_per_ip6_prefix_size,
addr,
);
let ipblock = ip_to_ipblock(self.max_connections_per_ip6_prefix_size, addr);
let mut inner = self.inner.lock();
match ipblock {
@ -315,7 +281,7 @@ impl AddressFilter {
}
pub fn punish_node_id(&self, node_id: TypedKey, reason: PunishmentReason) {
if let Ok(Some(nr)) = self.unlocked_inner.routing_table.lookup_node_ref(node_id) {
if let Ok(Some(nr)) = self.routing_table().lookup_node_ref(node_id) {
// make the entry dead if it's punished
nr.operate_mut(|_rti, e| e.set_punished(Some(reason)));
}
@ -338,14 +304,14 @@ impl AddressFilter {
#[instrument(parent = None, level = "trace", skip_all, err)]
pub async fn address_filter_task_routine(
self,
&self,
_stop_token: StopToken,
_last_ts: Timestamp,
cur_ts: Timestamp,
) -> EyreResult<()> {
//
let mut inner = self.inner.lock();
self.purge_old_timestamps(&mut inner, cur_ts);
self.purge_old_timestamps_inner(&mut inner, cur_ts);
self.purge_old_punishments(&mut inner, cur_ts);
Ok(())
@ -354,23 +320,20 @@ impl AddressFilter {
pub fn add_connection(&self, addr: IpAddr) -> Result<(), AddressFilterError> {
let inner = &mut *self.inner.lock();
let ipblock = ip_to_ipblock(
self.unlocked_inner.max_connections_per_ip6_prefix_size,
addr,
);
let ipblock = ip_to_ipblock(self.max_connections_per_ip6_prefix_size, addr);
if self.is_ip_addr_punished_inner(inner, ipblock) {
return Err(AddressFilterError::Punished);
}
let ts = Timestamp::now();
self.purge_old_timestamps(inner, ts);
self.purge_old_timestamps_inner(inner, ts);
match ipblock {
IpAddr::V4(v4) => {
// See if we have too many connections from this ip block
let cnt = inner.conn_count_by_ip4.entry(v4).or_default();
assert!(*cnt <= self.unlocked_inner.max_connections_per_ip4);
if *cnt == self.unlocked_inner.max_connections_per_ip4 {
assert!(*cnt <= self.max_connections_per_ip4);
if *cnt == self.max_connections_per_ip4 {
warn!("Address filter count exceeded: {:?}", v4);
return Err(AddressFilterError::CountExceeded);
}
@ -380,8 +343,8 @@ impl AddressFilter {
// keep timestamps that are less than a minute away
ts.saturating_sub(*v) < TimestampDuration::new(60_000_000u64)
});
assert!(tstamps.len() <= self.unlocked_inner.max_connection_frequency_per_min);
if tstamps.len() == self.unlocked_inner.max_connection_frequency_per_min {
assert!(tstamps.len() <= self.max_connection_frequency_per_min);
if tstamps.len() == self.max_connection_frequency_per_min {
warn!("Address filter rate exceeded: {:?}", v4);
return Err(AddressFilterError::RateExceeded);
}
@ -393,15 +356,15 @@ impl AddressFilter {
IpAddr::V6(v6) => {
// See if we have too many connections from this ip block
let cnt = inner.conn_count_by_ip6_prefix.entry(v6).or_default();
assert!(*cnt <= self.unlocked_inner.max_connections_per_ip6_prefix);
if *cnt == self.unlocked_inner.max_connections_per_ip6_prefix {
assert!(*cnt <= self.max_connections_per_ip6_prefix);
if *cnt == self.max_connections_per_ip6_prefix {
warn!("Address filter count exceeded: {:?}", v6);
return Err(AddressFilterError::CountExceeded);
}
// See if this ip block has connected too frequently
let tstamps = inner.conn_timestamps_by_ip6_prefix.entry(v6).or_default();
assert!(tstamps.len() <= self.unlocked_inner.max_connection_frequency_per_min);
if tstamps.len() == self.unlocked_inner.max_connection_frequency_per_min {
assert!(tstamps.len() <= self.max_connection_frequency_per_min);
if tstamps.len() == self.max_connection_frequency_per_min {
warn!("Address filter rate exceeded: {:?}", v6);
return Err(AddressFilterError::RateExceeded);
}
@ -414,16 +377,13 @@ impl AddressFilter {
Ok(())
}
pub fn remove_connection(&mut self, addr: IpAddr) -> Result<(), AddressNotInTableError> {
pub fn remove_connection(&self, addr: IpAddr) -> Result<(), AddressNotInTableError> {
let mut inner = self.inner.lock();
let ipblock = ip_to_ipblock(
self.unlocked_inner.max_connections_per_ip6_prefix_size,
addr,
);
let ipblock = ip_to_ipblock(self.max_connections_per_ip6_prefix_size, addr);
let ts = Timestamp::now();
self.purge_old_timestamps(&mut inner, ts);
self.purge_old_timestamps_inner(&mut inner, ts);
match ipblock {
IpAddr::V4(v4) => {

View File

@ -57,17 +57,16 @@ struct ConnectionManagerInner {
async_processor_jh: Option<MustJoinHandle<()>>,
stop_source: Option<StopSource>,
protected_addresses: HashMap<SocketAddress, ProtectedAddress>,
reconnection_processor: DeferredStreamProcessor,
}
struct ConnectionManagerArc {
network_manager: NetworkManager,
connection_initial_timeout_ms: u32,
connection_inactivity_timeout_ms: u32,
connection_table: ConnectionTable,
address_lock_table: AsyncTagLockTable<SocketAddr>,
startup_lock: StartupLock,
inner: Mutex<Option<ConnectionManagerInner>>,
reconnection_processor: DeferredStreamProcessor,
}
impl core::fmt::Debug for ConnectionManagerArc {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
@ -79,15 +78,17 @@ impl core::fmt::Debug for ConnectionManagerArc {
#[derive(Debug, Clone)]
pub struct ConnectionManager {
registry: VeilidComponentRegistry,
arc: Arc<ConnectionManagerArc>,
}
impl_veilid_component_registry_accessor!(ConnectionManager);
impl ConnectionManager {
fn new_inner(
stop_source: StopSource,
sender: flume::Sender<ConnectionManagerEvent>,
async_processor_jh: MustJoinHandle<()>,
reconnection_processor: DeferredStreamProcessor,
) -> ConnectionManagerInner {
ConnectionManagerInner {
next_id: 0.into(),
@ -95,11 +96,10 @@ impl ConnectionManager {
sender,
async_processor_jh: Some(async_processor_jh),
protected_addresses: HashMap::new(),
reconnection_processor,
}
}
fn new_arc(network_manager: NetworkManager) -> ConnectionManagerArc {
let config = network_manager.config();
fn new_arc(registry: VeilidComponentRegistry) -> ConnectionManagerArc {
let config = registry.config();
let (connection_initial_timeout_ms, connection_inactivity_timeout_ms) = {
let c = config.get();
(
@ -107,28 +107,24 @@ impl ConnectionManager {
c.network.connection_inactivity_timeout_ms,
)
};
let address_filter = network_manager.address_filter();
ConnectionManagerArc {
network_manager,
reconnection_processor: DeferredStreamProcessor::new(),
connection_initial_timeout_ms,
connection_inactivity_timeout_ms,
connection_table: ConnectionTable::new(config, address_filter),
connection_table: ConnectionTable::new(registry),
address_lock_table: AsyncTagLockTable::new(),
startup_lock: StartupLock::new(),
inner: Mutex::new(None),
}
}
pub fn new(network_manager: NetworkManager) -> Self {
pub fn new(registry: VeilidComponentRegistry) -> Self {
Self {
arc: Arc::new(Self::new_arc(network_manager)),
arc: Arc::new(Self::new_arc(registry.clone())),
registry,
}
}
pub fn network_manager(&self) -> NetworkManager {
self.arc.network_manager.clone()
}
pub fn connection_inactivity_timeout_ms(&self) -> u32 {
self.arc.connection_inactivity_timeout_ms
}
@ -150,21 +146,17 @@ impl ConnectionManager {
self.clone().async_processor(stop_source.token(), receiver),
);
// Spawn the reconnection processor
let mut reconnection_processor = DeferredStreamProcessor::new();
reconnection_processor.init().await;
// Store in the inner object
let mut inner = self.arc.inner.lock();
if inner.is_some() {
panic!("shouldn't start connection manager twice without shutting it down first");
{
let mut inner = self.arc.inner.lock();
if inner.is_some() {
panic!("shouldn't start connection manager twice without shutting it down first");
}
*inner = Some(Self::new_inner(stop_source, sender, async_processor));
}
*inner = Some(Self::new_inner(
stop_source,
sender,
async_processor,
reconnection_processor,
));
// Spawn the reconnection processor
self.arc.reconnection_processor.init().await;
guard.success();
@ -178,6 +170,10 @@ impl ConnectionManager {
return;
};
// Stop the reconnection processor
log_net!(debug "stopping reconnection processor task");
self.arc.reconnection_processor.terminate().await;
// Remove the inner from the lock
let mut inner = {
let mut inner_lock = self.arc.inner.lock();
@ -188,9 +184,6 @@ impl ConnectionManager {
}
}
};
// Stop the reconnection processor
log_net!(debug "stopping reconnection processor task");
inner.reconnection_processor.terminate().await;
// Stop all the connections and the async processor
log_net!(debug "stopping async processor task");
drop(inner.stop_source.take());
@ -452,13 +445,14 @@ impl ConnectionManager {
// Attempt new connection
let mut retry_count = NEW_CONNECTION_RETRY_COUNT;
let network_manager = self.network_manager();
let prot_conn = network_result_try!(loop {
let result_net_res = ProtocolNetworkConnection::connect(
preferred_local_address,
&dial_info,
self.arc.connection_initial_timeout_ms,
self.network_manager().address_filter(),
network_manager.address_filter(),
)
.await;
match result_net_res {
@ -574,7 +568,7 @@ impl ConnectionManager {
// Called by low-level network when any connection-oriented protocol connection appears
// either from incoming connections.
#[cfg_attr(target_arch = "wasm32", expect(dead_code))]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), expect(dead_code))]
pub(super) async fn on_accepted_protocol_network_connection(
&self,
protocol_connection: ProtocolNetworkConnection,
@ -660,7 +654,7 @@ impl ConnectionManager {
// Reconnect the protected connection immediately
if reconnect {
if let Some(dial_info) = conn.dial_info() {
self.spawn_reconnector_inner(inner, dial_info);
self.spawn_reconnector(dial_info);
} else {
log_net!(debug "Can't reconnect to accepted protected connection: {} -> {} for node {}", conn.connection_id(), conn.debug_print(Timestamp::now()), protect_nr);
}
@ -675,9 +669,9 @@ impl ConnectionManager {
}
}
fn spawn_reconnector_inner(&self, inner: &mut ConnectionManagerInner, dial_info: DialInfo) {
fn spawn_reconnector(&self, dial_info: DialInfo) {
let this = self.clone();
inner.reconnection_processor.add(
self.arc.reconnection_processor.add(
Box::pin(futures_util::stream::once(async { dial_info })),
move |dial_info| {
let this = this.clone();

View File

@ -44,17 +44,20 @@ struct ConnectionTableInner {
protocol_index_by_id: BTreeMap<NetworkConnectionId, usize>,
id_by_flow: BTreeMap<Flow, NetworkConnectionId>,
ids_by_remote: BTreeMap<PeerAddress, Vec<NetworkConnectionId>>,
address_filter: AddressFilter,
priority_flows: Vec<LruCache<Flow, ()>>,
}
#[derive(Debug)]
pub struct ConnectionTable {
inner: Arc<Mutex<ConnectionTableInner>>,
registry: VeilidComponentRegistry,
inner: Mutex<ConnectionTableInner>,
}
impl_veilid_component_registry_accessor!(ConnectionTable);
impl ConnectionTable {
pub fn new(config: VeilidConfig, address_filter: AddressFilter) -> Self {
pub fn new(registry: VeilidComponentRegistry) -> Self {
let config = registry.config();
let max_connections = {
let c = config.get();
vec![
@ -64,7 +67,8 @@ impl ConnectionTable {
]
};
Self {
inner: Arc::new(Mutex::new(ConnectionTableInner {
registry,
inner: Mutex::new(ConnectionTableInner {
conn_by_id: max_connections
.iter()
.map(|_| LruCache::new_unbounded())
@ -72,13 +76,12 @@ impl ConnectionTable {
protocol_index_by_id: BTreeMap::new(),
id_by_flow: BTreeMap::new(),
ids_by_remote: BTreeMap::new(),
address_filter,
priority_flows: max_connections
.iter()
.map(|x| LruCache::new(x * PRIORITY_FLOW_PERCENTAGE / 100))
.collect(),
max_connections,
})),
}),
}
}
@ -168,6 +171,7 @@ impl ConnectionTable {
/// when it is getting full while adding a new connection.
/// Factored out into its own function for clarity.
fn lru_out_connection_inner(
&self,
inner: &mut ConnectionTableInner,
protocol_index: usize,
) -> Result<Option<NetworkConnection>, ()> {
@ -198,7 +202,7 @@ impl ConnectionTable {
lruk
};
let dead_conn = Self::remove_connection_records(inner, dead_k);
let dead_conn = self.remove_connection_records_inner(inner, dead_k);
Ok(Some(dead_conn))
}
@ -235,20 +239,21 @@ impl ConnectionTable {
// Filter by ip for connection limits
let ip_addr = flow.remote_address().ip_addr();
match inner.address_filter.add_connection(ip_addr) {
Ok(()) => {}
Err(e) => {
// Return the connection in the error to be disposed of
return Err(ConnectionTableAddError::address_filter(
network_connection,
e,
));
}
};
if let Err(e) = self
.network_manager()
.address_filter()
.add_connection(ip_addr)
{
// Return the connection in the error to be disposed of
return Err(ConnectionTableAddError::address_filter(
network_connection,
e,
));
}
// if we have reached the maximum number of connections per protocol type
// then drop the least recently used connection that is not protected or referenced
let out_conn = match Self::lru_out_connection_inner(&mut inner, protocol_index) {
let out_conn = match self.lru_out_connection_inner(&mut inner, protocol_index) {
Ok(v) => v,
Err(()) => {
return Err(ConnectionTableAddError::table_full(network_connection));
@ -437,7 +442,8 @@ impl ConnectionTable {
}
#[instrument(level = "trace", skip(inner), ret)]
fn remove_connection_records(
fn remove_connection_records_inner(
&self,
inner: &mut ConnectionTableInner,
id: NetworkConnectionId,
) -> NetworkConnection {
@ -462,8 +468,8 @@ impl ConnectionTable {
}
// address_filter
let ip_addr = remote.socket_addr().ip();
inner
.address_filter
self.network_manager()
.address_filter()
.remove_connection(ip_addr)
.expect("Inconsistency in connection table");
conn
@ -477,7 +483,7 @@ impl ConnectionTable {
if !inner.conn_by_id[protocol_index].contains_key(&id) {
return None;
}
let conn = Self::remove_connection_records(&mut inner, id);
let conn = self.remove_connection_records_inner(&mut inner, id);
Some(conn)
}

View File

@ -35,7 +35,7 @@ impl NetworkManager {
// Direct bootstrap request
#[instrument(level = "trace", target = "net", err, skip(self))]
pub async fn boot_request(&self, dial_info: DialInfo) -> EyreResult<Vec<Arc<PeerInfo>>> {
let timeout_ms = self.with_config(|c| c.network.rpc.timeout_ms);
let timeout_ms = self.config().with(|c| c.network.rpc.timeout_ms);
// Send boot magic to requested peer address
let data = BOOT_MAGIC.to_vec();

View File

@ -1,8 +1,8 @@
use crate::*;
use super::*;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod native;
#[cfg(target_arch = "wasm32")]
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
mod wasm;
mod address_check;
@ -36,16 +36,15 @@ use connection_handle::*;
use crypto::*;
use futures_util::stream::FuturesUnordered;
use hashlink::LruCache;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
use native::*;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
pub use native::{MAX_CAPABILITIES, PUBLIC_INTERNET_CAPABILITIES};
use routing_table::*;
use rpc_processor::*;
use storage_manager::*;
#[cfg(target_arch = "wasm32")]
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
use wasm::*;
#[cfg(target_arch = "wasm32")]
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
pub use wasm::{/* LOCAL_NETWORK_CAPABILITIES, */ MAX_CAPABILITIES, PUBLIC_INTERNET_CAPABILITIES,};
////////////////////////////////////////////////////////////////////////////////////////
@ -65,7 +64,6 @@ pub const HOLE_PUNCH_DELAY_MS: u32 = 100;
struct NetworkComponents {
net: Network,
connection_manager: ConnectionManager,
rpc_processor: RPCProcessor,
receipt_manager: ReceiptManager,
}
@ -119,45 +117,74 @@ enum SendDataToExistingFlowResult {
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum StartupDisposition {
Success,
#[cfg_attr(target_arch = "wasm32", expect(dead_code))]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), expect(dead_code))]
BindRetry,
}
#[derive(Debug, Clone)]
pub struct NetworkManagerStartupContext {
pub startup_lock: Arc<StartupLock>,
}
impl NetworkManagerStartupContext {
pub fn new() -> Self {
Self {
startup_lock: Arc::new(StartupLock::new()),
}
}
}
impl Default for NetworkManagerStartupContext {
fn default() -> Self {
Self::new()
}
}
// The mutable state of the network manager
#[derive(Debug)]
struct NetworkManagerInner {
stats: NetworkManagerStats,
client_allowlist: LruCache<TypedKey, ClientAllowlistEntry>,
node_contact_method_cache: LruCache<NodeContactMethodCacheKey, NodeContactMethod>,
address_check: Option<AddressCheck>,
peer_info_change_subscription: Option<EventBusSubscription>,
socket_address_change_subscription: Option<EventBusSubscription>,
}
struct NetworkManagerUnlockedInner {
// Handles
event_bus: EventBus,
config: VeilidConfig,
storage_manager: StorageManager,
table_store: TableStore,
#[cfg(feature = "unstable-blockstore")]
block_store: BlockStore,
crypto: Crypto,
pub(crate) struct NetworkManager {
registry: VeilidComponentRegistry,
inner: Mutex<NetworkManagerInner>,
// Address filter
address_filter: AddressFilter,
// Accessors
routing_table: RwLock<Option<RoutingTable>>,
address_filter: RwLock<Option<AddressFilter>>,
components: RwLock<Option<NetworkComponents>>,
update_callback: RwLock<Option<UpdateCallback>>,
// Background processes
rolling_transfers_task: TickTask<EyreReport>,
address_filter_task: TickTask<EyreReport>,
// Network Key
// Network key
network_key: Option<SharedSecret>,
// Startup Lock
startup_lock: StartupLock,
// Startup context
startup_context: NetworkManagerStartupContext,
}
#[derive(Clone)]
pub(crate) struct NetworkManager {
inner: Arc<Mutex<NetworkManagerInner>>,
unlocked_inner: Arc<NetworkManagerUnlockedInner>,
impl_veilid_component!(NetworkManager);
impl fmt::Debug for NetworkManager {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NetworkManager")
//.field("registry", &self.registry)
.field("inner", &self.inner)
.field("address_filter", &self.address_filter)
// .field("components", &self.components)
// .field("rolling_transfers_task", &self.rolling_transfers_task)
// .field("address_filter_task", &self.address_filter_task)
.field("network_key", &self.network_key)
.field("startup_context", &self.startup_context)
.finish()
}
}
impl NetworkManager {
@ -167,52 +194,20 @@ impl NetworkManager {
client_allowlist: LruCache::new_unbounded(),
node_contact_method_cache: LruCache::new(NODE_CONTACT_METHOD_CACHE_SIZE),
address_check: None,
}
}
fn new_unlocked_inner(
event_bus: EventBus,
config: VeilidConfig,
storage_manager: StorageManager,
table_store: TableStore,
#[cfg(feature = "unstable-blockstore")] block_store: BlockStore,
crypto: Crypto,
network_key: Option<SharedSecret>,
) -> NetworkManagerUnlockedInner {
NetworkManagerUnlockedInner {
event_bus,
config: config.clone(),
storage_manager,
table_store,
#[cfg(feature = "unstable-blockstore")]
block_store,
crypto,
address_filter: RwLock::new(None),
routing_table: RwLock::new(None),
components: RwLock::new(None),
update_callback: RwLock::new(None),
rolling_transfers_task: TickTask::new(
"rolling_transfers_task",
ROLLING_TRANSFERS_INTERVAL_SECS,
),
address_filter_task: TickTask::new(
"address_filter_task",
ADDRESS_FILTER_TASK_INTERVAL_SECS,
),
network_key,
startup_lock: StartupLock::new(),
peer_info_change_subscription: None,
socket_address_change_subscription: None,
}
}
pub fn new(
event_bus: EventBus,
config: VeilidConfig,
storage_manager: StorageManager,
table_store: TableStore,
#[cfg(feature = "unstable-blockstore")] block_store: BlockStore,
crypto: Crypto,
registry: VeilidComponentRegistry,
startup_context: NetworkManagerStartupContext,
) -> Self {
// Make the network key
let network_key = {
let config = registry.config();
let crypto = registry.crypto();
let c = config.get();
let network_key_password = c.network.network_key_password.clone();
let network_key = if let Some(network_key_password) = network_key_password {
@ -238,110 +233,52 @@ impl NetworkManager {
network_key
};
let inner = Self::new_inner();
let address_filter = AddressFilter::new(registry.clone());
let this = Self {
inner: Arc::new(Mutex::new(Self::new_inner())),
unlocked_inner: Arc::new(Self::new_unlocked_inner(
event_bus,
config,
storage_manager,
table_store,
#[cfg(feature = "unstable-blockstore")]
block_store,
crypto,
network_key,
)),
registry,
inner: Mutex::new(inner),
address_filter,
components: RwLock::new(None),
rolling_transfers_task: TickTask::new(
"rolling_transfers_task",
ROLLING_TRANSFERS_INTERVAL_SECS,
),
address_filter_task: TickTask::new(
"address_filter_task",
ADDRESS_FILTER_TASK_INTERVAL_SECS,
),
network_key,
startup_context,
};
this.setup_tasks();
this
}
pub fn event_bus(&self) -> EventBus {
self.unlocked_inner.event_bus.clone()
}
pub fn config(&self) -> VeilidConfig {
self.unlocked_inner.config.clone()
}
pub fn with_config<F, R>(&self, f: F) -> R
where
F: FnOnce(&VeilidConfigInner) -> R,
{
f(&self.unlocked_inner.config.get())
}
pub fn storage_manager(&self) -> StorageManager {
self.unlocked_inner.storage_manager.clone()
}
pub fn table_store(&self) -> TableStore {
self.unlocked_inner.table_store.clone()
}
#[cfg(feature = "unstable-blockstore")]
pub fn block_store(&self) -> BlockStore {
self.unlocked_inner.block_store.clone()
}
pub fn crypto(&self) -> Crypto {
self.unlocked_inner.crypto.clone()
}
pub fn address_filter(&self) -> AddressFilter {
self.unlocked_inner
.address_filter
.read()
.as_ref()
.unwrap()
.clone()
}
pub fn routing_table(&self) -> RoutingTable {
self.unlocked_inner
.routing_table
.read()
.as_ref()
.unwrap()
.clone()
pub fn address_filter(&self) -> &AddressFilter {
&self.address_filter
}
fn net(&self) -> Network {
self.unlocked_inner
.components
.read()
.as_ref()
.unwrap()
.net
.clone()
self.components.read().as_ref().unwrap().net.clone()
}
fn opt_net(&self) -> Option<Network> {
self.unlocked_inner
.components
.read()
.as_ref()
.map(|x| x.net.clone())
self.components.read().as_ref().map(|x| x.net.clone())
}
fn receipt_manager(&self) -> ReceiptManager {
self.unlocked_inner
.components
self.components
.read()
.as_ref()
.unwrap()
.receipt_manager
.clone()
}
pub fn rpc_processor(&self) -> RPCProcessor {
self.unlocked_inner
.components
.read()
.as_ref()
.unwrap()
.rpc_processor
.clone()
}
pub fn opt_rpc_processor(&self) -> Option<RPCProcessor> {
self.unlocked_inner
.components
.read()
.as_ref()
.map(|x| x.rpc_processor.clone())
}
pub fn connection_manager(&self) -> ConnectionManager {
self.unlocked_inner
.components
self.components
.read()
.as_ref()
.unwrap()
@ -349,103 +286,48 @@ impl NetworkManager {
.clone()
}
pub fn opt_connection_manager(&self) -> Option<ConnectionManager> {
self.unlocked_inner
.components
self.components
.read()
.as_ref()
.map(|x| x.connection_manager.clone())
}
pub fn update_callback(&self) -> UpdateCallback {
self.unlocked_inner
.update_callback
.read()
.as_ref()
.unwrap()
.clone()
}
#[instrument(level = "debug", skip_all, err)]
pub async fn init(&self, update_callback: UpdateCallback) -> EyreResult<()> {
let routing_table = RoutingTable::new(self.clone());
routing_table.init().await?;
let address_filter = AddressFilter::new(self.config(), routing_table.clone());
*self.unlocked_inner.routing_table.write() = Some(routing_table.clone());
*self.unlocked_inner.address_filter.write() = Some(address_filter);
*self.unlocked_inner.update_callback.write() = Some(update_callback);
// Register event handlers
let this = self.clone();
self.event_bus().subscribe(move |evt| {
let this = this.clone();
Box::pin(async move {
this.peer_info_change_event_handler(evt);
})
});
let this = self.clone();
self.event_bus().subscribe(move |evt| {
let this = this.clone();
Box::pin(async move {
this.socket_address_change_event_handler(evt);
})
});
async fn init_async(&self) -> EyreResult<()> {
Ok(())
}
#[instrument(level = "debug", skip_all)]
pub async fn terminate(&self) {
let routing_table = self.unlocked_inner.routing_table.write().take();
if let Some(routing_table) = routing_table {
routing_table.terminate().await;
}
*self.unlocked_inner.update_callback.write() = None;
async fn post_init_async(&self) -> EyreResult<()> {
Ok(())
}
async fn pre_terminate_async(&self) {}
#[instrument(level = "debug", skip_all)]
async fn terminate_async(&self) {}
#[instrument(level = "debug", skip_all, err)]
pub async fn internal_startup(&self) -> EyreResult<StartupDisposition> {
if self.unlocked_inner.components.read().is_some() {
if self.components.read().is_some() {
log_net!(debug "NetworkManager::internal_startup already started");
return Ok(StartupDisposition::Success);
}
// Clean address filter for things that should not be persistent
self.address_filter().restart();
self.address_filter.restart();
// Create network components
let connection_manager = ConnectionManager::new(self.clone());
let net = Network::new(
self.clone(),
self.routing_table(),
connection_manager.clone(),
);
let rpc_processor = RPCProcessor::new(
self.clone(),
self.unlocked_inner
.update_callback
.read()
.as_ref()
.unwrap()
.clone(),
);
let connection_manager = ConnectionManager::new(self.registry());
let net = Network::new(self.registry());
let receipt_manager = ReceiptManager::new();
*self.unlocked_inner.components.write() = Some(NetworkComponents {
*self.components.write() = Some(NetworkComponents {
net: net.clone(),
connection_manager: connection_manager.clone(),
rpc_processor: rpc_processor.clone(),
receipt_manager: receipt_manager.clone(),
});
// Start network components
connection_manager.startup().await?;
match net.startup().await? {
StartupDisposition::Success => {}
StartupDisposition::BindRetry => {
return Ok(StartupDisposition::BindRetry);
}
}
let (detect_address_changes, ip6_prefix_size) = self.with_config(|c| {
let (detect_address_changes, ip6_prefix_size) = self.config().with(|c| {
(
c.network.detect_address_changes,
c.network.max_connections_per_ip6_prefix_size as usize,
@ -456,9 +338,30 @@ impl NetworkManager {
ip6_prefix_size,
};
let address_check = AddressCheck::new(address_check_config, net.clone());
self.inner.lock().address_check = Some(address_check);
rpc_processor.startup().await?;
// Register event handlers
let peer_info_change_subscription =
impl_subscribe_event_bus!(self, Self, peer_info_change_event_handler);
let socket_address_change_subscription =
impl_subscribe_event_bus!(self, Self, socket_address_change_event_handler);
{
let mut inner = self.inner.lock();
inner.address_check = Some(address_check);
inner.peer_info_change_subscription = Some(peer_info_change_subscription);
inner.socket_address_change_subscription = Some(socket_address_change_subscription);
}
// Start network components
connection_manager.startup().await?;
match net.startup().await? {
StartupDisposition::Success => {}
StartupDisposition::BindRetry => {
return Ok(StartupDisposition::BindRetry);
}
}
receipt_manager.startup().await?;
log_net!("NetworkManager::internal_startup end");
@ -468,15 +371,11 @@ impl NetworkManager {
#[instrument(level = "debug", skip_all, err)]
pub async fn startup(&self) -> EyreResult<StartupDisposition> {
let guard = self.unlocked_inner.startup_lock.startup()?;
let guard = self.startup_context.startup_lock.startup()?;
match self.internal_startup().await {
Ok(StartupDisposition::Success) => {
guard.success();
// Inform api clients that things have changed
self.send_network_update();
Ok(StartupDisposition::Success)
}
Ok(StartupDisposition::BindRetry) => {
@ -492,25 +391,30 @@ impl NetworkManager {
#[instrument(level = "debug", skip_all)]
async fn shutdown_internal(&self) {
// Cancel all tasks
self.cancel_tasks().await;
// Shutdown address check
self.inner.lock().address_check = Option::<AddressCheck>::None;
// Shutdown event bus subscriptions and address check
{
let mut inner = self.inner.lock();
if let Some(sub) = inner.socket_address_change_subscription.take() {
self.event_bus().unsubscribe(sub);
}
if let Some(sub) = inner.peer_info_change_subscription.take() {
self.event_bus().unsubscribe(sub);
}
inner.address_check = None;
}
// Shutdown network components if they started up
log_net!(debug "shutting down network components");
{
let components = self.unlocked_inner.components.read().clone();
let components = self.components.read().clone();
if let Some(components) = components {
components.net.shutdown().await;
components.rpc_processor.shutdown().await;
components.receipt_manager.shutdown().await;
components.connection_manager.shutdown().await;
}
}
*self.unlocked_inner.components.write() = None;
*self.components.write() = None;
// reset the state
log_net!(debug "resetting network manager state");
@ -521,21 +425,22 @@ impl NetworkManager {
#[instrument(level = "debug", skip_all)]
pub async fn shutdown(&self) {
log_net!(debug "starting network manager shutdown");
// Cancel all tasks
log_net!(debug "stopping network manager tasks");
self.cancel_tasks().await;
let Ok(guard) = self.unlocked_inner.startup_lock.shutdown().await else {
log_net!(debug "network manager is already shut down");
return;
};
// Proceed with shutdown
log_net!(debug "starting network manager shutdown");
let guard = self
.startup_context
.startup_lock
.shutdown()
.await
.expect("should be started up");
self.shutdown_internal().await;
guard.success();
// send update
log_net!(debug "sending network state update to api clients");
self.send_network_update();
log_net!(debug "finished network manager shutdown");
}
@ -568,7 +473,9 @@ impl NetworkManager {
}
pub fn purge_client_allowlist(&self) {
let timeout_ms = self.with_config(|c| c.network.client_allowlist_timeout_ms);
let timeout_ms = self
.config()
.with(|c| c.network.client_allowlist_timeout_ms);
let mut inner = self.inner.lock();
let cutoff_timestamp =
Timestamp::now() - TimestampDuration::new((timeout_ms as u64) * 1000u64);
@ -607,14 +514,15 @@ impl NetworkManager {
extra_data: D,
callback: impl ReceiptCallback,
) -> EyreResult<Vec<u8>> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else {
let Ok(_guard) = self.startup_context.startup_lock.enter() else {
bail!("network is not started");
};
let receipt_manager = self.receipt_manager();
let routing_table = self.routing_table();
let crypto = self.crypto();
// Generate receipt and serialized form to return
let vcrypto = self.crypto().best();
let vcrypto = crypto.best();
let nonce = vcrypto.random_nonce();
let node_id = routing_table.node_id(vcrypto.kind());
@ -628,7 +536,7 @@ impl NetworkManager {
extra_data,
)?;
let out = receipt
.to_signed_data(self.crypto(), &node_id_secret)
.to_signed_data(&crypto, &node_id_secret)
.wrap_err("failed to generate signed receipt")?;
// Record the receipt for later
@ -645,15 +553,16 @@ impl NetworkManager {
expiration_us: TimestampDuration,
extra_data: D,
) -> EyreResult<(Vec<u8>, EventualValueFuture<ReceiptEvent>)> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else {
let Ok(_guard) = self.startup_context.startup_lock.enter() else {
bail!("network is not started");
};
let receipt_manager = self.receipt_manager();
let routing_table = self.routing_table();
let crypto = self.crypto();
// Generate receipt and serialized form to return
let vcrypto = self.crypto().best();
let vcrypto = crypto.best();
let nonce = vcrypto.random_nonce();
let node_id = routing_table.node_id(vcrypto.kind());
@ -667,7 +576,7 @@ impl NetworkManager {
extra_data,
)?;
let out = receipt
.to_signed_data(self.crypto(), &node_id_secret)
.to_signed_data(&crypto, &node_id_secret)
.wrap_err("failed to generate signed receipt")?;
// Record the receipt for later
@ -685,13 +594,14 @@ impl NetworkManager {
&self,
receipt_data: R,
) -> NetworkResult<()> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else {
let Ok(_guard) = self.startup_context.startup_lock.enter() else {
return NetworkResult::service_unavailable("network is not started");
};
let receipt_manager = self.receipt_manager();
let crypto = self.crypto();
let receipt = match Receipt::from_signed_data(self.crypto(), receipt_data.as_ref()) {
let receipt = match Receipt::from_signed_data(&crypto, receipt_data.as_ref()) {
Err(e) => {
return NetworkResult::invalid_message(e.to_string());
}
@ -710,13 +620,14 @@ impl NetworkManager {
receipt_data: R,
inbound_noderef: FilteredNodeRef,
) -> NetworkResult<()> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else {
let Ok(_guard) = self.startup_context.startup_lock.enter() else {
return NetworkResult::service_unavailable("network is not started");
};
let receipt_manager = self.receipt_manager();
let crypto = self.crypto();
let receipt = match Receipt::from_signed_data(self.crypto(), receipt_data.as_ref()) {
let receipt = match Receipt::from_signed_data(&crypto, receipt_data.as_ref()) {
Err(e) => {
return NetworkResult::invalid_message(e.to_string());
}
@ -734,13 +645,14 @@ impl NetworkManager {
&self,
receipt_data: R,
) -> NetworkResult<()> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else {
let Ok(_guard) = self.startup_context.startup_lock.enter() else {
return NetworkResult::service_unavailable("network is not started");
};
let receipt_manager = self.receipt_manager();
let crypto = self.crypto();
let receipt = match Receipt::from_signed_data(self.crypto(), receipt_data.as_ref()) {
let receipt = match Receipt::from_signed_data(&crypto, receipt_data.as_ref()) {
Err(e) => {
return NetworkResult::invalid_message(e.to_string());
}
@ -759,13 +671,14 @@ impl NetworkManager {
receipt_data: R,
private_route: PublicKey,
) -> NetworkResult<()> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else {
let Ok(_guard) = self.startup_context.startup_lock.enter() else {
return NetworkResult::service_unavailable("network is not started");
};
let receipt_manager = self.receipt_manager();
let crypto = self.crypto();
let receipt = match Receipt::from_signed_data(self.crypto(), receipt_data.as_ref()) {
let receipt = match Receipt::from_signed_data(&crypto, receipt_data.as_ref()) {
Err(e) => {
return NetworkResult::invalid_message(e.to_string());
}
@ -784,7 +697,7 @@ impl NetworkManager {
signal_flow: Flow,
signal_info: SignalInfo,
) -> EyreResult<NetworkResult<()>> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else {
let Ok(_guard) = self.startup_context.startup_lock.enter() else {
return Ok(NetworkResult::service_unavailable("network is not started"));
};
@ -884,7 +797,8 @@ impl NetworkManager {
) -> EyreResult<Vec<u8>> {
// DH to get encryption key
let routing_table = self.routing_table();
let Some(vcrypto) = self.crypto().get(dest_node_id.kind) else {
let crypto = self.crypto();
let Some(vcrypto) = crypto.get(dest_node_id.kind) else {
bail!("should not have a destination with incompatible crypto here");
};
@ -905,12 +819,7 @@ impl NetworkManager {
dest_node_id.value,
);
envelope
.to_encrypted_data(
self.crypto(),
body.as_ref(),
&node_id_secret,
&self.unlocked_inner.network_key,
)
.to_encrypted_data(&crypto, body.as_ref(), &node_id_secret, &self.network_key)
.wrap_err("envelope failed to encode")
}
@ -925,7 +834,7 @@ impl NetworkManager {
destination_node_ref: Option<NodeRef>,
body: B,
) -> EyreResult<NetworkResult<SendDataMethod>> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else {
let Ok(_guard) = self.startup_context.startup_lock.enter() else {
return Ok(NetworkResult::no_connection_other("network is not started"));
};
@ -966,7 +875,7 @@ impl NetworkManager {
dial_info: DialInfo,
rcpt_data: Vec<u8>,
) -> EyreResult<()> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else {
let Ok(_guard) = self.startup_context.startup_lock.enter() else {
log_net!(debug "not sending out-of-band receipt to {} because network is stopped", dial_info);
return Ok(());
};
@ -993,7 +902,7 @@ impl NetworkManager {
// and passes it to the RPC handler
#[instrument(level = "trace", target = "net", skip_all)]
async fn on_recv_envelope(&self, data: &mut [u8], flow: Flow) -> EyreResult<bool> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else {
let Ok(_guard) = self.startup_context.startup_lock.enter() else {
return Ok(false);
};
@ -1043,21 +952,20 @@ impl NetworkManager {
}
// Decode envelope header (may fail signature validation)
let envelope =
match Envelope::from_signed_data(self.crypto(), data, &self.unlocked_inner.network_key)
{
Ok(v) => v,
Err(e) => {
log_net!(debug "envelope failed to decode: {}", e);
// safe to punish here because relays also check here to ensure they arent forwarding things that don't decode
self.address_filter()
.punish_ip_addr(remote_addr, PunishmentReason::FailedToDecodeEnvelope);
return Ok(false);
}
};
let crypto = self.crypto();
let envelope = match Envelope::from_signed_data(&crypto, data, &self.network_key) {
Ok(v) => v,
Err(e) => {
log_net!(debug "envelope failed to decode: {}", e);
// safe to punish here because relays also check here to ensure they arent forwarding things that don't decode
self.address_filter()
.punish_ip_addr(remote_addr, PunishmentReason::FailedToDecodeEnvelope);
return Ok(false);
}
};
// Get timestamp range
let (tsbehind, tsahead) = self.with_config(|c| {
let (tsbehind, tsahead) = self.config().with(|c| {
(
c.network
.rpc
@ -1136,7 +1044,10 @@ impl NetworkManager {
// which only performs a lightweight lookup before passing the packet back out
// If our node has the relay capability disabled, we should not be asked to relay
if self.with_config(|c| c.capabilities.disable.contains(&CAP_RELAY)) {
if self
.config()
.with(|c| c.capabilities.disable.contains(&CAP_RELAY))
{
log_net!(debug "node has relay capability disabled, dropping relayed envelope from {} to {}", sender_id, recipient_id);
return Ok(false);
}
@ -1191,12 +1102,8 @@ impl NetworkManager {
let node_id_secret = routing_table.node_id_secret_key(envelope.get_crypto_kind());
// Decrypt the envelope body
let body = match envelope.decrypt_body(
self.crypto(),
data,
&node_id_secret,
&self.unlocked_inner.network_key,
) {
let crypto = self.crypto();
let body = match envelope.decrypt_body(&crypto, data, &node_id_secret, &self.network_key) {
Ok(v) => v,
Err(e) => {
log_net!(debug "failed to decrypt envelope body: {}", e);

View File

@ -2,6 +2,7 @@
/// Also performs UPNP/IGD mapping if enabled and possible
use super::*;
use futures_util::stream::FuturesUnordered;
use igd_manager::{IGDAddressType, IGDProtocolType};
const PORT_MAP_VALIDATE_TRY_COUNT: usize = 3;
const PORT_MAP_VALIDATE_DELAY_MS: u32 = 500;
@ -42,9 +43,7 @@ struct DiscoveryContextInner {
external_info: Vec<ExternalInfo>,
}
struct DiscoveryContextUnlockedInner {
routing_table: RoutingTable,
net: Network,
pub(super) struct DiscoveryContextUnlockedInner {
config: DiscoveryContextConfig,
// per-protocol
@ -53,25 +52,30 @@ struct DiscoveryContextUnlockedInner {
#[derive(Clone)]
pub(super) struct DiscoveryContext {
registry: VeilidComponentRegistry,
unlocked_inner: Arc<DiscoveryContextUnlockedInner>,
inner: Arc<Mutex<DiscoveryContextInner>>,
}
impl_veilid_component_registry_accessor!(DiscoveryContext);
impl core::ops::Deref for DiscoveryContext {
type Target = DiscoveryContextUnlockedInner;
fn deref(&self) -> &Self::Target {
&self.unlocked_inner
}
}
impl DiscoveryContext {
pub fn new(routing_table: RoutingTable, net: Network, config: DiscoveryContextConfig) -> Self {
let intf_addrs = Self::get_local_addresses(
routing_table.clone(),
config.protocol_type,
config.address_type,
);
pub fn new(registry: VeilidComponentRegistry, config: DiscoveryContextConfig) -> Self {
let routing_table = registry.routing_table();
let intf_addrs =
Self::get_local_addresses(&routing_table, config.protocol_type, config.address_type);
Self {
unlocked_inner: Arc::new(DiscoveryContextUnlockedInner {
routing_table,
net,
config,
intf_addrs,
}),
registry,
unlocked_inner: Arc::new(DiscoveryContextUnlockedInner { config, intf_addrs }),
inner: Arc::new(Mutex::new(DiscoveryContextInner {
external_info: Vec::new(),
})),
@ -84,7 +88,7 @@ impl DiscoveryContext {
// This pulls the already-detected local interface dial info from the routing table
#[instrument(level = "trace", skip(routing_table), ret)]
fn get_local_addresses(
routing_table: RoutingTable,
routing_table: &RoutingTable,
protocol_type: ProtocolType,
address_type: AddressType,
) -> Vec<SocketAddress> {
@ -108,7 +112,7 @@ impl DiscoveryContext {
// This is done over the normal port using RPC
#[instrument(level = "trace", skip(self), ret)]
async fn request_public_address(&self, node_ref: FilteredNodeRef) -> Option<SocketAddress> {
let rpc = self.unlocked_inner.routing_table.rpc_processor();
let rpc = self.rpc_processor();
let res = network_result_value_or_log!(match rpc.rpc_call_status(Destination::direct(node_ref.clone())).await {
Ok(v) => v,
@ -136,16 +140,14 @@ impl DiscoveryContext {
// This is done over the normal port using RPC
#[instrument(level = "trace", skip(self), ret)]
async fn discover_external_addresses(&self) -> bool {
let node_count = {
let config = self.unlocked_inner.routing_table.network_manager().config();
let c = config.get();
c.network.dht.max_find_node_count as usize
};
let node_count = self
.config()
.with(|c| c.network.dht.max_find_node_count as usize);
let routing_domain = RoutingDomain::PublicInternet;
let protocol_type = self.unlocked_inner.config.protocol_type;
let address_type = self.unlocked_inner.config.address_type;
let port = self.unlocked_inner.config.port;
let protocol_type = self.config.protocol_type;
let address_type = self.config.address_type;
let port = self.config.port;
// Build an filter that matches our protocol and address type
// and excludes relayed nodes so we can get an accurate external address
@ -187,10 +189,11 @@ impl DiscoveryContext {
]);
// Find public nodes matching this filter
let nodes = self
.unlocked_inner
.routing_table
.find_fast_non_local_nodes_filtered(routing_domain, node_count, filters);
let nodes = self.routing_table().find_fast_non_local_nodes_filtered(
routing_domain,
node_count,
filters,
);
if nodes.is_empty() {
log_net!(debug
"no external address detection peers of type {:?}:{:?}",
@ -212,8 +215,8 @@ impl DiscoveryContext {
async move {
if let Some(address) = this.request_public_address(node.clone()).await {
let dial_info = this
.unlocked_inner
.net
.network_manager()
.net()
.make_dial_info(address, protocol_type);
return Some(ExternalInfo {
dial_info,
@ -297,10 +300,9 @@ impl DiscoveryContext {
dial_info: DialInfo,
redirect: bool,
) -> bool {
let rpc_processor = self.unlocked_inner.routing_table.rpc_processor();
// ask the node to send us a dial info validation receipt
match rpc_processor
match self
.rpc_processor()
.rpc_call_validate_dial_info(node_ref.clone(), dial_info, redirect)
.await
{
@ -314,14 +316,22 @@ impl DiscoveryContext {
#[instrument(level = "trace", skip(self), ret)]
async fn try_upnp_port_mapping(&self) -> Option<DialInfo> {
let protocol_type = self.unlocked_inner.config.protocol_type;
let address_type = self.unlocked_inner.config.address_type;
let local_port = self.unlocked_inner.config.port;
let protocol_type = self.config.protocol_type;
let address_type = self.config.address_type;
let local_port = self.config.port;
let igd_protocol_type = match protocol_type.low_level_protocol_type() {
LowLevelProtocolType::UDP => IGDProtocolType::UDP,
LowLevelProtocolType::TCP => IGDProtocolType::TCP,
};
let igd_address_type = match address_type {
AddressType::IPV6 => IGDAddressType::IPV6,
AddressType::IPV4 => IGDAddressType::IPV4,
};
let low_level_protocol_type = protocol_type.low_level_protocol_type();
let external_1 = self.inner.lock().external_info.first().unwrap().clone();
let igd_manager = self.unlocked_inner.net.unlocked_inner.igd_manager.clone();
let igd_manager = self.network_manager().net().igd_manager.clone();
let mut tries = 0;
loop {
tries += 1;
@ -329,15 +339,15 @@ impl DiscoveryContext {
// Attempt a port mapping. If this doesn't succeed, it's not going to
let mapped_external_address = igd_manager
.map_any_port(
low_level_protocol_type,
address_type,
igd_protocol_type,
igd_address_type,
local_port,
Some(external_1.address.ip_addr()),
)
.await?;
// Make dial info from the port mapping
let external_mapped_dial_info = self.unlocked_inner.net.make_dial_info(
let external_mapped_dial_info = self.network_manager().net().make_dial_info(
SocketAddress::from_socket_addr(mapped_external_address),
protocol_type,
);
@ -361,10 +371,7 @@ impl DiscoveryContext {
if validate_tries != PORT_MAP_VALIDATE_TRY_COUNT {
log_net!(debug "UPNP port mapping succeeded but port {}/{} is still unreachable.\nretrying\n",
local_port, match low_level_protocol_type {
LowLevelProtocolType::UDP => "udp",
LowLevelProtocolType::TCP => "tcp",
});
local_port, igd_protocol_type);
sleep(PORT_MAP_VALIDATE_DELAY_MS).await
} else {
break;
@ -374,18 +381,15 @@ impl DiscoveryContext {
// Release the mapping if we're still unreachable
let _ = igd_manager
.unmap_port(
low_level_protocol_type,
address_type,
igd_protocol_type,
igd_address_type,
external_1.address.port(),
)
.await;
if tries == PORT_MAP_TRY_COUNT {
warn!("UPNP port mapping succeeded but port {}/{} is still unreachable.\nYou may need to add a local firewall allowed port on this machine.\n",
local_port, match low_level_protocol_type {
LowLevelProtocolType::UDP => "udp",
LowLevelProtocolType::TCP => "tcp",
}
local_port, igd_protocol_type
);
break;
}
@ -413,7 +417,7 @@ impl DiscoveryContext {
{
// Add public dial info with Direct dialinfo class
Some(DetectionResult {
config: this.unlocked_inner.config,
config: this.config,
ddi: DetectedDialInfo::Detected(DialInfoDetail {
dial_info: external_1.dial_info.clone(),
class: DialInfoClass::Direct,
@ -423,7 +427,7 @@ impl DiscoveryContext {
} else {
// Add public dial info with Blocked dialinfo class
Some(DetectionResult {
config: this.unlocked_inner.config,
config: this.config,
ddi: DetectedDialInfo::Detected(DialInfoDetail {
dial_info: external_1.dial_info.clone(),
class: DialInfoClass::Blocked,
@ -445,7 +449,7 @@ impl DiscoveryContext {
let inner = self.inner.lock();
inner.external_info.clone()
};
let local_port = self.unlocked_inner.config.port;
let local_port = self.config.port;
// Get the external dial info histogram for our use here
let mut external_info_addr_port_hist = HashMap::<SocketAddress, usize>::new();
@ -501,7 +505,7 @@ impl DiscoveryContext {
let do_symmetric_nat_fut: SendPinBoxFuture<Option<DetectionResult>> =
Box::pin(async move {
Some(DetectionResult {
config: this.unlocked_inner.config,
config: this.config,
ddi: DetectedDialInfo::SymmetricNAT,
external_address_types,
})
@ -535,7 +539,7 @@ impl DiscoveryContext {
{
// Add public dial info with Direct dialinfo class
return Some(DetectionResult {
config: c_this.unlocked_inner.config,
config: c_this.config,
ddi: DetectedDialInfo::Detected(DialInfoDetail {
dial_info: external_1_dial_info_with_local_port,
class: DialInfoClass::Direct,
@ -558,10 +562,7 @@ impl DiscoveryContext {
///////////
let this = self.clone();
let do_nat_detect_fut: SendPinBoxFuture<Option<DetectionResult>> = Box::pin(async move {
let mut retry_count = {
let c = this.unlocked_inner.net.config.get();
c.network.restricted_nat_retries
};
let mut retry_count = this.config().with(|c| c.network.restricted_nat_retries);
// Loop for restricted NAT retries
loop {
@ -585,7 +586,7 @@ impl DiscoveryContext {
// Add public dial info with full cone NAT network class
return Some(DetectionResult {
config: c_this.unlocked_inner.config,
config: c_this.config,
ddi: DetectedDialInfo::Detected(DialInfoDetail {
dial_info: c_external_1.dial_info,
class: DialInfoClass::FullConeNAT,
@ -620,7 +621,7 @@ impl DiscoveryContext {
{
// Got a reply from a non-default port, which means we're only address restricted
return Some(DetectionResult {
config: c_this.unlocked_inner.config,
config: c_this.config,
ddi: DetectedDialInfo::Detected(DialInfoDetail {
dial_info: c_external_1.dial_info.clone(),
class: DialInfoClass::AddressRestrictedNAT,
@ -632,7 +633,7 @@ impl DiscoveryContext {
}
// Didn't get a reply from a non-default port, which means we are also port restricted
Some(DetectionResult {
config: c_this.unlocked_inner.config,
config: c_this.config,
ddi: DetectedDialInfo::Detected(DialInfoDetail {
dial_info: c_external_1.dial_info.clone(),
class: DialInfoClass::PortRestrictedNAT,
@ -678,10 +679,7 @@ impl DiscoveryContext {
&self,
unord: &mut FuturesUnordered<SendPinBoxFuture<Option<DetectionResult>>>,
) {
let enable_upnp = {
let c = self.unlocked_inner.net.config.get();
c.network.upnp
};
let enable_upnp = self.config().with(|c| c.network.upnp);
// Do this right away because it's fast and every detection is going to need it
// Get our external addresses from two fast nodes
@ -701,7 +699,7 @@ impl DiscoveryContext {
if let Some(external_mapped_dial_info) = this.try_upnp_port_mapping().await {
// Got a port mapping, let's use it
return Some(DetectionResult {
config: this.unlocked_inner.config,
config: this.config,
ddi: DetectedDialInfo::Detected(DialInfoDetail {
dial_info: external_mapped_dial_info.clone(),
class: DialInfoClass::Mapped,
@ -725,12 +723,7 @@ impl DiscoveryContext {
.lock()
.external_info
.iter()
.find_map(|ei| {
self.unlocked_inner
.intf_addrs
.contains(&ei.address)
.then_some(true)
})
.find_map(|ei| self.intf_addrs.contains(&ei.address).then_some(true))
.unwrap_or_default();
if local_address_in_external_info {

View File

@ -5,13 +5,12 @@ use std::net::UdpSocket;
const UPNP_GATEWAY_DETECT_TIMEOUT_MS: u32 = 5_000;
const UPNP_MAPPING_LIFETIME_MS: u32 = 120_000;
const UPNP_MAPPING_ATTEMPTS: u32 = 3;
const UPNP_MAPPING_LIFETIME_US: TimestampDuration =
TimestampDuration::new(UPNP_MAPPING_LIFETIME_MS as u64 * 1000u64);
const UPNP_MAPPING_LIFETIME_US: u64 = UPNP_MAPPING_LIFETIME_MS as u64 * 1000u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct PortMapKey {
llpt: LowLevelProtocolType,
at: AddressType,
protocol_type: IGDProtocolType,
address_type: IGDAddressType,
local_port: u16,
}
@ -19,36 +18,67 @@ struct PortMapKey {
struct PortMapValue {
ext_ip: IpAddr,
mapped_port: u16,
timestamp: Timestamp,
renewal_lifetime: TimestampDuration,
timestamp: u64,
renewal_lifetime: u64,
renewal_attempts: u32,
}
struct IGDManagerInner {
local_ip_addrs: BTreeMap<AddressType, IpAddr>,
local_ip_addrs: BTreeMap<IGDAddressType, IpAddr>,
gateways: BTreeMap<IpAddr, Arc<Gateway>>,
port_maps: BTreeMap<PortMapKey, PortMapValue>,
}
#[derive(Clone)]
pub struct IGDManager {
config: VeilidConfig,
program_name: String,
inner: Arc<Mutex<IGDManagerInner>>,
}
fn convert_llpt(llpt: LowLevelProtocolType) -> PortMappingProtocol {
match llpt {
LowLevelProtocolType::UDP => PortMappingProtocol::UDP,
LowLevelProtocolType::TCP => PortMappingProtocol::TCP,
fn convert_protocol_type(igdpt: IGDProtocolType) -> PortMappingProtocol {
match igdpt {
IGDProtocolType::UDP => PortMappingProtocol::UDP,
IGDProtocolType::TCP => PortMappingProtocol::TCP,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum IGDAddressType {
IPV6,
IPV4,
}
impl fmt::Display for IGDAddressType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
IGDAddressType::IPV6 => write!(f, "IPV6"),
IGDAddressType::IPV4 => write!(f, "IPV4"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum IGDProtocolType {
UDP,
TCP,
}
impl fmt::Display for IGDProtocolType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
IGDProtocolType::UDP => write!(f, "UDP"),
IGDProtocolType::TCP => write!(f, "TCP"),
}
}
}
impl IGDManager {
//
/////////////////////////////////////////////////////////////////////
// Public Interface
pub fn new(config: VeilidConfig) -> Self {
pub fn new(program_name: String) -> Self {
Self {
config,
program_name,
inner: Arc::new(Mutex::new(IGDManagerInner {
local_ip_addrs: BTreeMap::new(),
gateways: BTreeMap::new(),
@ -58,10 +88,306 @@ impl IGDManager {
}
#[instrument(level = "trace", target = "net", skip_all)]
fn get_routed_local_ip_address(address_type: AddressType) -> Option<IpAddr> {
pub async fn unmap_port(
&self,
protocol_type: IGDProtocolType,
address_type: IGDAddressType,
mapped_port: u16,
) -> Option<()> {
let this = self.clone();
blocking_wrapper(
"igd unmap_port",
move || {
let mut inner = this.inner.lock();
// If we already have this port mapped, just return the existing portmap
let mut found = None;
for (pmk, pmv) in &inner.port_maps {
if pmk.protocol_type == protocol_type
&& pmk.address_type == address_type
&& pmv.mapped_port == mapped_port
{
found = Some(*pmk);
break;
}
}
let pmk = found?;
let _pmv = inner
.port_maps
.remove(&pmk)
.expect("key found but remove failed");
// Get local ip address
let local_ip = Self::find_local_ip(&mut inner, address_type)?;
// Find gateway
let gw = Self::find_gateway(&mut inner, local_ip)?;
// Unmap port
match gw.remove_port(convert_protocol_type(protocol_type), mapped_port) {
Ok(()) => (),
Err(e) => {
// Failed to map external port
log_net!(debug "upnp failed to remove external port: {}", e);
return None;
}
};
Some(())
},
None,
)
.await
}
#[instrument(level = "trace", target = "net", skip_all)]
pub async fn map_any_port(
&self,
protocol_type: IGDProtocolType,
address_type: IGDAddressType,
local_port: u16,
expected_external_address: Option<IpAddr>,
) -> Option<SocketAddr> {
let this = self.clone();
blocking_wrapper("igd map_any_port", move || {
let mut inner = this.inner.lock();
// If we already have this port mapped, just return the existing portmap
let pmkey = PortMapKey {
protocol_type,
address_type,
local_port,
};
if let Some(pmval) = inner.port_maps.get(&pmkey) {
return Some(SocketAddr::new(pmval.ext_ip, pmval.mapped_port));
}
// Get local ip address
let local_ip = Self::find_local_ip(&mut inner, address_type)?;
// Find gateway
let gw = Self::find_gateway(&mut inner, local_ip)?;
// Get external address
let ext_ip = match gw.get_external_ip() {
Ok(ip) => ip,
Err(e) => {
log_net!(debug "couldn't get external ip from igd: {}", e);
return None;
}
};
// Ensure external IP matches address type
if ext_ip.is_ipv4() && address_type != IGDAddressType::IPV4 {
log_net!(debug "mismatched ip address type from igd, wanted v4, got v6");
return None;
} else if ext_ip.is_ipv6() && address_type != IGDAddressType::IPV6 {
log_net!(debug "mismatched ip address type from igd, wanted v6, got v4");
return None;
}
if let Some(expected_external_address) = expected_external_address {
if ext_ip != expected_external_address {
log_net!(debug "gateway external address does not match calculated external address: expected={} vs gateway={}", expected_external_address, ext_ip);
return None;
}
}
// Map any port
let desc = this.get_description(protocol_type, local_port);
let mapped_port = match gw.add_any_port(convert_protocol_type(protocol_type), SocketAddr::new(local_ip, local_port), (UPNP_MAPPING_LIFETIME_MS + 999) / 1000, &desc) {
Ok(mapped_port) => mapped_port,
Err(e) => {
// Failed to map external port
log_net!(debug "upnp failed to map external port: {}", e);
return None;
}
};
// Add to mapping list to keep alive
let timestamp = get_timestamp();
inner.port_maps.insert(PortMapKey {
protocol_type,
address_type,
local_port,
}, PortMapValue {
ext_ip,
mapped_port,
timestamp,
renewal_lifetime: (UPNP_MAPPING_LIFETIME_MS / 2) as u64 * 1000u64,
renewal_attempts: 0,
});
// Succeeded, return the externally mapped port
Some(SocketAddr::new(ext_ip, mapped_port))
}, None)
.await
}
#[instrument(
level = "trace",
target = "net",
name = "IGDManager::tick",
skip_all,
err
)]
pub async fn tick(&self) -> EyreResult<bool> {
// Refresh mappings if we have them
// If an error is received, then return false to restart the local network
let mut full_renews: Vec<(PortMapKey, PortMapValue)> = Vec::new();
let mut renews: Vec<(PortMapKey, PortMapValue)> = Vec::new();
{
let inner = self.inner.lock();
let now = get_timestamp();
for (k, v) in &inner.port_maps {
let mapping_lifetime = now.saturating_sub(v.timestamp);
if mapping_lifetime >= UPNP_MAPPING_LIFETIME_US
|| v.renewal_attempts >= UPNP_MAPPING_ATTEMPTS
{
// Past expiration time or tried N times, do a full renew and fail out if we can't
full_renews.push((*k, *v));
} else if mapping_lifetime >= v.renewal_lifetime {
// Attempt a normal renewal
renews.push((*k, *v));
}
}
// See if we need to do some blocking operations
if full_renews.is_empty() && renews.is_empty() {
// Just return now since there's nothing to renew
return Ok(true);
}
}
let this = self.clone();
blocking_wrapper(
"igd tick",
move || {
let mut inner = this.inner.lock();
// Process full renewals
for (k, v) in full_renews {
// Get local ip for address type
let local_ip = match Self::get_local_ip(&mut inner, k.address_type) {
Some(ip) => ip,
None => {
return Err(eyre!("local ip missing for address type"));
}
};
// Get gateway for interface
let gw = match Self::get_gateway(&mut inner, local_ip) {
Some(gw) => gw,
None => {
return Err(eyre!("gateway missing for interface"));
}
};
// Delete the mapping if it exists, ignore any errors here
let _ = gw.remove_port(convert_protocol_type(k.protocol_type), v.mapped_port);
inner.port_maps.remove(&k);
let desc = this.get_description(k.protocol_type, k.local_port);
match gw.add_any_port(
convert_protocol_type(k.protocol_type),
SocketAddr::new(local_ip, k.local_port),
(UPNP_MAPPING_LIFETIME_MS + 999) / 1000,
&desc,
) {
Ok(mapped_port) => {
log_net!(debug "full-renewed mapped port {:?} -> {:?}", v, k);
inner.port_maps.insert(
k,
PortMapValue {
ext_ip: v.ext_ip,
mapped_port,
timestamp: get_timestamp(),
renewal_lifetime: (UPNP_MAPPING_LIFETIME_MS / 2) as u64
* 1000u64,
renewal_attempts: 0,
},
);
}
Err(e) => {
info!("failed to full-renew mapped port {:?} -> {:?}: {}", v, k, e);
// Must restart network now :(
return Ok(false);
}
};
}
// Process normal renewals
for (k, mut v) in renews {
// Get local ip for address type
let local_ip = match Self::get_local_ip(&mut inner, k.address_type) {
Some(ip) => ip,
None => {
return Err(eyre!("local ip missing for address type"));
}
};
// Get gateway for interface
let gw = match Self::get_gateway(&mut inner, local_ip) {
Some(gw) => gw,
None => {
return Err(eyre!("gateway missing for address type"));
}
};
let desc = this.get_description(k.protocol_type, k.local_port);
match gw.add_port(
convert_protocol_type(k.protocol_type),
v.mapped_port,
SocketAddr::new(local_ip, k.local_port),
(UPNP_MAPPING_LIFETIME_MS + 999) / 1000,
&desc,
) {
Ok(()) => {
log_net!("renewed mapped port {:?} -> {:?}", v, k);
inner.port_maps.insert(
k,
PortMapValue {
ext_ip: v.ext_ip,
mapped_port: v.mapped_port,
timestamp: get_timestamp(),
renewal_lifetime: (UPNP_MAPPING_LIFETIME_MS / 2) as u64
* 1000u64,
renewal_attempts: 0,
},
);
}
Err(e) => {
log_net!(debug "failed to renew mapped port {:?} -> {:?}: {}", v, k, e);
// Get closer to the maximum renewal timeline by a factor of two each time
v.renewal_lifetime =
(v.renewal_lifetime + UPNP_MAPPING_LIFETIME_US) / 2u64;
v.renewal_attempts += 1;
// Store new value to try again
inner.port_maps.insert(k, v);
}
};
}
// Normal exit, no restart
Ok(true)
},
Err(eyre!("failed to process blocking task")),
)
.instrument(tracing::trace_span!("igd tick fut"))
.await
}
/////////////////////////////////////////////////////////////////////
// Private Implementation
#[instrument(level = "trace", target = "net", skip_all)]
fn get_routed_local_ip_address(address_type: IGDAddressType) -> Option<IpAddr> {
let socket = match UdpSocket::bind(match address_type {
AddressType::IPV4 => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
AddressType::IPV6 => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
IGDAddressType::IPV4 => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
IGDAddressType::IPV6 => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
}) {
Ok(s) => s,
Err(e) => {
@ -75,8 +401,8 @@ impl IGDManager {
// using google's dns, but it wont actually send any packets to it
socket
.connect(match address_type {
AddressType::IPV4 => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 80),
AddressType::IPV6 => SocketAddr::new(
IGDAddressType::IPV4 => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 80),
IGDAddressType::IPV6 => SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888)),
80,
),
@ -91,7 +417,7 @@ impl IGDManager {
}
#[instrument(level = "trace", target = "net", skip_all)]
fn find_local_ip(inner: &mut IGDManagerInner, address_type: AddressType) -> Option<IpAddr> {
fn find_local_ip(inner: &mut IGDManagerInner, address_type: IGDAddressType) -> Option<IpAddr> {
if let Some(ip) = inner.local_ip_addrs.get(&address_type) {
return Some(*ip);
}
@ -109,7 +435,7 @@ impl IGDManager {
}
#[instrument(level = "trace", target = "net", skip_all)]
fn get_local_ip(inner: &mut IGDManagerInner, address_type: AddressType) -> Option<IpAddr> {
fn get_local_ip(inner: &mut IGDManagerInner, address_type: IGDAddressType) -> Option<IpAddr> {
if let Some(ip) = inner.local_ip_addrs.get(&address_type) {
return Some(*ip);
}
@ -164,304 +490,10 @@ impl IGDManager {
None
}
fn get_description(&self, llpt: LowLevelProtocolType, local_port: u16) -> String {
fn get_description(&self, protocol_type: IGDProtocolType, local_port: u16) -> String {
format!(
"{} map {} for port {}",
self.config.get().program_name,
convert_llpt(llpt),
local_port
self.program_name, protocol_type, local_port
)
}
#[instrument(level = "trace", target = "net", skip_all)]
pub async fn unmap_port(
&self,
llpt: LowLevelProtocolType,
at: AddressType,
mapped_port: u16,
) -> Option<()> {
let this = self.clone();
blocking_wrapper(
"igd unmap_port",
move || {
let mut inner = this.inner.lock();
// If we already have this port mapped, just return the existing portmap
let mut found = None;
for (pmk, pmv) in &inner.port_maps {
if pmk.llpt == llpt && pmk.at == at && pmv.mapped_port == mapped_port {
found = Some(*pmk);
break;
}
}
let pmk = found?;
let _pmv = inner
.port_maps
.remove(&pmk)
.expect("key found but remove failed");
// Get local ip address
let local_ip = Self::find_local_ip(&mut inner, at)?;
// Find gateway
let gw = Self::find_gateway(&mut inner, local_ip)?;
// Unmap port
match gw.remove_port(convert_llpt(llpt), mapped_port) {
Ok(()) => (),
Err(e) => {
// Failed to map external port
log_net!(debug "upnp failed to remove external port: {}", e);
return None;
}
};
Some(())
},
None,
)
.await
}
#[instrument(level = "trace", target = "net", skip_all)]
pub async fn map_any_port(
&self,
llpt: LowLevelProtocolType,
at: AddressType,
local_port: u16,
expected_external_address: Option<IpAddr>,
) -> Option<SocketAddr> {
let this = self.clone();
blocking_wrapper("igd map_any_port", move || {
let mut inner = this.inner.lock();
// If we already have this port mapped, just return the existing portmap
let pmkey = PortMapKey {
llpt,
at,
local_port,
};
if let Some(pmval) = inner.port_maps.get(&pmkey) {
return Some(SocketAddr::new(pmval.ext_ip, pmval.mapped_port));
}
// Get local ip address
let local_ip = Self::find_local_ip(&mut inner, at)?;
// Find gateway
let gw = Self::find_gateway(&mut inner, local_ip)?;
// Get external address
let ext_ip = match gw.get_external_ip() {
Ok(ip) => ip,
Err(e) => {
log_net!(debug "couldn't get external ip from igd: {}", e);
return None;
}
};
// Ensure external IP matches address type
if ext_ip.is_ipv4() && at != AddressType::IPV4 {
log_net!(debug "mismatched ip address type from igd, wanted v4, got v6");
return None;
} else if ext_ip.is_ipv6() && at != AddressType::IPV6 {
log_net!(debug "mismatched ip address type from igd, wanted v6, got v4");
return None;
}
if let Some(expected_external_address) = expected_external_address {
if ext_ip != expected_external_address {
log_net!(debug "gateway external address does not match calculated external address: expected={} vs gateway={}", expected_external_address, ext_ip);
return None;
}
}
// Map any port
let desc = this.get_description(llpt, local_port);
let mapped_port = match gw.add_any_port(convert_llpt(llpt), SocketAddr::new(local_ip, local_port), (UPNP_MAPPING_LIFETIME_MS + 999) / 1000, &desc) {
Ok(mapped_port) => mapped_port,
Err(e) => {
// Failed to map external port
log_net!(debug "upnp failed to map external port: {}", e);
return None;
}
};
// Add to mapping list to keep alive
let timestamp = Timestamp::now();
inner.port_maps.insert(PortMapKey {
llpt,
at,
local_port,
}, PortMapValue {
ext_ip,
mapped_port,
timestamp,
renewal_lifetime: ((UPNP_MAPPING_LIFETIME_MS / 2) as u64 * 1000u64).into(),
renewal_attempts: 0,
});
// Succeeded, return the externally mapped port
Some(SocketAddr::new(ext_ip, mapped_port))
}, None)
.await
}
#[instrument(
level = "trace",
target = "net",
name = "IGDManager::tick",
skip_all,
err
)]
pub async fn tick(&self) -> EyreResult<bool> {
// Refresh mappings if we have them
// If an error is received, then return false to restart the local network
let mut full_renews: Vec<(PortMapKey, PortMapValue)> = Vec::new();
let mut renews: Vec<(PortMapKey, PortMapValue)> = Vec::new();
{
let inner = self.inner.lock();
let now = Timestamp::now();
for (k, v) in &inner.port_maps {
let mapping_lifetime = now.saturating_sub(v.timestamp);
if mapping_lifetime >= UPNP_MAPPING_LIFETIME_US
|| v.renewal_attempts >= UPNP_MAPPING_ATTEMPTS
{
// Past expiration time or tried N times, do a full renew and fail out if we can't
full_renews.push((*k, *v));
} else if mapping_lifetime >= v.renewal_lifetime {
// Attempt a normal renewal
renews.push((*k, *v));
}
}
// See if we need to do some blocking operations
if full_renews.is_empty() && renews.is_empty() {
// Just return now since there's nothing to renew
return Ok(true);
}
}
let this = self.clone();
blocking_wrapper(
"igd tick",
move || {
let mut inner = this.inner.lock();
// Process full renewals
for (k, v) in full_renews {
// Get local ip for address type
let local_ip = match Self::get_local_ip(&mut inner, k.at) {
Some(ip) => ip,
None => {
return Err(eyre!("local ip missing for address type"));
}
};
// Get gateway for interface
let gw = match Self::get_gateway(&mut inner, local_ip) {
Some(gw) => gw,
None => {
return Err(eyre!("gateway missing for interface"));
}
};
// Delete the mapping if it exists, ignore any errors here
let _ = gw.remove_port(convert_llpt(k.llpt), v.mapped_port);
inner.port_maps.remove(&k);
let desc = this.get_description(k.llpt, k.local_port);
match gw.add_any_port(
convert_llpt(k.llpt),
SocketAddr::new(local_ip, k.local_port),
(UPNP_MAPPING_LIFETIME_MS + 999) / 1000,
&desc,
) {
Ok(mapped_port) => {
log_net!(debug "full-renewed mapped port {:?} -> {:?}", v, k);
inner.port_maps.insert(
k,
PortMapValue {
ext_ip: v.ext_ip,
mapped_port,
timestamp: Timestamp::now(),
renewal_lifetime: TimestampDuration::new(
(UPNP_MAPPING_LIFETIME_MS / 2) as u64 * 1000u64,
),
renewal_attempts: 0,
},
);
}
Err(e) => {
info!("failed to full-renew mapped port {:?} -> {:?}: {}", v, k, e);
// Must restart network now :(
return Ok(false);
}
};
}
// Process normal renewals
for (k, mut v) in renews {
// Get local ip for address type
let local_ip = match Self::get_local_ip(&mut inner, k.at) {
Some(ip) => ip,
None => {
return Err(eyre!("local ip missing for address type"));
}
};
// Get gateway for interface
let gw = match Self::get_gateway(&mut inner, local_ip) {
Some(gw) => gw,
None => {
return Err(eyre!("gateway missing for address type"));
}
};
let desc = this.get_description(k.llpt, k.local_port);
match gw.add_port(
convert_llpt(k.llpt),
v.mapped_port,
SocketAddr::new(local_ip, k.local_port),
(UPNP_MAPPING_LIFETIME_MS + 999) / 1000,
&desc,
) {
Ok(()) => {
log_net!("renewed mapped port {:?} -> {:?}", v, k);
inner.port_maps.insert(
k,
PortMapValue {
ext_ip: v.ext_ip,
mapped_port: v.mapped_port,
timestamp: Timestamp::now(),
renewal_lifetime: ((UPNP_MAPPING_LIFETIME_MS / 2) as u64
* 1000u64)
.into(),
renewal_attempts: 0,
},
);
}
Err(e) => {
log_net!(debug "failed to renew mapped port {:?} -> {:?}: {}", v, k, e);
// Get closer to the maximum renewal timeline by a factor of two each time
v.renewal_lifetime =
(v.renewal_lifetime + UPNP_MAPPING_LIFETIME_US) / 2u64;
v.renewal_attempts += 1;
// Store new value to try again
inner.port_maps.insert(k, v);
}
};
}
// Normal exit, no restart
Ok(true)
},
Err(eyre!("failed to process blocking task")),
)
.instrument(tracing::trace_span!("igd tick fut"))
.await
}
}

View File

@ -113,16 +113,13 @@ struct NetworkInner {
network_state: Option<NetworkState>,
}
struct NetworkUnlockedInner {
pub(super) struct NetworkUnlockedInner {
// Startup lock
startup_lock: StartupLock,
// Accessors
routing_table: RoutingTable,
network_manager: NetworkManager,
connection_manager: ConnectionManager,
// Network
interfaces: NetworkInterfaces,
// Background processes
update_network_class_task: TickTask<EyreReport>,
network_interfaces_task: TickTask<EyreReport>,
@ -135,11 +132,21 @@ struct NetworkUnlockedInner {
#[derive(Clone)]
pub(super) struct Network {
config: VeilidConfig,
registry: VeilidComponentRegistry,
inner: Arc<Mutex<NetworkInner>>,
unlocked_inner: Arc<NetworkUnlockedInner>,
}
impl_veilid_component_registry_accessor!(Network);
impl core::ops::Deref for Network {
type Target = NetworkUnlockedInner;
fn deref(&self) -> &Self::Target {
&self.unlocked_inner
}
}
impl Network {
fn new_inner() -> NetworkInner {
NetworkInner {
@ -161,17 +168,11 @@ impl Network {
}
}
fn new_unlocked_inner(
network_manager: NetworkManager,
routing_table: RoutingTable,
connection_manager: ConnectionManager,
) -> NetworkUnlockedInner {
let config = network_manager.config();
fn new_unlocked_inner(registry: VeilidComponentRegistry) -> NetworkUnlockedInner {
let config = registry.config();
let program_name = config.get().program_name.clone();
NetworkUnlockedInner {
startup_lock: StartupLock::new(),
network_manager,
routing_table,
connection_manager,
interfaces: NetworkInterfaces::new(),
update_network_class_task: TickTask::new(
"update_network_class_task",
@ -183,23 +184,15 @@ impl Network {
),
upnp_task: TickTask::new("upnp_task", UPNP_TASK_TICK_PERIOD_SECS),
network_task_lock: AsyncMutex::new(()),
igd_manager: igd_manager::IGDManager::new(config.clone()),
igd_manager: igd_manager::IGDManager::new(program_name),
}
}
pub fn new(
network_manager: NetworkManager,
routing_table: RoutingTable,
connection_manager: ConnectionManager,
) -> Self {
pub fn new(registry: VeilidComponentRegistry) -> Self {
let this = Self {
config: network_manager.config(),
inner: Arc::new(Mutex::new(Self::new_inner())),
unlocked_inner: Arc::new(Self::new_unlocked_inner(
network_manager,
routing_table,
connection_manager,
)),
unlocked_inner: Arc::new(Self::new_unlocked_inner(registry.clone())),
registry,
};
this.setup_tasks();
@ -207,18 +200,6 @@ impl Network {
this
}
fn network_manager(&self) -> NetworkManager {
self.unlocked_inner.network_manager.clone()
}
fn routing_table(&self) -> RoutingTable {
self.unlocked_inner.routing_table.clone()
}
fn connection_manager(&self) -> ConnectionManager {
self.unlocked_inner.connection_manager.clone()
}
fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> {
let cvec = certs(&mut BufReader::new(File::open(path)?))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid TLS certificate"))?;
@ -248,7 +229,8 @@ impl Network {
}
fn load_server_config(&self) -> io::Result<ServerConfig> {
let c = self.config.get();
let config = self.config();
let c = config.get();
//
log_net!(
"loading certificate from {}",
@ -288,7 +270,10 @@ impl Network {
if !from.ip().is_unspecified() {
vec![from]
} else {
let addrs = self.last_network_state().stable_interface_addresses;
let addrs = self
.last_network_state()
.unwrap()
.stable_interface_addresses;
addrs
.iter()
.filter_map(|a| {
@ -346,16 +331,15 @@ impl Network {
dial_info: DialInfo,
data: Vec<u8>,
) -> EyreResult<NetworkResult<()>> {
let _guard = self.unlocked_inner.startup_lock.enter()?;
let _guard = self.startup_lock.enter()?;
self.record_dial_info_failure(
dial_info.clone(),
async move {
let data_len = data.len();
let connect_timeout_ms = {
let c = self.config.get();
c.network.connection_initial_timeout_ms
};
let connect_timeout_ms = self
.config()
.with(|c| c.network.connection_initial_timeout_ms);
if self
.network_manager()
@ -368,10 +352,12 @@ impl Network {
match dial_info.protocol_type() {
ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr();
let h =
RawUdpProtocolHandler::new_unspecified_bound_handler(&peer_socket_addr)
.await
.wrap_err("create socket failure")?;
let h = RawUdpProtocolHandler::new_unspecified_bound_handler(
self.registry(),
&peer_socket_addr,
)
.await
.wrap_err("create socket failure")?;
let _ = network_result_try!(h
.send_message(data, peer_socket_addr)
.await
@ -423,16 +409,15 @@ impl Network {
data: Vec<u8>,
timeout_ms: u32,
) -> EyreResult<NetworkResult<Vec<u8>>> {
let _guard = self.unlocked_inner.startup_lock.enter()?;
let _guard = self.startup_lock.enter()?;
self.record_dial_info_failure(
dial_info.clone(),
async move {
let data_len = data.len();
let connect_timeout_ms = {
let c = self.config.get();
c.network.connection_initial_timeout_ms
};
let connect_timeout_ms = self
.config()
.with(|c| c.network.connection_initial_timeout_ms);
if self
.network_manager()
@ -445,10 +430,12 @@ impl Network {
match dial_info.protocol_type() {
ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr();
let h =
RawUdpProtocolHandler::new_unspecified_bound_handler(&peer_socket_addr)
.await
.wrap_err("create socket failure")?;
let h = RawUdpProtocolHandler::new_unspecified_bound_handler(
self.registry(),
&peer_socket_addr,
)
.await
.wrap_err("create socket failure")?;
network_result_try!(h
.send_message(data, peer_socket_addr)
.await
@ -539,7 +526,7 @@ impl Network {
flow: Flow,
data: Vec<u8>,
) -> EyreResult<SendDataToExistingFlowResult> {
let _guard = self.unlocked_inner.startup_lock.enter()?;
let _guard = self.startup_lock.enter()?;
let data_len = data.len();
@ -573,7 +560,11 @@ impl Network {
// Handle connection-oriented protocols
// Try to send to the exact existing connection if one exists
if let Some(conn) = self.connection_manager().get_connection(flow) {
if let Some(conn) = self
.network_manager()
.connection_manager()
.get_connection(flow)
{
// connection exists, send over it
match conn.send_async(data).await {
ConnectionHandleSendResult::Sent => {
@ -606,7 +597,7 @@ impl Network {
dial_info: DialInfo,
data: Vec<u8>,
) -> EyreResult<NetworkResult<UniqueFlow>> {
let _guard = self.unlocked_inner.startup_lock.enter()?;
let _guard = self.startup_lock.enter()?;
self.record_dial_info_failure(
dial_info.clone(),
@ -635,7 +626,8 @@ impl Network {
} else {
// Handle connection-oriented protocols
let conn = network_result_try!(
self.connection_manager()
self.network_manager()
.connection_manager()
.get_or_create_connection(dial_info.clone())
.await?
);
@ -678,14 +670,9 @@ impl Network {
}
// Start editing routing table
let mut editor_public_internet = self
.unlocked_inner
.routing_table
.edit_public_internet_routing_domain();
let mut editor_local_network = self
.unlocked_inner
.routing_table
.edit_local_network_routing_domain();
let routing_table = self.routing_table();
let mut editor_public_internet = routing_table.edit_public_internet_routing_domain();
let mut editor_local_network = routing_table.edit_local_network_routing_domain();
// Setup network
editor_local_network.set_local_networks(network_state.local_networks);
@ -763,8 +750,8 @@ impl Network {
#[instrument(level = "debug", err, skip_all)]
pub(super) async fn register_all_dial_info(
&self,
editor_public_internet: &mut RoutingDomainEditorPublicInternet,
editor_local_network: &mut RoutingDomainEditorLocalNetwork,
editor_public_internet: &mut RoutingDomainEditorPublicInternet<'_>,
editor_local_network: &mut RoutingDomainEditorLocalNetwork<'_>,
) -> EyreResult<()> {
let Some(protocol_config) = ({
let inner = self.inner.lock();
@ -798,7 +785,7 @@ impl Network {
#[instrument(level = "debug", err, skip_all)]
pub async fn startup(&self) -> EyreResult<StartupDisposition> {
let guard = self.unlocked_inner.startup_lock.startup()?;
let guard = self.startup_lock.startup()?;
match self.startup_internal().await {
Ok(StartupDisposition::Success) => {
@ -824,7 +811,7 @@ impl Network {
}
pub fn is_started(&self) -> bool {
self.unlocked_inner.startup_lock.is_started()
self.startup_lock.is_started()
}
#[instrument(level = "debug", skip_all)]
@ -836,12 +823,6 @@ impl Network {
async fn shutdown_internal(&self) {
let routing_table = self.routing_table();
// Stop all tasks
log_net!(debug "stopping update network class task");
if let Err(e) = self.unlocked_inner.update_network_class_task.stop().await {
error!("update_network_class_task not cancelled: {}", e);
}
let mut unord = FuturesUnordered::new();
{
let mut inner = self.inner.lock();
@ -876,7 +857,7 @@ impl Network {
#[instrument(level = "debug", skip_all)]
pub async fn shutdown(&self) {
log_net!(debug "starting low level network shutdown");
let Ok(guard) = self.unlocked_inner.startup_lock.shutdown().await else {
let Ok(guard) = self.startup_lock.shutdown().await else {
log_net!(debug "low level network is already shut down");
return;
};
@ -892,7 +873,7 @@ impl Network {
&self,
punishment: Option<Box<dyn FnOnce() + Send + 'static>>,
) {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else {
let Ok(_guard) = self.startup_lock.enter() else {
log_net!(debug "ignoring due to not started up");
return;
};
@ -902,7 +883,7 @@ impl Network {
}
pub fn needs_public_dial_info_check(&self) -> bool {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else {
let Ok(_guard) = self.startup_lock.enter() else {
log_net!(debug "ignoring due to not started up");
return false;
};

View File

@ -28,7 +28,7 @@ pub(super) struct NetworkState {
impl Network {
fn make_stable_interface_addresses(&self) -> Vec<IpAddr> {
let addrs = self.unlocked_inner.interfaces.stable_addresses();
let addrs = self.interfaces.stable_addresses();
let mut addrs: Vec<IpAddr> = addrs
.into_iter()
.filter(|addr| {
@ -41,8 +41,8 @@ impl Network {
addrs
}
pub(super) fn last_network_state(&self) -> NetworkState {
self.inner.lock().network_state.clone().unwrap()
pub(super) fn last_network_state(&self) -> Option<NetworkState> {
self.inner.lock().network_state.clone()
}
pub(super) fn is_stable_interface_address(&self, addr: IpAddr) -> bool {
@ -57,8 +57,7 @@ impl Network {
pub(super) async fn make_network_state(&self) -> EyreResult<NetworkState> {
// refresh network interfaces
self.unlocked_inner
.interfaces
self.interfaces
.refresh()
.await
.wrap_err("failed to refresh network interfaces")?;
@ -66,22 +65,20 @@ impl Network {
// build the set of networks we should consider for the 'LocalNetwork' routing domain
let mut local_networks: HashSet<(IpAddr, IpAddr)> = HashSet::new();
self.unlocked_inner
.interfaces
.with_interfaces(|interfaces| {
for intf in interfaces.values() {
// Skip networks that we should never encounter
if intf.is_loopback() || !intf.is_running() {
continue;
}
// Add network to local networks table
for addr in &intf.addrs {
let netmask = addr.if_addr().netmask();
let network_ip = ipaddr_apply_netmask(addr.if_addr().ip(), netmask);
local_networks.insert((network_ip, netmask));
}
self.interfaces.with_interfaces(|interfaces| {
for intf in interfaces.values() {
// Skip networks that we should never encounter
if intf.is_loopback() || !intf.is_running() {
continue;
}
});
// Add network to local networks table
for addr in &intf.addrs {
let netmask = addr.if_addr().netmask();
let network_ip = ipaddr_apply_netmask(addr.if_addr().ip(), netmask);
local_networks.insert((network_ip, netmask));
}
}
});
let mut local_networks: Vec<(IpAddr, IpAddr)> = local_networks.into_iter().collect();
local_networks.sort();
@ -107,7 +104,8 @@ impl Network {
// Get protocol config
let protocol_config = {
let c = self.config.get();
let config = self.config();
let c = config.get();
let mut inbound = ProtocolTypeSet::new();
if c.network.protocol.udp.enabled {

View File

@ -1,6 +1,5 @@
use super::*;
use async_tls::TlsAcceptor;
use sockets::*;
use stop_token::future::FutureExt;
/////////////////////////////////////////////////////////////////
@ -122,8 +121,11 @@ impl Network {
}
};
// Check to see if it is punished
let address_filter = self.network_manager().address_filter();
if address_filter.is_ip_addr_punished(peer_addr.ip()) {
if self
.network_manager()
.address_filter()
.is_ip_addr_punished(peer_addr.ip())
{
return;
}
@ -135,39 +137,12 @@ impl Network {
}
};
#[cfg(all(feature = "rt-async-std", unix))]
if let Err(e) = set_tcp_stream_linger(&tcp_stream, Some(core::time::Duration::from_secs(0)))
{
// async-std does not directly support linger on TcpStream yet
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd};
if let Err(e) = unsafe {
let s = socket2::Socket::from_raw_fd(tcp_stream.as_raw_fd());
let res = s.set_linger(Some(core::time::Duration::from_secs(0)));
s.into_raw_fd();
res
} {
log_net!(debug "Couldn't set TCP linger: {}", e);
return;
}
}
#[cfg(all(feature = "rt-async-std", windows))]
{
// async-std does not directly support linger on TcpStream yet
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket};
if let Err(e) = unsafe {
let s = socket2::Socket::from_raw_socket(tcp_stream.as_raw_socket());
let res = s.set_linger(Some(core::time::Duration::from_secs(0)));
s.into_raw_socket();
res
} {
log_net!(debug "Couldn't set TCP linger: {}", e);
return;
}
}
#[cfg(not(feature = "rt-async-std"))]
if let Err(e) = tcp_stream.set_linger(Some(core::time::Duration::from_secs(0))) {
log_net!(debug "Couldn't set TCP linger: {}", e);
return;
}
if let Err(e) = tcp_stream.set_nodelay(true) {
log_net!(debug "Couldn't set TCP nodelay: {}", e);
return;
@ -249,49 +224,19 @@ impl Network {
#[instrument(level = "trace", skip_all)]
async fn spawn_socket_listener(&self, addr: SocketAddr) -> EyreResult<bool> {
// Get config
let (connection_initial_timeout_ms, tls_connection_initial_timeout_ms) = {
let c = self.config.get();
(
c.network.connection_initial_timeout_ms,
c.network.tls.connection_initial_timeout_ms,
)
};
// Create a socket and bind it
let Some(socket) = new_bound_default_tcp_socket(addr)
.wrap_err("failed to create default socket listener")?
else {
return Ok(false);
};
// Drop the socket
drop(socket);
let (connection_initial_timeout_ms, tls_connection_initial_timeout_ms) =
self.config().with(|c| {
(
c.network.connection_initial_timeout_ms,
c.network.tls.connection_initial_timeout_ms,
)
});
// Create a shared socket and bind it once we have determined the port is free
let Some(socket) = new_bound_shared_tcp_socket(addr)
.wrap_err("failed to create shared socket listener")?
else {
let Some(listener) = bind_async_tcp_listener(addr)? else {
return Ok(false);
};
// Listen on the socket
if socket.listen(128).is_err() {
return Ok(false);
}
// Make an async tcplistener from the socket2 socket
let std_listener: std::net::TcpListener = socket.into();
cfg_if! {
if #[cfg(feature="rt-async-std")] {
let listener = TcpListener::from(std_listener);
} else if #[cfg(feature="rt-tokio")] {
std_listener.set_nonblocking(true).expect("failed to set nonblocking");
let listener = TcpListener::from_std(std_listener).wrap_err("failed to create tokio tcp listener")?;
} else {
compile_error!("needs executor implementation");
}
}
log_net!(debug "spawn_socket_listener: binding successful to {}", addr);
// Create protocol handler records
@ -304,22 +249,14 @@ impl Network {
// Spawn the socket task
let this = self.clone();
let stop_token = self.inner.lock().stop_source.as_ref().unwrap().token();
let connection_manager = self.connection_manager();
let connection_manager = self.network_manager().connection_manager();
////////////////////////////////////////////////////////////
let jh = spawn(&format!("TCP listener {}", addr), async move {
// moves listener object in and get incoming iterator
// when this task exists, the listener will close the socket
cfg_if! {
if #[cfg(feature="rt-async-std")] {
let incoming_stream = listener.incoming();
} else if #[cfg(feature="rt-tokio")] {
let incoming_stream = tokio_stream::wrappers::TcpListenerStream::new(listener);
} else {
compile_error!("needs executor implementation");
}
}
let incoming_stream = async_tcp_listener_incoming(listener);
let _ = incoming_stream
.for_each_concurrent(None, |tcp_stream| {

View File

@ -1,15 +1,13 @@
use super::*;
use sockets::*;
use stop_token::future::FutureExt;
impl Network {
#[instrument(level = "trace", skip_all)]
pub(super) async fn create_udp_listener_tasks(&self) -> EyreResult<()> {
// Spawn socket tasks
let mut task_count = {
let c = self.config.get();
c.network.protocol.udp.socket_pool_size
};
let mut task_count = self
.config()
.with(|c| c.network.protocol.udp.socket_pool_size);
if task_count == 0 {
task_count = get_concurrency() / 2;
if task_count == 0 {
@ -38,7 +36,6 @@ impl Network {
// Spawn a local async task for each socket
let mut protocol_handlers_unordered = FuturesUnordered::new();
let network_manager = this.network_manager();
let stop_token = {
let inner = this.inner.lock();
if inner.stop_source.is_none() {
@ -49,7 +46,7 @@ impl Network {
};
for ph in protocol_handlers {
let network_manager = network_manager.clone();
let network_manager = this.network_manager();
let stop_token = stop_token.clone();
let ph_future = async move {
let mut data = vec![0u8; 65536];
@ -114,28 +111,14 @@ impl Network {
async fn create_udp_protocol_handler(&self, addr: SocketAddr) -> EyreResult<bool> {
log_net!(debug "create_udp_protocol_handler on {:?}", &addr);
// Create a reusable socket
let Some(socket) = new_bound_default_udp_socket(addr)? else {
// Create a single-address-family UDP socket with default options bound to an address
let Some(udp_socket) = bind_async_udp_socket(addr)? else {
return Ok(false);
};
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
cfg_if! {
if #[cfg(feature="rt-async-std")] {
let udp_socket = UdpSocket::from(std_udp_socket);
} else if #[cfg(feature="rt-tokio")] {
std_udp_socket.set_nonblocking(true).expect("failed to set nonblocking");
let udp_socket = UdpSocket::from_std(std_udp_socket).wrap_err("failed to make inbound tokio udpsocket")?;
} else {
compile_error!("needs executor implementation");
}
}
let socket_arc = Arc::new(udp_socket);
// Create protocol handler
let protocol_handler =
RawUdpProtocolHandler::new(socket_arc, Some(self.network_manager().address_filter()));
let protocol_handler = RawUdpProtocolHandler::new(self.registry(), socket_arc);
// Record protocol handler
let mut inner = self.inner.lock();

View File

@ -1,4 +1,3 @@
pub mod sockets;
pub mod tcp;
pub mod udp;
pub mod wrtc;
@ -22,7 +21,7 @@ impl ProtocolNetworkConnection {
local_address: Option<SocketAddr>,
dial_info: &DialInfo,
timeout_ms: u32,
address_filter: AddressFilter,
address_filter: &AddressFilter,
) -> io::Result<NetworkResult<ProtocolNetworkConnection>> {
if address_filter.is_ip_addr_punished(dial_info.address().ip_addr()) {
return Ok(NetworkResult::no_connection_other("punished"));

View File

@ -1,191 +0,0 @@
use crate::*;
use async_io::Async;
use std::io;
cfg_if! {
if #[cfg(feature="rt-async-std")] {
pub use async_std::net::{TcpStream, TcpListener, UdpSocket};
} else if #[cfg(feature="rt-tokio")] {
pub use tokio::net::{TcpStream, TcpListener, UdpSocket};
pub use tokio_util::compat::*;
} else {
compile_error!("needs executor implementation");
}
}
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
// cfg_if! {
// if #[cfg(windows)] {
// use winapi::shared::ws2def::{ SOL_SOCKET, SO_EXCLUSIVEADDRUSE};
// use winapi::um::winsock2::{SOCKET_ERROR, setsockopt};
// use winapi::ctypes::c_int;
// use std::os::windows::io::AsRawSocket;
// fn set_exclusiveaddruse(socket: &Socket) -> io::Result<()> {
// unsafe {
// let optval:c_int = 1;
// if setsockopt(socket.as_raw_socket().try_into().unwrap(), SOL_SOCKET, SO_EXCLUSIVEADDRUSE, (&optval as *const c_int).cast(),
// std::mem::size_of::<c_int>() as c_int) == SOCKET_ERROR {
// return Err(io::Error::last_os_error());
// }
// Ok(())
// }
// }
// }
// }
#[instrument(level = "trace", ret)]
pub fn new_shared_udp_socket(domain: Domain) -> io::Result<Socket> {
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
socket.set_reuse_address(true)?;
cfg_if! {
if #[cfg(unix)] {
socket.set_reuse_port(true)?;
}
}
Ok(socket)
}
#[instrument(level = "trace", ret)]
pub fn new_default_udp_socket(domain: Domain) -> io::Result<Socket> {
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
Ok(socket)
}
#[instrument(level = "trace", ret)]
pub fn new_bound_default_udp_socket(local_address: SocketAddr) -> io::Result<Option<Socket>> {
let domain = Domain::for_address(local_address);
let socket = new_default_udp_socket(domain)?;
let socket2_addr = SockAddr::from(local_address);
if socket.bind(&socket2_addr).is_err() {
return Ok(None);
}
log_net!("created bound default udp socket on {:?}", &local_address);
Ok(Some(socket))
}
#[instrument(level = "trace", ret)]
pub fn new_default_tcp_socket(domain: Domain) -> io::Result<Socket> {
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) {
log_net!(error "Couldn't set TCP linger: {}", e);
}
if let Err(e) = socket.set_nodelay(true) {
log_net!(error "Couldn't set TCP nodelay: {}", e);
}
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
Ok(socket)
}
#[instrument(level = "trace", ret)]
pub fn new_shared_tcp_socket(domain: Domain) -> io::Result<Socket> {
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) {
log_net!(error "Couldn't set TCP linger: {}", e);
}
if let Err(e) = socket.set_nodelay(true) {
log_net!(error "Couldn't set TCP nodelay: {}", e);
}
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
socket.set_reuse_address(true)?;
cfg_if! {
if #[cfg(unix)] {
socket.set_reuse_port(true)?;
}
}
Ok(socket)
}
#[instrument(level = "trace", ret)]
pub fn new_bound_default_tcp_socket(local_address: SocketAddr) -> io::Result<Option<Socket>> {
let domain = Domain::for_address(local_address);
let socket = new_default_tcp_socket(domain)?;
let socket2_addr = SockAddr::from(local_address);
if socket.bind(&socket2_addr).is_err() {
return Ok(None);
}
log_net!("created bound default tcp socket on {:?}", &local_address);
Ok(Some(socket))
}
#[instrument(level = "trace", ret)]
pub fn new_bound_shared_tcp_socket(local_address: SocketAddr) -> io::Result<Option<Socket>> {
let domain = Domain::for_address(local_address);
let socket = new_shared_tcp_socket(domain)?;
let socket2_addr = SockAddr::from(local_address);
if socket.bind(&socket2_addr).is_err() {
return Ok(None);
}
log_net!("created bound shared tcp socket on {:?}", &local_address);
Ok(Some(socket))
}
// Non-blocking connect is tricky when you want to start with a prepared socket
// Errors should not be logged as they are valid conditions for this function
#[instrument(level = "trace", ret)]
pub async fn nonblocking_connect(
socket: Socket,
addr: SocketAddr,
timeout_ms: u32,
) -> io::Result<TimeoutOr<TcpStream>> {
// Set for non blocking connect
socket.set_nonblocking(true)?;
// Make socket2 SockAddr
let socket2_addr = socket2::SockAddr::from(addr);
// Connect to the remote address
match socket.connect(&socket2_addr) {
Ok(()) => Ok(()),
#[cfg(unix)]
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(()),
Err(e) => Err(e),
}?;
let async_stream = Async::new(std::net::TcpStream::from(socket))?;
// The stream becomes writable when connected
timeout_or_try!(
timeout(timeout_ms, async_stream.writable().in_current_span())
.await
.into_timeout_or()
.into_result()?
);
// Check low level error
let async_stream = match async_stream.get_ref().take_error()? {
None => Ok(async_stream),
Some(err) => Err(err),
}?;
// Convert back to inner and then return async version
cfg_if! {
if #[cfg(feature="rt-async-std")] {
Ok(TimeoutOr::value(TcpStream::from(async_stream.into_inner()?)))
} else if #[cfg(feature="rt-tokio")] {
Ok(TimeoutOr::value(TcpStream::from_std(async_stream.into_inner()?)?))
} else {
compile_error!("needs executor implementation");
}
}
}

View File

@ -1,6 +1,5 @@
use super::*;
use futures_util::{AsyncReadExt, AsyncWriteExt};
use sockets::*;
pub struct RawTcpNetworkConnection {
flow: Flow,
@ -157,32 +156,28 @@ impl RawTcpProtocolHandler {
#[instrument(level = "trace", target = "protocol", err)]
pub async fn connect(
local_address: Option<SocketAddr>,
socket_addr: SocketAddr,
remote_address: SocketAddr,
timeout_ms: u32,
) -> io::Result<NetworkResult<ProtocolNetworkConnection>> {
// Make a shared socket
let socket = match local_address {
Some(a) => {
new_bound_shared_tcp_socket(a)?.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?
}
None => new_default_tcp_socket(socket2::Domain::for_address(socket_addr))?,
};
// Non-blocking connect to remote address
let ts = network_result_try!(nonblocking_connect(socket, socket_addr, timeout_ms)
.await
.folded()?);
let tcp_stream = network_result_try!(connect_async_tcp_stream(
local_address,
remote_address,
timeout_ms
)
.await
.folded()?);
// See what local address we ended up with and turn this into a stream
let actual_local_address = ts.local_addr()?;
let actual_local_address = tcp_stream.local_addr()?;
#[cfg(feature = "rt-tokio")]
let ts = ts.compat();
let ps = AsyncPeekStream::new(ts);
let tcp_stream = tcp_stream.compat();
let ps = AsyncPeekStream::new(tcp_stream);
// Wrap the stream in a network connection and return it
let flow = Flow::new(
PeerAddress::new(
SocketAddress::from_socket_addr(socket_addr),
SocketAddress::from_socket_addr(remote_address),
ProtocolType::TCP,
),
SocketAddress::from_socket_addr(actual_local_address),

View File

@ -1,19 +1,20 @@
use super::*;
use sockets::*;
#[derive(Clone)]
pub struct RawUdpProtocolHandler {
registry: VeilidComponentRegistry,
socket: Arc<UdpSocket>,
assembly_buffer: AssemblyBuffer,
address_filter: Option<AddressFilter>,
}
impl_veilid_component_registry_accessor!(RawUdpProtocolHandler);
impl RawUdpProtocolHandler {
pub fn new(socket: Arc<UdpSocket>, address_filter: Option<AddressFilter>) -> Self {
pub fn new(registry: VeilidComponentRegistry, socket: Arc<UdpSocket>) -> Self {
Self {
registry,
socket,
assembly_buffer: AssemblyBuffer::new(),
address_filter,
}
}
@ -24,10 +25,12 @@ impl RawUdpProtocolHandler {
let (size, remote_addr) = network_result_value_or_log!(self.socket.recv_from(data).await.into_network_result()? => continue);
// Check to see if it is punished
if let Some(af) = self.address_filter.as_ref() {
if af.is_ip_addr_punished(remote_addr.ip()) {
continue;
}
if self
.network_manager()
.address_filter()
.is_ip_addr_punished(remote_addr.ip())
{
continue;
}
// Insert into assembly buffer
@ -91,10 +94,12 @@ impl RawUdpProtocolHandler {
}
// Check to see if it is punished
if let Some(af) = self.address_filter.as_ref() {
if af.is_ip_addr_punished(remote_addr.ip()) {
return Ok(NetworkResult::no_connection_other("punished"));
}
if self
.network_manager()
.address_filter()
.is_ip_addr_punished(remote_addr.ip())
{
return Ok(NetworkResult::no_connection_other("punished"));
}
// Fragment and send
@ -137,11 +142,13 @@ impl RawUdpProtocolHandler {
#[instrument(level = "trace", target = "protocol", err)]
pub async fn new_unspecified_bound_handler(
registry: VeilidComponentRegistry,
socket_addr: &SocketAddr,
) -> io::Result<RawUdpProtocolHandler> {
// get local wildcard address for bind
let local_socket_addr = compatible_unspecified_socket_addr(socket_addr);
let socket = UdpSocket::bind(local_socket_addr).await?;
Ok(RawUdpProtocolHandler::new(Arc::new(socket), None))
let socket = bind_async_udp_socket(local_socket_addr)?
.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?;
Ok(RawUdpProtocolHandler::new(registry, Arc::new(socket)))
}
}

View File

@ -10,7 +10,6 @@ use async_tungstenite::tungstenite::protocol::{frame::coding::CloseCode, CloseFr
use async_tungstenite::tungstenite::Error;
use async_tungstenite::{accept_hdr_async, client_async, WebSocketStream};
use futures_util::{AsyncRead, AsyncWrite, SinkExt};
use sockets::*;
// Maximum number of websocket request headers to permit
const MAX_WS_HEADERS: usize = 24;
@ -316,21 +315,16 @@ impl WebsocketProtocolHandler {
let domain = split_url.host.clone();
// Resolve remote address
let remote_socket_addr = dial_info.to_socket_addr();
// Make a shared socket
let socket = match local_address {
Some(a) => {
new_bound_shared_tcp_socket(a)?.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?
}
None => new_default_tcp_socket(socket2::Domain::for_address(remote_socket_addr))?,
};
let remote_address = dial_info.to_socket_addr();
// Non-blocking connect to remote address
let tcp_stream =
network_result_try!(nonblocking_connect(socket, remote_socket_addr, timeout_ms)
.await
.folded()?);
let tcp_stream = network_result_try!(connect_async_tcp_stream(
local_address,
remote_address,
timeout_ms
)
.await
.folded()?);
// See what local address we ended up with
let actual_local_addr = tcp_stream.local_addr()?;

View File

@ -140,14 +140,13 @@ impl Network {
#[instrument(level = "trace", skip_all)]
pub(super) async fn bind_udp_protocol_handlers(&self) -> EyreResult<StartupDisposition> {
log_net!("UDP: binding protocol handlers");
let (listen_address, public_address, detect_address_changes) = {
let c = self.config.get();
let (listen_address, public_address, detect_address_changes) = self.config().with(|c| {
(
c.network.protocol.udp.listen_address.clone(),
c.network.protocol.udp.public_address.clone(),
c.network.detect_address_changes,
)
};
});
// Get the binding parameters from the user-specified listen address
let bind_set = self
@ -187,18 +186,17 @@ impl Network {
#[instrument(level = "trace", skip_all)]
pub(super) async fn register_udp_dial_info(
&self,
editor_public_internet: &mut RoutingDomainEditorPublicInternet,
editor_local_network: &mut RoutingDomainEditorLocalNetwork,
editor_public_internet: &mut RoutingDomainEditorPublicInternet<'_>,
editor_local_network: &mut RoutingDomainEditorLocalNetwork<'_>,
) -> EyreResult<()> {
log_net!("UDP: registering dial info");
let (public_address, detect_address_changes) = {
let c = self.config.get();
let (public_address, detect_address_changes) = self.config().with(|c| {
(
c.network.protocol.udp.public_address.clone(),
c.network.detect_address_changes,
)
};
});
let local_dial_info_list = {
let mut out = vec![];
@ -263,14 +261,13 @@ impl Network {
#[instrument(level = "trace", skip_all)]
pub(super) async fn start_ws_listeners(&self) -> EyreResult<StartupDisposition> {
log_net!("WS: binding protocol handlers");
let (listen_address, url, detect_address_changes) = {
let c = self.config.get();
let (listen_address, url, detect_address_changes) = self.config().with(|c| {
(
c.network.protocol.ws.listen_address.clone(),
c.network.protocol.ws.url.clone(),
c.network.detect_address_changes,
)
};
});
// Get the binding parameters from the user-specified listen address
let bind_set = self
@ -313,18 +310,17 @@ impl Network {
#[instrument(level = "trace", skip_all)]
pub(super) async fn register_ws_dial_info(
&self,
editor_public_internet: &mut RoutingDomainEditorPublicInternet,
editor_local_network: &mut RoutingDomainEditorLocalNetwork,
editor_public_internet: &mut RoutingDomainEditorPublicInternet<'_>,
editor_local_network: &mut RoutingDomainEditorLocalNetwork<'_>,
) -> EyreResult<()> {
log_net!("WS: registering dial info");
let (url, path, detect_address_changes) = {
let c = self.config.get();
let (url, path, detect_address_changes) = self.config().with(|c| {
(
c.network.protocol.ws.url.clone(),
c.network.protocol.ws.path.clone(),
c.network.detect_address_changes,
)
};
});
let mut registered_addresses: HashSet<IpAddr> = HashSet::new();
@ -409,14 +405,13 @@ impl Network {
pub(super) async fn start_wss_listeners(&self) -> EyreResult<StartupDisposition> {
log_net!("WSS: binding protocol handlers");
let (listen_address, url, detect_address_changes) = {
let c = self.config.get();
let (listen_address, url, detect_address_changes) = self.config().with(|c| {
(
c.network.protocol.wss.listen_address.clone(),
c.network.protocol.wss.url.clone(),
c.network.detect_address_changes,
)
};
});
// Get the binding parameters from the user-specified listen address
let bind_set = self
@ -460,18 +455,17 @@ impl Network {
#[instrument(level = "trace", skip_all)]
pub(super) async fn register_wss_dial_info(
&self,
editor_public_internet: &mut RoutingDomainEditorPublicInternet,
editor_local_network: &mut RoutingDomainEditorLocalNetwork,
editor_public_internet: &mut RoutingDomainEditorPublicInternet<'_>,
editor_local_network: &mut RoutingDomainEditorLocalNetwork<'_>,
) -> EyreResult<()> {
log_net!("WSS: registering dialinfo");
let (url, _detect_address_changes) = {
let c = self.config.get();
let (url, _detect_address_changes) = self.config().with(|c| {
(
c.network.protocol.wss.url.clone(),
c.network.detect_address_changes,
)
};
});
// NOTE: No interface dial info for WSS, as there is no way to connect to a local dialinfo via TLS
// If the hostname is specified, it is the public dialinfo via the URL. If no hostname
@ -520,14 +514,13 @@ impl Network {
pub(super) async fn start_tcp_listeners(&self) -> EyreResult<StartupDisposition> {
log_net!("TCP: binding protocol handlers");
let (listen_address, public_address, detect_address_changes) = {
let c = self.config.get();
let (listen_address, public_address, detect_address_changes) = self.config().with(|c| {
(
c.network.protocol.tcp.listen_address.clone(),
c.network.protocol.tcp.public_address.clone(),
c.network.detect_address_changes,
)
};
});
// Get the binding parameters from the user-specified listen address
let bind_set = self
@ -570,18 +563,17 @@ impl Network {
#[instrument(level = "trace", skip_all)]
pub(super) async fn register_tcp_dial_info(
&self,
editor_public_internet: &mut RoutingDomainEditorPublicInternet,
editor_local_network: &mut RoutingDomainEditorLocalNetwork,
editor_public_internet: &mut RoutingDomainEditorPublicInternet<'_>,
editor_local_network: &mut RoutingDomainEditorLocalNetwork<'_>,
) -> EyreResult<()> {
log_net!("TCP: registering dialinfo");
let (public_address, detect_address_changes) = {
let c = self.config.get();
let (public_address, detect_address_changes) = self.config().with(|c| {
(
c.network.protocol.tcp.public_address.clone(),
c.network.detect_address_changes,
)
};
});
let mut registered_addresses: HashSet<IpAddr> = HashSet::new();

View File

@ -7,46 +7,41 @@ use super::*;
impl Network {
pub fn setup_tasks(&self) {
// Set update network class tick task
{
let this = self.clone();
self.unlocked_inner
.update_network_class_task
.set_routine(move |s, l, t| {
Box::pin(this.clone().update_network_class_task_routine(
s,
Timestamp::new(l),
Timestamp::new(t),
))
});
}
let this = self.clone();
self.update_network_class_task.set_routine(move |s, l, t| {
let this = this.clone();
Box::pin(async move {
this.update_network_class_task_routine(s, Timestamp::new(l), Timestamp::new(t))
.await
})
});
// Set network interfaces tick task
{
let this = self.clone();
self.unlocked_inner
.network_interfaces_task
.set_routine(move |s, l, t| {
Box::pin(this.clone().network_interfaces_task_routine(
s,
Timestamp::new(l),
Timestamp::new(t),
))
});
}
let this = self.clone();
self.network_interfaces_task.set_routine(move |s, l, t| {
let this = this.clone();
Box::pin(async move {
this.network_interfaces_task_routine(s, Timestamp::new(l), Timestamp::new(t))
.await
})
});
// Set upnp tick task
{
let this = self.clone();
self.unlocked_inner.upnp_task.set_routine(move |s, l, t| {
Box::pin(
this.clone()
.upnp_task_routine(s, Timestamp::new(l), Timestamp::new(t)),
)
self.upnp_task.set_routine(move |s, l, t| {
let this = this.clone();
Box::pin(async move {
this.upnp_task_routine(s, Timestamp::new(l), Timestamp::new(t))
.await
})
});
}
}
#[instrument(level = "trace", target = "net", name = "Network::tick", skip_all, err)]
pub async fn tick(&self) -> EyreResult<()> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else {
let Ok(_guard) = self.startup_lock.enter() else {
log_net!(debug "ignoring due to not started up");
return Ok(());
};
@ -65,7 +60,7 @@ impl Network {
// If we need to figure out our network class, tick the task for it
if detect_address_changes {
// Check our network interfaces to see if they have changed
self.unlocked_inner.network_interfaces_task.tick().await?;
self.network_interfaces_task.tick().await?;
// Check our public dial info to see if it has changed
let public_internet_network_class = self
@ -95,16 +90,31 @@ impl Network {
}
if has_at_least_two {
self.unlocked_inner.update_network_class_task.tick().await?;
self.update_network_class_task.tick().await?;
}
}
}
// If we need to tick upnp, do it
if upnp {
self.unlocked_inner.upnp_task.tick().await?;
self.upnp_task.tick().await?;
}
Ok(())
}
pub async fn cancel_tasks(&self) {
log_net!(debug "stopping upnp task");
if let Err(e) = self.upnp_task.stop().await {
warn!("upnp_task not stopped: {}", e);
}
log_net!(debug "stopping network interfaces task");
if let Err(e) = self.network_interfaces_task.stop().await {
warn!("network_interfaces_task not stopped: {}", e);
}
log_net!(debug "stopping update network class task");
if let Err(e) = self.update_network_class_task.stop().await {
warn!("update_network_class_task not stopped: {}", e);
}
}
}

View File

@ -3,20 +3,28 @@ use super::*;
impl Network {
#[instrument(level = "trace", target = "net", skip_all, err)]
pub(super) async fn network_interfaces_task_routine(
self,
_stop_token: StopToken,
&self,
stop_token: StopToken,
_l: Timestamp,
_t: Timestamp,
) -> EyreResult<()> {
let _guard = self.unlocked_inner.network_task_lock.lock().await;
// Network lock ensures only one task operating on the low level network state
// can happen at the same time.
let _guard = match self.network_task_lock.try_lock() {
Ok(v) => v,
Err(_) => {
// If we can't get the lock right now, then
return Ok(());
}
};
self.update_network_state().await?;
self.update_network_state(stop_token).await?;
Ok(())
}
// See if our interface addresses have changed, if so redo public dial info if necessary
async fn update_network_state(&self) -> EyreResult<bool> {
async fn update_network_state(&self, _stop_token: StopToken) -> EyreResult<bool> {
let mut local_network_changed = false;
let mut public_internet_changed = false;
@ -29,7 +37,7 @@ impl Network {
}
};
if new_network_state != last_network_state {
if last_network_state.is_none() || new_network_state != last_network_state.unwrap() {
// Save new network state
{
let mut inner = self.inner.lock();
@ -37,17 +45,13 @@ impl Network {
}
// network state has changed
let mut editor_local_network = self
.unlocked_inner
.routing_table
.edit_local_network_routing_domain();
let routing_table = self.routing_table();
let mut editor_local_network = routing_table.edit_local_network_routing_domain();
editor_local_network.set_local_networks(new_network_state.local_networks);
editor_local_network.clear_dial_info_details(None, None);
let mut editor_public_internet = self
.unlocked_inner
.routing_table
.edit_public_internet_routing_domain();
let mut editor_public_internet = routing_table.edit_public_internet_routing_domain();
// Update protocols
self.register_all_dial_info(&mut editor_public_internet, &mut editor_local_network)

View File

@ -8,12 +8,20 @@ type InboundProtocolMap = HashMap<(AddressType, LowLevelProtocolType, u16), Vec<
impl Network {
#[instrument(parent = None, level = "trace", skip(self), err)]
pub async fn update_network_class_task_routine(
self,
&self,
stop_token: StopToken,
l: Timestamp,
t: Timestamp,
) -> EyreResult<()> {
let _guard = self.unlocked_inner.network_task_lock.lock().await;
// Network lock ensures only one task operating on the low level network state
// can happen at the same time.
let _guard = match self.network_task_lock.try_lock() {
Ok(v) => v,
Err(_) => {
// If we can't get the lock right now, then
return Ok(());
}
};
// Do the public dial info check
let finished = self.do_public_dial_info_check(stop_token, l, t).await?;
@ -125,8 +133,9 @@ impl Network {
};
// Save off existing public dial info for change detection later
let existing_public_dial_info: HashSet<DialInfoDetail> = self
.routing_table()
let routing_table = self.routing_table();
let existing_public_dial_info: HashSet<DialInfoDetail> = routing_table
.all_filtered_dial_info_details(
RoutingDomain::PublicInternet.into(),
&DialInfoFilter::all(),
@ -135,7 +144,7 @@ impl Network {
.collect();
// Set most permissive network config and start from scratch
let mut editor = self.routing_table().edit_public_internet_routing_domain();
let mut editor = routing_table.edit_public_internet_routing_domain();
editor.setup_network(
protocol_config.outbound,
protocol_config.inbound,
@ -156,7 +165,7 @@ impl Network {
port,
};
context_configs.insert(dcc);
let discovery_context = DiscoveryContext::new(self.routing_table(), self.clone(), dcc);
let discovery_context = DiscoveryContext::new(self.registry(), dcc);
discovery_context.discover(&mut unord).await;
}
@ -247,22 +256,18 @@ impl Network {
match protocol_type {
ProtocolType::UDP => DialInfo::udp(addr),
ProtocolType::TCP => DialInfo::tcp(addr),
ProtocolType::WS => {
let c = self.config.get();
DialInfo::try_ws(
addr,
format!("ws://{}/{}", addr, c.network.protocol.ws.path),
)
.unwrap()
}
ProtocolType::WSS => {
let c = self.config.get();
DialInfo::try_wss(
addr,
format!("wss://{}/{}", addr, c.network.protocol.wss.path),
)
.unwrap()
}
ProtocolType::WS => DialInfo::try_ws(
addr,
self.config()
.with(|c| format!("ws://{}/{}", addr, c.network.protocol.ws.path)),
)
.unwrap(),
ProtocolType::WSS => DialInfo::try_wss(
addr,
self.config()
.with(|c| format!("wss://{}/{}", addr, c.network.protocol.wss.path)),
)
.unwrap(),
}
}
}

View File

@ -3,12 +3,12 @@ use super::*;
impl Network {
#[instrument(parent = None, level = "trace", target = "net", skip_all, err)]
pub(super) async fn upnp_task_routine(
self,
&self,
_stop_token: StopToken,
_l: Timestamp,
_t: Timestamp,
) -> EyreResult<()> {
if !self.unlocked_inner.igd_manager.tick().await? {
if !self.igd_manager.tick().await? {
info!("upnp failed, restarting local network");
let mut inner = self.inner.lock();
inner.network_needs_restart = true;

View File

@ -4,7 +4,7 @@ use std::{io, sync::Arc};
use stop_token::prelude::*;
cfg_if::cfg_if! {
if #[cfg(target_arch = "wasm32")] {
if #[cfg(all(target_arch = "wasm32", target_os = "unknown"))] {
// No accept support for WASM
} else {
@ -307,8 +307,7 @@ impl NetworkConnection {
flow
);
let network_manager = connection_manager.network_manager();
let address_filter = network_manager.address_filter();
let registry = connection_manager.registry();
let mut unord = FuturesUnordered::new();
let mut need_receiver = true;
let mut need_sender = true;
@ -364,14 +363,17 @@ impl NetworkConnection {
// Add another message receiver future if necessary
if need_receiver {
need_receiver = false;
let registry = registry.clone();
let receiver_fut = Self::recv_internal(&protocol_connection, stats.clone())
.then(|res| async {
let registry = registry;
let network_manager = registry.network_manager();
match res {
Ok(v) => {
let peer_address = protocol_connection.flow().remote();
// Check to see if it is punished
if address_filter.is_ip_addr_punished(peer_address.socket_addr().ip()) {
if network_manager.address_filter().is_ip_addr_punished(peer_address.socket_addr().ip()) {
return RecvLoopAction::Finish;
}
@ -383,7 +385,7 @@ impl NetworkConnection {
// Punish invalid framing (tcp framing or websocket framing)
if v.is_invalid_message() {
address_filter.punish_ip_addr(peer_address.socket_addr().ip(), PunishmentReason::InvalidFraming);
network_manager.address_filter().punish_ip_addr(peer_address.socket_addr().ip(), PunishmentReason::InvalidFraming);
return RecvLoopAction::Finish;
}

View File

@ -309,13 +309,7 @@ impl ReceiptManager {
Ok(())
}
pub async fn shutdown(&self) {
log_net!(debug "starting receipt manager shutdown");
let Ok(guard) = self.unlocked_inner.startup_lock.shutdown().await else {
log_net!(debug "receipt manager is already shut down");
return;
};
pub async fn cancel_tasks(&self) {
// Stop all tasks
let timeout_task = {
let mut inner = self.inner.lock();
@ -329,6 +323,14 @@ impl ReceiptManager {
if timeout_task.join().await.is_err() {
panic!("joining timeout task failed");
}
}
pub async fn shutdown(&self) {
log_net!(debug "starting receipt manager shutdown");
let Ok(guard) = self.unlocked_inner.startup_lock.shutdown().await else {
log_net!(debug "receipt manager is already shut down");
return;
};
*self.inner.lock() = Self::new_inner();

View File

@ -40,9 +40,10 @@ impl NetworkManager {
destination_node_ref: FilteredNodeRef,
data: Vec<u8>,
) -> SendPinBoxFuture<EyreResult<NetworkResult<SendDataMethod>>> {
let this = self.clone();
let registry = self.registry();
Box::pin(
async move {
let this = registry.network_manager();
// If we need to relay, do it
let (contact_method, target_node_ref, opt_relayed_contact_method) = match possibly_relayed_contact_method.clone() {
@ -652,17 +653,14 @@ impl NetworkManager {
data: Vec<u8>,
) -> EyreResult<NetworkResult<UniqueFlow>> {
// Detect if network is stopping so we can break out of this
let Some(stop_token) = self.unlocked_inner.startup_lock.stop_token() else {
let Some(stop_token) = self.startup_context.startup_lock.stop_token() else {
return Ok(NetworkResult::service_unavailable("network is stopping"));
};
// Build a return receipt for the signal
let receipt_timeout = TimestampDuration::new_ms(
self.unlocked_inner
.config
.get()
.network
.reverse_connection_receipt_time_ms as u64,
self.config()
.with(|c| c.network.reverse_connection_receipt_time_ms as u64),
);
let (receipt, eventual_value) = self.generate_single_shot_receipt(receipt_timeout, [])?;
@ -763,7 +761,7 @@ impl NetworkManager {
data: Vec<u8>,
) -> EyreResult<NetworkResult<UniqueFlow>> {
// Detect if network is stopping so we can break out of this
let Some(stop_token) = self.unlocked_inner.startup_lock.stop_token() else {
let Some(stop_token) = self.startup_context.startup_lock.stop_token() else {
return Ok(NetworkResult::service_unavailable("network is stopping"));
};
@ -776,11 +774,8 @@ impl NetworkManager {
// Build a return receipt for the signal
let receipt_timeout = TimestampDuration::new_ms(
self.unlocked_inner
.config
.get()
.network
.hole_punch_receipt_time_ms as u64,
self.config()
.with(|c| c.network.hole_punch_receipt_time_ms as u64),
);
let (receipt, eventual_value) = self.generate_single_shot_receipt(receipt_timeout, [])?;

View File

@ -1,7 +1,7 @@
use super::*;
// Statistics per address
#[derive(Clone, Default)]
#[derive(Clone, Debug, Default)]
pub struct PerAddressStats {
pub last_seen_ts: Timestamp,
pub transfer_stats_accounting: TransferStatsAccounting,
@ -18,7 +18,7 @@ impl Default for PerAddressStatsKey {
}
// Statistics about the low-level network
#[derive(Clone)]
#[derive(Debug, Clone)]
pub struct NetworkManagerStats {
pub self_stats: PerAddressStats,
pub per_address_stats: LruCache<PerAddressStatsKey, PerAddressStats>,
@ -116,12 +116,10 @@ impl NetworkManager {
})
}
pub(super) fn send_network_update(&self) {
let update_cb = self.unlocked_inner.update_callback.read().clone();
if update_cb.is_none() {
return;
}
pub fn send_network_update(&self) {
let update_cb = self.update_callback();
let state = self.get_veilid_state();
(update_cb.unwrap())(VeilidUpdate::Network(state));
update_cb(VeilidUpdate::Network(state));
}
}

View File

@ -5,48 +5,39 @@ use super::*;
impl NetworkManager {
pub fn setup_tasks(&self) {
// Set rolling transfers tick task
{
let this = self.clone();
self.unlocked_inner
.rolling_transfers_task
.set_routine(move |s, l, t| {
Box::pin(this.clone().rolling_transfers_task_routine(
s,
Timestamp::new(l),
Timestamp::new(t),
))
});
}
impl_setup_task!(
self,
Self,
rolling_transfers_task,
rolling_transfers_task_routine
);
// Set address filter task
{
let this = self.clone();
self.unlocked_inner
.address_filter_task
.set_routine(move |s, l, t| {
Box::pin(this.address_filter().address_filter_task_routine(
s,
Timestamp::new(l),
Timestamp::new(t),
))
});
let registry = self.registry();
self.address_filter_task.set_routine(move |s, l, t| {
let registry = registry.clone();
Box::pin(async move {
registry
.network_manager()
.address_filter()
.address_filter_task_routine(s, Timestamp::new(l), Timestamp::new(t))
.await
})
});
}
}
#[instrument(level = "trace", name = "NetworkManager::tick", skip_all, err)]
pub async fn tick(&self) -> EyreResult<()> {
let routing_table = self.routing_table();
let net = self.net();
let receipt_manager = self.receipt_manager();
// Run the rolling transfers task
self.unlocked_inner.rolling_transfers_task.tick().await?;
self.rolling_transfers_task.tick().await?;
// Run the address filter task
self.unlocked_inner.address_filter_task.tick().await?;
// Run the routing table tick
routing_table.tick().await?;
self.address_filter_task.tick().await?;
// Run the low level network tick
net.tick().await?;
@ -61,15 +52,21 @@ impl NetworkManager {
}
pub async fn cancel_tasks(&self) {
log_net!(debug "stopping receipt manager tasks");
let receipt_manager = self.receipt_manager();
receipt_manager.cancel_tasks().await;
let net = self.net();
net.cancel_tasks().await;
log_net!(debug "stopping rolling transfers task");
if let Err(e) = self.unlocked_inner.rolling_transfers_task.stop().await {
if let Err(e) = self.rolling_transfers_task.stop().await {
warn!("rolling_transfers_task not stopped: {}", e);
}
log_net!(debug "stopping routing table tasks");
let routing_table = self.routing_table();
routing_table.cancel_tasks().await;
// other tasks will get cancelled via the 'shutdown' mechanism
log_net!(debug "stopping address filter task");
if let Err(e) = self.address_filter_task.stop().await {
warn!("address_filter_task not stopped: {}", e);
}
}
}

View File

@ -4,7 +4,7 @@ impl NetworkManager {
// Compute transfer statistics for the low level network
#[instrument(level = "trace", skip(self), err)]
pub async fn rolling_transfers_task_routine(
self,
&self,
_stop_token: StopToken,
last_ts: Timestamp,
cur_ts: Timestamp,

View File

@ -1,13 +1,12 @@
use super::*;
use super::connection_table::*;
use crate::tests::common::test_veilid_config::*;
use crate::tests::mock_routing_table;
use crate::tests::mock_registry;
pub async fn test_add_get_remove() {
let config = get_config();
let address_filter = AddressFilter::new(config.clone(), mock_routing_table());
let table = ConnectionTable::new(config, address_filter);
let registry = mock_registry::init("").await;
let table = ConnectionTable::new(registry.clone());
let a1 = Flow::new_no_local(PeerAddress::new(
SocketAddress::new(Address::IPV4(Ipv4Addr::new(192, 168, 0, 1)), 8080),
@ -122,6 +121,8 @@ pub async fn test_add_get_remove() {
a4
);
assert_eq!(table.connection_count(), 0);
mock_registry::terminate(registry).await;
}
pub async fn test_all() {

View File

@ -30,7 +30,7 @@ pub async fn test_signed_node_info() {
// Test correct validation
let keypair = vcrypto.generate_keypair();
let sni = SignedDirectNodeInfo::make_signatures(
crypto.clone(),
&crypto,
vec![TypedKeyPair::new(ck, keypair)],
node_info.clone(),
)
@ -42,7 +42,7 @@ pub async fn test_signed_node_info() {
sni.timestamp(),
sni.signatures().to_vec(),
);
let tks_validated = sdni.validate(&tks, crypto.clone()).unwrap();
let tks_validated = sdni.validate(&tks, &crypto).unwrap();
assert_eq!(tks_validated.len(), oldtkslen);
assert_eq!(tks_validated.len(), sni.signatures().len());
@ -54,7 +54,7 @@ pub async fn test_signed_node_info() {
sni.timestamp(),
sni.signatures().to_vec(),
);
sdni.validate(&tks1, crypto.clone()).unwrap_err();
sdni.validate(&tks1, &crypto).unwrap_err();
// Test unsupported cryptosystem validation
let fake_crypto_kind: CryptoKind = FourCC::from([0, 1, 2, 3]);
@ -65,7 +65,7 @@ pub async fn test_signed_node_info() {
tksfake.add(TypedKey::new(ck, keypair.key));
let sdnifake =
SignedDirectNodeInfo::new(node_info.clone(), sni.timestamp(), sigsfake.clone());
let tksfake_validated = sdnifake.validate(&tksfake, crypto.clone()).unwrap();
let tksfake_validated = sdnifake.validate(&tksfake, &crypto).unwrap();
assert_eq!(tksfake_validated.len(), 1);
assert_eq!(sdnifake.signatures().len(), sigsfake.len());
@ -89,7 +89,7 @@ pub async fn test_signed_node_info() {
let oldtks2len = tks2.len();
let sni2 = SignedRelayedNodeInfo::make_signatures(
crypto.clone(),
&crypto,
vec![TypedKeyPair::new(ck, keypair2)],
node_info2.clone(),
tks.clone(),
@ -103,7 +103,7 @@ pub async fn test_signed_node_info() {
sni2.timestamp(),
sni2.signatures().to_vec(),
);
let tks2_validated = srni.validate(&tks2, crypto.clone()).unwrap();
let tks2_validated = srni.validate(&tks2, &crypto).unwrap();
assert_eq!(tks2_validated.len(), oldtks2len);
assert_eq!(tks2_validated.len(), sni2.signatures().len());
@ -119,7 +119,7 @@ pub async fn test_signed_node_info() {
sni2.timestamp(),
sni2.signatures().to_vec(),
);
srni.validate(&tks3, crypto.clone()).unwrap_err();
srni.validate(&tks3, &crypto).unwrap_err();
// Test unsupported cryptosystem validation
let fake_crypto_kind: CryptoKind = FourCC::from([0, 1, 2, 3]);
@ -135,7 +135,7 @@ pub async fn test_signed_node_info() {
sni2.timestamp(),
sigsfake3.clone(),
);
let tksfake3_validated = srnifake.validate(&tksfake3, crypto.clone()).unwrap();
let tksfake3_validated = srnifake.validate(&tksfake3, &crypto).unwrap();
assert_eq!(tksfake3_validated.len(), 1);
assert_eq!(srnifake.signatures().len(), sigsfake3.len());
}

View File

@ -20,7 +20,7 @@ impl Address {
SocketAddr::V6(v6) => Address::IPV6(*v6.ip()),
}
}
#[cfg_attr(target_arch = "wasm32", expect(dead_code))]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), expect(dead_code))]
pub fn from_ip_addr(addr: IpAddr) -> Address {
match addr {
IpAddr::V4(v4) => Address::IPV4(v4),

View File

@ -268,7 +268,7 @@ impl DialInfo {
Self::WSS(di) => di.socket_address.port(),
}
}
#[cfg_attr(target_arch = "wasm32", expect(dead_code))]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), expect(dead_code))]
pub fn set_port(&mut self, port: u16) {
match self {
Self::UDP(di) => di.socket_address.set_port(port),
@ -366,7 +366,7 @@ impl DialInfo {
// This will not be used on signed dialinfo, only for bootstrapping, so we don't need to worry about
// the '0.0.0.0' address being propagated across the routing table
cfg_if::cfg_if! {
if #[cfg(target_arch = "wasm32")] {
if #[cfg(all(target_arch = "wasm32", target_os = "unknown"))] {
vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0,0,0,0)), port)]
} else {
match split_url.host {

View File

@ -21,7 +21,7 @@ pub(crate) enum SignalInfo {
}
impl SignalInfo {
pub fn validate(&self, crypto: Crypto) -> Result<(), RPCError> {
pub fn validate(&self, crypto: &Crypto) -> Result<(), RPCError> {
match self {
SignalInfo::HolePunch { receipt, peer_info } => {
if receipt.len() < MIN_RECEIPT_SIZE {

View File

@ -1,2 +1,2 @@
[build]
target = "wasm32-unknown-unknown"
all(target_arch = "wasm32", target_os = "unknown")

View File

@ -3,8 +3,6 @@ mod protocol;
use super::*;
use crate::routing_table::*;
use connection_manager::*;
use protocol::ws::WebsocketProtocolHandler;
pub use protocol::*;
use std::io;
@ -64,23 +62,28 @@ struct NetworkInner {
protocol_config: ProtocolConfig,
}
struct NetworkUnlockedInner {
pub(super) struct NetworkUnlockedInner {
// Startup lock
startup_lock: StartupLock,
// Accessors
routing_table: RoutingTable,
network_manager: NetworkManager,
connection_manager: ConnectionManager,
}
#[derive(Clone)]
pub(super) struct Network {
config: VeilidConfig,
registry: VeilidComponentRegistry,
inner: Arc<Mutex<NetworkInner>>,
unlocked_inner: Arc<NetworkUnlockedInner>,
}
impl_veilid_component_registry_accessor!(Network);
impl core::ops::Deref for Network {
type Target = NetworkUnlockedInner;
fn deref(&self) -> &Self::Target {
&self.unlocked_inner
}
}
impl Network {
fn new_inner() -> NetworkInner {
NetworkInner {
@ -89,45 +92,20 @@ impl Network {
}
}
fn new_unlocked_inner(
network_manager: NetworkManager,
routing_table: RoutingTable,
connection_manager: ConnectionManager,
) -> NetworkUnlockedInner {
fn new_unlocked_inner() -> NetworkUnlockedInner {
NetworkUnlockedInner {
startup_lock: StartupLock::new(),
network_manager,
routing_table,
connection_manager,
}
}
pub fn new(
network_manager: NetworkManager,
routing_table: RoutingTable,
connection_manager: ConnectionManager,
) -> Self {
pub fn new(registry: VeilidComponentRegistry) -> Self {
Self {
config: network_manager.config(),
registry,
inner: Arc::new(Mutex::new(Self::new_inner())),
unlocked_inner: Arc::new(Self::new_unlocked_inner(
network_manager,
routing_table,
connection_manager,
)),
unlocked_inner: Arc::new(Self::new_unlocked_inner()),
}
}
fn network_manager(&self) -> NetworkManager {
self.unlocked_inner.network_manager.clone()
}
fn routing_table(&self) -> RoutingTable {
self.unlocked_inner.routing_table.clone()
}
fn connection_manager(&self) -> ConnectionManager {
self.unlocked_inner.connection_manager.clone()
}
/////////////////////////////////////////////////////////////////
// Record DialInfo failures
@ -159,10 +137,9 @@ impl Network {
self.record_dial_info_failure(dial_info.clone(), async move {
let data_len = data.len();
let timeout_ms = {
let c = self.config.get();
c.network.connection_initial_timeout_ms
};
let timeout_ms = self
.config()
.with(|c| c.network.connection_initial_timeout_ms);
if self
.network_manager()
@ -180,7 +157,7 @@ impl Network {
bail!("no support for TCP protocol")
}
ProtocolType::WS | ProtocolType::WSS => {
let pnc = network_result_try!(WebsocketProtocolHandler::connect(
let pnc = network_result_try!(ws::WebsocketProtocolHandler::connect(
&dial_info, timeout_ms
)
.await
@ -210,14 +187,13 @@ impl Network {
data: Vec<u8>,
timeout_ms: u32,
) -> EyreResult<NetworkResult<Vec<u8>>> {
let _guard = self.unlocked_inner.startup_lock.enter()?;
let _guard = self.startup_lock.enter()?;
self.record_dial_info_failure(dial_info.clone(), async move {
let data_len = data.len();
let connect_timeout_ms = {
let c = self.config.get();
c.network.connection_initial_timeout_ms
};
let connect_timeout_ms = self
.config()
.with(|c| c.network.connection_initial_timeout_ms);
if self
.network_manager()
@ -239,7 +215,7 @@ impl Network {
ProtocolType::UDP => unreachable!(),
ProtocolType::TCP => unreachable!(),
ProtocolType::WS | ProtocolType::WSS => {
WebsocketProtocolHandler::connect(&dial_info, connect_timeout_ms)
ws::WebsocketProtocolHandler::connect(&dial_info, connect_timeout_ms)
.await
.wrap_err("connect failure")?
}
@ -271,7 +247,7 @@ impl Network {
flow: Flow,
data: Vec<u8>,
) -> EyreResult<SendDataToExistingFlowResult> {
let _guard = self.unlocked_inner.startup_lock.enter()?;
let _guard = self.startup_lock.enter()?;
let data_len = data.len();
match flow.protocol_type() {
@ -287,7 +263,11 @@ impl Network {
// Handle connection-oriented protocols
// Try to send to the exact existing connection if one exists
if let Some(conn) = self.connection_manager().get_connection(flow) {
if let Some(conn) = self
.network_manager()
.connection_manager()
.get_connection(flow)
{
// connection exists, send over it
match conn.send_async(data).await {
ConnectionHandleSendResult::Sent => {
@ -320,7 +300,7 @@ impl Network {
dial_info: DialInfo,
data: Vec<u8>,
) -> EyreResult<NetworkResult<UniqueFlow>> {
let _guard = self.unlocked_inner.startup_lock.enter()?;
let _guard = self.startup_lock.enter()?;
self.record_dial_info_failure(dial_info.clone(), async move {
let data_len = data.len();
@ -333,7 +313,8 @@ impl Network {
// Handle connection-oriented protocols
let conn = network_result_try!(
self.connection_manager()
self.network_manager()
.connection_manager()
.get_or_create_connection(dial_info.clone())
.await?
);
@ -361,7 +342,8 @@ impl Network {
log_net!(debug "starting network");
// get protocol config
let protocol_config = {
let c = self.config.get();
let config = self.config();
let c = config.get();
let inbound = ProtocolTypeSet::new();
let mut outbound = ProtocolTypeSet::new();
@ -398,10 +380,8 @@ impl Network {
self.inner.lock().protocol_config = protocol_config.clone();
// Start editing routing table
let mut editor_public_internet = self
.unlocked_inner
.routing_table
.edit_public_internet_routing_domain();
let routing_table = self.routing_table();
let mut editor_public_internet = routing_table.edit_public_internet_routing_domain();
// set up the routing table's network config
editor_public_internet.setup_network(
@ -421,7 +401,7 @@ impl Network {
#[instrument(level = "debug", err, skip_all)]
pub async fn startup(&self) -> EyreResult<StartupDisposition> {
let guard = self.unlocked_inner.startup_lock.startup()?;
let guard = self.startup_lock.startup()?;
match self.startup_internal().await {
Ok(StartupDisposition::Success) => {
@ -445,7 +425,7 @@ impl Network {
}
pub fn is_started(&self) -> bool {
self.unlocked_inner.startup_lock.is_started()
self.startup_lock.is_started()
}
#[instrument(level = "debug", skip_all)]
@ -456,7 +436,7 @@ impl Network {
#[instrument(level = "debug", skip_all)]
pub async fn shutdown(&self) {
log_net!(debug "starting low level network shutdown");
let Ok(guard) = self.unlocked_inner.startup_lock.shutdown().await else {
let Ok(guard) = self.startup_lock.shutdown().await else {
log_net!(debug "low level network is already shut down");
return;
};
@ -493,14 +473,14 @@ impl Network {
&self,
_punishment: Option<Box<dyn FnOnce() + Send + 'static>>,
) {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else {
let Ok(_guard) = self.startup_lock.enter() else {
log_net!(debug "ignoring due to not started up");
return;
};
}
pub fn needs_public_dial_info_check(&self) -> bool {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else {
let Ok(_guard) = self.startup_lock.enter() else {
log_net!(debug "ignoring due to not started up");
return false;
};
@ -511,11 +491,12 @@ impl Network {
//////////////////////////////////////////
#[instrument(level = "trace", target = "net", name = "Network::tick", skip_all, err)]
pub async fn tick(&self) -> EyreResult<()> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else {
let Ok(_guard) = self.startup_lock.enter() else {
log_net!(debug "ignoring due to not started up");
return Ok(());
};
Ok(())
}
pub async fn cancel_tasks(&self) {}
}

View File

@ -16,7 +16,7 @@ impl ProtocolNetworkConnection {
_local_address: Option<SocketAddr>,
dial_info: &DialInfo,
timeout_ms: u32,
address_filter: AddressFilter,
address_filter: &AddressFilter,
) -> io::Result<NetworkResult<ProtocolNetworkConnection>> {
if address_filter.is_ip_addr_punished(dial_info.address().ip_addr()) {
return Ok(NetworkResult::no_connection_other("punished"));

View File

@ -9,29 +9,6 @@ struct WebsocketNetworkConnectionInner {
ws_stream: CloneStream<WsStream>,
}
fn to_io(err: WsErr) -> io::Error {
match err {
WsErr::InvalidWsState { supplied: _ } => {
io::Error::new(io::ErrorKind::InvalidInput, err.to_string())
}
WsErr::ConnectionNotOpen => io::Error::new(io::ErrorKind::NotConnected, err.to_string()),
WsErr::InvalidUrl { supplied: _ } => {
io::Error::new(io::ErrorKind::InvalidInput, err.to_string())
}
WsErr::InvalidCloseCode { supplied: _ } => {
io::Error::new(io::ErrorKind::InvalidInput, err.to_string())
}
WsErr::ReasonStringToLong => io::Error::new(io::ErrorKind::InvalidInput, err.to_string()),
WsErr::ConnectionFailed { event: _ } => {
io::Error::new(io::ErrorKind::ConnectionRefused, err.to_string())
}
WsErr::InvalidEncoding => io::Error::new(io::ErrorKind::InvalidInput, err.to_string()),
WsErr::CantDecodeBlob => io::Error::new(io::ErrorKind::InvalidInput, err.to_string()),
WsErr::UnknownDataType => io::Error::new(io::ErrorKind::InvalidInput, err.to_string()),
_ => io::Error::new(io::ErrorKind::Other, err.to_string()),
}
}
#[derive(Clone)]
pub struct WebsocketNetworkConnection {
flow: Flow,
@ -65,7 +42,7 @@ impl WebsocketNetworkConnection {
)]
pub async fn close(&self) -> io::Result<NetworkResult<()>> {
#[allow(unused_variables)]
let x = self.inner.ws_meta.close().await.map_err(to_io);
let x = self.inner.ws_meta.close().await.map_err(ws_err_to_io_error);
#[cfg(feature = "verbose-tracing")]
log_net!(debug "close result: {:?}", x);
Ok(NetworkResult::value(()))
@ -83,7 +60,7 @@ impl WebsocketNetworkConnection {
.send(WsMessage::Binary(message)),
)
.await
.map_err(to_io)
.map_err(ws_err_to_io_error)
.into_network_result()?;
#[cfg(feature = "verbose-tracing")]
@ -140,7 +117,9 @@ impl WebsocketProtocolHandler {
}
let fut = SendWrapper::new(timeout(timeout_ms, async move {
WsMeta::connect(request, None).await.map_err(to_io)
WsMeta::connect(request, None)
.await
.map_err(ws_err_to_io_error)
}));
let (wsmeta, wsio) = network_result_try!(network_result_try!(fut

View File

@ -640,7 +640,7 @@ impl BucketEntryInner {
only_live: bool,
filter: NodeRefFilter,
) -> Vec<(Flow, Timestamp)> {
let opt_connection_manager = rti.unlocked_inner.network_manager.opt_connection_manager();
let opt_connection_manager = rti.network_manager().opt_connection_manager();
let mut out: Vec<(Flow, Timestamp)> = self
.last_flows

View File

@ -35,7 +35,6 @@ impl RoutingTable {
let valid_envelope_versions = VALID_ENVELOPE_VERSIONS.map(|x| x.to_string()).join(",");
let node_ids = self
.unlocked_inner
.node_ids()
.iter()
.map(|x| x.to_string())
@ -57,7 +56,7 @@ impl RoutingTable {
pub fn debug_info_nodeid(&self) -> String {
let mut out = String::new();
for nid in self.unlocked_inner.node_ids().iter() {
for nid in self.node_ids().iter() {
out += &format!("{}\n", nid);
}
out
@ -66,7 +65,7 @@ impl RoutingTable {
pub fn debug_info_nodeinfo(&self) -> String {
let mut out = String::new();
let inner = self.inner.read();
out += &format!("Node Ids: {}\n", self.unlocked_inner.node_ids());
out += &format!("Node Ids: {}\n", self.node_ids());
out += &format!(
"Self Transfer Stats:\n{}",
indent_all_string(&inner.self_transfer_stats)
@ -250,7 +249,7 @@ impl RoutingTable {
out += &format!("{:?}: {}: {}\n", routing_domain, crypto_kind, count);
}
for ck in &VALID_CRYPTO_KINDS {
let our_node_id = self.unlocked_inner.node_id(*ck);
let our_node_id = self.node_id(*ck);
let mut filtered_total = 0;
let mut b = 0;
@ -319,7 +318,7 @@ impl RoutingTable {
) -> String {
let cur_ts = Timestamp::now();
let relay_node_filter = self.make_public_internet_relay_node_filter();
let our_node_ids = self.unlocked_inner.node_ids();
let our_node_ids = self.node_ids();
let mut relay_count = 0usize;
let mut relaying_count = 0usize;
@ -340,7 +339,7 @@ impl RoutingTable {
node_count,
filters,
|_rti, entry: Option<Arc<BucketEntry>>| {
NodeRef::new(self.clone(), entry.unwrap().clone())
NodeRef::new(self.registry(), entry.unwrap().clone())
},
);
let mut out = String::new();
@ -376,9 +375,10 @@ impl RoutingTable {
relaying_count += 1;
}
let best_node_id = node.best_node_id();
out += " ";
out += &node
.operate(|_rti, e| Self::format_entry(cur_ts, node.best_node_id(), e, &relay_tag));
out += &node.operate(|_rti, e| Self::format_entry(cur_ts, best_node_id, e, &relay_tag));
out += "\n";
}

View File

@ -42,10 +42,9 @@ impl RoutingTable {
) as RoutingTableEntryFilter;
let filters = VecDeque::from([filter]);
let node_count = {
let c = self.config.get();
c.network.dht.max_find_node_count as usize
};
let node_count = self
.config()
.with(|c| c.network.dht.max_find_node_count as usize);
let closest_nodes = match self.find_preferred_closest_nodes(
node_count,
@ -82,11 +81,13 @@ impl RoutingTable {
// find N nodes closest to the target node in our routing table
// ensure the nodes returned are only the ones closer to the target node than ourself
let Some(vcrypto) = self.crypto().get(crypto_kind) else {
let crypto = self.crypto();
let Some(vcrypto) = crypto.get(crypto_kind) else {
return NetworkResult::invalid_message("unsupported cryptosystem");
};
let vcrypto = &vcrypto;
let own_distance = vcrypto.distance(&own_node_id.value, &key.value);
let vcrypto2 = vcrypto.clone();
let filter = Box::new(
move |rti: &RoutingTableInner, opt_entry: Option<Arc<BucketEntry>>| {
@ -121,10 +122,9 @@ impl RoutingTable {
) as RoutingTableEntryFilter;
let filters = VecDeque::from([filter]);
let node_count = {
let c = self.config.get();
c.network.dht.max_find_node_count as usize
};
let node_count = self
.config()
.with(|c| c.network.dht.max_find_node_count as usize);
//
let closest_nodes = match self.find_preferred_closest_nodes(
@ -147,7 +147,7 @@ impl RoutingTable {
// Validate peers returned are, in fact, closer to the key than the node we sent this to
// This same test is used on the other side so we vet things here
let valid = match Self::verify_peers_closer(vcrypto2, own_node_id, key, &closest_nodes) {
let valid = match Self::verify_peers_closer(vcrypto, own_node_id, key, &closest_nodes) {
Ok(v) => v,
Err(e) => {
panic!("missing cryptosystem in peers node ids: {}", e);
@ -166,7 +166,7 @@ impl RoutingTable {
/// Determine if set of peers is closer to key_near than key_far is to key_near
#[instrument(level = "trace", target = "rtab", skip_all, err)]
pub fn verify_peers_closer(
vcrypto: CryptoSystemVersion,
vcrypto: &crypto::CryptoSystemGuard<'_>,
key_far: TypedKey,
key_near: TypedKey,
peers: &[Arc<PeerInfo>],

View File

@ -91,16 +91,12 @@ pub struct RecentPeersEntry {
pub last_connection: Flow,
}
pub(crate) struct RoutingTableUnlockedInner {
// Accessors
event_bus: EventBus,
config: VeilidConfig,
network_manager: NetworkManager,
pub(crate) struct RoutingTable {
registry: VeilidComponentRegistry,
inner: RwLock<RoutingTableInner>,
/// The current node's public DHT keys
node_id: TypedKeyGroup,
/// The current node's public DHT secrets
node_id_secret: TypedSecretGroup,
/// Route spec store
route_spec_store: RouteSpecStore,
/// Buckets to kick on our next kick task
kick_queue: Mutex<BTreeSet<BucketIndex>>,
/// Background process for computing statistics
@ -131,103 +127,27 @@ pub(crate) struct RoutingTableUnlockedInner {
private_route_management_task: TickTask<EyreReport>,
}
impl RoutingTableUnlockedInner {
pub fn network_manager(&self) -> NetworkManager {
self.network_manager.clone()
}
pub fn crypto(&self) -> Crypto {
self.network_manager().crypto()
}
pub fn rpc_processor(&self) -> RPCProcessor {
self.network_manager().rpc_processor()
}
pub fn update_callback(&self) -> UpdateCallback {
self.network_manager().update_callback()
}
pub fn with_config<F, R>(&self, f: F) -> R
where
F: FnOnce(&VeilidConfigInner) -> R,
{
f(&self.config.get())
}
pub fn node_id(&self, kind: CryptoKind) -> TypedKey {
self.node_id.get(kind).unwrap()
}
pub fn node_id_secret_key(&self, kind: CryptoKind) -> SecretKey {
self.node_id_secret.get(kind).unwrap().value
}
pub fn node_ids(&self) -> TypedKeyGroup {
self.node_id.clone()
}
pub fn node_id_typed_key_pairs(&self) -> Vec<TypedKeyPair> {
let mut tkps = Vec::new();
for ck in VALID_CRYPTO_KINDS {
tkps.push(TypedKeyPair::new(
ck,
KeyPair::new(self.node_id(ck).value, self.node_id_secret_key(ck)),
));
}
tkps
}
pub fn matches_own_node_id(&self, node_ids: &[TypedKey]) -> bool {
for ni in node_ids {
if let Some(v) = self.node_id.get(ni.kind) {
if v.value == ni.value {
return true;
}
}
}
false
}
pub fn matches_own_node_id_key(&self, node_id_key: &PublicKey) -> bool {
for tk in self.node_id.iter() {
if tk.value == *node_id_key {
return true;
}
}
false
}
pub fn calculate_bucket_index(&self, node_id: &TypedKey) -> BucketIndex {
let crypto = self.crypto();
let self_node_id_key = self.node_id(node_id.kind).value;
let vcrypto = crypto.get(node_id.kind).unwrap();
(
node_id.kind,
vcrypto
.distance(&node_id.value, &self_node_id_key)
.first_nonzero_bit()
.unwrap(),
)
impl fmt::Debug for RoutingTable {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RoutingTable")
// .field("inner", &self.inner)
// .field("unlocked_inner", &self.unlocked_inner)
.finish()
}
}
#[derive(Clone)]
pub(crate) struct RoutingTable {
inner: Arc<RwLock<RoutingTableInner>>,
unlocked_inner: Arc<RoutingTableUnlockedInner>,
}
impl_veilid_component!(RoutingTable);
impl RoutingTable {
fn new_unlocked_inner(
event_bus: EventBus,
config: VeilidConfig,
network_manager: NetworkManager,
) -> RoutingTableUnlockedInner {
pub fn new(registry: VeilidComponentRegistry) -> Self {
let config = registry.config();
let c = config.get();
RoutingTableUnlockedInner {
event_bus,
config: config.clone(),
network_manager,
node_id: c.network.routing_table.node_id.clone(),
node_id_secret: c.network.routing_table.node_id_secret.clone(),
let inner = RwLock::new(RoutingTableInner::new(registry.clone()));
let route_spec_store = RouteSpecStore::new(registry.clone());
let this = Self {
registry,
inner,
route_spec_store,
kick_queue: Mutex::new(BTreeSet::default()),
rolling_transfers_task: TickTask::new(
"rolling_transfers_task",
@ -269,16 +189,6 @@ impl RoutingTable {
"private_route_management_task",
PRIVATE_ROUTE_MANAGEMENT_INTERVAL_SECS,
),
}
}
pub fn new(network_manager: NetworkManager) -> Self {
let event_bus = network_manager.event_bus();
let config = network_manager.config();
let unlocked_inner = Arc::new(Self::new_unlocked_inner(event_bus, config, network_manager));
let inner = Arc::new(RwLock::new(RoutingTableInner::new(unlocked_inner.clone())));
let this = Self {
inner,
unlocked_inner,
};
this.setup_tasks();
@ -290,7 +200,7 @@ impl RoutingTable {
/// Initialization
/// Called to initialize the routing table after it is created
pub async fn init(&self) -> EyreResult<()> {
async fn init_async(&self) -> EyreResult<()> {
log_rtab!(debug "starting routing table init");
// Set up routing buckets
@ -309,42 +219,35 @@ impl RoutingTable {
// Set up routespecstore
log_rtab!(debug "starting route spec store init");
let route_spec_store = match RouteSpecStore::load(self.clone()).await {
Ok(v) => v,
Err(e) => {
log_rtab!(debug "Error loading route spec store: {:#?}. Resetting.", e);
RouteSpecStore::new(self.clone())
}
if let Err(e) = self.route_spec_store().load().await {
log_rtab!(debug "Error loading route spec store: {:#?}. Resetting.", e);
self.route_spec_store().reset();
};
log_rtab!(debug "finished route spec store init");
{
let mut inner = self.inner.write();
inner.route_spec_store = Some(route_spec_store);
}
// Inform storage manager we are up
self.network_manager
.storage_manager()
.set_routing_table(Some(self.clone()))
.await;
log_rtab!(debug "finished routing table init");
Ok(())
}
/// Called to shut down the routing table
pub async fn terminate(&self) {
log_rtab!(debug "starting routing table terminate");
async fn post_init_async(&self) -> EyreResult<()> {
Ok(())
}
// Stop storage manager from using us
self.network_manager
.storage_manager()
.set_routing_table(None)
.await;
pub(crate) async fn startup(&self) -> EyreResult<()> {
Ok(())
}
pub(crate) async fn shutdown(&self) {
// Stop tasks
log_net!(debug "stopping routing table tasks");
self.cancel_tasks().await;
}
async fn pre_terminate_async(&self) {}
/// Called to shut down the routing table
async fn terminate_async(&self) {
log_rtab!(debug "starting routing table terminate");
// Load bucket entries from table db if possible
log_rtab!(debug "saving routing table entries");
@ -365,11 +268,73 @@ impl RoutingTable {
log_rtab!(debug "shutting down routing table");
let mut inner = self.inner.write();
*inner = RoutingTableInner::new(self.unlocked_inner.clone());
*inner = RoutingTableInner::new(self.registry());
log_rtab!(debug "finished routing table terminate");
}
///////////////////////////////////////////////////////////////////
pub fn node_id(&self, kind: CryptoKind) -> TypedKey {
self.config()
.with(|c| c.network.routing_table.node_id.get(kind).unwrap())
}
pub fn node_id_secret_key(&self, kind: CryptoKind) -> SecretKey {
self.config()
.with(|c| c.network.routing_table.node_id_secret.get(kind).unwrap())
.value
}
pub fn node_ids(&self) -> TypedKeyGroup {
self.config()
.with(|c| c.network.routing_table.node_id.clone())
}
pub fn node_id_typed_key_pairs(&self) -> Vec<TypedKeyPair> {
let mut tkps = Vec::new();
for ck in VALID_CRYPTO_KINDS {
tkps.push(TypedKeyPair::new(
ck,
KeyPair::new(self.node_id(ck).value, self.node_id_secret_key(ck)),
));
}
tkps
}
pub fn matches_own_node_id(&self, node_ids: &[TypedKey]) -> bool {
for ni in node_ids {
if let Some(v) = self.node_ids().get(ni.kind) {
if v.value == ni.value {
return true;
}
}
}
false
}
pub fn matches_own_node_id_key(&self, node_id_key: &PublicKey) -> bool {
for tk in self.node_ids().iter() {
if tk.value == *node_id_key {
return true;
}
}
false
}
pub fn calculate_bucket_index(&self, node_id: &TypedKey) -> BucketIndex {
let crypto = self.crypto();
let self_node_id_key = self.node_id(node_id.kind).value;
let vcrypto = crypto.get(node_id.kind).unwrap();
(
node_id.kind,
vcrypto
.distance(&node_id.value, &self_node_id_key)
.first_nonzero_bit()
.unwrap(),
)
}
/// Serialize the routing table.
fn serialized_buckets(&self) -> (SerializedBucketMap, SerializedBuckets) {
// Since entries are shared by multiple buckets per cryptokind
@ -406,7 +371,7 @@ impl RoutingTable {
async fn save_buckets(&self) -> EyreResult<()> {
let (serialized_bucket_map, all_entry_bytes) = self.serialized_buckets();
let table_store = self.unlocked_inner.network_manager().table_store();
let table_store = self.table_store();
let tdb = table_store.open(ROUTING_TABLE, 1).await?;
let dbx = tdb.transact();
if let Err(e) = dbx.store_json(0, SERIALIZED_BUCKET_MAP, &serialized_bucket_map) {
@ -420,12 +385,14 @@ impl RoutingTable {
dbx.commit().await?;
Ok(())
}
/// Deserialize routing table from table store
async fn load_buckets(&self) -> EyreResult<()> {
// Make a cache validity key of all our node ids and our bootstrap choice
let mut cache_validity_key: Vec<u8> = Vec::new();
{
let c = self.unlocked_inner.config.get();
let config = self.config();
let c = config.get();
for ck in VALID_CRYPTO_KINDS {
if let Some(nid) = c.network.routing_table.node_id.get(ck) {
cache_validity_key.append(&mut nid.value.bytes.to_vec());
@ -446,7 +413,7 @@ impl RoutingTable {
};
// Deserialize bucket map and all entries from the table store
let table_store = self.unlocked_inner.network_manager().table_store();
let table_store = self.table_store();
let db = table_store.open(ROUTING_TABLE, 1).await?;
let caches_valid = match db.load(0, CACHE_VALIDITY_KEY).await? {
@ -479,14 +446,13 @@ impl RoutingTable {
// Reconstruct all entries
let inner = &mut *self.inner.write();
self.populate_routing_table(inner, serialized_bucket_map, all_entry_bytes)?;
Self::populate_routing_table_inner(inner, serialized_bucket_map, all_entry_bytes)?;
Ok(())
}
/// Write the deserialized table store data to the routing table.
pub fn populate_routing_table(
&self,
pub fn populate_routing_table_inner(
inner: &mut RoutingTableInner,
serialized_bucket_map: SerializedBucketMap,
all_entry_bytes: SerializedBuckets,
@ -542,8 +508,8 @@ impl RoutingTable {
self.inner.read().routing_domain_for_address(address)
}
pub fn route_spec_store(&self) -> RouteSpecStore {
self.inner.read().route_spec_store.as_ref().unwrap().clone()
pub fn route_spec_store(&self) -> &RouteSpecStore {
&self.route_spec_store
}
pub fn relay_node(&self, domain: RoutingDomain) -> Option<FilteredNodeRef> {
@ -600,12 +566,12 @@ impl RoutingTable {
/// Edit the PublicInternet RoutingDomain
pub fn edit_public_internet_routing_domain(&self) -> RoutingDomainEditorPublicInternet {
RoutingDomainEditorPublicInternet::new(self.clone())
RoutingDomainEditorPublicInternet::new(self)
}
/// Edit the LocalNetwork RoutingDomain
pub fn edit_local_network_routing_domain(&self) -> RoutingDomainEditorLocalNetwork {
RoutingDomainEditorLocalNetwork::new(self.clone())
RoutingDomainEditorLocalNetwork::new(self)
}
/// Return a copy of our node's peerinfo (may not yet be published)
@ -619,7 +585,7 @@ impl RoutingTable {
}
/// Return the domain's currently registered network class
#[cfg_attr(target_arch = "wasm32", expect(dead_code))]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), expect(dead_code))]
pub fn get_network_class(&self, routing_domain: RoutingDomain) -> Option<NetworkClass> {
self.inner.read().get_network_class(routing_domain)
}
@ -656,7 +622,7 @@ impl RoutingTable {
) -> Vec<FilteredNodeRef> {
self.inner
.read()
.get_nodes_needing_ping(self.clone(), routing_domain, cur_ts)
.get_nodes_needing_ping(routing_domain, cur_ts)
}
fn queue_bucket_kicks(&self, node_ids: TypedKeyGroup) {
@ -667,21 +633,19 @@ impl RoutingTable {
}
// Put it in the kick queue
let x = self.unlocked_inner.calculate_bucket_index(node_id);
self.unlocked_inner.kick_queue.lock().insert(x);
let x = self.calculate_bucket_index(node_id);
self.kick_queue.lock().insert(x);
}
}
/// Resolve an existing routing table entry using any crypto kind and return a reference to it
pub fn lookup_any_node_ref(&self, node_id_key: PublicKey) -> EyreResult<Option<NodeRef>> {
self.inner
.read()
.lookup_any_node_ref(self.clone(), node_id_key)
self.inner.read().lookup_any_node_ref(node_id_key)
}
/// Resolve an existing routing table entry and return a reference to it
pub fn lookup_node_ref(&self, node_id: TypedKey) -> EyreResult<Option<NodeRef>> {
self.inner.read().lookup_node_ref(self.clone(), node_id)
self.inner.read().lookup_node_ref(node_id)
}
/// Resolve an existing routing table entry and return a filtered reference to it
@ -692,12 +656,9 @@ impl RoutingTable {
routing_domain_set: RoutingDomainSet,
dial_info_filter: DialInfoFilter,
) -> EyreResult<Option<FilteredNodeRef>> {
self.inner.read().lookup_and_filter_noderef(
self.clone(),
node_id,
routing_domain_set,
dial_info_filter,
)
self.inner
.read()
.lookup_and_filter_noderef(node_id, routing_domain_set, dial_info_filter)
}
/// Shortcut function to add a node to our routing table if it doesn't exist
@ -711,7 +672,7 @@ impl RoutingTable {
) -> EyreResult<FilteredNodeRef> {
self.inner
.write()
.register_node_with_peer_info(self.clone(), peer_info, allow_invalid)
.register_node_with_peer_info(peer_info, allow_invalid)
}
/// Shortcut function to add a node to our routing table if it doesn't exist
@ -726,7 +687,7 @@ impl RoutingTable {
) -> EyreResult<FilteredNodeRef> {
self.inner
.write()
.register_node_with_id(self.clone(), routing_domain, node_id, timestamp)
.register_node_with_id(routing_domain, node_id, timestamp)
}
//////////////////////////////////////////////////////////////////////
@ -824,7 +785,7 @@ impl RoutingTable {
}
/// Makes a filter that finds nodes with a matching inbound dialinfo
#[cfg_attr(target_arch = "wasm32", expect(dead_code))]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), expect(dead_code))]
pub fn make_inbound_dial_info_entry_filter<'a>(
routing_domain: RoutingDomain,
dial_info_filter: DialInfoFilter,
@ -885,7 +846,7 @@ impl RoutingTable {
filters: VecDeque<RoutingTableEntryFilter>,
) -> Vec<NodeRef> {
self.inner.read().find_fast_non_local_nodes_filtered(
self.clone(),
self.registry(),
routing_domain,
node_count,
filters,
@ -971,7 +932,7 @@ impl RoutingTable {
protocol_types_len * 2 * max_per_type,
filters,
|_rti, entry: Option<Arc<BucketEntry>>| {
NodeRef::new(self.clone(), entry.unwrap().clone())
NodeRef::new(self.registry(), entry.unwrap().clone())
},
)
}
@ -1073,7 +1034,6 @@ impl RoutingTable {
let res = network_result_try!(
rpc_processor
.clone()
.rpc_call_find_node(
Destination::direct(node_ref.default_filtered()),
node_id,
@ -1162,11 +1122,3 @@ impl RoutingTable {
}
}
}
impl core::ops::Deref for RoutingTable {
type Target = RoutingTableUnlockedInner;
fn deref(&self) -> &Self::Target {
&self.unlocked_inner
}
}

View File

@ -1,7 +1,7 @@
use super::*;
pub(crate) struct FilteredNodeRef {
routing_table: RoutingTable,
registry: VeilidComponentRegistry,
entry: Arc<BucketEntry>,
filter: NodeRefFilter,
sequencing: Sequencing,
@ -9,9 +9,11 @@ pub(crate) struct FilteredNodeRef {
track_id: usize,
}
impl_veilid_component_registry_accessor!(FilteredNodeRef);
impl FilteredNodeRef {
pub fn new(
routing_table: RoutingTable,
registry: VeilidComponentRegistry,
entry: Arc<BucketEntry>,
filter: NodeRefFilter,
sequencing: Sequencing,
@ -19,7 +21,7 @@ impl FilteredNodeRef {
entry.ref_count.fetch_add(1u32, Ordering::AcqRel);
Self {
routing_table,
registry,
entry,
filter,
sequencing,
@ -29,7 +31,7 @@ impl FilteredNodeRef {
}
pub fn unfiltered(&self) -> NodeRef {
NodeRef::new(self.routing_table.clone(), self.entry.clone())
NodeRef::new(self.registry(), self.entry.clone())
}
pub fn filtered_clone(&self, filter: NodeRefFilter) -> FilteredNodeRef {
@ -40,7 +42,7 @@ impl FilteredNodeRef {
pub fn sequencing_clone(&self, sequencing: Sequencing) -> FilteredNodeRef {
FilteredNodeRef::new(
self.routing_table.clone(),
self.registry.clone(),
self.entry.clone(),
self.filter(),
sequencing,
@ -70,9 +72,6 @@ impl FilteredNodeRef {
}
impl NodeRefAccessorsTrait for FilteredNodeRef {
fn routing_table(&self) -> RoutingTable {
self.routing_table.clone()
}
fn entry(&self) -> Arc<BucketEntry> {
self.entry.clone()
}
@ -105,7 +104,8 @@ impl NodeRefOperateTrait for FilteredNodeRef {
where
F: FnOnce(&RoutingTableInner, &BucketEntryInner) -> T,
{
let inner = &*self.routing_table.inner.read();
let routing_table = self.registry.routing_table();
let inner = &*routing_table.inner.read();
self.entry.with(inner, f)
}
@ -113,7 +113,8 @@ impl NodeRefOperateTrait for FilteredNodeRef {
where
F: FnOnce(&mut RoutingTableInner, &mut BucketEntryInner) -> T,
{
let inner = &mut *self.routing_table.inner.write();
let routing_table = self.registry.routing_table();
let inner = &mut *routing_table.inner.write();
self.entry.with_mut(inner, f)
}
}
@ -125,7 +126,7 @@ impl Clone for FilteredNodeRef {
self.entry.ref_count.fetch_add(1u32, Ordering::AcqRel);
Self {
routing_table: self.routing_table.clone(),
registry: self.registry.clone(),
entry: self.entry.clone(),
filter: self.filter,
sequencing: self.sequencing,
@ -162,7 +163,7 @@ impl Drop for FilteredNodeRef {
// get node ids with inner unlocked because nothing could be referencing this entry now
// and we don't know when it will get dropped, possibly inside a lock
let node_ids = self.entry.with_inner(|e| e.node_ids());
self.routing_table.queue_bucket_kicks(node_ids);
self.routing_table().queue_bucket_kicks(node_ids);
}
}
}

View File

@ -16,18 +16,20 @@ pub(crate) use traits::*;
// Default NodeRef
pub(crate) struct NodeRef {
routing_table: RoutingTable,
registry: VeilidComponentRegistry,
entry: Arc<BucketEntry>,
#[cfg(feature = "tracking")]
track_id: usize,
}
impl_veilid_component_registry_accessor!(NodeRef);
impl NodeRef {
pub fn new(routing_table: RoutingTable, entry: Arc<BucketEntry>) -> Self {
pub fn new(registry: VeilidComponentRegistry, entry: Arc<BucketEntry>) -> Self {
entry.ref_count.fetch_add(1u32, Ordering::AcqRel);
Self {
routing_table,
registry,
entry,
#[cfg(feature = "tracking")]
track_id: entry.track(),
@ -36,7 +38,7 @@ impl NodeRef {
pub fn default_filtered(&self) -> FilteredNodeRef {
FilteredNodeRef::new(
self.routing_table.clone(),
self.registry.clone(),
self.entry.clone(),
NodeRefFilter::new(),
Sequencing::default(),
@ -45,7 +47,7 @@ impl NodeRef {
pub fn sequencing_filtered(&self, sequencing: Sequencing) -> FilteredNodeRef {
FilteredNodeRef::new(
self.routing_table.clone(),
self.registry.clone(),
self.entry.clone(),
NodeRefFilter::new(),
sequencing,
@ -57,7 +59,7 @@ impl NodeRef {
routing_domain_set: R,
) -> FilteredNodeRef {
FilteredNodeRef::new(
self.routing_table.clone(),
self.registry.clone(),
self.entry.clone(),
NodeRefFilter::new().with_routing_domain_set(routing_domain_set.into()),
Sequencing::default(),
@ -66,7 +68,7 @@ impl NodeRef {
pub fn custom_filtered(&self, filter: NodeRefFilter) -> FilteredNodeRef {
FilteredNodeRef::new(
self.routing_table.clone(),
self.registry.clone(),
self.entry.clone(),
filter,
Sequencing::default(),
@ -76,7 +78,7 @@ impl NodeRef {
#[expect(dead_code)]
pub fn dial_info_filtered(&self, filter: DialInfoFilter) -> FilteredNodeRef {
FilteredNodeRef::new(
self.routing_table.clone(),
self.registry.clone(),
self.entry.clone(),
NodeRefFilter::new().with_dial_info_filter(filter),
Sequencing::default(),
@ -92,9 +94,6 @@ impl NodeRef {
}
impl NodeRefAccessorsTrait for NodeRef {
fn routing_table(&self) -> RoutingTable {
self.routing_table.clone()
}
fn entry(&self) -> Arc<BucketEntry> {
self.entry.clone()
}
@ -125,7 +124,8 @@ impl NodeRefOperateTrait for NodeRef {
where
F: FnOnce(&RoutingTableInner, &BucketEntryInner) -> T,
{
let inner = &*self.routing_table.inner.read();
let routing_table = self.routing_table();
let inner = &*routing_table.inner.read();
self.entry.with(inner, f)
}
@ -133,7 +133,8 @@ impl NodeRefOperateTrait for NodeRef {
where
F: FnOnce(&mut RoutingTableInner, &mut BucketEntryInner) -> T,
{
let inner = &mut *self.routing_table.inner.write();
let routing_table = self.routing_table();
let inner = &mut *routing_table.inner.write();
self.entry.with_mut(inner, f)
}
}
@ -145,7 +146,7 @@ impl Clone for NodeRef {
self.entry.ref_count.fetch_add(1u32, Ordering::AcqRel);
Self {
routing_table: self.routing_table.clone(),
registry: self.registry.clone(),
entry: self.entry.clone(),
#[cfg(feature = "tracking")]
track_id: self.entry.write().track(),
@ -178,7 +179,7 @@ impl Drop for NodeRef {
// get node ids with inner unlocked because nothing could be referencing this entry now
// and we don't know when it will get dropped, possibly inside a lock
let node_ids = self.entry.with_inner(|e| e.node_ids());
self.routing_table.queue_bucket_kicks(node_ids);
self.routing_table().queue_bucket_kicks(node_ids);
}
}
}

View File

@ -15,6 +15,21 @@ pub(crate) struct NodeRefLock<
nr: N,
}
impl<
'a,
N: NodeRefAccessorsTrait
+ NodeRefOperateTrait
+ VeilidComponentRegistryAccessor
+ fmt::Debug
+ fmt::Display
+ Clone,
> VeilidComponentRegistryAccessor for NodeRefLock<'a, N>
{
fn registry(&self) -> VeilidComponentRegistry {
self.nr.registry()
}
}
impl<'a, N: NodeRefAccessorsTrait + NodeRefOperateTrait + fmt::Debug + fmt::Display + Clone>
NodeRefLock<'a, N>
{
@ -33,9 +48,6 @@ impl<'a, N: NodeRefAccessorsTrait + NodeRefOperateTrait + fmt::Debug + fmt::Disp
impl<'a, N: NodeRefAccessorsTrait + NodeRefOperateTrait + fmt::Debug + fmt::Display + Clone>
NodeRefAccessorsTrait for NodeRefLock<'a, N>
{
fn routing_table(&self) -> RoutingTable {
self.nr.routing_table()
}
fn entry(&self) -> Arc<BucketEntry> {
self.nr.entry()
}

View File

@ -15,6 +15,21 @@ pub(crate) struct NodeRefLockMut<
nr: N,
}
impl<
'a,
N: NodeRefAccessorsTrait
+ NodeRefOperateTrait
+ VeilidComponentRegistryAccessor
+ fmt::Debug
+ fmt::Display
+ Clone,
> VeilidComponentRegistryAccessor for NodeRefLockMut<'a, N>
{
fn registry(&self) -> VeilidComponentRegistry {
self.nr.registry()
}
}
impl<'a, N: NodeRefAccessorsTrait + NodeRefOperateTrait + fmt::Debug + fmt::Display + Clone>
NodeRefLockMut<'a, N>
{
@ -34,9 +49,6 @@ impl<'a, N: NodeRefAccessorsTrait + NodeRefOperateTrait + fmt::Debug + fmt::Disp
impl<'a, N: NodeRefAccessorsTrait + NodeRefOperateTrait + fmt::Debug + fmt::Display + Clone>
NodeRefAccessorsTrait for NodeRefLockMut<'a, N>
{
fn routing_table(&self) -> RoutingTable {
self.nr.routing_table()
}
fn entry(&self) -> Arc<BucketEntry> {
self.nr.entry()
}

View File

@ -2,7 +2,6 @@ use super::*;
// Field accessors
pub(crate) trait NodeRefAccessorsTrait {
fn routing_table(&self) -> RoutingTable;
fn entry(&self) -> Arc<BucketEntry>;
fn sequencing(&self) -> Sequencing;
fn routing_domain_set(&self) -> RoutingDomainSet;
@ -125,12 +124,12 @@ pub(crate) trait NodeRefCommonTrait: NodeRefAccessorsTrait + NodeRefOperateTrait
};
// If relay is ourselves, then return None, because we can't relay through ourselves
// and to contact this node we should have had an existing inbound connection
if rti.unlocked_inner.matches_own_node_id(rpi.node_ids()) {
if rti.routing_table().matches_own_node_id(rpi.node_ids()) {
bail!("Can't relay though ourselves");
}
// Register relay node and return noderef
let nr = rti.register_node_with_peer_info(self.routing_table(), rpi, false)?;
let nr = rti.register_node_with_peer_info(rpi, false)?;
Ok(Some(nr))
})
}
@ -253,7 +252,7 @@ pub(crate) trait NodeRefCommonTrait: NodeRefAccessorsTrait + NodeRefOperateTrait
else {
return false;
};
let our_node_ids = rti.unlocked_inner.node_ids();
let our_node_ids = rti.routing_table().node_ids();
our_node_ids.contains_any(&relay_ids)
})
}

View File

@ -31,7 +31,7 @@ pub(crate) enum RouteNode {
}
impl RouteNode {
pub fn validate(&self, crypto: Crypto) -> VeilidAPIResult<()> {
pub fn validate(&self, crypto: &Crypto) -> VeilidAPIResult<()> {
match self {
RouteNode::NodeId(_) => Ok(()),
RouteNode::PeerInfo(pi) => pi.validate(crypto),
@ -40,7 +40,7 @@ impl RouteNode {
pub fn node_ref(
&self,
routing_table: RoutingTable,
routing_table: &RoutingTable,
crypto_kind: CryptoKind,
) -> Option<NodeRef> {
match self {
@ -91,7 +91,7 @@ pub(crate) struct RouteHop {
pub next_hop: Option<RouteHopData>,
}
impl RouteHop {
pub fn validate(&self, crypto: Crypto) -> VeilidAPIResult<()> {
pub fn validate(&self, crypto: &Crypto) -> VeilidAPIResult<()> {
self.node.validate(crypto)
}
}
@ -108,7 +108,7 @@ pub(crate) enum PrivateRouteHops {
}
impl PrivateRouteHops {
pub fn validate(&self, crypto: Crypto) -> VeilidAPIResult<()> {
pub fn validate(&self, crypto: &Crypto) -> VeilidAPIResult<()> {
match self {
PrivateRouteHops::FirstHop(rh) => rh.validate(crypto),
PrivateRouteHops::Data(_) => Ok(()),
@ -138,7 +138,7 @@ impl PrivateRoute {
}
}
pub fn validate(&self, crypto: Crypto) -> VeilidAPIResult<()> {
pub fn validate(&self, crypto: &Crypto) -> VeilidAPIResult<()> {
self.hops.validate(crypto)
}

View File

@ -34,85 +34,71 @@ struct RouteSpecStoreInner {
cache: RouteSpecStoreCache,
}
struct RouteSpecStoreUnlockedInner {
/// Handle to routing table
routing_table: RoutingTable,
/// The routing table's storage for private/safety routes
#[derive(Debug)]
pub(crate) struct RouteSpecStore {
registry: VeilidComponentRegistry,
inner: Mutex<RouteSpecStoreInner>,
/// Maximum number of hops in a route
max_route_hop_count: usize,
/// Default number of hops in a route
default_route_hop_count: usize,
}
impl fmt::Debug for RouteSpecStoreUnlockedInner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RouteSpecStoreUnlockedInner")
.field("max_route_hop_count", &self.max_route_hop_count)
.field("default_route_hop_count", &self.default_route_hop_count)
.finish()
}
}
/// The routing table's storage for private/safety routes
#[derive(Clone, Debug)]
pub(crate) struct RouteSpecStore {
inner: Arc<Mutex<RouteSpecStoreInner>>,
unlocked_inner: Arc<RouteSpecStoreUnlockedInner>,
}
impl_veilid_component_registry_accessor!(RouteSpecStore);
impl RouteSpecStore {
pub fn new(routing_table: RoutingTable) -> Self {
let config = routing_table.network_manager().config();
pub fn new(registry: VeilidComponentRegistry) -> Self {
let config = registry.config();
let c = config.get();
Self {
unlocked_inner: Arc::new(RouteSpecStoreUnlockedInner {
max_route_hop_count: c.network.rpc.max_route_hop_count.into(),
default_route_hop_count: c.network.rpc.default_route_hop_count.into(),
routing_table,
}),
inner: Arc::new(Mutex::new(RouteSpecStoreInner {
registry,
inner: Mutex::new(RouteSpecStoreInner {
content: RouteSpecStoreContent::new(),
cache: Default::default(),
})),
}),
max_route_hop_count: c.network.rpc.max_route_hop_count.into(),
default_route_hop_count: c.network.rpc.default_route_hop_count.into(),
}
}
#[instrument(level = "trace", target = "route", skip(routing_table), err)]
pub async fn load(routing_table: RoutingTable) -> EyreResult<RouteSpecStore> {
let (max_route_hop_count, default_route_hop_count) = {
let config = routing_table.network_manager().config();
let c = config.get();
(
c.network.rpc.max_route_hop_count as usize,
c.network.rpc.default_route_hop_count as usize,
)
};
// Get frozen blob from table store
let content = RouteSpecStoreContent::load(routing_table.clone()).await?;
let mut inner = RouteSpecStoreInner {
content,
#[instrument(level = "trace", target = "route", skip_all)]
pub fn reset(&self) {
*self.inner.lock() = RouteSpecStoreInner {
content: RouteSpecStoreContent::new(),
cache: Default::default(),
};
}
// Rebuild the routespecstore cache
let rti = &*routing_table.inner.read();
for (_, rssd) in inner.content.iter_details() {
inner.cache.add_to_cache(rti, rssd);
}
#[instrument(level = "trace", target = "route", skip_all, err)]
pub async fn load(&self) -> EyreResult<()> {
let inner = {
let table_store = self.table_store();
let routing_table = self.routing_table();
// Return the loaded RouteSpecStore
let rss = RouteSpecStore {
unlocked_inner: Arc::new(RouteSpecStoreUnlockedInner {
max_route_hop_count,
default_route_hop_count,
routing_table: routing_table.clone(),
}),
inner: Arc::new(Mutex::new(inner)),
// Get frozen blob from table store
let content = RouteSpecStoreContent::load(&table_store, &routing_table).await?;
let mut inner = RouteSpecStoreInner {
content,
cache: Default::default(),
};
// Rebuild the routespecstore cache
let rti = &*routing_table.inner.read();
for (_, rssd) in inner.content.iter_details() {
inner.cache.add_to_cache(rti, rssd);
}
inner
};
Ok(rss)
// Return the loaded RouteSpecStore
*self.inner.lock() = inner;
Ok(())
}
#[instrument(level = "trace", target = "route", skip(self), err)]
@ -123,9 +109,8 @@ impl RouteSpecStore {
};
// Save our content
content
.save(self.unlocked_inner.routing_table.clone())
.await?;
let table_store = self.table_store();
content.save(&table_store).await?;
Ok(())
}
@ -146,16 +131,17 @@ impl RouteSpecStore {
dead_remote_routes,
}));
let update_callback = self.unlocked_inner.routing_table.update_callback();
let update_callback = self.registry.update_callback();
update_callback(update);
}
/// Purge the route spec store
pub async fn purge(&self) -> VeilidAPIResult<()> {
// Briefly pause routing table ticker while changes are made
let _tick_guard = self.unlocked_inner.routing_table.pause_tasks().await;
self.unlocked_inner.routing_table.cancel_tasks().await;
let routing_table = self.routing_table();
let _tick_guard = routing_table.pause_tasks().await;
routing_table.cancel_tasks().await;
{
let inner = &mut *self.inner.lock();
inner.content = Default::default();
@ -181,7 +167,7 @@ impl RouteSpecStore {
automatic: bool,
) -> VeilidAPIResult<RouteId> {
let inner = &mut *self.inner.lock();
let routing_table = self.unlocked_inner.routing_table.clone();
let routing_table = self.routing_table();
let rti = &mut *routing_table.inner.write();
self.allocate_route_inner(
@ -213,12 +199,10 @@ impl RouteSpecStore {
apibail_generic!("safety_spec.preferred_route must be empty when allocating new route");
}
let ip6_prefix_size = rti
.unlocked_inner
.config
.get()
.network
.max_connections_per_ip6_prefix_size as usize;
let ip6_prefix_size = self
.registry()
.config()
.with(|c| c.network.max_connections_per_ip6_prefix_size as usize);
if safety_spec.hop_count < 1 {
apibail_invalid_argument!(
@ -228,7 +212,7 @@ impl RouteSpecStore {
);
}
if safety_spec.hop_count > self.unlocked_inner.max_route_hop_count {
if safety_spec.hop_count > self.max_route_hop_count {
apibail_invalid_argument!(
"Not allocating route longer than max route hop count",
"hop_count",
@ -492,9 +476,8 @@ impl RouteSpecStore {
})
};
let routing_table = self.unlocked_inner.routing_table.clone();
let transform = |_rti: &RoutingTableInner, entry: Option<Arc<BucketEntry>>| -> NodeRef {
NodeRef::new(routing_table.clone(), entry.unwrap())
NodeRef::new(self.registry(), entry.unwrap())
};
// Pull the whole routing table in sorted order
@ -667,13 +650,9 @@ impl RouteSpecStore {
// Got a unique route, lets build the details, register it, and return it
let hop_node_refs: Vec<NodeRef> = route_nodes.iter().map(|k| nodes[*k].clone()).collect();
let mut route_set = BTreeMap::<PublicKey, RouteSpecDetail>::new();
let crypto = self.crypto();
for crypto_kind in crypto_kinds.iter().copied() {
let vcrypto = self
.unlocked_inner
.routing_table
.crypto()
.get(crypto_kind)
.unwrap();
let vcrypto = crypto.get(crypto_kind).unwrap();
let keypair = vcrypto.generate_keypair();
let hops: Vec<PublicKey> = route_nodes
.iter()
@ -734,7 +713,7 @@ impl RouteSpecStore {
R: fmt::Debug,
{
let inner = &*self.inner.lock();
let crypto = self.unlocked_inner.routing_table.crypto();
let crypto = self.crypto();
let Some(vcrypto) = crypto.get(public_key.kind) else {
log_rpc!(debug "can't handle route with public key: {:?}", public_key);
return None;
@ -852,7 +831,7 @@ impl RouteSpecStore {
};
// Test with double-round trip ping to self
let rpc_processor = self.unlocked_inner.routing_table.rpc_processor();
let rpc_processor = self.rpc_processor();
let _res = match rpc_processor.rpc_call_status(dest).await? {
NetworkResult::Value(v) => v,
_ => {
@ -886,7 +865,7 @@ impl RouteSpecStore {
// Get a safety route that is good enough
let safety_spec = SafetySpec {
preferred_route: None,
hop_count: self.unlocked_inner.default_route_hop_count,
hop_count: self.default_route_hop_count,
stability,
sequencing,
};
@ -900,8 +879,7 @@ impl RouteSpecStore {
};
// Test with double-round trip ping to self
let rpc_processor = self.unlocked_inner.routing_table.rpc_processor();
let _res = match rpc_processor.rpc_call_status(dest).await? {
let _res = match self.rpc_processor().rpc_call_status(dest).await? {
NetworkResult::Value(v) => v,
_ => {
// Did not error, but did not come back, just return false
@ -921,7 +899,8 @@ impl RouteSpecStore {
};
// Remove from hop cache
let rti = &*self.unlocked_inner.routing_table.inner.read();
let routing_table = self.routing_table();
let rti = &*routing_table.inner.read();
if !inner.cache.remove_from_cache(rti, id, &rssd) {
panic!("hop cache should have contained cache key");
}
@ -1097,7 +1076,7 @@ impl RouteSpecStore {
) -> VeilidAPIResult<CompiledRoute> {
// let profile_start_ts = get_timestamp();
let inner = &mut *self.inner.lock();
let routing_table = self.unlocked_inner.routing_table.clone();
let routing_table = self.routing_table();
let rti = &mut *routing_table.inner.write();
// Get useful private route properties
@ -1108,7 +1087,7 @@ impl RouteSpecStore {
};
let pr_pubkey = private_route.public_key.value;
let pr_hopcount = private_route.hop_count as usize;
let max_route_hop_count = self.unlocked_inner.max_route_hop_count;
let max_route_hop_count = self.max_route_hop_count;
// Check private route hop count isn't larger than the max route hop count plus one for the 'first hop' header
if pr_hopcount > (max_route_hop_count + 1) {
@ -1130,10 +1109,10 @@ impl RouteSpecStore {
let opt_first_hop = match pr_first_hop_node {
RouteNode::NodeId(id) => rti
.lookup_node_ref(routing_table.clone(), TypedKey::new(crypto_kind, id))
.lookup_node_ref(TypedKey::new(crypto_kind, id))
.map_err(VeilidAPIError::internal)?,
RouteNode::PeerInfo(pi) => Some(
rti.register_node_with_peer_info(routing_table.clone(), pi, false)
rti.register_node_with_peer_info(pi, false)
.map_err(VeilidAPIError::internal)?
.unfiltered(),
),
@ -1362,7 +1341,7 @@ impl RouteSpecStore {
avoid_nodes: &[TypedKey],
) -> VeilidAPIResult<PublicKey> {
// Ensure the total hop count isn't too long for our config
let max_route_hop_count = self.unlocked_inner.max_route_hop_count;
let max_route_hop_count = self.max_route_hop_count;
if safety_spec.hop_count == 0 {
apibail_invalid_argument!(
"safety route hop count is zero",
@ -1438,7 +1417,7 @@ impl RouteSpecStore {
avoid_nodes: &[TypedKey],
) -> VeilidAPIResult<PublicKey> {
let inner = &mut *self.inner.lock();
let routing_table = self.unlocked_inner.routing_table.clone();
let routing_table = self.routing_table();
let rti = &mut *routing_table.inner.write();
self.get_route_for_safety_spec_inner(
@ -1457,7 +1436,7 @@ impl RouteSpecStore {
rsd: &RouteSpecDetail,
optimized: bool,
) -> VeilidAPIResult<PrivateRoute> {
let routing_table = self.unlocked_inner.routing_table.clone();
let routing_table = self.routing_table();
let rti = &*routing_table.inner.read();
// Ensure we get the crypto for it
@ -1732,8 +1711,7 @@ impl RouteSpecStore {
cur_ts: Timestamp,
) -> VeilidAPIResult<()> {
let Some(our_node_info_ts) = self
.unlocked_inner
.routing_table
.routing_table()
.get_published_peer_info(RoutingDomain::PublicInternet)
.map(|pi| pi.signed_node_info().timestamp())
else {
@ -1767,11 +1745,7 @@ impl RouteSpecStore {
let inner = &mut *self.inner.lock();
// Check for stub route
if self
.unlocked_inner
.routing_table
.matches_own_node_id_key(key)
{
if self.routing_table().matches_own_node_id_key(key) {
return None;
}
@ -1869,7 +1843,7 @@ impl RouteSpecStore {
/// Convert binary blob to private route vector
pub fn blob_to_private_routes(&self, blob: Vec<u8>) -> VeilidAPIResult<Vec<PrivateRoute>> {
// Get crypto
let crypto = self.unlocked_inner.routing_table.crypto();
let crypto = self.crypto();
// Deserialize count
if blob.is_empty() {
@ -1904,7 +1878,7 @@ impl RouteSpecStore {
let private_route = decode_private_route(&decode_context, &pr_reader).map_err(|e| {
VeilidAPIError::invalid_argument("failed to decode private route", "e", e)
})?;
private_route.validate(crypto.clone()).map_err(|e| {
private_route.validate(&crypto).map_err(|e| {
VeilidAPIError::invalid_argument("failed to validate private route", "e", e)
})?;
@ -1920,7 +1894,7 @@ impl RouteSpecStore {
/// Generate RouteId from typed key set of route public keys
fn generate_allocated_route_id(&self, rssd: &RouteSetSpecDetail) -> VeilidAPIResult<RouteId> {
let route_set_keys = rssd.get_route_set_keys();
let crypto = self.unlocked_inner.routing_table.crypto();
let crypto = self.crypto();
let mut idbytes = Vec::with_capacity(PUBLIC_KEY_LENGTH * route_set_keys.len());
let mut best_kind: Option<CryptoKind> = None;
@ -1945,7 +1919,7 @@ impl RouteSpecStore {
&self,
private_routes: &[PrivateRoute],
) -> VeilidAPIResult<RouteId> {
let crypto = self.unlocked_inner.routing_table.crypto();
let crypto = self.crypto();
let mut idbytes = Vec::with_capacity(PUBLIC_KEY_LENGTH * private_routes.len());
let mut best_kind: Option<CryptoKind> = None;

View File

@ -17,9 +17,11 @@ impl RouteSpecStoreContent {
}
}
pub async fn load(routing_table: RoutingTable) -> EyreResult<RouteSpecStoreContent> {
pub async fn load(
table_store: &TableStore,
routing_table: &RoutingTable,
) -> EyreResult<RouteSpecStoreContent> {
// Deserialize what we can
let table_store = routing_table.network_manager().table_store();
let rsstdb = table_store.open("RouteSpecStore", 1).await?;
let mut content: RouteSpecStoreContent =
rsstdb.load_json(0, b"content").await?.unwrap_or_default();
@ -59,10 +61,9 @@ impl RouteSpecStoreContent {
Ok(content)
}
pub async fn save(&self, routing_table: RoutingTable) -> EyreResult<()> {
pub async fn save(&self, table_store: &TableStore) -> EyreResult<()> {
// Save all the fields we care about to the frozen blob in table storage
// This skips #[with(Skip)] saving the secret keys, we save them in the protected store instead
let table_store = routing_table.network_manager().table_store();
let rsstdb = table_store.open("RouteSpecStore", 1).await?;
rsstdb.store_json(0, b"content", self).await?;

View File

@ -15,8 +15,8 @@ pub type EntryCounts = BTreeMap<(RoutingDomain, CryptoKind), usize>;
/// RoutingTable rwlock-internal data
pub struct RoutingTableInner {
/// Extra pointer to unlocked members to simplify access
pub(super) unlocked_inner: Arc<RoutingTableUnlockedInner>,
/// Convenience accessor for the global component registry
pub(super) registry: VeilidComponentRegistry,
/// Routing table buckets that hold references to entries, per crypto kind
pub(super) buckets: BTreeMap<CryptoKind, Vec<Bucket>>,
/// A weak set of all the entries we have in the buckets for faster iteration
@ -44,10 +44,12 @@ pub struct RoutingTableInner {
pub(super) opt_active_watch_keepalive_ts: Option<Timestamp>,
}
impl_veilid_component_registry_accessor!(RoutingTableInner);
impl RoutingTableInner {
pub(super) fn new(unlocked_inner: Arc<RoutingTableUnlockedInner>) -> RoutingTableInner {
pub(super) fn new(registry: VeilidComponentRegistry) -> RoutingTableInner {
RoutingTableInner {
unlocked_inner,
registry,
buckets: BTreeMap::new(),
public_internet_routing_domain: PublicInternetRoutingDomainDetail::default(),
local_network_routing_domain: LocalNetworkRoutingDomainDetail::default(),
@ -458,7 +460,6 @@ impl RoutingTableInner {
// Collect all entries that are 'needs_ping' and have some node info making them reachable somehow
pub(super) fn get_nodes_needing_ping(
&self,
outer_self: RoutingTable,
routing_domain: RoutingDomain,
cur_ts: Timestamp,
) -> Vec<FilteredNodeRef> {
@ -559,7 +560,7 @@ impl RoutingTableInner {
let transform = |_rti: &RoutingTableInner, v: Option<Arc<BucketEntry>>| {
FilteredNodeRef::new(
outer_self.clone(),
self.registry.clone(),
v.unwrap().clone(),
NodeRefFilter::new().with_routing_domain(routing_domain),
Sequencing::default(),
@ -570,10 +571,10 @@ impl RoutingTableInner {
}
#[expect(dead_code)]
pub fn get_all_alive_nodes(&self, outer_self: RoutingTable, cur_ts: Timestamp) -> Vec<NodeRef> {
pub fn get_all_alive_nodes(&self, cur_ts: Timestamp) -> Vec<NodeRef> {
let mut node_refs = Vec::<NodeRef>::with_capacity(self.bucket_entry_count());
self.with_entries(cur_ts, BucketEntryState::Unreliable, |_rti, entry| {
node_refs.push(NodeRef::new(outer_self.clone(), entry));
node_refs.push(NodeRef::new(self.registry(), entry));
Option::<()>::None
});
node_refs
@ -601,6 +602,8 @@ impl RoutingTableInner {
entry: Arc<BucketEntry>,
node_ids: &[TypedKey],
) -> EyreResult<()> {
let routing_table = self.routing_table();
entry.with_mut_inner(|e| {
let mut existing_node_ids = e.node_ids();
@ -631,21 +634,21 @@ impl RoutingTableInner {
if let Some(old_node_id) = e.add_node_id(*node_id)? {
// Remove any old node id for this crypto kind
if VALID_CRYPTO_KINDS.contains(&ck) {
let bucket_index = self.unlocked_inner.calculate_bucket_index(&old_node_id);
let bucket_index = routing_table.calculate_bucket_index(&old_node_id);
let bucket = self.get_bucket_mut(bucket_index);
bucket.remove_entry(&old_node_id.value);
self.unlocked_inner.kick_queue.lock().insert(bucket_index);
routing_table.kick_queue.lock().insert(bucket_index);
}
}
// Bucket the entry appropriately
if VALID_CRYPTO_KINDS.contains(&ck) {
let bucket_index = self.unlocked_inner.calculate_bucket_index(node_id);
let bucket_index = routing_table.calculate_bucket_index(node_id);
let bucket = self.get_bucket_mut(bucket_index);
bucket.add_existing_entry(node_id.value, entry.clone());
// Kick bucket
self.unlocked_inner.kick_queue.lock().insert(bucket_index);
routing_table.kick_queue.lock().insert(bucket_index);
}
}
@ -653,7 +656,7 @@ impl RoutingTableInner {
for node_id in existing_node_ids.iter() {
let ck = node_id.kind;
if VALID_CRYPTO_KINDS.contains(&ck) {
let bucket_index = self.unlocked_inner.calculate_bucket_index(node_id);
let bucket_index = routing_table.calculate_bucket_index(node_id);
let bucket = self.get_bucket_mut(bucket_index);
bucket.remove_entry(&node_id.value);
entry.with_mut_inner(|e| e.remove_node_id(ck));
@ -687,15 +690,16 @@ impl RoutingTableInner {
#[instrument(level = "trace", skip_all, err)]
fn create_node_ref<F>(
&mut self,
outer_self: RoutingTable,
node_ids: &TypedKeyGroup,
update_func: F,
) -> EyreResult<NodeRef>
where
F: FnOnce(&mut RoutingTableInner, &mut BucketEntryInner),
{
let routing_table = self.routing_table();
// Ensure someone isn't trying register this node itself
if self.unlocked_inner.matches_own_node_id(node_ids) {
if routing_table.matches_own_node_id(node_ids) {
bail!("can't register own node");
}
@ -708,7 +712,7 @@ impl RoutingTableInner {
continue;
}
// Find the first in crypto sort order
let bucket_index = self.unlocked_inner.calculate_bucket_index(node_id);
let bucket_index = routing_table.calculate_bucket_index(node_id);
let bucket = self.get_bucket(bucket_index);
if let Some(entry) = bucket.entry(&node_id.value) {
// Best entry is the first one in sorted order that exists from the node id list
@ -730,7 +734,7 @@ impl RoutingTableInner {
}
// Make a noderef to return
let nr = NodeRef::new(outer_self.clone(), best_entry.clone());
let nr = NodeRef::new(self.registry(), best_entry.clone());
// Update the entry with the update func
best_entry.with_mut_inner(|e| update_func(self, e));
@ -741,11 +745,11 @@ impl RoutingTableInner {
// If no entry exists yet, add the first entry to a bucket, possibly evicting a bucket member
let first_node_id = node_ids[0];
let bucket_entry = self.unlocked_inner.calculate_bucket_index(&first_node_id);
let bucket_entry = routing_table.calculate_bucket_index(&first_node_id);
let bucket = self.get_bucket_mut(bucket_entry);
let new_entry = bucket.add_new_entry(first_node_id.value);
self.all_entries.insert(new_entry.clone());
self.unlocked_inner.kick_queue.lock().insert(bucket_entry);
routing_table.kick_queue.lock().insert(bucket_entry);
// Update the other bucket entries with the remaining node ids
if let Err(e) = self.update_bucket_entry_node_ids(new_entry.clone(), node_ids) {
@ -753,7 +757,7 @@ impl RoutingTableInner {
}
// Make node ref to return
let nr = NodeRef::new(outer_self.clone(), new_entry.clone());
let nr = NodeRef::new(self.registry(), new_entry.clone());
// Update the entry with the update func
new_entry.with_mut_inner(|e| update_func(self, e));
@ -766,15 +770,9 @@ impl RoutingTableInner {
/// Resolve an existing routing table entry using any crypto kind and return a reference to it
#[instrument(level = "trace", skip_all, err)]
pub fn lookup_any_node_ref(
&self,
outer_self: RoutingTable,
node_id_key: PublicKey,
) -> EyreResult<Option<NodeRef>> {
pub fn lookup_any_node_ref(&self, node_id_key: PublicKey) -> EyreResult<Option<NodeRef>> {
for ck in VALID_CRYPTO_KINDS {
if let Some(nr) =
self.lookup_node_ref(outer_self.clone(), TypedKey::new(ck, node_id_key))?
{
if let Some(nr) = self.lookup_node_ref(TypedKey::new(ck, node_id_key))? {
return Ok(Some(nr));
}
}
@ -783,35 +781,30 @@ impl RoutingTableInner {
/// Resolve an existing routing table entry and return a reference to it
#[instrument(level = "trace", skip_all, err)]
pub fn lookup_node_ref(
&self,
outer_self: RoutingTable,
node_id: TypedKey,
) -> EyreResult<Option<NodeRef>> {
if self.unlocked_inner.matches_own_node_id(&[node_id]) {
pub fn lookup_node_ref(&self, node_id: TypedKey) -> EyreResult<Option<NodeRef>> {
if self.routing_table().matches_own_node_id(&[node_id]) {
bail!("can't look up own node id in routing table");
}
if !VALID_CRYPTO_KINDS.contains(&node_id.kind) {
bail!("can't look up node id with invalid crypto kind");
}
let bucket_index = self.unlocked_inner.calculate_bucket_index(&node_id);
let bucket_index = self.routing_table().calculate_bucket_index(&node_id);
let bucket = self.get_bucket(bucket_index);
Ok(bucket
.entry(&node_id.value)
.map(|e| NodeRef::new(outer_self, e)))
.map(|e| NodeRef::new(self.registry(), e)))
}
/// Resolve an existing routing table entry and return a filtered reference to it
#[instrument(level = "trace", skip_all, err)]
pub fn lookup_and_filter_noderef(
&self,
outer_self: RoutingTable,
node_id: TypedKey,
routing_domain_set: RoutingDomainSet,
dial_info_filter: DialInfoFilter,
) -> EyreResult<Option<FilteredNodeRef>> {
let nr = self.lookup_node_ref(outer_self, node_id)?;
let nr = self.lookup_node_ref(node_id)?;
Ok(nr.map(|nr| {
nr.custom_filtered(
NodeRefFilter::new()
@ -826,7 +819,7 @@ impl RoutingTableInner {
where
F: FnOnce(Arc<BucketEntry>) -> R,
{
if self.unlocked_inner.matches_own_node_id(&[node_id]) {
if self.routing_table().matches_own_node_id(&[node_id]) {
log_rtab!(error "can't look up own node id in routing table");
return None;
}
@ -834,7 +827,7 @@ impl RoutingTableInner {
log_rtab!(error "can't look up node id with invalid crypto kind");
return None;
}
let bucket_entry = self.unlocked_inner.calculate_bucket_index(&node_id);
let bucket_entry = self.routing_table().calculate_bucket_index(&node_id);
let bucket = self.get_bucket(bucket_entry);
bucket.entry(&node_id.value).map(f)
}
@ -845,7 +838,6 @@ impl RoutingTableInner {
#[instrument(level = "trace", skip_all, err)]
pub fn register_node_with_peer_info(
&mut self,
outer_self: RoutingTable,
peer_info: Arc<PeerInfo>,
allow_invalid: bool,
) -> EyreResult<FilteredNodeRef> {
@ -853,7 +845,7 @@ impl RoutingTableInner {
// if our own node is in the list, then ignore it as we don't add ourselves to our own routing table
if self
.unlocked_inner
.routing_table()
.matches_own_node_id(peer_info.node_ids())
{
bail!("can't register own node id in routing table");
@ -891,10 +883,10 @@ impl RoutingTableInner {
if let Some(relay_peer_info) = peer_info.signed_node_info().relay_peer_info(routing_domain)
{
if !self
.unlocked_inner
.routing_table()
.matches_own_node_id(relay_peer_info.node_ids())
{
self.register_node_with_peer_info(outer_self.clone(), relay_peer_info, false)?;
self.register_node_with_peer_info(relay_peer_info, false)?;
}
}
@ -902,7 +894,7 @@ impl RoutingTableInner {
Arc::unwrap_or_clone(peer_info).destructure();
let mut updated = false;
let mut old_peer_info = None;
let nr = self.create_node_ref(outer_self, &node_ids, |_rti, e| {
let nr = self.create_node_ref(&node_ids, |_rti, e| {
old_peer_info = e.make_peer_info(routing_domain);
updated = e.update_signed_node_info(routing_domain, &signed_node_info);
})?;
@ -922,12 +914,11 @@ impl RoutingTableInner {
#[instrument(level = "trace", skip_all, err)]
pub fn register_node_with_id(
&mut self,
outer_self: RoutingTable,
routing_domain: RoutingDomain,
node_id: TypedKey,
timestamp: Timestamp,
) -> EyreResult<FilteredNodeRef> {
let nr = self.create_node_ref(outer_self, &TypedKeyGroup::from(node_id), |_rti, e| {
let nr = self.create_node_ref(&TypedKeyGroup::from(node_id), |_rti, e| {
//e.make_not_dead(timestamp);
e.touch_last_seen(timestamp);
})?;
@ -1057,7 +1048,7 @@ impl RoutingTableInner {
#[instrument(level = "trace", skip_all)]
pub fn find_fast_non_local_nodes_filtered(
&self,
outer_self: RoutingTable,
registry: VeilidComponentRegistry,
routing_domain: RoutingDomain,
node_count: usize,
mut filters: VecDeque<RoutingTableEntryFilter>,
@ -1089,7 +1080,7 @@ impl RoutingTableInner {
node_count,
filters,
|_rti: &RoutingTableInner, v: Option<Arc<BucketEntry>>| {
NodeRef::new(outer_self.clone(), v.unwrap().clone())
NodeRef::new(registry.clone(), v.unwrap().clone())
},
)
}
@ -1283,10 +1274,12 @@ impl RoutingTableInner {
T: for<'r> FnMut(&'r RoutingTableInner, Option<Arc<BucketEntry>>) -> O,
{
let cur_ts = Timestamp::now();
let routing_table = self.routing_table();
// Get the crypto kind
let crypto_kind = node_id.kind;
let Some(vcrypto) = self.unlocked_inner.crypto().get(crypto_kind) else {
let crypto = self.crypto();
let Some(vcrypto) = crypto.get(crypto_kind) else {
apibail_generic!("invalid crypto kind");
};
@ -1338,12 +1331,12 @@ impl RoutingTableInner {
let a_key = if let Some(a_entry) = a_entry {
a_entry.with_inner(|e| e.node_ids().get(crypto_kind).unwrap())
} else {
self.unlocked_inner.node_id(crypto_kind)
routing_table.node_id(crypto_kind)
};
let b_key = if let Some(b_entry) = b_entry {
b_entry.with_inner(|e| e.node_ids().get(crypto_kind).unwrap())
} else {
self.unlocked_inner.node_id(crypto_kind)
routing_table.node_id(crypto_kind)
};
// distance is the next metric, closer nodes first
@ -1379,7 +1372,8 @@ impl RoutingTableInner {
.collect();
// Sort closest
let sort = make_closest_noderef_sort(self.unlocked_inner.crypto(), node_id);
let crypto = self.crypto();
let sort = make_closest_noderef_sort(&crypto, node_id);
closest_nodes_locked.sort_by(sort);
// Unlock noderefs
@ -1388,10 +1382,10 @@ impl RoutingTableInner {
}
#[instrument(level = "trace", skip_all)]
pub fn make_closest_noderef_sort(
crypto: Crypto,
pub fn make_closest_noderef_sort<'a>(
crypto: &'a Crypto,
node_id: TypedKey,
) -> impl Fn(&LockedNodeRef, &LockedNodeRef) -> core::cmp::Ordering {
) -> impl Fn(&LockedNodeRef, &LockedNodeRef) -> core::cmp::Ordering + 'a {
let kind = node_id.kind;
// Get cryptoversion to check distance with
let vcrypto = crypto.get(kind).unwrap();
@ -1418,9 +1412,9 @@ pub fn make_closest_noderef_sort(
}
pub fn make_closest_node_id_sort(
crypto: Crypto,
crypto: &Crypto,
node_id: TypedKey,
) -> impl Fn(&CryptoKey, &CryptoKey) -> core::cmp::Ordering {
) -> impl Fn(&CryptoKey, &CryptoKey) -> core::cmp::Ordering + '_ {
let kind = node_id.kind;
// Get cryptoversion to check distance with
let vcrypto = crypto.get(kind).unwrap();

View File

@ -7,7 +7,7 @@ pub trait RoutingDomainEditorCommonTrait {
protocol_type: Option<ProtocolType>,
) -> &mut Self;
fn set_relay_node(&mut self, relay_node: Option<NodeRef>) -> &mut Self;
#[cfg_attr(target_arch = "wasm32", expect(dead_code))]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), expect(dead_code))]
fn add_dial_info(&mut self, dial_info: DialInfo, class: DialInfoClass) -> &mut Self;
fn setup_network(
&mut self,
@ -83,7 +83,7 @@ pub(super) enum RoutingDomainChangeCommon {
AddDialInfo {
dial_info_detail: DialInfoDetail,
},
// #[cfg_attr(target_arch = "wasm32", expect(dead_code))]
// #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), expect(dead_code))]
// RemoveDialInfoDetail {
// dial_info_detail: DialInfoDetail,
// },

View File

@ -1,4 +1,4 @@
#![cfg_attr(target_arch = "wasm32", expect(dead_code))]
#![cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), expect(dead_code))]
use super::*;
@ -10,15 +10,15 @@ enum RoutingDomainChangeLocalNetwork {
Common(RoutingDomainChangeCommon),
}
pub struct RoutingDomainEditorLocalNetwork {
routing_table: RoutingTable,
pub struct RoutingDomainEditorLocalNetwork<'a> {
routing_table: &'a RoutingTable,
changes: Vec<RoutingDomainChangeLocalNetwork>,
}
impl RoutingDomainEditorLocalNetwork {
pub(in crate::routing_table) fn new(routing_table: RoutingTable) -> Self {
impl<'a> RoutingDomainEditorLocalNetwork<'a> {
pub(in crate::routing_table) fn new(routing_table: &'a RoutingTable) -> Self {
Self {
routing_table: routing_table.clone(),
routing_table,
changes: Vec::new(),
}
}
@ -30,7 +30,7 @@ impl RoutingDomainEditorLocalNetwork {
}
}
impl RoutingDomainEditorCommonTrait for RoutingDomainEditorLocalNetwork {
impl<'a> RoutingDomainEditorCommonTrait for RoutingDomainEditorLocalNetwork<'a> {
#[instrument(level = "debug", skip(self))]
fn clear_dial_info_details(
&mut self,

View File

@ -144,11 +144,7 @@ impl RoutingDomainDetail for LocalNetworkRoutingDomainDetail {
pi
};
if let Err(e) = rti
.unlocked_inner
.event_bus
.post(PeerInfoChangeEvent { peer_info })
{
if let Err(e) = rti.event_bus().post(PeerInfoChangeEvent { peer_info }) {
log_rtab!(debug "Failed to post event: {}", e);
}

View File

@ -143,7 +143,7 @@ impl RoutingDomainDetailCommon {
pub fn network_class(&self) -> Option<NetworkClass> {
cfg_if! {
if #[cfg(target_arch = "wasm32")] {
if #[cfg(all(target_arch = "wasm32", target_os = "unknown"))] {
Some(NetworkClass::WebApp)
} else {
if self.address_types.is_empty() {
@ -312,6 +312,9 @@ impl RoutingDomainDetailCommon {
// Internal functions
fn make_peer_info(&self, rti: &RoutingTableInner) -> PeerInfo {
let crypto = rti.crypto();
let routing_table = rti.routing_table();
let node_info = NodeInfo::new(
self.network_class().unwrap_or(NetworkClass::Invalid),
self.outbound_protocols,
@ -343,8 +346,8 @@ impl RoutingDomainDetailCommon {
let signed_node_info = match relay_info {
Some((relay_ids, relay_sdni)) => SignedNodeInfo::Relayed(
SignedRelayedNodeInfo::make_signatures(
rti.unlocked_inner.crypto(),
rti.unlocked_inner.node_id_typed_key_pairs(),
&crypto,
routing_table.node_id_typed_key_pairs(),
node_info,
relay_ids,
relay_sdni,
@ -353,8 +356,8 @@ impl RoutingDomainDetailCommon {
),
None => SignedNodeInfo::Direct(
SignedDirectNodeInfo::make_signatures(
rti.unlocked_inner.crypto(),
rti.unlocked_inner.node_id_typed_key_pairs(),
&crypto,
routing_table.node_id_typed_key_pairs(),
node_info,
)
.unwrap(),
@ -363,7 +366,7 @@ impl RoutingDomainDetailCommon {
PeerInfo::new(
self.routing_domain,
rti.unlocked_inner.node_ids(),
routing_table.node_ids(),
signed_node_info,
)
}

View File

@ -5,13 +5,13 @@ enum RoutingDomainChangePublicInternet {
Common(RoutingDomainChangeCommon),
}
pub struct RoutingDomainEditorPublicInternet {
routing_table: RoutingTable,
pub struct RoutingDomainEditorPublicInternet<'a> {
routing_table: &'a RoutingTable,
changes: Vec<RoutingDomainChangePublicInternet>,
}
impl RoutingDomainEditorPublicInternet {
pub(in crate::routing_table) fn new(routing_table: RoutingTable) -> Self {
impl<'a> RoutingDomainEditorPublicInternet<'a> {
pub(in crate::routing_table) fn new(routing_table: &'a RoutingTable) -> Self {
Self {
routing_table,
changes: Vec::new(),
@ -41,7 +41,7 @@ impl RoutingDomainEditorPublicInternet {
}
}
impl RoutingDomainEditorCommonTrait for RoutingDomainEditorPublicInternet {
impl<'a> RoutingDomainEditorCommonTrait for RoutingDomainEditorPublicInternet<'a> {
#[instrument(level = "debug", skip(self))]
fn clear_dial_info_details(
&mut self,
@ -263,8 +263,7 @@ impl RoutingDomainEditorCommonTrait for RoutingDomainEditorPublicInternet {
if changed {
// Clear the routespecstore cache if our PublicInternet dial info has changed
let rss = self.routing_table.route_spec_store();
rss.reset_cache();
self.routing_table.route_spec_store().reset_cache();
}
}

View File

@ -122,11 +122,7 @@ impl RoutingDomainDetail for PublicInternetRoutingDomainDetail {
pi
};
if let Err(e) = rti
.unlocked_inner
.event_bus
.post(PeerInfoChangeEvent { peer_info })
{
if let Err(e) = rti.event_bus().post(PeerInfoChangeEvent { peer_info }) {
log_rtab!(debug "Failed to post event: {}", e);
}
@ -167,11 +163,8 @@ impl RoutingDomainDetail for PublicInternetRoutingDomainDetail {
dif_sort: Option<Arc<DialInfoDetailSort>>,
) -> ContactMethod {
let ip6_prefix_size = rti
.unlocked_inner
.config
.get()
.network
.max_connections_per_ip6_prefix_size as usize;
.config()
.with(|c| c.network.max_connections_per_ip6_prefix_size as usize);
// Get the nodeinfos for convenience
let node_a = peer_a.signed_node_info().node_info();

View File

@ -81,7 +81,7 @@ impl RoutingTable {
}
// If this is our own node id, then we skip it for bootstrap, in case we are a bootstrap node
if self.unlocked_inner.matches_own_node_id(&node_ids) {
if self.matches_own_node_id(&node_ids) {
return Ok(None);
}
@ -255,7 +255,7 @@ impl RoutingTable {
//#[instrument(level = "trace", skip(self), err)]
pub fn bootstrap_with_peer(
self,
&self,
crypto_kinds: Vec<CryptoKind>,
pi: Arc<PeerInfo>,
unord: &FuturesUnordered<SendPinBoxFuture<()>>,
@ -280,19 +280,20 @@ impl RoutingTable {
for crypto_kind in crypto_kinds {
// Bootstrap this crypto kind
let nr = nr.unfiltered();
let routing_table = self.clone();
unord.push(Box::pin(
async move {
let network_manager = nr.network_manager();
let routing_table = nr.routing_table();
// Get what contact method would be used for contacting the bootstrap
let bsdi = match routing_table
.network_manager()
let bsdi = match network_manager
.get_node_contact_method(nr.default_filtered())
{
Ok(NodeContactMethod::Direct(v)) => v,
Ok(v) => {
log_rtab!(debug "invalid contact method for bootstrap, ignoring peer: {:?}", v);
// let _ = routing_table
// .network_manager()
// let _ =
// network_manager
// .get_node_contact_method(nr.clone());
return;
}
@ -312,7 +313,7 @@ impl RoutingTable {
log_rtab!(debug "bootstrap server is not responding for dialinfo: {}", bsdi);
// Try a different dialinfo next time
routing_table.network_manager().address_filter().set_dial_info_failed(bsdi);
network_manager.address_filter().set_dial_info_failed(bsdi);
} else {
// otherwise this bootstrap is valid, lets ask it to find ourselves now
routing_table.reverse_find_node(crypto_kind, nr, true, vec![]).await
@ -325,7 +326,7 @@ impl RoutingTable {
#[instrument(level = "trace", skip(self), err)]
pub async fn bootstrap_with_peer_list(
self,
&self,
peers: Vec<Arc<PeerInfo>>,
stop_token: StopToken,
) -> EyreResult<()> {
@ -339,8 +340,7 @@ impl RoutingTable {
// Run all bootstrap operations concurrently
let mut unord = FuturesUnordered::<SendPinBoxFuture<()>>::new();
for peer in peers {
self.clone()
.bootstrap_with_peer(crypto_kinds.clone(), peer, &unord);
self.bootstrap_with_peer(crypto_kinds.clone(), peer, &unord);
}
// Wait for all bootstrap operations to complete before we complete the singlefuture
@ -364,10 +364,15 @@ impl RoutingTable {
}
#[instrument(level = "trace", skip(self), err)]
pub async fn bootstrap_task_routine(self, stop_token: StopToken) -> EyreResult<()> {
pub async fn bootstrap_task_routine(
&self,
stop_token: StopToken,
_last_ts: Timestamp,
_cur_ts: Timestamp,
) -> EyreResult<()> {
let bootstrap = self
.unlocked_inner
.with_config(|c| c.network.routing_table.bootstrap.clone());
.config()
.with(|c| c.network.routing_table.bootstrap.clone());
// Don't bother if bootstraps aren't configured
if bootstrap.is_empty() {
@ -445,8 +450,6 @@ impl RoutingTable {
peers
};
self.clone()
.bootstrap_with_peer_list(peers, stop_token)
.await
self.bootstrap_with_peer_list(peers, stop_token).await
}
}

Some files were not shown because too many files have changed in this diff Show More