[skip ci] more refactor

This commit is contained in:
Christien Rioux 2025-01-31 20:19:24 -05:00
parent 6319db2b06
commit 7edbf28e36
25 changed files with 346 additions and 423 deletions

View File

@ -100,7 +100,7 @@ impl fmt::Debug for CryptoInner {
/// Crypto factory implementation /// Crypto factory implementation
pub struct Crypto { pub struct Crypto {
registry: VeilidComponentRegistry, registry: VeilidComponentRegistry,
inner: Arc<Mutex<CryptoInner>>, inner: Mutex<CryptoInner>,
#[cfg(feature = "enable-crypto-vld0")] #[cfg(feature = "enable-crypto-vld0")]
crypto_vld0: Arc<dyn CryptoSystem + Send + Sync>, crypto_vld0: Arc<dyn CryptoSystem + Send + Sync>,
#[cfg(feature = "enable-crypto-none")] #[cfg(feature = "enable-crypto-none")]
@ -131,7 +131,7 @@ impl Crypto {
pub fn new(registry: VeilidComponentRegistry) -> Self { pub fn new(registry: VeilidComponentRegistry) -> Self {
Self { Self {
registry: registry.clone(), registry: registry.clone(),
inner: Arc::new(Mutex::new(Self::new_inner())), inner: Mutex::new(Self::new_inner()),
#[cfg(feature = "enable-crypto-vld0")] #[cfg(feature = "enable-crypto-vld0")]
crypto_vld0: Arc::new(vld0::CryptoSystemVLD0::new(registry.clone())), crypto_vld0: Arc::new(vld0::CryptoSystemVLD0::new(registry.clone())),
#[cfg(feature = "enable-crypto-none")] #[cfg(feature = "enable-crypto-none")]

View File

@ -10,10 +10,10 @@ impl fmt::Debug for BlockStoreInner {
} }
} }
#[derive(Clone, Debug)] #[derive(Debug)]
pub struct BlockStore { pub struct BlockStore {
registry: VeilidComponentRegistry, registry: VeilidComponentRegistry,
inner: Arc<Mutex<BlockStoreInner>>, inner: Mutex<BlockStoreInner>,
} }
impl_veilid_component!(BlockStore); impl_veilid_component!(BlockStore);
@ -25,7 +25,7 @@ impl BlockStore {
pub fn new(registry: VeilidComponentRegistry) -> Self { pub fn new(registry: VeilidComponentRegistry) -> Self {
Self { Self {
registry, registry,
inner: Arc::new(Mutex::new(Self::new_inner())), inner: Mutex::new(Self::new_inner()),
} }
} }

View File

@ -15,7 +15,7 @@ impl fmt::Debug for ProtectedStoreInner {
#[derive(Debug)] #[derive(Debug)]
pub struct ProtectedStore { pub struct ProtectedStore {
registry: VeilidComponentRegistry, registry: VeilidComponentRegistry,
inner: Arc<Mutex<ProtectedStoreInner>>, inner: Mutex<ProtectedStoreInner>,
} }
impl_veilid_component!(ProtectedStore); impl_veilid_component!(ProtectedStore);
@ -30,7 +30,7 @@ impl ProtectedStore {
pub fn new(registry: VeilidComponentRegistry) -> Self { pub fn new(registry: VeilidComponentRegistry) -> Self {
Self { Self {
registry, registry,
inner: Arc::new(Mutex::new(Self::new_inner())), inner: Mutex::new(Self::new_inner()),
} }
} }

View File

@ -13,7 +13,7 @@ impl fmt::Debug for BlockStoreInner {
#[derive(Debug)] #[derive(Debug)]
pub struct BlockStore { pub struct BlockStore {
registry: VeilidComponentRegistry, registry: VeilidComponentRegistry,
inner: Arc<Mutex<BlockStoreInner>>, inner: Mutex<BlockStoreInner>,
} }
impl_veilid_component!(BlockStore); impl_veilid_component!(BlockStore);
@ -25,7 +25,7 @@ impl BlockStore {
pub fn new(registry: VeilidComponentRegistry) -> Self { pub fn new(registry: VeilidComponentRegistry) -> Self {
Self { Self {
registry, registry,
inner: Arc::new(Mutex::new(Self::new_inner())), inner: Mutex::new(Self::new_inner()),
} }
} }

View File

@ -32,7 +32,10 @@ struct AddressFilterInner {
dial_info_failures: BTreeMap<DialInfo, Timestamp>, dial_info_failures: BTreeMap<DialInfo, Timestamp>,
} }
struct AddressFilterUnlockedInner { #[derive(Debug)]
pub(crate) struct AddressFilter {
registry: VeilidComponentRegistry,
inner: Mutex<AddressFilterInner>,
max_connections_per_ip4: usize, max_connections_per_ip4: usize,
max_connections_per_ip6_prefix: usize, max_connections_per_ip6_prefix: usize,
max_connections_per_ip6_prefix_size: usize, max_connections_per_ip6_prefix_size: usize,
@ -41,38 +44,6 @@ struct AddressFilterUnlockedInner {
dial_info_failure_duration_min: usize, dial_info_failure_duration_min: usize,
} }
impl fmt::Debug for AddressFilterUnlockedInner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AddressFilterUnlockedInner")
.field("max_connections_per_ip4", &self.max_connections_per_ip4)
.field(
"max_connections_per_ip6_prefix",
&self.max_connections_per_ip6_prefix,
)
.field(
"max_connections_per_ip6_prefix_size",
&self.max_connections_per_ip6_prefix_size,
)
.field(
"max_connection_frequency_per_min",
&self.max_connection_frequency_per_min,
)
.field("punishment_duration_min", &self.punishment_duration_min)
.field(
"dial_info_failure_duration_min",
&self.dial_info_failure_duration_min,
)
.finish()
}
}
#[derive(Clone, Debug)]
pub(crate) struct AddressFilter {
registry: VeilidComponentRegistry,
unlocked_inner: Arc<AddressFilterUnlockedInner>,
inner: Arc<Mutex<AddressFilterInner>>,
}
impl_veilid_component_registry_accessor!(AddressFilter); impl_veilid_component_registry_accessor!(AddressFilter);
impl AddressFilter { impl AddressFilter {
@ -81,17 +52,7 @@ impl AddressFilter {
let c = config.get(); let c = config.get();
Self { Self {
registry, registry,
unlocked_inner: Arc::new(AddressFilterUnlockedInner { inner: Mutex::new(AddressFilterInner {
max_connections_per_ip4: c.network.max_connections_per_ip4 as usize,
max_connections_per_ip6_prefix: c.network.max_connections_per_ip6_prefix as usize,
max_connections_per_ip6_prefix_size: c.network.max_connections_per_ip6_prefix_size
as usize,
max_connection_frequency_per_min: c.network.max_connection_frequency_per_min
as usize,
punishment_duration_min: PUNISHMENT_DURATION_MIN,
dial_info_failure_duration_min: DIAL_INFO_FAILURE_DURATION_MIN,
}),
inner: Arc::new(Mutex::new(AddressFilterInner {
conn_count_by_ip4: BTreeMap::new(), conn_count_by_ip4: BTreeMap::new(),
conn_count_by_ip6_prefix: BTreeMap::new(), conn_count_by_ip6_prefix: BTreeMap::new(),
conn_timestamps_by_ip4: BTreeMap::new(), conn_timestamps_by_ip4: BTreeMap::new(),
@ -100,7 +61,14 @@ impl AddressFilter {
punishments_by_ip6_prefix: BTreeMap::new(), punishments_by_ip6_prefix: BTreeMap::new(),
punishments_by_node_id: BTreeMap::new(), punishments_by_node_id: BTreeMap::new(),
dial_info_failures: BTreeMap::new(), dial_info_failures: BTreeMap::new(),
})), }),
max_connections_per_ip4: c.network.max_connections_per_ip4 as usize,
max_connections_per_ip6_prefix: c.network.max_connections_per_ip6_prefix as usize,
max_connections_per_ip6_prefix_size: c.network.max_connections_per_ip6_prefix_size
as usize,
max_connection_frequency_per_min: c.network.max_connection_frequency_per_min as usize,
punishment_duration_min: PUNISHMENT_DURATION_MIN,
dial_info_failure_duration_min: DIAL_INFO_FAILURE_DURATION_MIN,
} }
} }
@ -154,7 +122,7 @@ impl AddressFilter {
for (key, value) in &mut inner.punishments_by_ip4 { for (key, value) in &mut inner.punishments_by_ip4 {
// Drop punishments older than the punishment duration // Drop punishments older than the punishment duration
if cur_ts.as_u64().saturating_sub(value.timestamp.as_u64()) if cur_ts.as_u64().saturating_sub(value.timestamp.as_u64())
> self.unlocked_inner.punishment_duration_min as u64 * 60_000_000u64 > self.punishment_duration_min as u64 * 60_000_000u64
{ {
dead_keys.push(*key); dead_keys.push(*key);
} }
@ -170,7 +138,7 @@ impl AddressFilter {
for (key, value) in &mut inner.punishments_by_ip6_prefix { for (key, value) in &mut inner.punishments_by_ip6_prefix {
// Drop punishments older than the punishment duration // Drop punishments older than the punishment duration
if cur_ts.as_u64().saturating_sub(value.timestamp.as_u64()) if cur_ts.as_u64().saturating_sub(value.timestamp.as_u64())
> self.unlocked_inner.punishment_duration_min as u64 * 60_000_000u64 > self.punishment_duration_min as u64 * 60_000_000u64
{ {
dead_keys.push(*key); dead_keys.push(*key);
} }
@ -186,7 +154,7 @@ impl AddressFilter {
for (key, value) in &mut inner.punishments_by_node_id { for (key, value) in &mut inner.punishments_by_node_id {
// Drop punishments older than the punishment duration // Drop punishments older than the punishment duration
if cur_ts.as_u64().saturating_sub(value.timestamp.as_u64()) if cur_ts.as_u64().saturating_sub(value.timestamp.as_u64())
> self.unlocked_inner.punishment_duration_min as u64 * 60_000_000u64 > self.punishment_duration_min as u64 * 60_000_000u64
{ {
dead_keys.push(*key); dead_keys.push(*key);
} }
@ -206,7 +174,7 @@ impl AddressFilter {
for (key, value) in &mut inner.dial_info_failures { for (key, value) in &mut inner.dial_info_failures {
// Drop failures older than the failure duration // Drop failures older than the failure duration
if cur_ts.as_u64().saturating_sub(value.as_u64()) if cur_ts.as_u64().saturating_sub(value.as_u64())
> self.unlocked_inner.dial_info_failure_duration_min as u64 * 60_000_000u64 > self.dial_info_failure_duration_min as u64 * 60_000_000u64
{ {
dead_keys.push(key.clone()); dead_keys.push(key.clone());
} }
@ -244,10 +212,7 @@ impl AddressFilter {
pub fn is_ip_addr_punished(&self, addr: IpAddr) -> bool { pub fn is_ip_addr_punished(&self, addr: IpAddr) -> bool {
let inner = self.inner.lock(); let inner = self.inner.lock();
let ipblock = ip_to_ipblock( let ipblock = ip_to_ipblock(self.max_connections_per_ip6_prefix_size, addr);
self.unlocked_inner.max_connections_per_ip6_prefix_size,
addr,
);
self.is_ip_addr_punished_inner(&inner, ipblock) self.is_ip_addr_punished_inner(&inner, ipblock)
} }
@ -276,8 +241,9 @@ impl AddressFilter {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
inner.punishments_by_ip4.clear(); inner.punishments_by_ip4.clear();
inner.punishments_by_ip6_prefix.clear(); inner.punishments_by_ip6_prefix.clear();
self.unlocked_inner.routing_table.clear_punishments();
inner.punishments_by_node_id.clear(); inner.punishments_by_node_id.clear();
self.routing_table().clear_punishments();
} }
pub fn punish_ip_addr(&self, addr: IpAddr, reason: PunishmentReason) { pub fn punish_ip_addr(&self, addr: IpAddr, reason: PunishmentReason) {
@ -285,10 +251,7 @@ impl AddressFilter {
let timestamp = Timestamp::now(); let timestamp = Timestamp::now();
let punishment = Punishment { reason, timestamp }; let punishment = Punishment { reason, timestamp };
let ipblock = ip_to_ipblock( let ipblock = ip_to_ipblock(self.max_connections_per_ip6_prefix_size, addr);
self.unlocked_inner.max_connections_per_ip6_prefix_size,
addr,
);
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
match ipblock { match ipblock {
@ -318,7 +281,7 @@ impl AddressFilter {
} }
pub fn punish_node_id(&self, node_id: TypedKey, reason: PunishmentReason) { pub fn punish_node_id(&self, node_id: TypedKey, reason: PunishmentReason) {
if let Ok(Some(nr)) = self.unlocked_inner.routing_table.lookup_node_ref(node_id) { if let Ok(Some(nr)) = self.routing_table().lookup_node_ref(node_id) {
// make the entry dead if it's punished // make the entry dead if it's punished
nr.operate_mut(|_rti, e| e.set_punished(Some(reason))); nr.operate_mut(|_rti, e| e.set_punished(Some(reason)));
} }
@ -357,10 +320,7 @@ impl AddressFilter {
pub fn add_connection(&self, addr: IpAddr) -> Result<(), AddressFilterError> { pub fn add_connection(&self, addr: IpAddr) -> Result<(), AddressFilterError> {
let inner = &mut *self.inner.lock(); let inner = &mut *self.inner.lock();
let ipblock = ip_to_ipblock( let ipblock = ip_to_ipblock(self.max_connections_per_ip6_prefix_size, addr);
self.unlocked_inner.max_connections_per_ip6_prefix_size,
addr,
);
if self.is_ip_addr_punished_inner(inner, ipblock) { if self.is_ip_addr_punished_inner(inner, ipblock) {
return Err(AddressFilterError::Punished); return Err(AddressFilterError::Punished);
} }
@ -372,8 +332,8 @@ impl AddressFilter {
IpAddr::V4(v4) => { IpAddr::V4(v4) => {
// See if we have too many connections from this ip block // See if we have too many connections from this ip block
let cnt = inner.conn_count_by_ip4.entry(v4).or_default(); let cnt = inner.conn_count_by_ip4.entry(v4).or_default();
assert!(*cnt <= self.unlocked_inner.max_connections_per_ip4); assert!(*cnt <= self.max_connections_per_ip4);
if *cnt == self.unlocked_inner.max_connections_per_ip4 { if *cnt == self.max_connections_per_ip4 {
warn!("Address filter count exceeded: {:?}", v4); warn!("Address filter count exceeded: {:?}", v4);
return Err(AddressFilterError::CountExceeded); return Err(AddressFilterError::CountExceeded);
} }
@ -383,8 +343,8 @@ impl AddressFilter {
// keep timestamps that are less than a minute away // keep timestamps that are less than a minute away
ts.saturating_sub(*v) < TimestampDuration::new(60_000_000u64) ts.saturating_sub(*v) < TimestampDuration::new(60_000_000u64)
}); });
assert!(tstamps.len() <= self.unlocked_inner.max_connection_frequency_per_min); assert!(tstamps.len() <= self.max_connection_frequency_per_min);
if tstamps.len() == self.unlocked_inner.max_connection_frequency_per_min { if tstamps.len() == self.max_connection_frequency_per_min {
warn!("Address filter rate exceeded: {:?}", v4); warn!("Address filter rate exceeded: {:?}", v4);
return Err(AddressFilterError::RateExceeded); return Err(AddressFilterError::RateExceeded);
} }
@ -396,15 +356,15 @@ impl AddressFilter {
IpAddr::V6(v6) => { IpAddr::V6(v6) => {
// See if we have too many connections from this ip block // See if we have too many connections from this ip block
let cnt = inner.conn_count_by_ip6_prefix.entry(v6).or_default(); let cnt = inner.conn_count_by_ip6_prefix.entry(v6).or_default();
assert!(*cnt <= self.unlocked_inner.max_connections_per_ip6_prefix); assert!(*cnt <= self.max_connections_per_ip6_prefix);
if *cnt == self.unlocked_inner.max_connections_per_ip6_prefix { if *cnt == self.max_connections_per_ip6_prefix {
warn!("Address filter count exceeded: {:?}", v6); warn!("Address filter count exceeded: {:?}", v6);
return Err(AddressFilterError::CountExceeded); return Err(AddressFilterError::CountExceeded);
} }
// See if this ip block has connected too frequently // See if this ip block has connected too frequently
let tstamps = inner.conn_timestamps_by_ip6_prefix.entry(v6).or_default(); let tstamps = inner.conn_timestamps_by_ip6_prefix.entry(v6).or_default();
assert!(tstamps.len() <= self.unlocked_inner.max_connection_frequency_per_min); assert!(tstamps.len() <= self.max_connection_frequency_per_min);
if tstamps.len() == self.unlocked_inner.max_connection_frequency_per_min { if tstamps.len() == self.max_connection_frequency_per_min {
warn!("Address filter rate exceeded: {:?}", v6); warn!("Address filter rate exceeded: {:?}", v6);
return Err(AddressFilterError::RateExceeded); return Err(AddressFilterError::RateExceeded);
} }
@ -420,10 +380,7 @@ impl AddressFilter {
pub fn remove_connection(&mut self, addr: IpAddr) -> Result<(), AddressNotInTableError> { pub fn remove_connection(&mut self, addr: IpAddr) -> Result<(), AddressNotInTableError> {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
let ipblock = ip_to_ipblock( let ipblock = ip_to_ipblock(self.max_connections_per_ip6_prefix_size, addr);
self.unlocked_inner.max_connections_per_ip6_prefix_size,
addr,
);
let ts = Timestamp::now(); let ts = Timestamp::now();
self.purge_old_timestamps(&mut inner, ts); self.purge_old_timestamps(&mut inner, ts);

View File

@ -61,7 +61,6 @@ struct ConnectionManagerInner {
} }
struct ConnectionManagerArc { struct ConnectionManagerArc {
network_manager: NetworkManager,
connection_initial_timeout_ms: u32, connection_initial_timeout_ms: u32,
connection_inactivity_timeout_ms: u32, connection_inactivity_timeout_ms: u32,
connection_table: ConnectionTable, connection_table: ConnectionTable,
@ -79,9 +78,12 @@ impl core::fmt::Debug for ConnectionManagerArc {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ConnectionManager { pub struct ConnectionManager {
registry: VeilidComponentRegistry,
arc: Arc<ConnectionManagerArc>, arc: Arc<ConnectionManagerArc>,
} }
impl_veilid_component_registry_accessor!(ConnectionManager);
impl ConnectionManager { impl ConnectionManager {
fn new_inner( fn new_inner(
stop_source: StopSource, stop_source: StopSource,
@ -98,8 +100,8 @@ impl ConnectionManager {
reconnection_processor, reconnection_processor,
} }
} }
fn new_arc(network_manager: NetworkManager) -> ConnectionManagerArc { fn new_arc(registry: VeilidComponentRegistry) -> ConnectionManagerArc {
let config = network_manager.config(); let config = registry.config();
let (connection_initial_timeout_ms, connection_inactivity_timeout_ms) = { let (connection_initial_timeout_ms, connection_inactivity_timeout_ms) = {
let c = config.get(); let c = config.get();
( (
@ -107,28 +109,23 @@ impl ConnectionManager {
c.network.connection_inactivity_timeout_ms, c.network.connection_inactivity_timeout_ms,
) )
}; };
let address_filter = network_manager.address_filter();
ConnectionManagerArc { ConnectionManagerArc {
network_manager,
connection_initial_timeout_ms, connection_initial_timeout_ms,
connection_inactivity_timeout_ms, connection_inactivity_timeout_ms,
connection_table: ConnectionTable::new(config, address_filter), connection_table: ConnectionTable::new(registry),
address_lock_table: AsyncTagLockTable::new(), address_lock_table: AsyncTagLockTable::new(),
startup_lock: StartupLock::new(), startup_lock: StartupLock::new(),
inner: Mutex::new(None), inner: Mutex::new(None),
} }
} }
pub fn new(network_manager: NetworkManager) -> Self { pub fn new(registry: VeilidComponentRegistry) -> Self {
Self { Self {
arc: Arc::new(Self::new_arc(network_manager)), arc: Arc::new(Self::new_arc(registry.clone())),
registry,
} }
} }
pub fn network_manager(&self) -> NetworkManager {
self.arc.network_manager.clone()
}
pub fn connection_inactivity_timeout_ms(&self) -> u32 { pub fn connection_inactivity_timeout_ms(&self) -> u32 {
self.arc.connection_inactivity_timeout_ms self.arc.connection_inactivity_timeout_ms
} }
@ -452,13 +449,15 @@ impl ConnectionManager {
// Attempt new connection // Attempt new connection
let mut retry_count = NEW_CONNECTION_RETRY_COUNT; let mut retry_count = NEW_CONNECTION_RETRY_COUNT;
let network_manager = self.network_manager();
let prot_conn = network_result_try!(loop { let prot_conn = network_result_try!(loop {
let address_filter = network_manager.address_filter();
let result_net_res = ProtocolNetworkConnection::connect( let result_net_res = ProtocolNetworkConnection::connect(
preferred_local_address, preferred_local_address,
&dial_info, &dial_info,
self.arc.connection_initial_timeout_ms, self.arc.connection_initial_timeout_ms,
self.network_manager().address_filter(), &*address_filter,
) )
.await; .await;
match result_net_res { match result_net_res {

View File

@ -44,17 +44,20 @@ struct ConnectionTableInner {
protocol_index_by_id: BTreeMap<NetworkConnectionId, usize>, protocol_index_by_id: BTreeMap<NetworkConnectionId, usize>,
id_by_flow: BTreeMap<Flow, NetworkConnectionId>, id_by_flow: BTreeMap<Flow, NetworkConnectionId>,
ids_by_remote: BTreeMap<PeerAddress, Vec<NetworkConnectionId>>, ids_by_remote: BTreeMap<PeerAddress, Vec<NetworkConnectionId>>,
address_filter: AddressFilter,
priority_flows: Vec<LruCache<Flow, ()>>, priority_flows: Vec<LruCache<Flow, ()>>,
} }
#[derive(Debug)] #[derive(Debug)]
pub struct ConnectionTable { pub struct ConnectionTable {
inner: Arc<Mutex<ConnectionTableInner>>, registry: VeilidComponentRegistry,
inner: Mutex<ConnectionTableInner>,
} }
impl_veilid_component_registry_accessor!(ConnectionTable);
impl ConnectionTable { impl ConnectionTable {
pub fn new(config: VeilidConfig, address_filter: AddressFilter) -> Self { pub fn new(registry: VeilidComponentRegistry) -> Self {
let config = registry.config();
let max_connections = { let max_connections = {
let c = config.get(); let c = config.get();
vec![ vec![
@ -64,7 +67,8 @@ impl ConnectionTable {
] ]
}; };
Self { Self {
inner: Arc::new(Mutex::new(ConnectionTableInner { registry,
inner: Mutex::new(ConnectionTableInner {
conn_by_id: max_connections conn_by_id: max_connections
.iter() .iter()
.map(|_| LruCache::new_unbounded()) .map(|_| LruCache::new_unbounded())
@ -72,13 +76,12 @@ impl ConnectionTable {
protocol_index_by_id: BTreeMap::new(), protocol_index_by_id: BTreeMap::new(),
id_by_flow: BTreeMap::new(), id_by_flow: BTreeMap::new(),
ids_by_remote: BTreeMap::new(), ids_by_remote: BTreeMap::new(),
address_filter,
priority_flows: max_connections priority_flows: max_connections
.iter() .iter()
.map(|x| LruCache::new(x * PRIORITY_FLOW_PERCENTAGE / 100)) .map(|x| LruCache::new(x * PRIORITY_FLOW_PERCENTAGE / 100))
.collect(), .collect(),
max_connections, max_connections,
})), }),
} }
} }
@ -168,6 +171,7 @@ impl ConnectionTable {
/// when it is getting full while adding a new connection. /// when it is getting full while adding a new connection.
/// Factored out into its own function for clarity. /// Factored out into its own function for clarity.
fn lru_out_connection_inner( fn lru_out_connection_inner(
&self,
inner: &mut ConnectionTableInner, inner: &mut ConnectionTableInner,
protocol_index: usize, protocol_index: usize,
) -> Result<Option<NetworkConnection>, ()> { ) -> Result<Option<NetworkConnection>, ()> {
@ -198,7 +202,7 @@ impl ConnectionTable {
lruk lruk
}; };
let dead_conn = Self::remove_connection_records(inner, dead_k); let dead_conn = self.remove_connection_records_inner(inner, dead_k);
Ok(Some(dead_conn)) Ok(Some(dead_conn))
} }
@ -235,20 +239,20 @@ impl ConnectionTable {
// Filter by ip for connection limits // Filter by ip for connection limits
let ip_addr = flow.remote_address().ip_addr(); let ip_addr = flow.remote_address().ip_addr();
match inner.address_filter.add_connection(ip_addr) { if let Err(e) = self
Ok(()) => {} .network_manager()
Err(e) => { .with_address_filter_mut(|af| af.add_connection(ip_addr))
// Return the connection in the error to be disposed of {
return Err(ConnectionTableAddError::address_filter( // Return the connection in the error to be disposed of
network_connection, return Err(ConnectionTableAddError::address_filter(
e, network_connection,
)); e,
} ));
}; }
// if we have reached the maximum number of connections per protocol type // if we have reached the maximum number of connections per protocol type
// then drop the least recently used connection that is not protected or referenced // then drop the least recently used connection that is not protected or referenced
let out_conn = match Self::lru_out_connection_inner(&mut inner, protocol_index) { let out_conn = match self.lru_out_connection_inner(&mut inner, protocol_index) {
Ok(v) => v, Ok(v) => v,
Err(()) => { Err(()) => {
return Err(ConnectionTableAddError::table_full(network_connection)); return Err(ConnectionTableAddError::table_full(network_connection));
@ -437,7 +441,8 @@ impl ConnectionTable {
} }
#[instrument(level = "trace", skip(inner), ret)] #[instrument(level = "trace", skip(inner), ret)]
fn remove_connection_records( fn remove_connection_records_inner(
&self,
inner: &mut ConnectionTableInner, inner: &mut ConnectionTableInner,
id: NetworkConnectionId, id: NetworkConnectionId,
) -> NetworkConnection { ) -> NetworkConnection {
@ -462,9 +467,8 @@ impl ConnectionTable {
} }
// address_filter // address_filter
let ip_addr = remote.socket_addr().ip(); let ip_addr = remote.socket_addr().ip();
inner self.network_manager()
.address_filter .with_address_filter_mut(|af| af.remove_connection(ip_addr))
.remove_connection(ip_addr)
.expect("Inconsistency in connection table"); .expect("Inconsistency in connection table");
conn conn
} }
@ -477,7 +481,7 @@ impl ConnectionTable {
if !inner.conn_by_id[protocol_index].contains_key(&id) { if !inner.conn_by_id[protocol_index].contains_key(&id) {
return None; return None;
} }
let conn = Self::remove_connection_records(&mut inner, id); let conn = self.remove_connection_records_inner(&mut inner, id);
Some(conn) Some(conn)
} }

View File

@ -42,7 +42,6 @@ use native::*;
pub use native::{MAX_CAPABILITIES, PUBLIC_INTERNET_CAPABILITIES}; pub use native::{MAX_CAPABILITIES, PUBLIC_INTERNET_CAPABILITIES};
use routing_table::*; use routing_table::*;
use rpc_processor::*; use rpc_processor::*;
use storage_manager::*;
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))] #[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
use wasm::*; use wasm::*;
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))] #[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
@ -65,7 +64,6 @@ pub const HOLE_PUNCH_DELAY_MS: u32 = 100;
struct NetworkComponents { struct NetworkComponents {
net: Network, net: Network,
connection_manager: ConnectionManager, connection_manager: ConnectionManager,
rpc_processor: RPCProcessor,
receipt_manager: ReceiptManager, receipt_manager: ReceiptManager,
} }
@ -134,14 +132,17 @@ struct NetworkManagerInner {
socket_address_change_subscription: Option<EventBusSubscription>, socket_address_change_subscription: Option<EventBusSubscription>,
} }
#[derive(Debug)]
pub(crate) struct NetworkManager { pub(crate) struct NetworkManager {
registry: VeilidComponentRegistry, registry: VeilidComponentRegistry,
inner: Arc<Mutex<NetworkManagerInner>>, inner: Mutex<NetworkManagerInner>,
// Address filter
address_filter: RwLock<AddressFilter>,
// Accessors // Accessors
address_filter: RwLock<Option<AddressFilter>>,
components: RwLock<Option<NetworkComponents>>, components: RwLock<Option<NetworkComponents>>,
update_callback: RwLock<Option<UpdateCallback>>,
// Background processes // Background processes
rolling_transfers_task: TickTask<EyreReport>, rolling_transfers_task: TickTask<EyreReport>,
address_filter_task: TickTask<EyreReport>, address_filter_task: TickTask<EyreReport>,
@ -160,7 +161,6 @@ impl fmt::Debug for NetworkManager {
.field("inner", &self.inner) .field("inner", &self.inner)
.field("address_filter", &self.address_filter) .field("address_filter", &self.address_filter)
// .field("components", &self.components) // .field("components", &self.components)
// .field("update_callback", &self.update_callback)
// .field("rolling_transfers_task", &self.rolling_transfers_task) // .field("rolling_transfers_task", &self.rolling_transfers_task)
// .field("address_filter_task", &self.address_filter_task) // .field("address_filter_task", &self.address_filter_task)
.field("network_key", &self.network_key) .field("network_key", &self.network_key)
@ -212,12 +212,14 @@ impl NetworkManager {
network_key network_key
}; };
let inner = Self::new_inner();
let address_filter = AddressFilter::new(registry.clone());
let this = Self { let this = Self {
registry, registry,
inner: Arc::new(Mutex::new(Self::new_inner())), inner: Mutex::new(inner),
address_filter: RwLock::new(None), address_filter: RwLock::new(address_filter),
components: RwLock::new(None), components: RwLock::new(None),
update_callback: RwLock::new(None),
rolling_transfers_task: TickTask::new( rolling_transfers_task: TickTask::new(
"rolling_transfers_task", "rolling_transfers_task",
ROLLING_TRANSFERS_INTERVAL_SECS, ROLLING_TRANSFERS_INTERVAL_SECS,
@ -235,8 +237,24 @@ impl NetworkManager {
this this
} }
pub fn address_filter(&self) -> AddressFilter { pub fn with_address_filter_mut<F, R>(&self, callback: F) -> R
self.address_filter.read().as_ref().unwrap().clone() where
F: FnOnce(&mut AddressFilter) -> R,
{
let mut af = self.address_filter.write();
callback(&mut *af)
}
pub fn with_address_filter<F, R>(&self, callback: F) -> R
where
F: FnOnce(&AddressFilter) -> R,
{
let af = self.address_filter.read();
callback(&*af)
}
pub fn address_filter<'a>(&self) -> RwLockReadGuard<'a, AddressFilter> {
self.address_filter.read()
} }
fn net(&self) -> Network { fn net(&self) -> Network {
@ -277,16 +295,20 @@ impl NetworkManager {
Ok(()) Ok(())
} }
async fn post_init_async(&self) -> EyreResult<()> {} async fn post_init_async(&self) -> EyreResult<()> {
Ok(())
}
async fn pre_terminate_async(&self) {} async fn pre_terminate_async(&self) {}
#[instrument(level = "debug", skip_all)] #[instrument(level = "debug", skip_all)]
async fn terminate_async(&self) {} async fn terminate_async(&self) {
*self.address_filter.write() = None;
}
#[instrument(level = "debug", skip_all, err)] #[instrument(level = "debug", skip_all, err)]
pub async fn internal_startup(&self) -> EyreResult<StartupDisposition> { pub async fn internal_startup(&self) -> EyreResult<StartupDisposition> {
if self.unlocked_inner.components.read().is_some() { if self.components.read().is_some() {
log_net!(debug "NetworkManager::internal_startup already started"); log_net!(debug "NetworkManager::internal_startup already started");
return Ok(StartupDisposition::Success); return Ok(StartupDisposition::Success);
} }
@ -295,26 +317,12 @@ impl NetworkManager {
self.address_filter().restart(); self.address_filter().restart();
// Create network components // Create network components
let connection_manager = ConnectionManager::new(self.clone()); let connection_manager = ConnectionManager::new(self.registry());
let net = Network::new( let net = Network::new(self.registry(), connection_manager.clone());
self.clone(),
self.routing_table(),
connection_manager.clone(),
);
let rpc_processor = RPCProcessor::new(
self.clone(),
self.unlocked_inner
.update_callback
.read()
.as_ref()
.unwrap()
.clone(),
);
let receipt_manager = ReceiptManager::new(); let receipt_manager = ReceiptManager::new();
*self.unlocked_inner.components.write() = Some(NetworkComponents { *self.components.write() = Some(NetworkComponents {
net: net.clone(), net: net.clone(),
connection_manager: connection_manager.clone(), connection_manager: connection_manager.clone(),
rpc_processor: rpc_processor.clone(),
receipt_manager: receipt_manager.clone(), receipt_manager: receipt_manager.clone(),
}); });
@ -327,7 +335,7 @@ impl NetworkManager {
} }
} }
let (detect_address_changes, ip6_prefix_size) = self.with_config(|c| { let (detect_address_changes, ip6_prefix_size) = self.config().with(|c| {
( (
c.network.detect_address_changes, c.network.detect_address_changes,
c.network.max_connections_per_ip6_prefix_size as usize, c.network.max_connections_per_ip6_prefix_size as usize,
@ -353,7 +361,6 @@ impl NetworkManager {
inner.socket_address_change_subscription = Some(socket_address_change_subscription); inner.socket_address_change_subscription = Some(socket_address_change_subscription);
} }
rpc_processor.startup().await?;
receipt_manager.startup().await?; receipt_manager.startup().await?;
log_net!("NetworkManager::internal_startup end"); log_net!("NetworkManager::internal_startup end");
@ -363,7 +370,7 @@ impl NetworkManager {
#[instrument(level = "debug", skip_all, err)] #[instrument(level = "debug", skip_all, err)]
pub async fn startup(&self) -> EyreResult<StartupDisposition> { pub async fn startup(&self) -> EyreResult<StartupDisposition> {
let guard = self.unlocked_inner.startup_lock.startup()?; let guard = self.startup_lock.startup()?;
match self.internal_startup().await { match self.internal_startup().await {
Ok(StartupDisposition::Success) => { Ok(StartupDisposition::Success) => {
@ -406,15 +413,14 @@ impl NetworkManager {
log_net!(debug "shutting down network components"); log_net!(debug "shutting down network components");
{ {
let components = self.unlocked_inner.components.read().clone(); let components = self.components.read().clone();
if let Some(components) = components { if let Some(components) = components {
components.net.shutdown().await; components.net.shutdown().await;
components.rpc_processor.shutdown().await;
components.receipt_manager.shutdown().await; components.receipt_manager.shutdown().await;
components.connection_manager.shutdown().await; components.connection_manager.shutdown().await;
} }
} }
*self.unlocked_inner.components.write() = None; *self.components.write() = None;
// reset the state // reset the state
log_net!(debug "resetting network manager state"); log_net!(debug "resetting network manager state");
@ -427,7 +433,7 @@ impl NetworkManager {
pub async fn shutdown(&self) { pub async fn shutdown(&self) {
log_net!(debug "starting network manager shutdown"); log_net!(debug "starting network manager shutdown");
let Ok(guard) = self.unlocked_inner.startup_lock.shutdown().await else { let Ok(guard) = self.startup_lock.shutdown().await else {
log_net!(debug "network manager is already shut down"); log_net!(debug "network manager is already shut down");
return; return;
}; };
@ -472,7 +478,9 @@ impl NetworkManager {
} }
pub fn purge_client_allowlist(&self) { pub fn purge_client_allowlist(&self) {
let timeout_ms = self.with_config(|c| c.network.client_allowlist_timeout_ms); let timeout_ms = self
.config()
.with(|c| c.network.client_allowlist_timeout_ms);
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
let cutoff_timestamp = let cutoff_timestamp =
Timestamp::now() - TimestampDuration::new((timeout_ms as u64) * 1000u64); Timestamp::now() - TimestampDuration::new((timeout_ms as u64) * 1000u64);
@ -511,11 +519,12 @@ impl NetworkManager {
extra_data: D, extra_data: D,
callback: impl ReceiptCallback, callback: impl ReceiptCallback,
) -> EyreResult<Vec<u8>> { ) -> EyreResult<Vec<u8>> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_lock.enter() else {
bail!("network is not started"); bail!("network is not started");
}; };
let receipt_manager = self.receipt_manager(); let receipt_manager = self.receipt_manager();
let routing_table = self.routing_table(); let routing_table = self.routing_table();
let crypto = self.crypto();
// Generate receipt and serialized form to return // Generate receipt and serialized form to return
let vcrypto = self.crypto().best(); let vcrypto = self.crypto().best();
@ -532,7 +541,7 @@ impl NetworkManager {
extra_data, extra_data,
)?; )?;
let out = receipt let out = receipt
.to_signed_data(self.crypto(), &node_id_secret) .to_signed_data(&crypto, &node_id_secret)
.wrap_err("failed to generate signed receipt")?; .wrap_err("failed to generate signed receipt")?;
// Record the receipt for later // Record the receipt for later
@ -549,12 +558,13 @@ impl NetworkManager {
expiration_us: TimestampDuration, expiration_us: TimestampDuration,
extra_data: D, extra_data: D,
) -> EyreResult<(Vec<u8>, EventualValueFuture<ReceiptEvent>)> { ) -> EyreResult<(Vec<u8>, EventualValueFuture<ReceiptEvent>)> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_lock.enter() else {
bail!("network is not started"); bail!("network is not started");
}; };
let receipt_manager = self.receipt_manager(); let receipt_manager = self.receipt_manager();
let routing_table = self.routing_table(); let routing_table = self.routing_table();
let crypto = self.crypto();
// Generate receipt and serialized form to return // Generate receipt and serialized form to return
let vcrypto = self.crypto().best(); let vcrypto = self.crypto().best();
@ -571,7 +581,7 @@ impl NetworkManager {
extra_data, extra_data,
)?; )?;
let out = receipt let out = receipt
.to_signed_data(self.crypto(), &node_id_secret) .to_signed_data(&crypto, &node_id_secret)
.wrap_err("failed to generate signed receipt")?; .wrap_err("failed to generate signed receipt")?;
// Record the receipt for later // Record the receipt for later
@ -589,13 +599,14 @@ impl NetworkManager {
&self, &self,
receipt_data: R, receipt_data: R,
) -> NetworkResult<()> { ) -> NetworkResult<()> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_lock.enter() else {
return NetworkResult::service_unavailable("network is not started"); return NetworkResult::service_unavailable("network is not started");
}; };
let receipt_manager = self.receipt_manager(); let receipt_manager = self.receipt_manager();
let crypto = self.crypto();
let receipt = match Receipt::from_signed_data(self.crypto(), receipt_data.as_ref()) { let receipt = match Receipt::from_signed_data(&crypto, receipt_data.as_ref()) {
Err(e) => { Err(e) => {
return NetworkResult::invalid_message(e.to_string()); return NetworkResult::invalid_message(e.to_string());
} }
@ -614,13 +625,14 @@ impl NetworkManager {
receipt_data: R, receipt_data: R,
inbound_noderef: FilteredNodeRef, inbound_noderef: FilteredNodeRef,
) -> NetworkResult<()> { ) -> NetworkResult<()> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_lock.enter() else {
return NetworkResult::service_unavailable("network is not started"); return NetworkResult::service_unavailable("network is not started");
}; };
let receipt_manager = self.receipt_manager(); let receipt_manager = self.receipt_manager();
let crypto = self.crypto();
let receipt = match Receipt::from_signed_data(self.crypto(), receipt_data.as_ref()) { let receipt = match Receipt::from_signed_data(&crypto, receipt_data.as_ref()) {
Err(e) => { Err(e) => {
return NetworkResult::invalid_message(e.to_string()); return NetworkResult::invalid_message(e.to_string());
} }
@ -638,13 +650,14 @@ impl NetworkManager {
&self, &self,
receipt_data: R, receipt_data: R,
) -> NetworkResult<()> { ) -> NetworkResult<()> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_lock.enter() else {
return NetworkResult::service_unavailable("network is not started"); return NetworkResult::service_unavailable("network is not started");
}; };
let receipt_manager = self.receipt_manager(); let receipt_manager = self.receipt_manager();
let crypto = self.crypto();
let receipt = match Receipt::from_signed_data(self.crypto(), receipt_data.as_ref()) { let receipt = match Receipt::from_signed_data(&crypto, receipt_data.as_ref()) {
Err(e) => { Err(e) => {
return NetworkResult::invalid_message(e.to_string()); return NetworkResult::invalid_message(e.to_string());
} }
@ -663,13 +676,14 @@ impl NetworkManager {
receipt_data: R, receipt_data: R,
private_route: PublicKey, private_route: PublicKey,
) -> NetworkResult<()> { ) -> NetworkResult<()> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_lock.enter() else {
return NetworkResult::service_unavailable("network is not started"); return NetworkResult::service_unavailable("network is not started");
}; };
let receipt_manager = self.receipt_manager(); let receipt_manager = self.receipt_manager();
let crypto = self.crypto();
let receipt = match Receipt::from_signed_data(self.crypto(), receipt_data.as_ref()) { let receipt = match Receipt::from_signed_data(&crypto, receipt_data.as_ref()) {
Err(e) => { Err(e) => {
return NetworkResult::invalid_message(e.to_string()); return NetworkResult::invalid_message(e.to_string());
} }
@ -688,7 +702,7 @@ impl NetworkManager {
signal_flow: Flow, signal_flow: Flow,
signal_info: SignalInfo, signal_info: SignalInfo,
) -> EyreResult<NetworkResult<()>> { ) -> EyreResult<NetworkResult<()>> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_lock.enter() else {
return Ok(NetworkResult::service_unavailable("network is not started")); return Ok(NetworkResult::service_unavailable("network is not started"));
}; };
@ -788,7 +802,8 @@ impl NetworkManager {
) -> EyreResult<Vec<u8>> { ) -> EyreResult<Vec<u8>> {
// DH to get encryption key // DH to get encryption key
let routing_table = self.routing_table(); let routing_table = self.routing_table();
let Some(vcrypto) = self.crypto().get(dest_node_id.kind) else { let crypto = self.crypto();
let Some(vcrypto) = crypto.get(dest_node_id.kind) else {
bail!("should not have a destination with incompatible crypto here"); bail!("should not have a destination with incompatible crypto here");
}; };
@ -809,12 +824,7 @@ impl NetworkManager {
dest_node_id.value, dest_node_id.value,
); );
envelope envelope
.to_encrypted_data( .to_encrypted_data(&crypto, body.as_ref(), &node_id_secret, &self.network_key)
self.crypto(),
body.as_ref(),
&node_id_secret,
&self.unlocked_inner.network_key,
)
.wrap_err("envelope failed to encode") .wrap_err("envelope failed to encode")
} }
@ -829,7 +839,7 @@ impl NetworkManager {
destination_node_ref: Option<NodeRef>, destination_node_ref: Option<NodeRef>,
body: B, body: B,
) -> EyreResult<NetworkResult<SendDataMethod>> { ) -> EyreResult<NetworkResult<SendDataMethod>> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_lock.enter() else {
return Ok(NetworkResult::no_connection_other("network is not started")); return Ok(NetworkResult::no_connection_other("network is not started"));
}; };
@ -870,7 +880,7 @@ impl NetworkManager {
dial_info: DialInfo, dial_info: DialInfo,
rcpt_data: Vec<u8>, rcpt_data: Vec<u8>,
) -> EyreResult<()> { ) -> EyreResult<()> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_lock.enter() else {
log_net!(debug "not sending out-of-band receipt to {} because network is stopped", dial_info); log_net!(debug "not sending out-of-band receipt to {} because network is stopped", dial_info);
return Ok(()); return Ok(());
}; };
@ -897,7 +907,7 @@ impl NetworkManager {
// and passes it to the RPC handler // and passes it to the RPC handler
#[instrument(level = "trace", target = "net", skip_all)] #[instrument(level = "trace", target = "net", skip_all)]
async fn on_recv_envelope(&self, data: &mut [u8], flow: Flow) -> EyreResult<bool> { async fn on_recv_envelope(&self, data: &mut [u8], flow: Flow) -> EyreResult<bool> {
let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else { let Ok(_guard) = self.startup_lock.enter() else {
return Ok(false); return Ok(false);
}; };
@ -947,21 +957,20 @@ impl NetworkManager {
} }
// Decode envelope header (may fail signature validation) // Decode envelope header (may fail signature validation)
let envelope = let crypto = self.crypto();
match Envelope::from_signed_data(self.crypto(), data, &self.unlocked_inner.network_key) let envelope = match Envelope::from_signed_data(&crypto, data, &self.network_key) {
{ Ok(v) => v,
Ok(v) => v, Err(e) => {
Err(e) => { log_net!(debug "envelope failed to decode: {}", e);
log_net!(debug "envelope failed to decode: {}", e); // safe to punish here because relays also check here to ensure they arent forwarding things that don't decode
// safe to punish here because relays also check here to ensure they arent forwarding things that don't decode self.address_filter()
self.address_filter() .punish_ip_addr(remote_addr, PunishmentReason::FailedToDecodeEnvelope);
.punish_ip_addr(remote_addr, PunishmentReason::FailedToDecodeEnvelope); return Ok(false);
return Ok(false); }
} };
};
// Get timestamp range // Get timestamp range
let (tsbehind, tsahead) = self.with_config(|c| { let (tsbehind, tsahead) = self.config().with(|c| {
( (
c.network c.network
.rpc .rpc
@ -1040,7 +1049,10 @@ impl NetworkManager {
// which only performs a lightweight lookup before passing the packet back out // which only performs a lightweight lookup before passing the packet back out
// If our node has the relay capability disabled, we should not be asked to relay // If our node has the relay capability disabled, we should not be asked to relay
if self.with_config(|c| c.capabilities.disable.contains(&CAP_RELAY)) { if self
.config()
.with(|c| c.capabilities.disable.contains(&CAP_RELAY))
{
log_net!(debug "node has relay capability disabled, dropping relayed envelope from {} to {}", sender_id, recipient_id); log_net!(debug "node has relay capability disabled, dropping relayed envelope from {} to {}", sender_id, recipient_id);
return Ok(false); return Ok(false);
} }
@ -1095,12 +1107,8 @@ impl NetworkManager {
let node_id_secret = routing_table.node_id_secret_key(envelope.get_crypto_kind()); let node_id_secret = routing_table.node_id_secret_key(envelope.get_crypto_kind());
// Decrypt the envelope body // Decrypt the envelope body
let body = match envelope.decrypt_body( let crypto = self.crypto();
self.crypto(), let body = match envelope.decrypt_body(&crypto, data, &node_id_secret, &self.network_key) {
data,
&node_id_secret,
&self.unlocked_inner.network_key,
) {
Ok(v) => v, Ok(v) => v,
Err(e) => { Err(e) => {
log_net!(debug "failed to decrypt envelope body: {}", e); log_net!(debug "failed to decrypt envelope body: {}", e);

View File

@ -57,11 +57,7 @@ pub(super) struct DiscoveryContext {
inner: Arc<Mutex<DiscoveryContextInner>>, inner: Arc<Mutex<DiscoveryContextInner>>,
} }
impl VeilidComponentRegistryAccessor for DiscoveryContext { impl_veilid_component_registry_accessor!(DiscoveryContext);
fn registry(&self) -> VeilidComponentRegistry {
self.registry.clone()
}
}
impl DiscoveryContext { impl DiscoveryContext {
pub fn new(registry: VeilidComponentRegistry, config: DiscoveryContextConfig) -> Self { pub fn new(registry: VeilidComponentRegistry, config: DiscoveryContextConfig) -> Self {
@ -136,11 +132,9 @@ impl DiscoveryContext {
// This is done over the normal port using RPC // This is done over the normal port using RPC
#[instrument(level = "trace", skip(self), ret)] #[instrument(level = "trace", skip(self), ret)]
async fn discover_external_addresses(&self) -> bool { async fn discover_external_addresses(&self) -> bool {
let node_count = { let node_count = self
let config = self.registry.config(); .config()
let c = config.get(); .with(|c| c.network.dht.max_find_node_count as usize);
c.network.dht.max_find_node_count as usize
};
let routing_domain = RoutingDomain::PublicInternet; let routing_domain = RoutingDomain::PublicInternet;
let protocol_type = self.unlocked_inner.config.protocol_type; let protocol_type = self.unlocked_inner.config.protocol_type;
@ -213,8 +207,8 @@ impl DiscoveryContext {
async move { async move {
if let Some(address) = this.request_public_address(node.clone()).await { if let Some(address) = this.request_public_address(node.clone()).await {
let dial_info = this let dial_info = this
.unlocked_inner .network_manager()
.net .net()
.make_dial_info(address, protocol_type); .make_dial_info(address, protocol_type);
return Some(ExternalInfo { return Some(ExternalInfo {
dial_info, dial_info,
@ -298,10 +292,9 @@ impl DiscoveryContext {
dial_info: DialInfo, dial_info: DialInfo,
redirect: bool, redirect: bool,
) -> bool { ) -> bool {
let rpc_processor = self.unlocked_inner.routing_table.rpc_processor();
// ask the node to send us a dial info validation receipt // ask the node to send us a dial info validation receipt
match rpc_processor match self
.rpc_processor()
.rpc_call_validate_dial_info(node_ref.clone(), dial_info, redirect) .rpc_call_validate_dial_info(node_ref.clone(), dial_info, redirect)
.await .await
{ {
@ -330,7 +323,12 @@ impl DiscoveryContext {
let external_1 = self.inner.lock().external_info.first().unwrap().clone(); let external_1 = self.inner.lock().external_info.first().unwrap().clone();
let igd_manager = self.unlocked_inner.net.unlocked_inner.igd_manager.clone(); let igd_manager = self
.network_manager()
.net()
.unlocked_inner
.igd_manager
.clone();
let mut tries = 0; let mut tries = 0;
loop { loop {
tries += 1; tries += 1;
@ -346,7 +344,7 @@ impl DiscoveryContext {
.await?; .await?;
// Make dial info from the port mapping // Make dial info from the port mapping
let external_mapped_dial_info = self.unlocked_inner.net.make_dial_info( let external_mapped_dial_info = self.network_manager().net().make_dial_info(
SocketAddress::from_socket_addr(mapped_external_address), SocketAddress::from_socket_addr(mapped_external_address),
protocol_type, protocol_type,
); );
@ -561,10 +559,7 @@ impl DiscoveryContext {
/////////// ///////////
let this = self.clone(); let this = self.clone();
let do_nat_detect_fut: SendPinBoxFuture<Option<DetectionResult>> = Box::pin(async move { let do_nat_detect_fut: SendPinBoxFuture<Option<DetectionResult>> = Box::pin(async move {
let mut retry_count = { let mut retry_count = this.config().with(|c| c.network.restricted_nat_retries);
let c = this.unlocked_inner.net.config.get();
c.network.restricted_nat_retries
};
// Loop for restricted NAT retries // Loop for restricted NAT retries
loop { loop {
@ -681,10 +676,7 @@ impl DiscoveryContext {
&self, &self,
unord: &mut FuturesUnordered<SendPinBoxFuture<Option<DetectionResult>>>, unord: &mut FuturesUnordered<SendPinBoxFuture<Option<DetectionResult>>>,
) { ) {
let enable_upnp = { let enable_upnp = self.config().with(|c| c.network.upnp);
let c = self.unlocked_inner.net.config.get();
c.network.upnp
};
// Do this right away because it's fast and every detection is going to need it // Do this right away because it's fast and every detection is going to need it
// Get our external addresses from two fast nodes // Get our external addresses from two fast nodes

View File

@ -117,10 +117,6 @@ struct NetworkUnlockedInner {
// Startup lock // Startup lock
startup_lock: StartupLock, startup_lock: StartupLock,
// Accessors
routing_table: RoutingTable,
network_manager: NetworkManager,
connection_manager: ConnectionManager,
// Network // Network
interfaces: NetworkInterfaces, interfaces: NetworkInterfaces,
// Background processes // Background processes
@ -135,11 +131,13 @@ struct NetworkUnlockedInner {
#[derive(Clone)] #[derive(Clone)]
pub(super) struct Network { pub(super) struct Network {
config: VeilidConfig, registry: VeilidComponentRegistry,
inner: Arc<Mutex<NetworkInner>>, inner: Arc<Mutex<NetworkInner>>,
unlocked_inner: Arc<NetworkUnlockedInner>, unlocked_inner: Arc<NetworkUnlockedInner>,
} }
impl_veilid_component_registry_accessor!(Network);
impl Network { impl Network {
fn new_inner() -> NetworkInner { fn new_inner() -> NetworkInner {
NetworkInner { NetworkInner {
@ -161,18 +159,11 @@ impl Network {
} }
} }
fn new_unlocked_inner( fn new_unlocked_inner(registry: VeilidComponentRegistry) -> NetworkUnlockedInner {
network_manager: NetworkManager, let config = registry.config();
routing_table: RoutingTable,
connection_manager: ConnectionManager,
) -> NetworkUnlockedInner {
let config = network_manager.config();
let program_name = config.get().program_name.clone(); let program_name = config.get().program_name.clone();
NetworkUnlockedInner { NetworkUnlockedInner {
startup_lock: StartupLock::new(), startup_lock: StartupLock::new(),
network_manager,
routing_table,
connection_manager,
interfaces: NetworkInterfaces::new(), interfaces: NetworkInterfaces::new(),
update_network_class_task: TickTask::new( update_network_class_task: TickTask::new(
"update_network_class_task", "update_network_class_task",
@ -188,19 +179,11 @@ impl Network {
} }
} }
pub fn new( pub fn new(registry: VeilidComponentRegistry) -> Self {
network_manager: NetworkManager,
routing_table: RoutingTable,
connection_manager: ConnectionManager,
) -> Self {
let this = Self { let this = Self {
config: network_manager.config(),
inner: Arc::new(Mutex::new(Self::new_inner())), inner: Arc::new(Mutex::new(Self::new_inner())),
unlocked_inner: Arc::new(Self::new_unlocked_inner( unlocked_inner: Arc::new(Self::new_unlocked_inner(registry.clone())),
network_manager, registry,
routing_table,
connection_manager,
)),
}; };
this.setup_tasks(); this.setup_tasks();
@ -208,18 +191,6 @@ impl Network {
this this
} }
fn network_manager(&self) -> NetworkManager {
self.unlocked_inner.network_manager.clone()
}
fn routing_table(&self) -> RoutingTable {
self.unlocked_inner.routing_table.clone()
}
fn connection_manager(&self) -> ConnectionManager {
self.unlocked_inner.connection_manager.clone()
}
fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> { fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> {
let cvec = certs(&mut BufReader::new(File::open(path)?)) let cvec = certs(&mut BufReader::new(File::open(path)?))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid TLS certificate"))?; .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid TLS certificate"))?;
@ -249,7 +220,8 @@ impl Network {
} }
fn load_server_config(&self) -> io::Result<ServerConfig> { fn load_server_config(&self) -> io::Result<ServerConfig> {
let c = self.config.get(); let config = self.config();
let c = config.get();
// //
log_net!( log_net!(
"loading certificate from {}", "loading certificate from {}",
@ -356,10 +328,9 @@ impl Network {
dial_info.clone(), dial_info.clone(),
async move { async move {
let data_len = data.len(); let data_len = data.len();
let connect_timeout_ms = { let connect_timeout_ms = self
let c = self.config.get(); .config()
c.network.connection_initial_timeout_ms .with(|c| c.network.connection_initial_timeout_ms);
};
if self if self
.network_manager() .network_manager()
@ -372,10 +343,12 @@ impl Network {
match dial_info.protocol_type() { match dial_info.protocol_type() {
ProtocolType::UDP => { ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr(); let peer_socket_addr = dial_info.to_socket_addr();
let h = let h = RawUdpProtocolHandler::new_unspecified_bound_handler(
RawUdpProtocolHandler::new_unspecified_bound_handler(&peer_socket_addr) self.registry(),
.await &peer_socket_addr,
.wrap_err("create socket failure")?; )
.await
.wrap_err("create socket failure")?;
let _ = network_result_try!(h let _ = network_result_try!(h
.send_message(data, peer_socket_addr) .send_message(data, peer_socket_addr)
.await .await
@ -433,10 +406,9 @@ impl Network {
dial_info.clone(), dial_info.clone(),
async move { async move {
let data_len = data.len(); let data_len = data.len();
let connect_timeout_ms = { let connect_timeout_ms = self
let c = self.config.get(); .config()
c.network.connection_initial_timeout_ms .with(|c| c.network.connection_initial_timeout_ms);
};
if self if self
.network_manager() .network_manager()
@ -449,10 +421,12 @@ impl Network {
match dial_info.protocol_type() { match dial_info.protocol_type() {
ProtocolType::UDP => { ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr(); let peer_socket_addr = dial_info.to_socket_addr();
let h = let h = RawUdpProtocolHandler::new_unspecified_bound_handler(
RawUdpProtocolHandler::new_unspecified_bound_handler(&peer_socket_addr) self.registry(),
.await &peer_socket_addr,
.wrap_err("create socket failure")?; )
.await
.wrap_err("create socket failure")?;
network_result_try!(h network_result_try!(h
.send_message(data, peer_socket_addr) .send_message(data, peer_socket_addr)
.await .await
@ -577,7 +551,11 @@ impl Network {
// Handle connection-oriented protocols // Handle connection-oriented protocols
// Try to send to the exact existing connection if one exists // Try to send to the exact existing connection if one exists
if let Some(conn) = self.connection_manager().get_connection(flow) { if let Some(conn) = self
.network_manager()
.connection_manager()
.get_connection(flow)
{
// connection exists, send over it // connection exists, send over it
match conn.send_async(data).await { match conn.send_async(data).await {
ConnectionHandleSendResult::Sent => { ConnectionHandleSendResult::Sent => {
@ -639,7 +617,8 @@ impl Network {
} else { } else {
// Handle connection-oriented protocols // Handle connection-oriented protocols
let conn = network_result_try!( let conn = network_result_try!(
self.connection_manager() self.network_manager()
.connection_manager()
.get_or_create_connection(dial_info.clone()) .get_or_create_connection(dial_info.clone())
.await? .await?
); );
@ -682,14 +661,9 @@ impl Network {
} }
// Start editing routing table // Start editing routing table
let mut editor_public_internet = self let routing_table = self.routing_table();
.unlocked_inner let mut editor_public_internet = routing_table.edit_public_internet_routing_domain();
.routing_table let mut editor_local_network = routing_table.edit_local_network_routing_domain();
.edit_public_internet_routing_domain();
let mut editor_local_network = self
.unlocked_inner
.routing_table
.edit_local_network_routing_domain();
// Setup network // Setup network
editor_local_network.set_local_networks(network_state.local_networks); editor_local_network.set_local_networks(network_state.local_networks);
@ -767,8 +741,8 @@ impl Network {
#[instrument(level = "debug", err, skip_all)] #[instrument(level = "debug", err, skip_all)]
pub(super) async fn register_all_dial_info( pub(super) async fn register_all_dial_info(
&self, &self,
editor_public_internet: &mut RoutingDomainEditorPublicInternet, editor_public_internet: &mut RoutingDomainEditorPublicInternet<'_>,
editor_local_network: &mut RoutingDomainEditorLocalNetwork, editor_local_network: &mut RoutingDomainEditorLocalNetwork<'_>,
) -> EyreResult<()> { ) -> EyreResult<()> {
let Some(protocol_config) = ({ let Some(protocol_config) = ({
let inner = self.inner.lock(); let inner = self.inner.lock();

View File

@ -107,7 +107,8 @@ impl Network {
// Get protocol config // Get protocol config
let protocol_config = { let protocol_config = {
let c = self.config.get(); let config = self.config();
let c = config.get();
let mut inbound = ProtocolTypeSet::new(); let mut inbound = ProtocolTypeSet::new();
if c.network.protocol.udp.enabled { if c.network.protocol.udp.enabled {

View File

@ -121,8 +121,11 @@ impl Network {
} }
}; };
// Check to see if it is punished // Check to see if it is punished
let address_filter = self.network_manager().address_filter(); if self
if address_filter.is_ip_addr_punished(peer_addr.ip()) { .network_manager()
.address_filter()
.is_ip_addr_punished(peer_addr.ip())
{
return; return;
} }
@ -221,13 +224,13 @@ impl Network {
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
async fn spawn_socket_listener(&self, addr: SocketAddr) -> EyreResult<bool> { async fn spawn_socket_listener(&self, addr: SocketAddr) -> EyreResult<bool> {
// Get config // Get config
let (connection_initial_timeout_ms, tls_connection_initial_timeout_ms) = { let (connection_initial_timeout_ms, tls_connection_initial_timeout_ms) =
let c = self.config.get(); self.config().with(|c| {
( (
c.network.connection_initial_timeout_ms, c.network.connection_initial_timeout_ms,
c.network.tls.connection_initial_timeout_ms, c.network.tls.connection_initial_timeout_ms,
) )
}; });
// Create a shared socket and bind it once we have determined the port is free // Create a shared socket and bind it once we have determined the port is free
let Some(listener) = bind_async_tcp_listener(addr)? else { let Some(listener) = bind_async_tcp_listener(addr)? else {
@ -246,7 +249,7 @@ impl Network {
// Spawn the socket task // Spawn the socket task
let this = self.clone(); let this = self.clone();
let stop_token = self.inner.lock().stop_source.as_ref().unwrap().token(); let stop_token = self.inner.lock().stop_source.as_ref().unwrap().token();
let connection_manager = self.connection_manager(); let connection_manager = self.network_manager().connection_manager();
//////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////
let jh = spawn(&format!("TCP listener {}", addr), async move { let jh = spawn(&format!("TCP listener {}", addr), async move {

View File

@ -5,10 +5,9 @@ impl Network {
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub(super) async fn create_udp_listener_tasks(&self) -> EyreResult<()> { pub(super) async fn create_udp_listener_tasks(&self) -> EyreResult<()> {
// Spawn socket tasks // Spawn socket tasks
let mut task_count = { let mut task_count = self
let c = self.config.get(); .config()
c.network.protocol.udp.socket_pool_size .with(|c| c.network.protocol.udp.socket_pool_size);
};
if task_count == 0 { if task_count == 0 {
task_count = get_concurrency() / 2; task_count = get_concurrency() / 2;
if task_count == 0 { if task_count == 0 {
@ -37,7 +36,6 @@ impl Network {
// Spawn a local async task for each socket // Spawn a local async task for each socket
let mut protocol_handlers_unordered = FuturesUnordered::new(); let mut protocol_handlers_unordered = FuturesUnordered::new();
let network_manager = this.network_manager();
let stop_token = { let stop_token = {
let inner = this.inner.lock(); let inner = this.inner.lock();
if inner.stop_source.is_none() { if inner.stop_source.is_none() {
@ -48,7 +46,7 @@ impl Network {
}; };
for ph in protocol_handlers { for ph in protocol_handlers {
let network_manager = network_manager.clone(); let network_manager = this.network_manager();
let stop_token = stop_token.clone(); let stop_token = stop_token.clone();
let ph_future = async move { let ph_future = async move {
let mut data = vec![0u8; 65536]; let mut data = vec![0u8; 65536];
@ -120,8 +118,7 @@ impl Network {
let socket_arc = Arc::new(udp_socket); let socket_arc = Arc::new(udp_socket);
// Create protocol handler // Create protocol handler
let protocol_handler = let protocol_handler = RawUdpProtocolHandler::new(self.registry(), socket_arc);
RawUdpProtocolHandler::new(socket_arc, Some(self.network_manager().address_filter()));
// Record protocol handler // Record protocol handler
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();

View File

@ -21,7 +21,7 @@ impl ProtocolNetworkConnection {
local_address: Option<SocketAddr>, local_address: Option<SocketAddr>,
dial_info: &DialInfo, dial_info: &DialInfo,
timeout_ms: u32, timeout_ms: u32,
address_filter: AddressFilter, address_filter: &AddressFilter,
) -> io::Result<NetworkResult<ProtocolNetworkConnection>> { ) -> io::Result<NetworkResult<ProtocolNetworkConnection>> {
if address_filter.is_ip_addr_punished(dial_info.address().ip_addr()) { if address_filter.is_ip_addr_punished(dial_info.address().ip_addr()) {
return Ok(NetworkResult::no_connection_other("punished")); return Ok(NetworkResult::no_connection_other("punished"));

View File

@ -2,17 +2,19 @@ use super::*;
#[derive(Clone)] #[derive(Clone)]
pub struct RawUdpProtocolHandler { pub struct RawUdpProtocolHandler {
registry: VeilidComponentRegistry,
socket: Arc<UdpSocket>, socket: Arc<UdpSocket>,
assembly_buffer: AssemblyBuffer, assembly_buffer: AssemblyBuffer,
address_filter: Option<AddressFilter>,
} }
impl_veilid_component_registry_accessor!(RawUdpProtocolHandler);
impl RawUdpProtocolHandler { impl RawUdpProtocolHandler {
pub fn new(socket: Arc<UdpSocket>, address_filter: Option<AddressFilter>) -> Self { pub fn new(registry: VeilidComponentRegistry, socket: Arc<UdpSocket>) -> Self {
Self { Self {
registry,
socket, socket,
assembly_buffer: AssemblyBuffer::new(), assembly_buffer: AssemblyBuffer::new(),
address_filter,
} }
} }
@ -23,10 +25,12 @@ impl RawUdpProtocolHandler {
let (size, remote_addr) = network_result_value_or_log!(self.socket.recv_from(data).await.into_network_result()? => continue); let (size, remote_addr) = network_result_value_or_log!(self.socket.recv_from(data).await.into_network_result()? => continue);
// Check to see if it is punished // Check to see if it is punished
if let Some(af) = self.address_filter.as_ref() { if self
if af.is_ip_addr_punished(remote_addr.ip()) { .network_manager()
continue; .address_filter()
} .is_ip_addr_punished(remote_addr.ip())
{
continue;
} }
// Insert into assembly buffer // Insert into assembly buffer
@ -90,10 +94,12 @@ impl RawUdpProtocolHandler {
} }
// Check to see if it is punished // Check to see if it is punished
if let Some(af) = self.address_filter.as_ref() { if self
if af.is_ip_addr_punished(remote_addr.ip()) { .network_manager()
return Ok(NetworkResult::no_connection_other("punished")); .address_filter()
} .is_ip_addr_punished(remote_addr.ip())
{
return Ok(NetworkResult::no_connection_other("punished"));
} }
// Fragment and send // Fragment and send
@ -136,12 +142,13 @@ impl RawUdpProtocolHandler {
#[instrument(level = "trace", target = "protocol", err)] #[instrument(level = "trace", target = "protocol", err)]
pub async fn new_unspecified_bound_handler( pub async fn new_unspecified_bound_handler(
registry: VeilidComponentRegistry,
socket_addr: &SocketAddr, socket_addr: &SocketAddr,
) -> io::Result<RawUdpProtocolHandler> { ) -> io::Result<RawUdpProtocolHandler> {
// get local wildcard address for bind // get local wildcard address for bind
let local_socket_addr = compatible_unspecified_socket_addr(socket_addr); let local_socket_addr = compatible_unspecified_socket_addr(socket_addr);
let socket = bind_async_udp_socket(local_socket_addr)? let socket = bind_async_udp_socket(local_socket_addr)?
.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?; .ok_or(io::Error::from(io::ErrorKind::AddrInUse))?;
Ok(RawUdpProtocolHandler::new(Arc::new(socket), None)) Ok(RawUdpProtocolHandler::new(registry, Arc::new(socket)))
} }
} }

View File

@ -140,14 +140,13 @@ impl Network {
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub(super) async fn bind_udp_protocol_handlers(&self) -> EyreResult<StartupDisposition> { pub(super) async fn bind_udp_protocol_handlers(&self) -> EyreResult<StartupDisposition> {
log_net!("UDP: binding protocol handlers"); log_net!("UDP: binding protocol handlers");
let (listen_address, public_address, detect_address_changes) = { let (listen_address, public_address, detect_address_changes) = self.config().with(|c| {
let c = self.config.get();
( (
c.network.protocol.udp.listen_address.clone(), c.network.protocol.udp.listen_address.clone(),
c.network.protocol.udp.public_address.clone(), c.network.protocol.udp.public_address.clone(),
c.network.detect_address_changes, c.network.detect_address_changes,
) )
}; });
// Get the binding parameters from the user-specified listen address // Get the binding parameters from the user-specified listen address
let bind_set = self let bind_set = self
@ -187,18 +186,17 @@ impl Network {
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub(super) async fn register_udp_dial_info( pub(super) async fn register_udp_dial_info(
&self, &self,
editor_public_internet: &mut RoutingDomainEditorPublicInternet, editor_public_internet: &mut RoutingDomainEditorPublicInternet<'_>,
editor_local_network: &mut RoutingDomainEditorLocalNetwork, editor_local_network: &mut RoutingDomainEditorLocalNetwork<'_>,
) -> EyreResult<()> { ) -> EyreResult<()> {
log_net!("UDP: registering dial info"); log_net!("UDP: registering dial info");
let (public_address, detect_address_changes) = { let (public_address, detect_address_changes) = self.config().with(|c| {
let c = self.config.get();
( (
c.network.protocol.udp.public_address.clone(), c.network.protocol.udp.public_address.clone(),
c.network.detect_address_changes, c.network.detect_address_changes,
) )
}; });
let local_dial_info_list = { let local_dial_info_list = {
let mut out = vec![]; let mut out = vec![];
@ -263,14 +261,13 @@ impl Network {
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub(super) async fn start_ws_listeners(&self) -> EyreResult<StartupDisposition> { pub(super) async fn start_ws_listeners(&self) -> EyreResult<StartupDisposition> {
log_net!("WS: binding protocol handlers"); log_net!("WS: binding protocol handlers");
let (listen_address, url, detect_address_changes) = { let (listen_address, url, detect_address_changes) = self.config().with(|c| {
let c = self.config.get();
( (
c.network.protocol.ws.listen_address.clone(), c.network.protocol.ws.listen_address.clone(),
c.network.protocol.ws.url.clone(), c.network.protocol.ws.url.clone(),
c.network.detect_address_changes, c.network.detect_address_changes,
) )
}; });
// Get the binding parameters from the user-specified listen address // Get the binding parameters from the user-specified listen address
let bind_set = self let bind_set = self
@ -313,18 +310,17 @@ impl Network {
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub(super) async fn register_ws_dial_info( pub(super) async fn register_ws_dial_info(
&self, &self,
editor_public_internet: &mut RoutingDomainEditorPublicInternet, editor_public_internet: &mut RoutingDomainEditorPublicInternet<'_>,
editor_local_network: &mut RoutingDomainEditorLocalNetwork, editor_local_network: &mut RoutingDomainEditorLocalNetwork<'_>,
) -> EyreResult<()> { ) -> EyreResult<()> {
log_net!("WS: registering dial info"); log_net!("WS: registering dial info");
let (url, path, detect_address_changes) = { let (url, path, detect_address_changes) = self.config().with(|c| {
let c = self.config.get();
( (
c.network.protocol.ws.url.clone(), c.network.protocol.ws.url.clone(),
c.network.protocol.ws.path.clone(), c.network.protocol.ws.path.clone(),
c.network.detect_address_changes, c.network.detect_address_changes,
) )
}; });
let mut registered_addresses: HashSet<IpAddr> = HashSet::new(); let mut registered_addresses: HashSet<IpAddr> = HashSet::new();
@ -409,14 +405,13 @@ impl Network {
pub(super) async fn start_wss_listeners(&self) -> EyreResult<StartupDisposition> { pub(super) async fn start_wss_listeners(&self) -> EyreResult<StartupDisposition> {
log_net!("WSS: binding protocol handlers"); log_net!("WSS: binding protocol handlers");
let (listen_address, url, detect_address_changes) = { let (listen_address, url, detect_address_changes) = self.config().with(|c| {
let c = self.config.get();
( (
c.network.protocol.wss.listen_address.clone(), c.network.protocol.wss.listen_address.clone(),
c.network.protocol.wss.url.clone(), c.network.protocol.wss.url.clone(),
c.network.detect_address_changes, c.network.detect_address_changes,
) )
}; });
// Get the binding parameters from the user-specified listen address // Get the binding parameters from the user-specified listen address
let bind_set = self let bind_set = self
@ -460,18 +455,17 @@ impl Network {
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub(super) async fn register_wss_dial_info( pub(super) async fn register_wss_dial_info(
&self, &self,
editor_public_internet: &mut RoutingDomainEditorPublicInternet, editor_public_internet: &mut RoutingDomainEditorPublicInternet<'_>,
editor_local_network: &mut RoutingDomainEditorLocalNetwork, editor_local_network: &mut RoutingDomainEditorLocalNetwork<'_>,
) -> EyreResult<()> { ) -> EyreResult<()> {
log_net!("WSS: registering dialinfo"); log_net!("WSS: registering dialinfo");
let (url, _detect_address_changes) = { let (url, _detect_address_changes) = self.config().with(|c| {
let c = self.config.get();
( (
c.network.protocol.wss.url.clone(), c.network.protocol.wss.url.clone(),
c.network.detect_address_changes, c.network.detect_address_changes,
) )
}; });
// NOTE: No interface dial info for WSS, as there is no way to connect to a local dialinfo via TLS // NOTE: No interface dial info for WSS, as there is no way to connect to a local dialinfo via TLS
// If the hostname is specified, it is the public dialinfo via the URL. If no hostname // If the hostname is specified, it is the public dialinfo via the URL. If no hostname
@ -520,14 +514,13 @@ impl Network {
pub(super) async fn start_tcp_listeners(&self) -> EyreResult<StartupDisposition> { pub(super) async fn start_tcp_listeners(&self) -> EyreResult<StartupDisposition> {
log_net!("TCP: binding protocol handlers"); log_net!("TCP: binding protocol handlers");
let (listen_address, public_address, detect_address_changes) = { let (listen_address, public_address, detect_address_changes) = self.config().with(|c| {
let c = self.config.get();
( (
c.network.protocol.tcp.listen_address.clone(), c.network.protocol.tcp.listen_address.clone(),
c.network.protocol.tcp.public_address.clone(), c.network.protocol.tcp.public_address.clone(),
c.network.detect_address_changes, c.network.detect_address_changes,
) )
}; });
// Get the binding parameters from the user-specified listen address // Get the binding parameters from the user-specified listen address
let bind_set = self let bind_set = self
@ -570,18 +563,17 @@ impl Network {
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub(super) async fn register_tcp_dial_info( pub(super) async fn register_tcp_dial_info(
&self, &self,
editor_public_internet: &mut RoutingDomainEditorPublicInternet, editor_public_internet: &mut RoutingDomainEditorPublicInternet<'_>,
editor_local_network: &mut RoutingDomainEditorLocalNetwork, editor_local_network: &mut RoutingDomainEditorLocalNetwork<'_>,
) -> EyreResult<()> { ) -> EyreResult<()> {
log_net!("TCP: registering dialinfo"); log_net!("TCP: registering dialinfo");
let (public_address, detect_address_changes) = { let (public_address, detect_address_changes) = self.config().with(|c| {
let c = self.config.get();
( (
c.network.protocol.tcp.public_address.clone(), c.network.protocol.tcp.public_address.clone(),
c.network.detect_address_changes, c.network.detect_address_changes,
) )
}; });
let mut registered_addresses: HashSet<IpAddr> = HashSet::new(); let mut registered_addresses: HashSet<IpAddr> = HashSet::new();

View File

@ -37,17 +37,13 @@ impl Network {
} }
// network state has changed // network state has changed
let mut editor_local_network = self let routing_table = self.routing_table();
.unlocked_inner
.routing_table let mut editor_local_network = routing_table.edit_local_network_routing_domain();
.edit_local_network_routing_domain();
editor_local_network.set_local_networks(new_network_state.local_networks); editor_local_network.set_local_networks(new_network_state.local_networks);
editor_local_network.clear_dial_info_details(None, None); editor_local_network.clear_dial_info_details(None, None);
let mut editor_public_internet = self let mut editor_public_internet = routing_table.edit_public_internet_routing_domain();
.unlocked_inner
.routing_table
.edit_public_internet_routing_domain();
// Update protocols // Update protocols
self.register_all_dial_info(&mut editor_public_internet, &mut editor_local_network) self.register_all_dial_info(&mut editor_public_internet, &mut editor_local_network)

View File

@ -125,8 +125,9 @@ impl Network {
}; };
// Save off existing public dial info for change detection later // Save off existing public dial info for change detection later
let existing_public_dial_info: HashSet<DialInfoDetail> = self let routing_table = self.routing_table();
.routing_table()
let existing_public_dial_info: HashSet<DialInfoDetail> = routing_table
.all_filtered_dial_info_details( .all_filtered_dial_info_details(
RoutingDomain::PublicInternet.into(), RoutingDomain::PublicInternet.into(),
&DialInfoFilter::all(), &DialInfoFilter::all(),
@ -135,7 +136,7 @@ impl Network {
.collect(); .collect();
// Set most permissive network config and start from scratch // Set most permissive network config and start from scratch
let mut editor = self.routing_table().edit_public_internet_routing_domain(); let mut editor = routing_table.edit_public_internet_routing_domain();
editor.setup_network( editor.setup_network(
protocol_config.outbound, protocol_config.outbound,
protocol_config.inbound, protocol_config.inbound,
@ -247,22 +248,18 @@ impl Network {
match protocol_type { match protocol_type {
ProtocolType::UDP => DialInfo::udp(addr), ProtocolType::UDP => DialInfo::udp(addr),
ProtocolType::TCP => DialInfo::tcp(addr), ProtocolType::TCP => DialInfo::tcp(addr),
ProtocolType::WS => { ProtocolType::WS => DialInfo::try_ws(
let c = self.config.get(); addr,
DialInfo::try_ws( self.config()
addr, .with(|c| format!("ws://{}/{}", addr, c.network.protocol.ws.path)),
format!("ws://{}/{}", addr, c.network.protocol.ws.path), )
) .unwrap(),
.unwrap() ProtocolType::WSS => DialInfo::try_wss(
} addr,
ProtocolType::WSS => { self.config()
let c = self.config.get(); .with(|c| format!("wss://{}/{}", addr, c.network.protocol.wss.path)),
DialInfo::try_wss( )
addr, .unwrap(),
format!("wss://{}/{}", addr, c.network.protocol.wss.path),
)
.unwrap()
}
} }
} }
} }

View File

@ -5,31 +5,24 @@ use super::*;
impl NetworkManager { impl NetworkManager {
pub fn setup_tasks(&self) { pub fn setup_tasks(&self) {
// Set rolling transfers tick task // Set rolling transfers tick task
{ impl_setup_task!(
let this = self.clone(); self,
self.unlocked_inner Self,
.rolling_transfers_task rolling_transfers_task,
.set_routine(move |s, l, t| { rolling_transfers_task_routine
Box::pin(this.clone().rolling_transfers_task_routine( );
s,
Timestamp::new(l),
Timestamp::new(t),
))
});
}
// Set address filter task // Set address filter task
{ {
let this = self.clone(); let this = self.clone();
self.unlocked_inner self.address_filter_task.set_routine(move |s, l, t| {
.address_filter_task xxx continue here
.set_routine(move |s, l, t| { Box::pin(this.address_filter().address_filter_task_routine(
Box::pin(this.address_filter().address_filter_task_routine( s,
s, Timestamp::new(l),
Timestamp::new(l), Timestamp::new(t),
Timestamp::new(t), ))
)) });
});
} }
} }
@ -40,10 +33,10 @@ impl NetworkManager {
let receipt_manager = self.receipt_manager(); let receipt_manager = self.receipt_manager();
// Run the rolling transfers task // Run the rolling transfers task
self.unlocked_inner.rolling_transfers_task.tick().await?; self.rolling_transfers_task.tick().await?;
// Run the address filter task // Run the address filter task
self.unlocked_inner.address_filter_task.tick().await?; self.address_filter_task.tick().await?;
// Run the routing table tick // Run the routing table tick
routing_table.tick().await?; routing_table.tick().await?;
@ -62,10 +55,15 @@ impl NetworkManager {
pub async fn cancel_tasks(&self) { pub async fn cancel_tasks(&self) {
log_net!(debug "stopping rolling transfers task"); log_net!(debug "stopping rolling transfers task");
if let Err(e) = self.unlocked_inner.rolling_transfers_task.stop().await { if let Err(e) = self.rolling_transfers_task.stop().await {
warn!("rolling_transfers_task not stopped: {}", e); warn!("rolling_transfers_task not stopped: {}", e);
} }
log_net!(debug "stopping address filter task");
if let Err(e) = self.address_filter_task.stop().await {
warn!("address_filter_task not stopped: {}", e);
}
log_net!(debug "stopping routing table tasks"); log_net!(debug "stopping routing table tasks");
let routing_table = self.routing_table(); let routing_table = self.routing_table();
routing_table.cancel_tasks().await; routing_table.cancel_tasks().await;

View File

@ -93,7 +93,7 @@ pub struct RecentPeersEntry {
pub(crate) struct RoutingTable { pub(crate) struct RoutingTable {
registry: VeilidComponentRegistry, registry: VeilidComponentRegistry,
inner: Arc<RwLock<RoutingTableInner>>, inner: RwLock<RoutingTableInner>,
/// The current node's public DHT keys /// The current node's public DHT keys
node_id: TypedKeyGroup, node_id: TypedKeyGroup,
@ -144,7 +144,7 @@ impl RoutingTable {
pub fn new(registry: VeilidComponentRegistry) -> Self { pub fn new(registry: VeilidComponentRegistry) -> Self {
let config = registry.config(); let config = registry.config();
let c = config.get(); let c = config.get();
let inner = Arc::new(RwLock::new(RoutingTableInner::new(registry.clone()))); let inner = RwLock::new(RoutingTableInner::new(registry.clone()));
let this = Self { let this = Self {
registry, registry,
inner, inner,

View File

@ -252,7 +252,7 @@ pub(crate) trait NodeRefCommonTrait: NodeRefAccessorsTrait + NodeRefOperateTrait
else { else {
return false; return false;
}; };
let our_node_ids = rti.unlocked_inner.node_ids(); let our_node_ids = rti.routing_table().node_ids();
our_node_ids.contains_any(&relay_ids) our_node_ids.contains_any(&relay_ids)
}) })
} }

View File

@ -38,7 +38,7 @@ struct RouteSpecStoreInner {
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct RouteSpecStore { pub(crate) struct RouteSpecStore {
registry: VeilidComponentRegistry, registry: VeilidComponentRegistry,
inner: Arc<Mutex<RouteSpecStoreInner>>, inner: Mutex<RouteSpecStoreInner>,
/// Maximum number of hops in a route /// Maximum number of hops in a route
max_route_hop_count: usize, max_route_hop_count: usize,
@ -55,10 +55,10 @@ impl RouteSpecStore {
Self { Self {
registry, registry,
inner: Arc::new(Mutex::new(RouteSpecStoreInner { inner: Mutex::new(RouteSpecStoreInner {
content: RouteSpecStoreContent::new(), content: RouteSpecStoreContent::new(),
cache: Default::default(), cache: Default::default(),
})), }),
max_route_hop_count: c.network.rpc.max_route_hop_count.into(), max_route_hop_count: c.network.rpc.max_route_hop_count.into(),
default_route_hop_count: c.network.rpc.default_route_hop_count.into(), default_route_hop_count: c.network.rpc.default_route_hop_count.into(),
} }
@ -901,7 +901,7 @@ impl RouteSpecStore {
}; };
// Remove from hop cache // Remove from hop cache
let rti = &*self.unlocked_inner.routing_table.inner.read(); let rti = &*self.routing_table().inner.read();
if !inner.cache.remove_from_cache(rti, id, &rssd) { if !inner.cache.remove_from_cache(rti, id, &rssd) {
panic!("hop cache should have contained cache key"); panic!("hop cache should have contained cache key");
} }

View File

@ -98,7 +98,7 @@ struct RPCProcessorInner {
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct RPCProcessor { pub(crate) struct RPCProcessor {
registry: VeilidComponentRegistry, registry: VeilidComponentRegistry,
inner: Arc<Mutex<RPCProcessorInner>>, inner: Mutex<RPCProcessorInner>,
timeout_us: TimestampDuration, timeout_us: TimestampDuration,
queue_size: u32, queue_size: u32,
concurrency: u32, concurrency: u32,
@ -144,7 +144,7 @@ impl RPCProcessor {
Self { Self {
registry, registry,
inner: Arc::new(Mutex::new(Self::new_inner())), inner: Mutex::new(Self::new_inner()),
timeout_us, timeout_us,
queue_size, queue_size,
concurrency, concurrency,
@ -155,8 +155,6 @@ impl RPCProcessor {
} }
} }
xxx continue here
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
#[instrument(level = "debug", skip_all, err)] #[instrument(level = "debug", skip_all, err)]

View File

@ -82,7 +82,7 @@ impl fmt::Debug for StorageManagerInner {
pub(crate) struct StorageManager { pub(crate) struct StorageManager {
registry: VeilidComponentRegistry, registry: VeilidComponentRegistry,
inner: Arc<AsyncMutex<StorageManagerInner>>, inner: AsyncMutex<StorageManagerInner>,
// Background processes // Background processes
flush_record_stores_task: TickTask<EyreReport>, flush_record_stores_task: TickTask<EyreReport>,
@ -137,7 +137,7 @@ impl StorageManager {
let inner = Self::new_inner(); let inner = Self::new_inner();
let this = StorageManager { let this = StorageManager {
registry, registry,
inner: Arc::new(AsyncMutex::new(inner)), inner: AsyncMutex::new(inner),
flush_record_stores_task: TickTask::new( flush_record_stores_task: TickTask::new(
"flush_record_stores_task", "flush_record_stores_task",

View File

@ -87,7 +87,7 @@ impl fmt::Debug for TableStoreInner {
/// Database for storing key value pairs persistently and securely across runs. /// Database for storing key value pairs persistently and securely across runs.
pub struct TableStore { pub struct TableStore {
registry: VeilidComponentRegistry, registry: VeilidComponentRegistry,
inner: Arc<Mutex<TableStoreInner>>, // Sync mutex here because TableDB drops can happen at any time inner: Mutex<TableStoreInner>, // Sync mutex here because TableDB drops can happen at any time
table_store_driver: TableStoreDriver, table_store_driver: TableStoreDriver,
async_lock: Arc<AsyncMutex<()>>, // Async mutex for operations async_lock: Arc<AsyncMutex<()>>, // Async mutex for operations
} }
@ -120,7 +120,7 @@ impl TableStore {
Self { Self {
registry, registry,
inner: Arc::new(Mutex::new(inner)), inner: Mutex::new(inner),
table_store_driver, table_store_driver,
async_lock: Arc::new(AsyncMutex::new(())), async_lock: Arc::new(AsyncMutex::new(())),
} }