From c1644f1015736d927c627847b336182a61b201d2 Mon Sep 17 00:00:00 2001 From: John Smith Date: Thu, 3 Nov 2022 22:02:40 -0400 Subject: [PATCH] bug fixes --- veilid-core/src/routing_table/mod.rs | 2 + veilid-core/src/routing_table/node_ref.rs | 491 ++++++++++-------- .../src/routing_table/node_ref_filter.rs | 61 +++ .../src/routing_table/routing_table_inner.rs | 10 +- veilid-core/src/veilid_api/mod.rs | 2 +- 5 files changed, 344 insertions(+), 222 deletions(-) create mode 100644 veilid-core/src/routing_table/node_ref_filter.rs diff --git a/veilid-core/src/routing_table/mod.rs b/veilid-core/src/routing_table/mod.rs index ec1508d1..b1d5c4fe 100644 --- a/veilid-core/src/routing_table/mod.rs +++ b/veilid-core/src/routing_table/mod.rs @@ -2,6 +2,7 @@ mod bucket; mod bucket_entry; mod debug; mod node_ref; +mod node_ref_filter; mod route_spec_store; mod routing_domain_editor; mod routing_domains; @@ -19,6 +20,7 @@ pub use bucket_entry::*; pub use debug::*; use hashlink::LruCache; pub use node_ref::*; +pub use node_ref_filter::*; pub use route_spec_store::*; pub use routing_domain_editor::*; pub use routing_domains::*; diff --git a/veilid-core/src/routing_table/node_ref.rs b/veilid-core/src/routing_table/node_ref.rs index c555c172..4b660037 100644 --- a/veilid-core/src/routing_table/node_ref.rs +++ b/veilid-core/src/routing_table/node_ref.rs @@ -6,67 +6,9 @@ use alloc::fmt; // We should ping them with some frequency and 30 seconds is typical timeout const CONNECTIONLESS_TIMEOUT_SECS: u32 = 29; -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct NodeRefFilter { - pub routing_domain_set: RoutingDomainSet, - pub dial_info_filter: DialInfoFilter, -} +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -impl Default for NodeRefFilter { - fn default() -> Self { - Self::new() - } -} - -impl NodeRefFilter { - pub fn new() -> Self { - Self { - routing_domain_set: RoutingDomainSet::all(), - dial_info_filter: DialInfoFilter::all(), - } - } - - pub fn with_routing_domain(mut self, routing_domain: RoutingDomain) -> Self { - self.routing_domain_set = routing_domain.into(); - self - } - pub fn with_routing_domain_set(mut self, routing_domain_set: RoutingDomainSet) -> Self { - self.routing_domain_set = routing_domain_set; - self - } - pub fn with_dial_info_filter(mut self, dial_info_filter: DialInfoFilter) -> Self { - self.dial_info_filter = dial_info_filter; - self - } - pub fn with_protocol_type(mut self, protocol_type: ProtocolType) -> Self { - self.dial_info_filter = self.dial_info_filter.with_protocol_type(protocol_type); - self - } - pub fn with_protocol_type_set(mut self, protocol_set: ProtocolTypeSet) -> Self { - self.dial_info_filter = self.dial_info_filter.with_protocol_type_set(protocol_set); - self - } - pub fn with_address_type(mut self, address_type: AddressType) -> Self { - self.dial_info_filter = self.dial_info_filter.with_address_type(address_type); - self - } - pub fn with_address_type_set(mut self, address_set: AddressTypeSet) -> Self { - self.dial_info_filter = self.dial_info_filter.with_address_type_set(address_set); - self - } - pub fn filtered(mut self, other_filter: &NodeRefFilter) -> Self { - self.routing_domain_set &= other_filter.routing_domain_set; - self.dial_info_filter = self - .dial_info_filter - .filtered(&other_filter.dial_info_filter); - self - } - pub fn is_dead(&self) -> bool { - self.dial_info_filter.is_dead() || self.routing_domain_set.is_empty() - } -} - -pub struct NodeRef { +pub struct NodeRefBaseCommon { routing_table: RoutingTable, node_id: DHTKey, entry: Arc, @@ -76,104 +18,79 @@ pub struct NodeRef { track_id: usize, } -impl NodeRef { - pub fn new( - routing_table: RoutingTable, - node_id: DHTKey, - entry: Arc, - filter: Option, - ) -> Self { - entry.ref_count.fetch_add(1u32, Ordering::Relaxed); +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - Self { - routing_table, - node_id, - entry, - filter, - sequencing: Sequencing::NoPreference, - #[cfg(feature = "tracking")] - track_id: entry.track(), - } - } +pub trait NodeRefBase: Sized { + // Common field access + fn common(&self) -> &NodeRefBaseCommon; + fn common_mut(&mut self) -> &mut NodeRefBaseCommon; - // Operate on entry accessors - pub(super) fn operate(&self, f: F) -> T + // Implementation-specific operators + fn operate(&self, f: F) -> T where - F: FnOnce(&RoutingTableInner, &BucketEntryInner) -> T, - { - let inner = &*self.routing_table.inner.read(); - self.entry.with(inner, f) - } - - pub(super) fn operate_mut(&self, f: F) -> T + F: FnOnce(&RoutingTableInner, &BucketEntryInner) -> T; + fn operate_mut(&self, f: F) -> T where - F: FnOnce(&mut RoutingTableInner, &mut BucketEntryInner) -> T, - { - let inner = &mut *self.routing_table.inner.write(); - self.entry.with_mut(inner, f) - } + F: FnOnce(&mut RoutingTableInner, &mut BucketEntryInner) -> T; // Filtering - - pub fn filter_ref(&self) -> Option<&NodeRefFilter> { - self.filter.as_ref() + fn filter_ref(&self) -> Option<&NodeRefFilter> { + self.common().filter.as_ref() } - pub fn take_filter(&mut self) -> Option { - self.filter.take() + fn take_filter(&mut self) -> Option { + self.common_mut().filter.take() } - pub fn set_filter(&mut self, filter: Option) { - self.filter = filter + fn set_filter(&mut self, filter: Option) { + self.common_mut().filter = filter } - pub fn set_sequencing(&mut self, sequencing: Sequencing) { - self.sequencing = sequencing; + fn set_sequencing(&mut self, sequencing: Sequencing) { + self.common_mut().sequencing = sequencing; } - pub fn sequencing(&self) -> Sequencing { - self.sequencing + fn sequencing(&self) -> Sequencing { + self.common().sequencing } - pub fn merge_filter(&mut self, filter: NodeRefFilter) { - if let Some(self_filter) = self.filter.take() { - self.filter = Some(self_filter.filtered(&filter)); + fn merge_filter(&mut self, filter: NodeRefFilter) { + let common_mut = self.common_mut(); + if let Some(self_filter) = common_mut.filter.take() { + common_mut.filter = Some(self_filter.filtered(&filter)); } else { - self.filter = Some(filter); + common_mut.filter = Some(filter); } } - pub fn filtered_clone(&self, filter: NodeRefFilter) -> Self { - let mut out = self.clone(); - out.merge_filter(filter); - out - } - - pub fn is_filter_dead(&self) -> bool { - if let Some(filter) = &self.filter { + fn is_filter_dead(&self) -> bool { + if let Some(filter) = &self.common().filter { filter.is_dead() } else { false } } - pub fn routing_domain_set(&self) -> RoutingDomainSet { - self.filter + fn routing_domain_set(&self) -> RoutingDomainSet { + self.common() + .filter .as_ref() .map(|f| f.routing_domain_set) .unwrap_or(RoutingDomainSet::all()) } - pub fn dial_info_filter(&self) -> DialInfoFilter { - self.filter + fn dial_info_filter(&self) -> DialInfoFilter { + self.common() + .filter .as_ref() .map(|f| f.dial_info_filter.clone()) .unwrap_or(DialInfoFilter::all()) } - pub fn best_routing_domain(&self) -> Option { + fn best_routing_domain(&self) -> Option { self.operate(|_rti, e| { e.best_routing_domain( - self.filter + self.common() + .filter .as_ref() .map(|f| f.routing_domain_set) .unwrap_or(RoutingDomainSet::all()), @@ -182,66 +99,66 @@ impl NodeRef { } // Accessors - pub fn routing_table(&self) -> RoutingTable { - self.routing_table.clone() + fn routing_table(&self) -> RoutingTable { + self.common().routing_table.clone() } - pub fn node_id(&self) -> DHTKey { - self.node_id + fn node_id(&self) -> DHTKey { + self.common().node_id } - pub fn has_updated_since_last_network_change(&self) -> bool { + fn has_updated_since_last_network_change(&self) -> bool { self.operate(|_rti, e| e.has_updated_since_last_network_change()) } - pub fn set_updated_since_last_network_change(&self) { + fn set_updated_since_last_network_change(&self) { self.operate_mut(|_rti, e| e.set_updated_since_last_network_change(true)); } - pub fn update_node_status(&self, node_status: NodeStatus) { + fn update_node_status(&self, node_status: NodeStatus) { self.operate_mut(|_rti, e| { e.update_node_status(node_status); }); } - pub fn min_max_version(&self) -> Option<(u8, u8)> { + fn min_max_version(&self) -> Option<(u8, u8)> { self.operate(|_rti, e| e.min_max_version()) } - pub fn set_min_max_version(&self, min_max_version: (u8, u8)) { + fn set_min_max_version(&self, min_max_version: (u8, u8)) { self.operate_mut(|_rti, e| e.set_min_max_version(min_max_version)) } - pub fn state(&self, cur_ts: u64) -> BucketEntryState { + fn state(&self, cur_ts: u64) -> BucketEntryState { self.operate(|_rti, e| e.state(cur_ts)) } - pub fn peer_stats(&self) -> PeerStats { + fn peer_stats(&self) -> PeerStats { self.operate(|_rti, e| e.peer_stats().clone()) } // Per-RoutingDomain accessors - pub fn make_peer_info(&self, routing_domain: RoutingDomain) -> Option { + fn make_peer_info(&self, routing_domain: RoutingDomain) -> Option { self.operate(|_rti, e| e.make_peer_info(self.node_id(), routing_domain)) } - pub fn node_info(&self, routing_domain: RoutingDomain) -> Option { + fn node_info(&self, routing_domain: RoutingDomain) -> Option { self.operate(|_rti, e| e.node_info(routing_domain).cloned()) } - pub fn signed_node_info_has_valid_signature(&self, routing_domain: RoutingDomain) -> bool { + fn signed_node_info_has_valid_signature(&self, routing_domain: RoutingDomain) -> bool { self.operate(|_rti, e| { e.signed_node_info(routing_domain) .map(|sni| sni.has_valid_signature()) .unwrap_or(false) }) } - pub fn has_seen_our_node_info(&self, routing_domain: RoutingDomain) -> bool { + fn has_seen_our_node_info(&self, routing_domain: RoutingDomain) -> bool { self.operate(|_rti, e| e.has_seen_our_node_info(routing_domain)) } - pub fn set_seen_our_node_info(&self, routing_domain: RoutingDomain) { + fn set_seen_our_node_info(&self, routing_domain: RoutingDomain) { self.operate_mut(|_rti, e| e.set_seen_our_node_info(routing_domain, true)); } - pub fn network_class(&self, routing_domain: RoutingDomain) -> Option { + fn network_class(&self, routing_domain: RoutingDomain) -> Option { self.operate(|_rt, e| e.node_info(routing_domain).map(|n| n.network_class)) } - pub fn outbound_protocols(&self, routing_domain: RoutingDomain) -> Option { + fn outbound_protocols(&self, routing_domain: RoutingDomain) -> Option { self.operate(|_rt, e| e.node_info(routing_domain).map(|n| n.outbound_protocols)) } - pub fn address_types(&self, routing_domain: RoutingDomain) -> Option { + fn address_types(&self, routing_domain: RoutingDomain) -> Option { self.operate(|_rt, e| e.node_info(routing_domain).map(|n| n.address_types)) } - pub fn node_info_outbound_filter(&self, routing_domain: RoutingDomain) -> DialInfoFilter { + fn node_info_outbound_filter(&self, routing_domain: RoutingDomain) -> DialInfoFilter { let mut dif = DialInfoFilter::all(); if let Some(outbound_protocols) = self.outbound_protocols(routing_domain) { dif = dif.with_protocol_type_set(outbound_protocols); @@ -251,34 +168,37 @@ impl NodeRef { } dif } - pub fn relay(&self, routing_domain: RoutingDomain) -> Option { - let target_rpi = self.operate(|_rti, e| { - e.node_info(routing_domain) + fn relay(&self, routing_domain: RoutingDomain) -> Option { + self.operate_mut(|rti, e| { + let opt_target_rpi = e + .node_info(routing_domain) .map(|n| n.relay_peer_info.as_ref().map(|pi| pi.as_ref().clone())) - })?; - target_rpi.and_then(|t| { - // If relay is ourselves, then return None, because we can't relay through ourselves - // and to contact this node we should have had an existing inbound connection - if t.node_id.key == self.routing_table.node_id() { - return None; - } + .flatten(); + opt_target_rpi.and_then(|t| { + // If relay is ourselves, then return None, because we can't relay through ourselves + // and to contact this node we should have had an existing inbound connection + if t.node_id.key == rti.unlocked_inner.node_id { + return None; + } - // Register relay node and return noderef - self.routing_table.register_node_with_signed_node_info( - routing_domain, - t.node_id.key, - t.signed_node_info, - false, - ) + // Register relay node and return noderef + rti.register_node_with_signed_node_info( + self.routing_table(), + routing_domain, + t.node_id.key, + t.signed_node_info, + false, + ) + }) }) } // Filtered accessors - pub fn first_filtered_dial_info_detail(&self) -> Option { + fn first_filtered_dial_info_detail(&self) -> Option { let routing_domain_set = self.routing_domain_set(); let dial_info_filter = self.dial_info_filter(); - let (sort, dial_info_filter) = match self.sequencing { + let (sort, dial_info_filter) = match self.common().sequencing { Sequencing::NoPreference => (None, dial_info_filter), Sequencing::PreferOrdered => ( Some(DialInfoDetail::ordered_sequencing_sort), @@ -305,11 +225,11 @@ impl NodeRef { }) } - pub fn all_filtered_dial_info_details(&self) -> Vec { + fn all_filtered_dial_info_details(&self) -> Vec { let routing_domain_set = self.routing_domain_set(); let dial_info_filter = self.dial_info_filter(); - let (sort, dial_info_filter) = match self.sequencing { + let (sort, dial_info_filter) = match self.common().sequencing { Sequencing::NoPreference => (None, dial_info_filter), Sequencing::PreferOrdered => ( Some(DialInfoDetail::ordered_sequencing_sort), @@ -338,43 +258,47 @@ impl NodeRef { out } - pub fn last_connection(&self) -> Option { + fn last_connection(&self) -> Option { // Get the last connections and the last time we saw anything with this connection // Filtered first and then sorted by most recent - let last_connections = self.operate(|rti, e| e.last_connections(rti, self.filter.clone())); + self.operate(|rti, e| { + let last_connections = e.last_connections(rti, self.common().filter.clone()); - // Do some checks to ensure these are possibly still 'live' - for (last_connection, last_seen) in last_connections { - // Should we check the connection table? - if last_connection.protocol_type().is_connection_oriented() { - // Look the connection up in the connection manager and see if it's still there - let connection_manager = self.routing_table.network_manager().connection_manager(); - if connection_manager.get_connection(last_connection).is_some() { - return Some(last_connection); - } - } else { - // If this is not connection oriented, then we check our last seen time - // to see if this mapping has expired (beyond our timeout) - let cur_ts = intf::get_timestamp(); - if (last_seen + (CONNECTIONLESS_TIMEOUT_SECS as u64 * 1_000_000u64)) >= cur_ts { - return Some(last_connection); + // Do some checks to ensure these are possibly still 'live' + for (last_connection, last_seen) in last_connections { + // Should we check the connection table? + if last_connection.protocol_type().is_connection_oriented() { + // Look the connection up in the connection manager and see if it's still there + let connection_manager = + rti.unlocked_inner.network_manager.connection_manager(); + if connection_manager.get_connection(last_connection).is_some() { + return Some(last_connection); + } + } else { + // If this is not connection oriented, then we check our last seen time + // to see if this mapping has expired (beyond our timeout) + let cur_ts = intf::get_timestamp(); + if (last_seen + (CONNECTIONLESS_TIMEOUT_SECS as u64 * 1_000_000u64)) >= cur_ts { + return Some(last_connection); + } } } - } - None + None + }) } - pub fn clear_last_connections(&self) { + fn clear_last_connections(&self) { self.operate_mut(|_rti, e| e.clear_last_connections()) } - pub fn set_last_connection(&self, connection_descriptor: ConnectionDescriptor, ts: u64) { - self.operate_mut(|_rti, e| e.set_last_connection(connection_descriptor, ts)); - self.routing_table - .touch_recent_peer(self.node_id(), connection_descriptor); + fn set_last_connection(&self, connection_descriptor: ConnectionDescriptor, ts: u64) { + self.operate_mut(|rti, e| { + e.set_last_connection(connection_descriptor, ts); + rti.touch_recent_peer(self.common().node_id, connection_descriptor); + }) } - pub fn has_any_dial_info(&self) -> bool { + fn has_any_dial_info(&self) -> bool { self.operate(|_rti, e| { for rtd in RoutingDomain::all() { if let Some(ni) = e.node_info(rtd) { @@ -387,25 +311,25 @@ impl NodeRef { }) } - pub fn stats_question_sent(&self, ts: u64, bytes: u64, expects_answer: bool) { + fn stats_question_sent(&self, ts: u64, bytes: u64, expects_answer: bool) { self.operate_mut(|rti, e| { rti.transfer_stats_accounting().add_up(bytes); e.question_sent(ts, bytes, expects_answer); }) } - pub fn stats_question_rcvd(&self, ts: u64, bytes: u64) { + fn stats_question_rcvd(&self, ts: u64, bytes: u64) { self.operate_mut(|rti, e| { rti.transfer_stats_accounting().add_down(bytes); e.question_rcvd(ts, bytes); }) } - pub fn stats_answer_sent(&self, bytes: u64) { + fn stats_answer_sent(&self, bytes: u64) { self.operate_mut(|rti, e| { rti.transfer_stats_accounting().add_up(bytes); e.answer_sent(bytes); }) } - pub fn stats_answer_rcvd(&self, send_ts: u64, recv_ts: u64, bytes: u64) { + fn stats_answer_rcvd(&self, send_ts: u64, recv_ts: u64, bytes: u64) { self.operate_mut(|rti, e| { rti.transfer_stats_accounting().add_down(bytes); rti.latency_stats_accounting() @@ -413,54 +337,118 @@ impl NodeRef { e.answer_rcvd(send_ts, recv_ts, bytes); }) } - pub fn stats_question_lost(&self) { + fn stats_question_lost(&self) { self.operate_mut(|_rti, e| { e.question_lost(); }) } - pub fn stats_failed_to_send(&self, ts: u64, expects_answer: bool) { + fn stats_failed_to_send(&self, ts: u64, expects_answer: bool) { self.operate_mut(|_rti, e| { e.failed_to_send(ts, expects_answer); }) } } -impl Clone for NodeRef { - fn clone(&self) -> Self { - self.entry.ref_count.fetch_add(1u32, Ordering::Relaxed); +//////////////////////////////////////////////////////////////////////////////////// + +/// Reference to a routing table entry +/// Keeps entry in the routing table until all references are gone +pub struct NodeRef { + common: NodeRefBaseCommon, +} + +impl NodeRef { + pub fn new( + routing_table: RoutingTable, + node_id: DHTKey, + entry: Arc, + filter: Option, + ) -> Self { + entry.ref_count.fetch_add(1u32, Ordering::Relaxed); Self { - routing_table: self.routing_table.clone(), - node_id: self.node_id, - entry: self.entry.clone(), - filter: self.filter.clone(), - sequencing: self.sequencing, - #[cfg(feature = "tracking")] - track_id: e.track(), + common: NodeRefBaseCommon { + routing_table, + node_id, + entry, + filter, + sequencing: Sequencing::NoPreference, + #[cfg(feature = "tracking")] + track_id: entry.track(), + }, + } + } + + pub fn filtered_clone(&self, filter: NodeRefFilter) -> Self { + let mut out = self.clone(); + out.merge_filter(filter); + out + } + + pub fn locked<'a>(&self, rti: &'a mut RoutingTableInner) -> NodeRefLocked<'a> { + NodeRefLocked::new(rti, self.clone()) + } +} + +impl NodeRefBase for NodeRef { + fn common(&self) -> &NodeRefBaseCommon { + &self.common + } + + fn common_mut(&mut self) -> &mut NodeRefBaseCommon { + &mut self.common + } + + fn operate(&self, f: F) -> T + where + F: FnOnce(&RoutingTableInner, &BucketEntryInner) -> T, + { + let inner = &*self.common.routing_table.inner.read(); + self.common.entry.with(inner, f) + } + + fn operate_mut(&self, f: F) -> T + where + F: FnOnce(&mut RoutingTableInner, &mut BucketEntryInner) -> T, + { + let inner = &mut *self.common.routing_table.inner.write(); + self.common.entry.with_mut(inner, f) + } +} + +impl Clone for NodeRef { + fn clone(&self) -> Self { + self.common + .entry + .ref_count + .fetch_add(1u32, Ordering::Relaxed); + + Self { + common: NodeRefBaseCommon { + routing_table: self.common.routing_table.clone(), + node_id: self.common.node_id, + entry: self.common.entry.clone(), + filter: self.common.filter.clone(), + sequencing: self.common.sequencing, + #[cfg(feature = "tracking")] + track_id: self.common.entry.write().track(), + }, } } } -// impl PartialEq for NodeRef { -// fn eq(&self, other: &Self) -> bool { -// self.node_id == other.node_id -// } -// } - -// impl Eq for NodeRef {} - impl fmt::Display for NodeRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.node_id.encode()) + write!(f, "{}", self.common.node_id.encode()) } } impl fmt::Debug for NodeRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("NodeRef") - .field("node_id", &self.node_id) - .field("filter", &self.filter) - .field("sequencing", &self.sequencing) + .field("node_id", &self.common.node_id) + .field("filter", &self.common.filter) + .field("sequencing", &self.common.sequencing) .finish() } } @@ -468,12 +456,79 @@ impl fmt::Debug for NodeRef { impl Drop for NodeRef { fn drop(&mut self) { #[cfg(feature = "tracking")] - self.operate(|e| e.untrack(self.track_id)); + self.common.entry.write().untrack(self.track_id); // drop the noderef and queue a bucket kick if it was the last one - let new_ref_count = self.entry.ref_count.fetch_sub(1u32, Ordering::Relaxed) - 1; + let new_ref_count = self + .common + .entry + .ref_count + .fetch_sub(1u32, Ordering::Relaxed) + - 1; if new_ref_count == 0 { - self.routing_table.queue_bucket_kick(self.node_id); + self.common + .routing_table + .queue_bucket_kick(self.common.node_id); } } } + +//////////////////////////////////////////////////////////////////////////////////// + +/// Locked reference to a routing table entry +/// For internal use inside the RoutingTable module where you have +/// already locked a RoutingTableInner +/// Keeps entry in the routing table until all references are gone +pub struct NodeRefLocked<'a> { + inner: Mutex<&'a mut RoutingTableInner>, + nr: NodeRef, +} + +impl<'a> NodeRefLocked<'a> { + pub fn new(inner: &'a mut RoutingTableInner, nr: NodeRef) -> Self { + Self { + inner: Mutex::new(inner), + nr, + } + } +} + +impl<'a> NodeRefBase for NodeRefLocked<'a> { + fn common(&self) -> &NodeRefBaseCommon { + &self.nr.common + } + + fn common_mut(&mut self) -> &mut NodeRefBaseCommon { + &mut self.nr.common + } + + fn operate(&self, f: F) -> T + where + F: FnOnce(&RoutingTableInner, &BucketEntryInner) -> T, + { + let inner = &*self.inner.lock(); + self.nr.common.entry.with(inner, f) + } + + fn operate_mut(&self, f: F) -> T + where + F: FnOnce(&mut RoutingTableInner, &mut BucketEntryInner) -> T, + { + let inner = &mut *self.inner.lock(); + self.nr.common.entry.with_mut(inner, f) + } +} + +impl<'a> fmt::Display for NodeRefLocked<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.nr) + } +} + +impl<'a> fmt::Debug for NodeRefLocked<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("NodeRefLocked") + .field("nr", &self.nr) + .finish() + } +} diff --git a/veilid-core/src/routing_table/node_ref_filter.rs b/veilid-core/src/routing_table/node_ref_filter.rs new file mode 100644 index 00000000..934d93a1 --- /dev/null +++ b/veilid-core/src/routing_table/node_ref_filter.rs @@ -0,0 +1,61 @@ +use super::*; + +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct NodeRefFilter { + pub routing_domain_set: RoutingDomainSet, + pub dial_info_filter: DialInfoFilter, +} + +impl Default for NodeRefFilter { + fn default() -> Self { + Self::new() + } +} + +impl NodeRefFilter { + pub fn new() -> Self { + Self { + routing_domain_set: RoutingDomainSet::all(), + dial_info_filter: DialInfoFilter::all(), + } + } + + pub fn with_routing_domain(mut self, routing_domain: RoutingDomain) -> Self { + self.routing_domain_set = routing_domain.into(); + self + } + pub fn with_routing_domain_set(mut self, routing_domain_set: RoutingDomainSet) -> Self { + self.routing_domain_set = routing_domain_set; + self + } + pub fn with_dial_info_filter(mut self, dial_info_filter: DialInfoFilter) -> Self { + self.dial_info_filter = dial_info_filter; + self + } + pub fn with_protocol_type(mut self, protocol_type: ProtocolType) -> Self { + self.dial_info_filter = self.dial_info_filter.with_protocol_type(protocol_type); + self + } + pub fn with_protocol_type_set(mut self, protocol_set: ProtocolTypeSet) -> Self { + self.dial_info_filter = self.dial_info_filter.with_protocol_type_set(protocol_set); + self + } + pub fn with_address_type(mut self, address_type: AddressType) -> Self { + self.dial_info_filter = self.dial_info_filter.with_address_type(address_type); + self + } + pub fn with_address_type_set(mut self, address_set: AddressTypeSet) -> Self { + self.dial_info_filter = self.dial_info_filter.with_address_type_set(address_set); + self + } + pub fn filtered(mut self, other_filter: &NodeRefFilter) -> Self { + self.routing_domain_set &= other_filter.routing_domain_set; + self.dial_info_filter = self + .dial_info_filter + .filtered(&other_filter.dial_info_filter); + self + } + pub fn is_dead(&self) -> bool { + self.dial_info_filter.is_dead() || self.routing_domain_set.is_empty() + } +} diff --git a/veilid-core/src/routing_table/routing_table_inner.rs b/veilid-core/src/routing_table/routing_table_inner.rs index 36c7240e..cab84df6 100644 --- a/veilid-core/src/routing_table/routing_table_inner.rs +++ b/veilid-core/src/routing_table/routing_table_inner.rs @@ -720,7 +720,7 @@ impl RoutingTableInner { }); if let Some(nr) = &out { // set the most recent node address for connection finding and udp replies - nr.set_last_connection(descriptor, timestamp); + nr.locked(self).set_last_connection(descriptor, timestamp); } out } @@ -841,12 +841,16 @@ impl RoutingTableInner { Vec::<(DHTKey, Option>)>::with_capacity(self.bucket_entry_count + 1); // add our own node (only one of there with the None entry) + let mut filtered = false; for filter in &mut filters { - if filter(self, self.unlocked_inner.node_id, None) { - nodes.push((self.unlocked_inner.node_id, None)); + if !filter(self, self.unlocked_inner.node_id, None) { + filtered = true; break; } } + if !filtered { + nodes.push((self.unlocked_inner.node_id, None)); + } // add all nodes from buckets self.with_entries(cur_ts, BucketEntryState::Unreliable, |rti, k, v| { diff --git a/veilid-core/src/veilid_api/mod.rs b/veilid-core/src/veilid_api/mod.rs index e30ef957..b28bc49b 100644 --- a/veilid-core/src/veilid_api/mod.rs +++ b/veilid-core/src/veilid_api/mod.rs @@ -25,7 +25,7 @@ pub use intf::BlockStore; pub use intf::ProtectedStore; pub use intf::TableStore; pub use network_manager::NetworkManager; -pub use routing_table::RoutingTable; +pub use routing_table::{NodeRef, NodeRefBase, RoutingTable}; use core::fmt; use core_context::{api_shutdown, VeilidCoreContext};