wasm dht test passes

This commit is contained in:
Christien Rioux 2023-10-30 20:17:07 -04:00
parent 11c19d1bad
commit d750b7c5c3
10 changed files with 151 additions and 102 deletions

View File

@ -1,4 +1,5 @@
use super::*; use super::*;
pub(crate) use connection_table::ConnectionRefKind;
use connection_table::*; use connection_table::*;
use network_connection::*; use network_connection::*;
use stop_token::future::FutureExt; use stop_token::future::FutureExt;
@ -12,6 +13,40 @@ enum ConnectionManagerEvent {
Dead(NetworkConnection), Dead(NetworkConnection),
} }
#[derive(Debug)]
pub(crate) struct ConnectionRefScope {
connection_manager: ConnectionManager,
descriptor: ConnectionDescriptor,
protect: bool,
}
impl ConnectionRefScope {
pub fn new(
connection_manager: ConnectionManager,
descriptor: ConnectionDescriptor,
protect: bool,
) -> Self {
connection_manager.connection_ref(descriptor, ConnectionRefKind::AddRef, protect);
Self {
connection_manager,
descriptor,
protect,
}
}
}
impl Drop for ConnectionRefScope {
fn drop(&mut self) {
if !self.protect {
self.connection_manager.connection_ref(
self.descriptor,
ConnectionRefKind::RemoveRef,
false,
);
}
}
}
#[derive(Debug)] #[derive(Debug)]
struct ConnectionManagerInner { struct ConnectionManagerInner {
next_id: NetworkConnectionId, next_id: NetworkConnectionId,
@ -134,36 +169,6 @@ impl ConnectionManager {
debug!("finished connection manager shutdown"); debug!("finished connection manager shutdown");
} }
// Internal routine to see if we should keep this connection
// from being LRU removed. Used on our initiated relay connections and allocated routes
fn should_protect_connection(&self, conn: &NetworkConnection) -> bool {
let netman = self.network_manager();
let routing_table = netman.routing_table();
// See if this is a relay connection
let remote_address = conn.connection_descriptor().remote_address().address();
let Some(routing_domain) = routing_table.routing_domain_for_address(remote_address) else {
return false;
};
let Some(rn) = routing_table.relay_node(routing_domain) else {
return false;
};
let relay_nr = rn.filtered_clone(
NodeRefFilter::new()
.with_routing_domain(routing_domain)
.with_address_type(conn.connection_descriptor().address_type())
.with_protocol_type(conn.connection_descriptor().protocol_type()),
);
let dids = relay_nr.all_filtered_dial_info_details();
for did in dids {
if did.dial_info.address() == remote_address {
return true;
}
}
false
}
// Internal routine to register new connection atomically. // Internal routine to register new connection atomically.
// Registers connection in the connection table for later access // Registers connection in the connection table for later access
// and spawns a message processing loop for the connection // and spawns a message processing loop for the connection
@ -188,16 +193,9 @@ impl ConnectionManager {
None => bail!("not creating connection because we are stopping"), None => bail!("not creating connection because we are stopping"),
}; };
let mut conn = NetworkConnection::from_protocol(self.clone(), stop_token, prot_conn, id); let conn = NetworkConnection::from_protocol(self.clone(), stop_token, prot_conn, id);
let handle = conn.get_handle(); let handle = conn.get_handle();
// See if this should be a protected connection
let protect = self.should_protect_connection(&conn);
if protect {
log_net!(debug "== PROTECTING connection: {} -> {}", id, conn.debug_print(get_aligned_timestamp()));
conn.protect();
}
// Add to the connection table // Add to the connection table
match self.arc.connection_table.add_connection(conn) { match self.arc.connection_table.add_connection(conn) {
Ok(None) => { Ok(None) => {
@ -227,6 +225,15 @@ impl ConnectionManager {
desc desc
))); )));
} }
Err(ConnectionTableAddError::TableFull(conn)) => {
// Connection table is full
let desc = conn.connection_descriptor();
let _ = inner.sender.send(ConnectionManagerEvent::Dead(conn));
return Ok(NetworkResult::no_connection_other(format!(
"connection table is full: {:?}",
desc
)));
}
}; };
Ok(NetworkResult::Value(handle)) Ok(NetworkResult::Value(handle))
} }
@ -239,10 +246,22 @@ impl ConnectionManager {
} }
// Protects a network connection if one already is established // Protects a network connection if one already is established
pub fn protect_connection(&self, descriptor: ConnectionDescriptor) -> bool { fn connection_ref(
&self,
descriptor: ConnectionDescriptor,
kind: ConnectionRefKind,
protect: bool,
) {
self.arc self.arc
.connection_table .connection_table
.protect_connection_by_descriptor(descriptor) .ref_connection_by_descriptor(descriptor, kind, protect);
}
pub fn connection_ref_scope(
&self,
descriptor: ConnectionDescriptor,
protect: bool,
) -> ConnectionRefScope {
ConnectionRefScope::new(self.clone(), descriptor, protect)
} }
/// Called when we want to create a new connection or get the current one that already exists /// Called when we want to create a new connection or get the current one that already exists

View File

@ -9,6 +9,8 @@ pub(in crate::network_manager) enum ConnectionTableAddError {
AlreadyExists(NetworkConnection), AlreadyExists(NetworkConnection),
#[error("Connection address was filtered")] #[error("Connection address was filtered")]
AddressFilter(NetworkConnection, AddressFilterError), AddressFilter(NetworkConnection, AddressFilterError),
#[error("Connection table is full")]
TableFull(NetworkConnection),
} }
impl ConnectionTableAddError { impl ConnectionTableAddError {
@ -18,6 +20,14 @@ impl ConnectionTableAddError {
pub fn address_filter(conn: NetworkConnection, err: AddressFilterError) -> Self { pub fn address_filter(conn: NetworkConnection, err: AddressFilterError) -> Self {
ConnectionTableAddError::AddressFilter(conn, err) ConnectionTableAddError::AddressFilter(conn, err)
} }
pub fn table_full(conn: NetworkConnection) -> Self {
ConnectionTableAddError::TableFull(conn)
}
}
pub(crate) enum ConnectionRefKind {
AddRef,
RemoveRef,
} }
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -176,31 +186,35 @@ impl ConnectionTable {
} }
}; };
// if we have reached the maximum number of connections per protocol type
// then drop the least recently used connection that is not protected or referenced
let mut out_conn = None;
if inner.conn_by_id[protocol_index].len() >= inner.max_connections[protocol_index] {
// Find a free connection to terminate to make room
let dead_k = {
let Some(lruk) = inner.conn_by_id[protocol_index].iter().find_map(|(k, v)| {
if !v.is_in_use() {
Some(*k)
} else {
None
}
}) else {
// Can't make room, connection table is full
return Err(ConnectionTableAddError::table_full(network_connection));
};
lruk
};
let dead_conn = Self::remove_connection_records(&mut inner, dead_k);
log_net!(debug "== LRU Connection Killed: {} -> {}", dead_k, dead_conn.debug_print(get_aligned_timestamp()));
out_conn = Some(dead_conn);
}
// Add the connection to the table // Add the connection to the table
let res = inner.conn_by_id[protocol_index].insert(id, network_connection); let res = inner.conn_by_id[protocol_index].insert(id, network_connection);
assert!(res.is_none()); assert!(res.is_none());
// if we have reached the maximum number of connections per protocol type
// then drop the least recently used connection
let mut out_conn = None;
if inner.conn_by_id[protocol_index].len() > inner.max_connections[protocol_index] {
while let Some((lruk, lru_conn)) = inner.conn_by_id[protocol_index].peek_lru() {
let lruk = *lruk;
// Don't LRU protected connections
if lru_conn.is_protected() {
// Mark as recently used
log_net!(debug "== No LRU Out for PROTECTED connection: {} -> {}", lruk, lru_conn.debug_print(get_aligned_timestamp()));
inner.conn_by_id[protocol_index].get(&lruk);
continue;
}
log_net!(debug "== LRU Connection Killed: {} -> {}", lruk, lru_conn.debug_print(get_aligned_timestamp()));
out_conn = Some(Self::remove_connection_records(&mut inner, lruk));
break;
}
}
// add connection records // add connection records
inner.protocol_index_by_id.insert(id, protocol_index); inner.protocol_index_by_id.insert(id, protocol_index);
inner.id_by_descriptor.insert(descriptor, id); inner.id_by_descriptor.insert(descriptor, id);
@ -218,18 +232,6 @@ impl ConnectionTable {
Some(out.get_handle()) Some(out.get_handle())
} }
//#[instrument(level = "trace", skip(self), ret)]
#[allow(dead_code)]
pub fn protect_connection_by_id(&self, id: NetworkConnectionId) -> bool {
let mut inner = self.inner.lock();
let Some(protocol_index) = inner.protocol_index_by_id.get(&id).copied() else {
return false;
};
let out = inner.conn_by_id[protocol_index].get_mut(&id).unwrap();
out.protect();
true
}
//#[instrument(level = "trace", skip(self), ret)] //#[instrument(level = "trace", skip(self), ret)]
pub fn get_connection_by_descriptor( pub fn get_connection_by_descriptor(
&self, &self,
@ -244,7 +246,12 @@ impl ConnectionTable {
} }
//#[instrument(level = "trace", skip(self), ret)] //#[instrument(level = "trace", skip(self), ret)]
pub fn protect_connection_by_descriptor(&self, descriptor: ConnectionDescriptor) -> bool { pub fn ref_connection_by_descriptor(
&self,
descriptor: ConnectionDescriptor,
ref_type: ConnectionRefKind,
protect: bool,
) -> bool {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
let Some(id) = inner.id_by_descriptor.get(&descriptor).copied() else { let Some(id) = inner.id_by_descriptor.get(&descriptor).copied() else {
@ -252,7 +259,14 @@ impl ConnectionTable {
}; };
let protocol_index = Self::protocol_to_index(descriptor.protocol_type()); let protocol_index = Self::protocol_to_index(descriptor.protocol_type());
let out = inner.conn_by_id[protocol_index].get_mut(&id).unwrap(); let out = inner.conn_by_id[protocol_index].get_mut(&id).unwrap();
out.protect(); match ref_type {
ConnectionRefKind::AddRef => out.change_ref_count(true),
ConnectionRefKind::RemoveRef => out.change_ref_count(false),
}
if protect {
out.protect();
}
true true
} }

View File

@ -95,8 +95,18 @@ pub struct NetworkConnection {
sender: flume::Sender<(Option<Id>, Vec<u8>)>, sender: flume::Sender<(Option<Id>, Vec<u8>)>,
stop_source: Option<StopSource>, stop_source: Option<StopSource>,
protected: bool, protected: bool,
ref_count: usize,
} }
impl Drop for NetworkConnection {
fn drop(&mut self) {
if self.ref_count != 0 && self.stop_source.is_some() {
log_net!(error "ref_count for network connection should be zero: {:?}", self.ref_count);
}
}
}
impl NetworkConnection { impl NetworkConnection {
pub(super) fn dummy(id: NetworkConnectionId, descriptor: ConnectionDescriptor) -> Self { pub(super) fn dummy(id: NetworkConnectionId, descriptor: ConnectionDescriptor) -> Self {
// Create handle for sending (dummy is immediately disconnected) // Create handle for sending (dummy is immediately disconnected)
@ -114,6 +124,7 @@ impl NetworkConnection {
sender, sender,
stop_source: None, stop_source: None,
protected: false, protected: false,
ref_count: 0,
} }
} }
@ -160,6 +171,7 @@ impl NetworkConnection {
sender, sender,
stop_source: Some(stop_source), stop_source: Some(stop_source),
protected: false, protected: false,
ref_count: 0,
} }
} }
@ -175,14 +187,22 @@ impl NetworkConnection {
ConnectionHandle::new(self.connection_id, self.descriptor, self.sender.clone()) ConnectionHandle::new(self.connection_id, self.descriptor, self.sender.clone())
} }
pub fn is_protected(&self) -> bool { pub fn is_in_use(&self) -> bool {
self.protected self.protected || self.ref_count > 0
} }
pub fn protect(&mut self) { pub fn protect(&mut self) {
self.protected = true; self.protected = true;
} }
pub fn change_ref_count(&mut self, add: bool) {
if add {
self.ref_count += 1;
} else {
self.ref_count -= 1;
}
}
pub fn close(&mut self) { pub fn close(&mut self) {
if let Some(stop_source) = self.stop_source.take() { if let Some(stop_source) = self.stop_source.take() {
// drop the stopper // drop the stopper

View File

@ -308,17 +308,6 @@ pub(crate) trait NodeRefBase: Sized {
}) })
} }
fn protect_last_connection(&self) -> bool {
if let Some(descriptor) = self.last_connection() {
self.routing_table()
.network_manager()
.connection_manager()
.protect_connection(descriptor)
} else {
false
}
}
fn has_any_dial_info(&self) -> bool { fn has_any_dial_info(&self) -> bool {
self.operate(|_rti, e| { self.operate(|_rti, e| {
for rtd in RoutingDomain::all() { for rtd in RoutingDomain::all() {

View File

@ -189,6 +189,7 @@ struct WaitableReply {
safety_route: Option<PublicKey>, safety_route: Option<PublicKey>,
remote_private_route: Option<PublicKey>, remote_private_route: Option<PublicKey>,
reply_private_route: Option<PublicKey>, reply_private_route: Option<PublicKey>,
_connection_ref_scope: ConnectionRefScope,
} }
///////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////
@ -1147,7 +1148,8 @@ impl RPCProcessor {
dest: Destination, dest: Destination,
question: RPCQuestion, question: RPCQuestion,
context: Option<QuestionContext>, context: Option<QuestionContext>,
) ->RPCNetworkResult<WaitableReply> { protect: bool,
) -> RPCNetworkResult<WaitableReply> {
// Get sender peer info if we should send that // Get sender peer info if we should send that
let spi = self.get_sender_peer_info(&dest); let spi = self.get_sender_peer_info(&dest);
@ -1157,7 +1159,7 @@ impl RPCProcessor {
// Log rpc send // Log rpc send
#[cfg(feature = "verbose-tracing")] #[cfg(feature = "verbose-tracing")]
debug!(target: "rpc_message", dir = "send", kind = "question", op_id = op_id.as_u64(), desc = operation.kind().desc(), ?dest); debug!(target: "rpc_message", dir = "send", kind = "question", op_id = op_id.as_u64(), desc = operation.kind().desc(), ?dest, protect);
// Produce rendered operation // Produce rendered operation
let RenderedOperation { let RenderedOperation {
@ -1221,6 +1223,16 @@ impl RPCProcessor {
remote_private_route, remote_private_route,
); );
// Ref the connection so it doesn't go away until we're done with the waitable reply
let connection_ref_scope = self
.network_manager()
.connection_manager()
.connection_ref_scope(
send_data_method.connection_descriptor,
protect,
);
// Pass back waitable reply completion // Pass back waitable reply completion
Ok(NetworkResult::value(WaitableReply { Ok(NetworkResult::value(WaitableReply {
handle, handle,
@ -1231,6 +1243,7 @@ impl RPCProcessor {
safety_route, safety_route,
remote_private_route, remote_private_route,
reply_private_route, reply_private_route,
_connection_ref_scope: connection_ref_scope,
})) }))
} }

View File

@ -21,7 +21,7 @@ impl RPCProcessor {
); );
// Send the app call question // Send the app call question
let waitable_reply = network_result_try!(self.question(dest, question, None).await?); let waitable_reply = network_result_try!(self.question(dest, question, None, false).await?);
// Wait for reply // Wait for reply
let (msg, latency) = match self.wait_for_reply(waitable_reply, debug_string).await? { let (msg, latency) = match self.wait_for_reply(waitable_reply, debug_string).await? {

View File

@ -41,7 +41,8 @@ impl RPCProcessor {
let debug_string = format!("FindNode(node_id={}) => {}", node_id, dest); let debug_string = format!("FindNode(node_id={}) => {}", node_id, dest);
// Send the find_node request // Send the find_node request
let waitable_reply = network_result_try!(self.question(dest, find_node_q, None).await?); let waitable_reply =
network_result_try!(self.question(dest, find_node_q, None, false).await?);
// Wait for reply // Wait for reply
let (msg, latency) = match self.wait_for_reply(waitable_reply, debug_string).await? { let (msg, latency) = match self.wait_for_reply(waitable_reply, debug_string).await? {

View File

@ -78,7 +78,7 @@ impl RPCProcessor {
log_rpc!(debug "{}", debug_string); log_rpc!(debug "{}", debug_string);
let waitable_reply = network_result_try!( let waitable_reply = network_result_try!(
self.question(dest.clone(), question, Some(question_context)) self.question(dest.clone(), question, Some(question_context), false)
.await? .await?
); );

View File

@ -92,7 +92,7 @@ impl RPCProcessor {
log_rpc!(debug "{}", debug_string); log_rpc!(debug "{}", debug_string);
let waitable_reply = network_result_try!( let waitable_reply = network_result_try!(
self.question(dest.clone(), question, Some(question_context)) self.question(dest.clone(), question, Some(question_context), false)
.await? .await?
); );

View File

@ -109,14 +109,7 @@ impl RPCProcessor {
// Send the info request // Send the info request
let waitable_reply = let waitable_reply =
network_result_try!(self.question(dest.clone(), question, None).await?); network_result_try!(self.question(dest.clone(), question, None, protect).await?);
// Optionally protect the connection in the event this for a relay or route keepalive
if protect {
self.network_manager()
.connection_manager()
.protect_connection(waitable_reply.send_data_method.connection_descriptor);
}
// Note what kind of ping this was and to what peer scope // Note what kind of ping this was and to what peer scope
let send_data_method = waitable_reply.send_data_method.clone(); let send_data_method = waitable_reply.send_data_method.clone();