diff --git a/veilid-core/src/connection_manager.rs b/veilid-core/src/connection_manager.rs index f8dc964f..c654743a 100644 --- a/veilid-core/src/connection_manager.rs +++ b/veilid-core/src/connection_manager.rs @@ -89,7 +89,7 @@ impl ConnectionManager { inner.connection_table.get_connection(descriptor) } - // Internal routine to register new connection + // Internal routine to register new connection atomically async fn on_new_connection_internal( &self, mut inner: AsyncMutexGuard<'_, ConnectionManagerInner>, @@ -136,18 +136,14 @@ impl ConnectionManager { // If connection exists, then return it let inner = self.arc.inner.lock().await; - if let Some(conn) = inner - .connection_table - .get_connection(&descriptor) - .map(|e| e.conn) - { + if let Some(conn) = inner.connection_table.get_connection(&descriptor) { return Ok(conn); } // If not, attempt new connection let conn = NetworkConnection::connect(local_addr, dial_info).await?; - self.on_new_connection_internal(inner, conn).await; + self.on_new_connection_internal(inner, conn.clone()).await?; Ok(conn) } @@ -160,41 +156,31 @@ impl ConnectionManager { let network_manager = this.network_manager(); Box::pin(async move { // - let exit_value: Result, ()> = Err(()); let descriptor = conn.connection_descriptor(); loop { - let res = match select( - entry.stopper.clone().instance_clone(exit_value.clone()), - Box::pin(conn.clone().recv()), - ) - .await - { - Either::Left((_x, _b)) => break, - Either::Right((y, _a)) => y, - }; + let res = conn.clone().recv().await; let message = match res { Ok(v) => v, Err(_) => break, }; - match network_manager - .on_recv_envelope(message.as_slice(), &descriptor) + if let Err(e) = network_manager + .on_recv_envelope(message.as_slice(), descriptor) .await { - Ok(_) => (), - Err(e) => { - error!("{}", e); - break; - } - }; + log_net!(error e); + break; + } } - if let Err(err) = this + if let Err(e) = this + .arc .inner .lock() + .await .connection_table .remove_connection(&descriptor) { - error!("{}", err); + log_net!(error e); } }) } diff --git a/veilid-core/src/connection_table.rs b/veilid-core/src/connection_table.rs index 5d18ecae..6baff899 100644 --- a/veilid-core/src/connection_table.rs +++ b/veilid-core/src/connection_table.rs @@ -1,4 +1,3 @@ -use crate::intf::*; use crate::network_connection::*; use crate::xx::*; use crate::*; @@ -17,38 +16,23 @@ impl ConnectionTable { pub fn add_connection(&mut self, conn: NetworkConnection) -> Result<(), String> { let descriptor = conn.connection_descriptor(); - assert_ne!( descriptor.protocol_type(), ProtocolType::UDP, "Only connection oriented protocols go in the table!" ); - if self.conn_by_addr.contains_key(&descriptor) { return Err(format!( "Connection already added to table: {:?}", descriptor )); } - - let timestamp = get_timestamp(); - - let entry = ConnectionTableEntry { - conn, - established_time: timestamp, - last_message_sent_time: None, - last_message_recv_time: None, - stopper: Eventual::new(), - }; - let res = self.conn_by_addr.insert(descriptor, entry.clone()); + let res = self.conn_by_addr.insert(descriptor, conn); assert!(res.is_none()); - Ok(entry) + Ok(()) } - pub fn get_connection( - &self, - descriptor: &ConnectionDescriptor, - ) -> Option { + pub fn get_connection(&self, descriptor: &ConnectionDescriptor) -> Option { self.conn_by_addr.get(descriptor).cloned() } @@ -59,7 +43,7 @@ impl ConnectionTable { pub fn remove_connection( &mut self, descriptor: &ConnectionDescriptor, - ) -> Result { + ) -> Result { self.conn_by_addr .remove(descriptor) .ok_or_else(|| format!("Connection not in table: {:?}", descriptor)) diff --git a/veilid-core/src/intf/native/network/network_udp.rs b/veilid-core/src/intf/native/network/network_udp.rs index cb54d06b..1b4d169d 100644 --- a/veilid-core/src/intf/native/network/network_udp.rs +++ b/veilid-core/src/intf/native/network/network_udp.rs @@ -54,7 +54,7 @@ impl Network { log_net!("UDP packet: {:?}", descriptor); if let Err(e) = network_manager - .on_recv_envelope(&data[..size], &descriptor) + .on_recv_envelope(&data[..size], descriptor) .await { log_net!(error "failed to process received udp envelope: {}", e); diff --git a/veilid-core/src/intf/native/network/protocol/mod.rs b/veilid-core/src/intf/native/network/protocol/mod.rs index e9b71684..6f63af29 100644 --- a/veilid-core/src/intf/native/network/protocol/mod.rs +++ b/veilid-core/src/intf/native/network/protocol/mod.rs @@ -56,6 +56,16 @@ impl ProtocolNetworkConnection { } } + pub async fn close(&mut self) -> Result<(), String> { + match self { + Self::Dummy(d) => d.close(), + Self::RawTcp(t) => t.close().await, + Self::WsAccepted(w) => w.close().await, + Self::Ws(w) => w.close().await, + Self::Wss(w) => w.close().await, + } + } + pub async fn send(&mut self, message: Vec) -> Result<(), String> { match self { Self::Dummy(d) => d.send(message), diff --git a/veilid-core/src/intf/native/network/protocol/tcp.rs b/veilid-core/src/intf/native/network/protocol/tcp.rs index db5c77d6..7d53faf8 100644 --- a/veilid-core/src/intf/native/network/protocol/tcp.rs +++ b/veilid-core/src/intf/native/network/protocol/tcp.rs @@ -3,9 +3,9 @@ use crate::intf::native::utils::async_peek_stream::*; use crate::intf::*; use crate::network_manager::MAX_MESSAGE_SIZE; use crate::*; -use async_std::net::*; -use async_std::prelude::*; -use std::fmt; +use async_std::net::TcpStream; +use core::fmt; +use futures_util::{AsyncReadExt, AsyncWriteExt}; pub struct RawTcpNetworkConnection { stream: AsyncPeekStream, @@ -22,6 +22,14 @@ impl RawTcpNetworkConnection { Self { stream } } + pub async fn close(&mut self) -> Result<(), String> { + self.stream + .close() + .await + .map_err(map_to_string) + .map_err(logthru_net!()) + } + pub async fn send(&mut self, message: Vec) -> Result<(), String> { if message.len() > MAX_MESSAGE_SIZE { return Err("sending too large TCP message".to_owned()); @@ -183,7 +191,7 @@ impl ProtocolAcceptHandler for RawTcpProtocolHandler { &self, stream: AsyncPeekStream, peer_addr: SocketAddr, - ) -> SystemPinBoxFuture, String>> { + ) -> SystemPinBoxFuture, String>> { Box::pin(self.clone().on_accept_async(stream, peer_addr)) } } diff --git a/veilid-core/src/intf/native/network/protocol/ws.rs b/veilid-core/src/intf/native/network/protocol/ws.rs index d43bd7a0..b236c17e 100644 --- a/veilid-core/src/intf/native/network/protocol/ws.rs +++ b/veilid-core/src/intf/native/network/protocol/ws.rs @@ -3,16 +3,16 @@ use crate::intf::native::utils::async_peek_stream::*; use crate::intf::*; use crate::network_manager::MAX_MESSAGE_SIZE; use crate::*; +use alloc::sync::Arc; use async_std::io; use async_std::net::*; use async_tls::TlsConnector; use async_tungstenite::tungstenite::protocol::Message; use async_tungstenite::{accept_async, client_async, WebSocketStream}; +use core::fmt; +use core::time::Duration; use futures_util::sink::SinkExt; use futures_util::stream::StreamExt; -use std::fmt; -use std::sync::Arc; -use std::time::Duration; pub type WebSocketNetworkConnectionAccepted = WebsocketNetworkConnection; pub type WebsocketNetworkConnectionWSS = @@ -68,6 +68,16 @@ where } } + pub async fn close(&self) -> Result<(), String> { + let mut inner = self.inner.lock().await; + inner + .ws_stream + .close(None) + .await + .map_err(map_to_string) + .map_err(logthru_net!(error "failed to close websocket")) + } + pub async fn send(&self, message: Vec) -> Result<(), String> { if message.len() > MAX_MESSAGE_SIZE { return Err("received too large WS message".to_owned()); diff --git a/veilid-core/src/intf/wasm/network/protocol/mod.rs b/veilid-core/src/intf/wasm/network/protocol/mod.rs index 2dc3d3bd..2e414710 100644 --- a/veilid-core/src/intf/wasm/network/protocol/mod.rs +++ b/veilid-core/src/intf/wasm/network/protocol/mod.rs @@ -8,7 +8,7 @@ use crate::xx::*; #[derive(Debug)] pub enum ProtocolNetworkConnection { Dummy(DummyNetworkConnection), - WS(ws::WebsocketNetworkConnection), + Ws(ws::WebsocketNetworkConnection), //WebRTC(wrtc::WebRTCNetworkConnection), } @@ -46,18 +46,23 @@ impl ProtocolNetworkConnection { } } } - + pub async fn close(&mut self) -> Result<(), String> { + match self { + Self::Dummy(d) => d.close(), + Self::Ws(w) => w.close().await, + } + } pub async fn send(&mut self, message: Vec) -> Result<(), String> { match self { Self::Dummy(d) => d.send(message), - Self::WS(w) => w.send(message), + Self::Ws(w) => w.send(message).await, } } pub async fn recv(&mut self) -> Result, String> { match self { Self::Dummy(d) => d.recv(), - Self::WS(w) => w.recv(), + Self::Ws(w) => w.recv().await, } } } diff --git a/veilid-core/src/intf/wasm/network/protocol/ws.rs b/veilid-core/src/intf/wasm/network/protocol/ws.rs index de2f1cd9..a510e821 100644 --- a/veilid-core/src/intf/wasm/network/protocol/ws.rs +++ b/veilid-core/src/intf/wasm/network/protocol/ws.rs @@ -7,14 +7,14 @@ use web_sys::WebSocket; use ws_stream_wasm::*; struct WebsocketNetworkConnectionInner { + ws_meta: WsMeta, ws_stream: WsStream, - ws: WebSocket, } #[derive(Clone)] pub struct WebsocketNetworkConnection { tls: bool, - inner: Arc>, + inner: Arc>, } impl fmt::Debug for WebsocketNetworkConnection { @@ -24,33 +24,34 @@ impl fmt::Debug for WebsocketNetworkConnection { } impl WebsocketNetworkConnection { - pub fn new(tls: bool, ws_stream: WsStream) -> Self { - let ws = ws_stream.wrapped().clone(); + pub fn new(tls: bool, ws_meta: WsMeta, ws_stream: WsStream) -> Self { Self { tls, inner: Arc::new(Mutex::new(WebsocketNetworkConnectionInner { + ws_meta, ws_stream, - ws, })), } } -xxx convert this to async and use stream api not low level websocket -xxx implement close() everywhere and skip using eventual for loop shutdown + pub async fn close(&self) -> Result<(), String> { + let inner = self.inner.lock().await; + inner.ws_meta.close().await; + } - pub fn send(&self, message: Vec) -> Result<(), String> { + pub async fn send(&self, message: Vec) -> Result<(), String> { if message.len() > MAX_MESSAGE_SIZE { return Err("sending too large WS message".to_owned()).map_err(logthru_net!(error)); } - self.inner - .lock() - .ws - .send_with_u8_array(&message) + let mut inner = self.inner.lock().await; + inner.ws_stream + .send(WsMessage::Binary(message)).await .map_err(|_| "failed to send to websocket".to_owned()) .map_err(logthru_net!(error)) } - pub fn recv(&self) -> Result, String> { - let out = match self.inner.lock().ws_stream.next().await { + pub async fn recv(&self) -> Result, String> { + let mut inner = self.inner.lock().await; + let out = match inner.ws_stream.next().await { Some(WsMessage::Binary(v)) => v, Some(_) => { return Err("Unexpected WS message type".to_owned()) @@ -112,7 +113,7 @@ impl WebsocketProtocolHandler { remote: dial_info.to_peer_address(), }; - Ok(NetworkConnection::from_protocol(descriptor,ProtocolNetworkConnection::WS(WebsocketNetworkConnection::new(tls, wsio)))) + Ok(NetworkConnection::from_protocol(descriptor,ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new(tls, wsio)))) } pub async fn send_unbound_message(dial_info: &DialInfo, data: Vec) -> Result<(), String> { diff --git a/veilid-core/src/network_connection.rs b/veilid-core/src/network_connection.rs index 8ca70b5d..927d6d55 100644 --- a/veilid-core/src/network_connection.rs +++ b/veilid-core/src/network_connection.rs @@ -48,8 +48,8 @@ cfg_if! { pub struct DummyNetworkConnection {} impl DummyNetworkConnection { - pub fn new(descriptor: ConnectionDescriptor) -> NetworkConnection { - NetworkConnection::from_protocol(descriptor, ProtocolNetworkConnection::Dummy(Self {})) + pub fn close(&self) -> Result<(), String> { + Ok(()) } pub fn send(&self, _message: Vec) -> Result<(), String> { Ok(()) @@ -73,7 +73,6 @@ struct NetworkConnectionInner { struct NetworkConnectionArc { descriptor: ConnectionDescriptor, established_time: u64, - stopper: Eventual, inner: AsyncMutex, } @@ -81,6 +80,13 @@ struct NetworkConnectionArc { pub struct NetworkConnection { arc: Arc, } +impl PartialEq for NetworkConnection { + fn eq(&self, other: &Self) -> bool { + Arc::as_ptr(&self.arc) == Arc::as_ptr(&other.arc) + } +} + +impl Eq for NetworkConnection {} impl NetworkConnection { fn new_inner(protocol_connection: ProtocolNetworkConnection) -> NetworkConnectionInner { @@ -101,6 +107,13 @@ impl NetworkConnection { } } + pub fn dummy(descriptor: ConnectionDescriptor) -> Self { + NetworkConnection::from_protocol( + descriptor, + ProtocolNetworkConnection::Dummy(DummyNetworkConnection {}), + ) + } + pub fn from_protocol( descriptor: ConnectionDescriptor, protocol_connection: ProtocolNetworkConnection, @@ -121,6 +134,11 @@ impl NetworkConnection { self.arc.descriptor } + pub async fn close(&self) -> Result<(), String> { + let mut inner = self.arc.inner.lock().await; + inner.protocol_connection.close().await + } + pub async fn send(&self, message: Vec) -> Result<(), String> { let mut inner = self.arc.inner.lock().await; let out = inner.protocol_connection.send(message).await; diff --git a/veilid-core/src/network_manager.rs b/veilid-core/src/network_manager.rs index 9a76110c..2e3b7995 100644 --- a/veilid-core/src/network_manager.rs +++ b/veilid-core/src/network_manager.rs @@ -431,7 +431,7 @@ impl NetworkManager { pub async fn on_recv_envelope( &self, data: &[u8], - descriptor: &ConnectionDescriptor, + descriptor: ConnectionDescriptor, ) -> Result { // Is this an out-of-band receipt instead of an envelope? if data[0..4] == *RECEIPT_MAGIC { @@ -522,11 +522,7 @@ impl NetworkManager { // Cache the envelope information in the routing table let source_noderef = routing_table - .register_node_with_existing_connection( - envelope.get_sender_id(), - descriptor.clone(), - ts, - ) + .register_node_with_existing_connection(envelope.get_sender_id(), descriptor, ts) .map_err(|e| format!("node id registration failed: {}", e))?; source_noderef.operate(|e| e.set_min_max_version(envelope.get_min_max_version())); diff --git a/veilid-core/src/routing_table/bucket_entry.rs b/veilid-core/src/routing_table/bucket_entry.rs index 39621aa2..00827e80 100644 --- a/veilid-core/src/routing_table/bucket_entry.rs +++ b/veilid-core/src/routing_table/bucket_entry.rs @@ -101,7 +101,7 @@ impl BucketEntry { } pub fn last_connection(&self) -> Option { - self.last_connection.as_ref().map(|x| x.0.clone()) + self.last_connection.as_ref().map(|x| x.0) } pub fn set_min_max_version(&mut self, min_max_version: (u8, u8)) { diff --git a/veilid-core/src/rpc_processor/mod.rs b/veilid-core/src/rpc_processor/mod.rs index 452a4a5c..17335839 100644 --- a/veilid-core/src/rpc_processor/mod.rs +++ b/veilid-core/src/rpc_processor/mod.rs @@ -801,14 +801,10 @@ impl RPCProcessor { } fn generate_sender_info(&self, rpcreader: &RPCMessageReader) -> SenderInfo { - let socket_address = - rpcreader - .header - .peer_noderef - .operate(|entry| match entry.last_connection() { - None => None, - Some(c) => Some(c.remote.socket_address), - }); + let socket_address = rpcreader + .header + .peer_noderef + .operate(|entry| entry.last_connection().map(|c| c.remote.socket_address)); SenderInfo { socket_address } } diff --git a/veilid-core/src/tests/common/test_connection_table.rs b/veilid-core/src/tests/common/test_connection_table.rs index deda9616..1529e265 100644 --- a/veilid-core/src/tests/common/test_connection_table.rs +++ b/veilid-core/src/tests/common/test_connection_table.rs @@ -10,7 +10,7 @@ pub async fn test_add_get_remove() { SocketAddress::new(Address::IPV4(Ipv4Addr::new(127, 0, 0, 1)), 8080), ProtocolType::TCP, )); - let a2 = a1.clone(); + let a2 = a1; let a3 = ConnectionDescriptor::new( PeerAddress::new( SocketAddress::new(Address::IPV6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8090), @@ -48,11 +48,11 @@ pub async fn test_add_get_remove() { ))), ); - let c1 = DummyNetworkConnection::new(a1.clone()); - let c2 = DummyNetworkConnection::new(a2.clone()); - let c3 = DummyNetworkConnection::new(a3.clone()); - let c4 = DummyNetworkConnection::new(a4.clone()); - let c5 = DummyNetworkConnection::new(a5); + let c1 = NetworkConnection::dummy(a1); + let c2 = NetworkConnection::dummy(a2); + let c3 = NetworkConnection::dummy(a3); + let c4 = NetworkConnection::dummy(a4); + let c5 = NetworkConnection::dummy(a5); assert_eq!(a1, c2.connection_descriptor()); assert_ne!(a3, c4.connection_descriptor()); @@ -60,36 +60,36 @@ pub async fn test_add_get_remove() { assert_eq!(table.connection_count(), 0); assert_eq!(table.get_connection(&a1), None); - let entry1 = table.add_connection(c1.clone()).unwrap(); + table.add_connection(c1.clone()).unwrap(); assert_eq!(table.connection_count(), 1); assert_err!(table.remove_connection(&a3)); assert_err!(table.remove_connection(&a4)); assert_eq!(table.connection_count(), 1); - assert_eq!(table.get_connection(&a1), Some(entry1.clone())); - assert_eq!(table.get_connection(&a1), Some(entry1.clone())); + assert_eq!(table.get_connection(&a1), Some(c1.clone())); + assert_eq!(table.get_connection(&a1), Some(c1.clone())); assert_eq!(table.connection_count(), 1); assert_err!(table.add_connection(c1.clone())); assert_err!(table.add_connection(c2.clone())); assert_eq!(table.connection_count(), 1); - assert_eq!(table.get_connection(&a1), Some(entry1.clone())); - assert_eq!(table.get_connection(&a1), Some(entry1.clone())); + assert_eq!(table.get_connection(&a1), Some(c1.clone())); + assert_eq!(table.get_connection(&a1), Some(c1.clone())); assert_eq!(table.connection_count(), 1); - assert_eq!(table.remove_connection(&a2), Ok(entry1)); + assert_eq!(table.remove_connection(&a2), Ok(c1.clone())); assert_eq!(table.connection_count(), 0); assert_err!(table.remove_connection(&a2)); assert_eq!(table.connection_count(), 0); assert_eq!(table.get_connection(&a2), None); assert_eq!(table.get_connection(&a1), None); assert_eq!(table.connection_count(), 0); - let entry2 = table.add_connection(c1).unwrap(); + table.add_connection(c1.clone()).unwrap(); assert_err!(table.add_connection(c2)); - let entry3 = table.add_connection(c3).unwrap(); - let entry4 = table.add_connection(c4).unwrap(); + table.add_connection(c3.clone()).unwrap(); + table.add_connection(c4.clone()).unwrap(); assert_eq!(table.connection_count(), 3); - assert_eq!(table.remove_connection(&a2), Ok(entry2)); - assert_eq!(table.remove_connection(&a3), Ok(entry3)); - assert_eq!(table.remove_connection(&a4), Ok(entry4)); + assert_eq!(table.remove_connection(&a2), Ok(c1)); + assert_eq!(table.remove_connection(&a3), Ok(c3)); + assert_eq!(table.remove_connection(&a4), Ok(c4)); assert_eq!(table.connection_count(), 0); }