mirror of
https://gitlab.com/veilid/veilid.git
synced 2025-01-12 07:49:49 -05:00
massive network refactor
This commit is contained in:
parent
8148c37708
commit
cfcf430a99
@ -3,7 +3,6 @@ use core::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd};
|
||||
use core::convert::{TryFrom, TryInto};
|
||||
use core::fmt;
|
||||
use core::hash::{Hash, Hasher};
|
||||
use hex;
|
||||
|
||||
use crate::veilid_rng::*;
|
||||
use ed25519_dalek::{Keypair, PublicKey, Signature};
|
||||
|
38
veilid-core/src/network_manager/connection_handle.rs
Normal file
38
veilid-core/src/network_manager/connection_handle.rs
Normal file
@ -0,0 +1,38 @@
|
||||
use super::*;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ConnectionHandle {
|
||||
descriptor: ConnectionDescriptor,
|
||||
channel: flume::Sender<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl ConnectionHandle {
|
||||
pub(super) fn new(descriptor: ConnectionDescriptor, channel: flume::Sender<Vec<u8>>) -> Self {
|
||||
Self {
|
||||
descriptor,
|
||||
channel,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn connection_descriptor(&self) -> ConnectionDescriptor {
|
||||
self.descriptor.clone()
|
||||
}
|
||||
|
||||
pub fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
self.channel.send(message).map_err(map_to_string)
|
||||
}
|
||||
pub async fn send_async(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
self.channel
|
||||
.send_async(message)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for ConnectionHandle {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.descriptor == other.descriptor
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for ConnectionHandle {}
|
@ -1,7 +1,5 @@
|
||||
use crate::xx::*;
|
||||
use crate::*;
|
||||
use super::*;
|
||||
use alloc::collections::btree_map::Entry;
|
||||
use core::fmt;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AddressFilterError {
|
||||
|
@ -3,8 +3,6 @@ use crate::xx::*;
|
||||
use connection_table::*;
|
||||
use network_connection::*;
|
||||
|
||||
const CONNECTION_PROCESSOR_CHANNEL_SIZE: usize = 128usize;
|
||||
|
||||
///////////////////////////////////////////////////////////
|
||||
// Connection manager
|
||||
|
||||
@ -59,8 +57,7 @@ impl ConnectionManager {
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) {
|
||||
// xxx close all connections in the connection table
|
||||
|
||||
// Drops connection table, which drops all connections in it
|
||||
*self.arc.inner.lock().await = Self::new_inner(self.arc.network_manager.config());
|
||||
}
|
||||
|
||||
@ -68,47 +65,48 @@ impl ConnectionManager {
|
||||
pub async fn get_connection(
|
||||
&self,
|
||||
descriptor: ConnectionDescriptor,
|
||||
) -> Option<NetworkConnection> {
|
||||
) -> Option<ConnectionHandle> {
|
||||
let mut inner = self.arc.inner.lock().await;
|
||||
inner.connection_table.get_connection(descriptor)
|
||||
}
|
||||
|
||||
// Internal routine to register new connection atomically
|
||||
fn on_new_connection_internal(
|
||||
// Internal routine to register new connection atomically.
|
||||
// Registers connection in the connection table for later access
|
||||
// and spawns a message processing loop for the connection
|
||||
fn on_new_protocol_network_connection(
|
||||
&self,
|
||||
inner: &mut ConnectionManagerInner,
|
||||
conn: NetworkConnection,
|
||||
) -> Result<(), String> {
|
||||
log_net!("on_new_connection_internal: {:?}", conn);
|
||||
let tx = inner
|
||||
.connection_add_channel_tx
|
||||
.as_ref()
|
||||
.ok_or_else(fn_string!("connection channel isn't open yet"))?
|
||||
.clone();
|
||||
conn: ProtocolNetworkConnection,
|
||||
) -> Result<ConnectionHandle, String> {
|
||||
log_net!("on_new_protocol_network_connection: {:?}", conn);
|
||||
|
||||
let receiver_loop_future = Self::process_connection(self.clone(), conn.clone());
|
||||
tx.try_send(receiver_loop_future)
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "failed to start receiver loop"))?;
|
||||
|
||||
// If the receiver loop started successfully,
|
||||
// add the new connection to the table
|
||||
inner.connection_table.add_connection(conn)
|
||||
// Wrap with NetworkConnection object to start the connection processing loop
|
||||
let conn = NetworkConnection::from_protocol(self.clone(), conn);
|
||||
let handle = conn.get_handle();
|
||||
// Add to the connection table
|
||||
inner.connection_table.add_connection(conn)?;
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
// Called by low-level network when any connection-oriented protocol connection appears
|
||||
// either from incoming or outgoing connections. Registers connection in the connection table for later access
|
||||
// and spawns a message processing loop for the connection
|
||||
pub async fn on_new_connection(&self, conn: NetworkConnection) -> Result<(), String> {
|
||||
// either from incoming connections.
|
||||
pub(super) async fn on_accepted_protocol_network_connection(
|
||||
&self,
|
||||
conn: ProtocolNetworkConnection,
|
||||
) -> Result<(), String> {
|
||||
let mut inner = self.arc.inner.lock().await;
|
||||
self.on_new_connection_internal(&mut *inner, conn)
|
||||
self.on_new_protocol_network_connection(&mut *inner, conn)
|
||||
.map(drop)
|
||||
}
|
||||
|
||||
// Called when we want to create a new connection or get the current one that already exists
|
||||
// This will kill off any connections that are in conflict with the new connection to be made
|
||||
// in order to make room for the new connection in the system's connection table
|
||||
pub async fn get_or_create_connection(
|
||||
&self,
|
||||
local_addr: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
) -> Result<ConnectionHandle, String> {
|
||||
log_net!(
|
||||
"== get_or_create_connection local_addr={:?} dial_info={:?}",
|
||||
local_addr.green(),
|
||||
@ -146,8 +144,10 @@ impl ConnectionManager {
|
||||
if local_addr.port() != 0 {
|
||||
for pt in [ProtocolType::TCP, ProtocolType::WS, ProtocolType::WSS] {
|
||||
let pa = PeerAddress::new(descriptor.remote.socket_address, pt);
|
||||
for conn in inner.connection_table.get_connections_by_remote(pa) {
|
||||
let desc = conn.connection_descriptor();
|
||||
for desc in inner
|
||||
.connection_table
|
||||
.get_connection_descriptors_by_remote(pa)
|
||||
{
|
||||
let mut kill = false;
|
||||
if let Some(conn_local) = desc.local {
|
||||
if (local_addr.ip().is_unspecified()
|
||||
@ -163,7 +163,9 @@ impl ConnectionManager {
|
||||
local_addr.green(),
|
||||
pa.green()
|
||||
);
|
||||
conn.close().await?;
|
||||
if let Err(e) = inner.connection_table.remove_connection(descriptor) {
|
||||
log_net!(error e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -171,73 +173,17 @@ impl ConnectionManager {
|
||||
}
|
||||
|
||||
// Attempt new connection
|
||||
let conn = NetworkConnection::connect(local_addr, dial_info).await?;
|
||||
let conn = ProtocolNetworkConnection::connect(local_addr, dial_info).await?;
|
||||
|
||||
self.on_new_connection_internal(&mut *inner, conn.clone())?;
|
||||
|
||||
Ok(conn)
|
||||
self.on_new_protocol_network_connection(&mut *inner, conn)
|
||||
}
|
||||
|
||||
// Connection receiver loop
|
||||
fn process_connection(
|
||||
this: ConnectionManager,
|
||||
conn: NetworkConnection,
|
||||
) -> SystemPinBoxFuture<()> {
|
||||
log_net!("Starting process_connection loop for {:?}", conn.green());
|
||||
let network_manager = this.network_manager();
|
||||
Box::pin(async move {
|
||||
//
|
||||
let descriptor = conn.connection_descriptor();
|
||||
let inactivity_timeout = this
|
||||
.network_manager()
|
||||
.config()
|
||||
.get()
|
||||
.network
|
||||
.connection_inactivity_timeout_ms;
|
||||
loop {
|
||||
// process inactivity timeout on receives only
|
||||
// if you want a keepalive, it has to be requested from the other side
|
||||
let message = select! {
|
||||
res = conn.recv().fuse() => {
|
||||
match res {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
log_net!(debug e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = intf::sleep(inactivity_timeout).fuse()=> {
|
||||
// timeout
|
||||
log_net!("connection timeout on {:?}", descriptor.green());
|
||||
break;
|
||||
}
|
||||
};
|
||||
if let Err(e) = network_manager
|
||||
.on_recv_envelope(message.as_slice(), descriptor)
|
||||
.await
|
||||
{
|
||||
log_net!(error e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
log_net!(
|
||||
"== Connection loop finished local_addr={:?} remote={:?}",
|
||||
descriptor.local.green(),
|
||||
descriptor.remote.green()
|
||||
);
|
||||
|
||||
if let Err(e) = this
|
||||
.arc
|
||||
.inner
|
||||
.lock()
|
||||
.await
|
||||
.connection_table
|
||||
.remove_connection(descriptor)
|
||||
{
|
||||
// Callback from network connection receive loop when it exits
|
||||
// cleans up the entry in the connection table
|
||||
pub(super) async fn report_connection_finished(&self, descriptor: ConnectionDescriptor) {
|
||||
let mut inner = self.arc.inner.lock().await;
|
||||
if let Err(e) = inner.connection_table.remove_connection(descriptor) {
|
||||
log_net!(error e);
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,4 @@
|
||||
use super::connection_limits::*;
|
||||
use super::network_connection::*;
|
||||
use crate::xx::*;
|
||||
use crate::*;
|
||||
use super::*;
|
||||
use alloc::collections::btree_map::Entry;
|
||||
use hashlink::LruCache;
|
||||
|
||||
@ -9,7 +6,7 @@ use hashlink::LruCache;
|
||||
pub struct ConnectionTable {
|
||||
max_connections: Vec<usize>,
|
||||
conn_by_descriptor: Vec<LruCache<ConnectionDescriptor, NetworkConnection>>,
|
||||
conns_by_remote: BTreeMap<PeerAddress, Vec<NetworkConnection>>,
|
||||
descriptors_by_remote: BTreeMap<PeerAddress, Vec<ConnectionDescriptor>>,
|
||||
address_filter: ConnectionLimits,
|
||||
}
|
||||
|
||||
@ -39,7 +36,7 @@ impl ConnectionTable {
|
||||
LruCache::new_unbounded(),
|
||||
LruCache::new_unbounded(),
|
||||
],
|
||||
conns_by_remote: BTreeMap::new(),
|
||||
descriptors_by_remote: BTreeMap::new(),
|
||||
address_filter: ConnectionLimits::new(config),
|
||||
}
|
||||
}
|
||||
@ -60,7 +57,7 @@ impl ConnectionTable {
|
||||
self.address_filter.add(ip_addr).map_err(map_to_string)?;
|
||||
|
||||
// Add the connection to the table
|
||||
let res = self.conn_by_descriptor[index].insert(descriptor, conn.clone());
|
||||
let res = self.conn_by_descriptor[index].insert(descriptor.clone(), conn);
|
||||
assert!(res.is_none());
|
||||
|
||||
// if we have reached the maximum number of connections per protocol type
|
||||
@ -73,49 +70,54 @@ impl ConnectionTable {
|
||||
}
|
||||
|
||||
// add connection records
|
||||
let conns = self.conns_by_remote.entry(descriptor.remote).or_default();
|
||||
let descriptors = self
|
||||
.descriptors_by_remote
|
||||
.entry(descriptor.remote)
|
||||
.or_default();
|
||||
|
||||
warn!("add_connection: {:?}", conn);
|
||||
conns.push(conn);
|
||||
warn!("add_connection: {:?}", descriptor);
|
||||
descriptors.push(descriptor);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_connection(
|
||||
&mut self,
|
||||
descriptor: ConnectionDescriptor,
|
||||
) -> Option<NetworkConnection> {
|
||||
pub fn get_connection(&mut self, descriptor: ConnectionDescriptor) -> Option<ConnectionHandle> {
|
||||
warn!("get_connection: {:?}", descriptor);
|
||||
let index = protocol_to_index(descriptor.protocol_type());
|
||||
let out = self.conn_by_descriptor[index].get(&descriptor).cloned();
|
||||
warn!("get_connection: {:?} -> {:?}", descriptor, out);
|
||||
out
|
||||
let out = self.conn_by_descriptor[index].get(&descriptor);
|
||||
out.map(|c| c.get_handle())
|
||||
}
|
||||
|
||||
pub fn get_last_connection_by_remote(
|
||||
&mut self,
|
||||
remote: PeerAddress,
|
||||
) -> Option<NetworkConnection> {
|
||||
let out = self
|
||||
.conns_by_remote
|
||||
) -> Option<ConnectionHandle> {
|
||||
warn!("get_last_connection_by_remote: {:?}", remote);
|
||||
let descriptor = self
|
||||
.descriptors_by_remote
|
||||
.get(&remote)
|
||||
.map(|v| v[(v.len() - 1)].clone());
|
||||
warn!("get_last_connection_by_remote: {:?} -> {:?}", remote, out);
|
||||
if let Some(connection) = &out {
|
||||
if let Some(descriptor) = descriptor {
|
||||
// lru bump
|
||||
let index = protocol_to_index(connection.connection_descriptor().protocol_type());
|
||||
let _ = self.conn_by_descriptor[index].get(&connection.connection_descriptor());
|
||||
let index = protocol_to_index(descriptor.protocol_type());
|
||||
let handle = self.conn_by_descriptor[index]
|
||||
.get(&descriptor)
|
||||
.map(|c| c.get_handle());
|
||||
handle
|
||||
} else {
|
||||
None
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
pub fn get_connections_by_remote(&mut self, remote: PeerAddress) -> Vec<NetworkConnection> {
|
||||
let out = self
|
||||
.conns_by_remote
|
||||
pub fn get_connection_descriptors_by_remote(
|
||||
&mut self,
|
||||
remote: PeerAddress,
|
||||
) -> Vec<ConnectionDescriptor> {
|
||||
warn!("get_connection_descriptors_by_remote: {:?}", remote);
|
||||
self.descriptors_by_remote
|
||||
.get(&remote)
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
warn!("get_connections_by_remote: {:?} -> {:?}", remote, out);
|
||||
out
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn connection_count(&self) -> usize {
|
||||
@ -126,7 +128,7 @@ impl ConnectionTable {
|
||||
let ip_addr = descriptor.remote.socket_address.to_ip_addr();
|
||||
|
||||
// conns_by_remote
|
||||
match self.conns_by_remote.entry(descriptor.remote) {
|
||||
match self.descriptors_by_remote.entry(descriptor.remote) {
|
||||
Entry::Vacant(_) => {
|
||||
panic!("inconsistency in connection table")
|
||||
}
|
||||
@ -135,7 +137,7 @@ impl ConnectionTable {
|
||||
|
||||
// Remove one matching connection from the list
|
||||
for (n, elem) in v.iter().enumerate() {
|
||||
if elem.connection_descriptor() == descriptor {
|
||||
if *elem == descriptor {
|
||||
v.remove(n);
|
||||
break;
|
||||
}
|
||||
@ -151,18 +153,14 @@ impl ConnectionTable {
|
||||
.expect("Inconsistency in connection table");
|
||||
}
|
||||
|
||||
pub fn remove_connection(
|
||||
&mut self,
|
||||
descriptor: ConnectionDescriptor,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
pub fn remove_connection(&mut self, descriptor: ConnectionDescriptor) -> Result<(), String> {
|
||||
warn!("remove_connection: {:?}", descriptor);
|
||||
let index = protocol_to_index(descriptor.protocol_type());
|
||||
let out = self.conn_by_descriptor[index]
|
||||
let _ = self.conn_by_descriptor[index]
|
||||
.remove(&descriptor)
|
||||
.ok_or_else(|| format!("Connection not in table: {:?}", descriptor))?;
|
||||
|
||||
self.remove_connection_records(descriptor);
|
||||
|
||||
Ok(out)
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ mod native;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
mod wasm;
|
||||
|
||||
mod connection_handle;
|
||||
mod connection_limits;
|
||||
mod connection_manager;
|
||||
mod connection_table;
|
||||
@ -17,8 +18,9 @@ pub mod tests;
|
||||
pub use network_connection::*;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
use connection_limits::*;
|
||||
use connection_manager::*;
|
||||
use connection_handle::*;
|
||||
use dht::*;
|
||||
use hashlink::LruCache;
|
||||
use intf::*;
|
||||
@ -1034,7 +1036,7 @@ impl NetworkManager {
|
||||
// Called when a packet potentially containing an RPC envelope is received by a low-level
|
||||
// network protocol handler. Processes the envelope, authenticates and decrypts the RPC message
|
||||
// and passes it to the RPC handler
|
||||
pub async fn on_recv_envelope(
|
||||
async fn on_recv_envelope(
|
||||
&self,
|
||||
data: &[u8],
|
||||
descriptor: ConnectionDescriptor,
|
||||
|
@ -341,7 +341,7 @@ impl Network {
|
||||
log_net!("send_data_to_existing_connection to {:?}", descriptor);
|
||||
|
||||
// connection exists, send over it
|
||||
conn.send(data).await.map_err(logthru_net!())?;
|
||||
conn.send_async(data).await.map_err(logthru_net!())?;
|
||||
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
@ -389,7 +389,7 @@ impl Network {
|
||||
.get_or_create_connection(Some(local_addr), dial_info.clone())
|
||||
.await?;
|
||||
|
||||
let res = conn.send(data).await.map_err(logthru_net!(error));
|
||||
let res = conn.send_async(data).await.map_err(logthru_net!(error));
|
||||
if res.is_ok() {
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
|
@ -7,7 +7,7 @@ use sockets::*;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ListenerState {
|
||||
pub protocol_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
|
||||
pub protocol_accept_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
|
||||
pub tls_protocol_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
|
||||
pub tls_acceptor: Option<TlsAcceptor>,
|
||||
}
|
||||
@ -15,7 +15,7 @@ pub struct ListenerState {
|
||||
impl ListenerState {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
protocol_handlers: Vec::new(),
|
||||
protocol_accept_handlers: Vec::new(),
|
||||
tls_protocol_handlers: Vec::new(),
|
||||
tls_acceptor: None,
|
||||
}
|
||||
@ -46,7 +46,7 @@ impl Network {
|
||||
addr: SocketAddr,
|
||||
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
|
||||
tls_connection_initial_timeout: u64,
|
||||
) -> Result<Option<NetworkConnection>, String> {
|
||||
) -> Result<Option<ProtocolNetworkConnection>, String> {
|
||||
let ts = tls_acceptor
|
||||
.accept(stream)
|
||||
.await
|
||||
@ -76,9 +76,9 @@ impl Network {
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
addr: SocketAddr,
|
||||
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
|
||||
) -> Result<Option<NetworkConnection>, String> {
|
||||
for ah in protocol_handlers.iter() {
|
||||
protocol_accept_handlers: &[Box<dyn ProtocolAcceptHandler>],
|
||||
) -> Result<Option<ProtocolNetworkConnection>, String> {
|
||||
for ah in protocol_accept_handlers.iter() {
|
||||
if let Some(nc) = ah
|
||||
.on_accept(stream.clone(), tcp_stream.clone(), addr)
|
||||
.await
|
||||
@ -185,7 +185,7 @@ impl Network {
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
this.try_handlers(ps, tcp_stream, addr, &ls.protocol_handlers)
|
||||
this.try_handlers(ps, tcp_stream, addr, &ls.protocol_accept_handlers)
|
||||
.await
|
||||
};
|
||||
|
||||
@ -207,7 +207,10 @@ impl Network {
|
||||
};
|
||||
|
||||
// Register the new connection in the connection manager
|
||||
if let Err(e) = connection_manager.on_new_connection(conn).await {
|
||||
if let Err(e) = connection_manager
|
||||
.on_accepted_protocol_network_connection(conn)
|
||||
.await
|
||||
{
|
||||
log_net!(error "failed to register new connection: {}", e);
|
||||
}
|
||||
})
|
||||
@ -270,7 +273,7 @@ impl Network {
|
||||
));
|
||||
} else {
|
||||
ls.write()
|
||||
.protocol_handlers
|
||||
.protocol_accept_handlers
|
||||
.push(new_protocol_accept_handler(
|
||||
self.network_manager().config(),
|
||||
false,
|
||||
|
@ -21,7 +21,7 @@ impl ProtocolNetworkConnection {
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
) -> Result<ProtocolNetworkConnection, String> {
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
panic!("Should not connect to UDP dialinfo");
|
||||
@ -55,6 +55,16 @@ impl ProtocolNetworkConnection {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn descriptor(&self) -> ConnectionDescriptor {
|
||||
match self {
|
||||
Self::Dummy(d) => d.descriptor(),
|
||||
Self::RawTcp(t) => t.descriptor(),
|
||||
Self::WsAccepted(w) => w.descriptor(),
|
||||
Self::Ws(w) => w.descriptor(),
|
||||
Self::Wss(w) => w.descriptor(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.close(),
|
||||
|
@ -3,6 +3,7 @@ use futures_util::{AsyncReadExt, AsyncWriteExt};
|
||||
use sockets::*;
|
||||
|
||||
pub struct RawTcpNetworkConnection {
|
||||
descriptor: ConnectionDescriptor,
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
}
|
||||
@ -14,8 +15,20 @@ impl fmt::Debug for RawTcpNetworkConnection {
|
||||
}
|
||||
|
||||
impl RawTcpNetworkConnection {
|
||||
pub fn new(stream: AsyncPeekStream, tcp_stream: TcpStream) -> Self {
|
||||
Self { stream, tcp_stream }
|
||||
pub fn new(
|
||||
descriptor: ConnectionDescriptor,
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
) -> Self {
|
||||
Self {
|
||||
descriptor,
|
||||
stream,
|
||||
tcp_stream,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn descriptor(&self) -> ConnectionDescriptor {
|
||||
self.descriptor.clone()
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
@ -33,7 +46,7 @@ impl RawTcpNetworkConnection {
|
||||
.map_err(logthru_net!())
|
||||
}
|
||||
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
async fn send_internal(mut stream: AsyncPeekStream, message: Vec<u8>) -> Result<(), String> {
|
||||
log_net!("sending TCP message of size {}", message.len());
|
||||
if message.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large TCP message".to_owned());
|
||||
@ -41,7 +54,6 @@ impl RawTcpNetworkConnection {
|
||||
let len = message.len() as u16;
|
||||
let header = [b'V', b'L', len as u8, (len >> 8) as u8];
|
||||
|
||||
let mut stream = self.stream.clone();
|
||||
stream
|
||||
.write_all(&header)
|
||||
.await
|
||||
@ -54,6 +66,11 @@ impl RawTcpNetworkConnection {
|
||||
.map_err(logthru_net!())
|
||||
}
|
||||
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
let stream = self.stream.clone();
|
||||
Self::send_internal(stream, message).await
|
||||
}
|
||||
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
let mut header = [0u8; 4];
|
||||
|
||||
@ -108,7 +125,7 @@ impl RawTcpProtocolHandler {
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
socket_addr: SocketAddr,
|
||||
) -> Result<Option<NetworkConnection>, String> {
|
||||
) -> Result<Option<ProtocolNetworkConnection>, String> {
|
||||
log_net!("TCP: on_accept_async: enter");
|
||||
let mut peekbuf: [u8; PEEK_DETECT_LEN] = [0u8; PEEK_DETECT_LEN];
|
||||
let peeklen = stream
|
||||
@ -123,10 +140,11 @@ impl RawTcpProtocolHandler {
|
||||
ProtocolType::TCP,
|
||||
);
|
||||
let local_address = self.inner.lock().local_address;
|
||||
let conn = NetworkConnection::from_protocol(
|
||||
let conn = ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(
|
||||
ConnectionDescriptor::new(peer_addr, SocketAddress::from_socket_addr(local_address)),
|
||||
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(stream, tcp_stream)),
|
||||
);
|
||||
stream,
|
||||
tcp_stream,
|
||||
));
|
||||
|
||||
log_net!(debug "TCP: on_accept_async from: {}", socket_addr);
|
||||
|
||||
@ -136,7 +154,7 @@ impl RawTcpProtocolHandler {
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
) -> Result<ProtocolNetworkConnection, String> {
|
||||
// Get remote socket address to connect to
|
||||
let remote_socket_addr = dial_info.to_socket_addr();
|
||||
|
||||
@ -161,13 +179,15 @@ impl RawTcpProtocolHandler {
|
||||
let ps = AsyncPeekStream::new(ts.clone());
|
||||
|
||||
// Wrap the stream in a network connection and return it
|
||||
let conn = NetworkConnection::from_protocol(
|
||||
let conn = ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(
|
||||
ConnectionDescriptor {
|
||||
local: Some(SocketAddress::from_socket_addr(actual_local_address)),
|
||||
remote: dial_info.to_peer_address(),
|
||||
},
|
||||
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps, ts)),
|
||||
);
|
||||
ps,
|
||||
ts,
|
||||
));
|
||||
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
@ -194,24 +214,15 @@ impl RawTcpProtocolHandler {
|
||||
.map_err(logthru_net!(error "remote_addr={}", socket_addr))?;
|
||||
|
||||
// See what local address we ended up with and turn this into a stream
|
||||
let actual_local_address = ts
|
||||
.local_addr()
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("could not get local address from TCP stream"))?;
|
||||
// let actual_local_address = ts
|
||||
// .local_addr()
|
||||
// .map_err(map_to_string)
|
||||
// .map_err(logthru_net!("could not get local address from TCP stream"))?;
|
||||
let ps = AsyncPeekStream::new(ts.clone());
|
||||
|
||||
// Wrap the stream in a network connection and return it
|
||||
let conn = NetworkConnection::from_protocol(
|
||||
ConnectionDescriptor {
|
||||
local: Some(SocketAddress::from_socket_addr(actual_local_address)),
|
||||
remote: PeerAddress::new(
|
||||
SocketAddress::from_socket_addr(socket_addr),
|
||||
ProtocolType::TCP,
|
||||
),
|
||||
},
|
||||
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps, ts)),
|
||||
);
|
||||
conn.send(data).await
|
||||
// Send directly from the raw network connection
|
||||
// this builds the connection and tears it down immediately after the send
|
||||
RawTcpNetworkConnection::send_internal(ps, data).await
|
||||
}
|
||||
}
|
||||
|
||||
@ -221,7 +232,7 @@ impl ProtocolAcceptHandler for RawTcpProtocolHandler {
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> SystemPinBoxFuture<core::result::Result<Option<NetworkConnection>, String>> {
|
||||
) -> SystemPinBoxFuture<core::result::Result<Option<ProtocolNetworkConnection>, String>> {
|
||||
Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr))
|
||||
}
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ pub struct WebsocketNetworkConnection<T>
|
||||
where
|
||||
T: io::Read + io::Write + Send + Unpin + 'static,
|
||||
{
|
||||
descriptor: ConnectionDescriptor,
|
||||
stream: CloneStream<WebSocketStream<T>>,
|
||||
tcp_stream: TcpStream,
|
||||
}
|
||||
@ -32,13 +33,22 @@ impl<T> WebsocketNetworkConnection<T>
|
||||
where
|
||||
T: io::Read + io::Write + Send + Unpin + 'static,
|
||||
{
|
||||
pub fn new(stream: WebSocketStream<T>, tcp_stream: TcpStream) -> Self {
|
||||
pub fn new(
|
||||
descriptor: ConnectionDescriptor,
|
||||
stream: WebSocketStream<T>,
|
||||
tcp_stream: TcpStream,
|
||||
) -> Self {
|
||||
Self {
|
||||
descriptor,
|
||||
stream: CloneStream::new(stream),
|
||||
tcp_stream,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn descriptor(&self) -> ConnectionDescriptor {
|
||||
self.descriptor.clone()
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
// Make an attempt to flush the stream
|
||||
self.stream
|
||||
@ -132,7 +142,7 @@ impl WebsocketProtocolHandler {
|
||||
ps: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
socket_addr: SocketAddr,
|
||||
) -> Result<Option<NetworkConnection>, String> {
|
||||
) -> Result<Option<ProtocolNetworkConnection>, String> {
|
||||
log_net!("WS: on_accept_async: enter");
|
||||
let request_path_len = self.arc.request_path.len() + 2;
|
||||
|
||||
@ -179,25 +189,24 @@ impl WebsocketProtocolHandler {
|
||||
let peer_addr =
|
||||
PeerAddress::new(SocketAddress::from_socket_addr(socket_addr), protocol_type);
|
||||
|
||||
let conn = NetworkConnection::from_protocol(
|
||||
let conn = ProtocolNetworkConnection::WsAccepted(WebsocketNetworkConnection::new(
|
||||
ConnectionDescriptor::new(
|
||||
peer_addr,
|
||||
SocketAddress::from_socket_addr(self.arc.local_address),
|
||||
),
|
||||
ProtocolNetworkConnection::WsAccepted(WebsocketNetworkConnection::new(
|
||||
ws_stream, tcp_stream,
|
||||
)),
|
||||
);
|
||||
ws_stream,
|
||||
tcp_stream,
|
||||
));
|
||||
|
||||
log_net!(debug "{}: on_accept_async from: {}", if self.arc.tls { "WSS" } else { "WS" }, socket_addr);
|
||||
|
||||
Ok(Some(conn))
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
async fn connect_internal(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
) -> Result<ProtocolNetworkConnection, String> {
|
||||
// Split dial info up
|
||||
let (tls, scheme) = match &dial_info {
|
||||
DialInfo::WS(_) => (false, "ws"),
|
||||
@ -251,26 +260,27 @@ impl WebsocketProtocolHandler {
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
|
||||
Ok(NetworkConnection::from_protocol(
|
||||
descriptor,
|
||||
ProtocolNetworkConnection::Wss(WebsocketNetworkConnection::new(
|
||||
ws_stream, tcp_stream,
|
||||
)),
|
||||
Ok(ProtocolNetworkConnection::Wss(
|
||||
WebsocketNetworkConnection::new(descriptor, ws_stream, tcp_stream),
|
||||
))
|
||||
} else {
|
||||
let (ws_stream, _response) = client_async(request, tcp_stream.clone())
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
Ok(NetworkConnection::from_protocol(
|
||||
descriptor,
|
||||
ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new(
|
||||
ws_stream, tcp_stream,
|
||||
)),
|
||||
Ok(ProtocolNetworkConnection::Ws(
|
||||
WebsocketNetworkConnection::new(descriptor, ws_stream, tcp_stream),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<ProtocolNetworkConnection, String> {
|
||||
Self::connect_internal(local_address, dial_info).await
|
||||
}
|
||||
|
||||
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large unbound WS message".to_owned());
|
||||
@ -281,11 +291,11 @@ impl WebsocketProtocolHandler {
|
||||
dial_info,
|
||||
);
|
||||
|
||||
let conn = Self::connect(None, dial_info.clone())
|
||||
let protconn = Self::connect_internal(None, dial_info.clone())
|
||||
.await
|
||||
.map_err(|e| format!("failed to connect websocket for unbound message: {}", e))?;
|
||||
|
||||
conn.send(data).await
|
||||
protconn.send(data).await
|
||||
}
|
||||
}
|
||||
|
||||
@ -295,7 +305,7 @@ impl ProtocolAcceptHandler for WebsocketProtocolHandler {
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>> {
|
||||
) -> SystemPinBoxFuture<Result<Option<ProtocolNetworkConnection>, String>> {
|
||||
Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr))
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
use super::*;
|
||||
use crate::xx::*;
|
||||
use futures_util::{FutureExt, StreamExt};
|
||||
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(target_arch = "wasm32")] {
|
||||
@ -16,7 +16,7 @@ cfg_if::cfg_if! {
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>>;
|
||||
) -> SystemPinBoxFuture<Result<Option<ProtocolNetworkConnection>, String>>;
|
||||
}
|
||||
|
||||
pub trait ProtocolAcceptHandlerClone {
|
||||
@ -45,9 +45,14 @@ cfg_if::cfg_if! {
|
||||
// Dummy protocol network connection for testing
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DummyNetworkConnection {}
|
||||
pub struct DummyNetworkConnection {
|
||||
descriptor: ConnectionDescriptor,
|
||||
}
|
||||
|
||||
impl DummyNetworkConnection {
|
||||
pub fn descriptor(&self) -> ConnectionDescriptor {
|
||||
self.descriptor.clone()
|
||||
}
|
||||
pub fn close(&self) -> Result<(), String> {
|
||||
Ok(())
|
||||
}
|
||||
@ -62,6 +67,14 @@ impl DummyNetworkConnection {
|
||||
///////////////////////////////////////////////////////////
|
||||
// Top-level protocol independent network connection object
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
enum RecvLoopAction {
|
||||
Send,
|
||||
Recv,
|
||||
Finish,
|
||||
Timeout,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NetworkConnectionStats {
|
||||
last_message_sent_time: Option<u64>,
|
||||
@ -69,107 +82,249 @@ pub struct NetworkConnectionStats {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct NetworkConnectionInner {
|
||||
stats: NetworkConnectionStats,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct NetworkConnectionArc {
|
||||
descriptor: ConnectionDescriptor,
|
||||
protocol_connection: ProtocolNetworkConnection,
|
||||
established_time: u64,
|
||||
inner: Mutex<NetworkConnectionInner>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NetworkConnection {
|
||||
arc: Arc<NetworkConnectionArc>,
|
||||
descriptor: ConnectionDescriptor,
|
||||
_processor: Option<JoinHandle<()>>,
|
||||
established_time: u64,
|
||||
stats: Arc<Mutex<NetworkConnectionStats>>,
|
||||
sender: flume::Sender<Vec<u8>>,
|
||||
}
|
||||
impl PartialEq for NetworkConnection {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
Arc::as_ptr(&self.arc) == Arc::as_ptr(&other.arc)
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for NetworkConnection {}
|
||||
|
||||
impl NetworkConnection {
|
||||
fn new_inner() -> NetworkConnectionInner {
|
||||
NetworkConnectionInner {
|
||||
stats: NetworkConnectionStats {
|
||||
pub(super) fn dummy(descriptor: ConnectionDescriptor) -> Self {
|
||||
// Create handle for sending (dummy is immediately disconnected)
|
||||
let (sender, _receiver) = flume::bounded(intf::get_concurrency() as usize);
|
||||
|
||||
Self {
|
||||
descriptor,
|
||||
_processor: None,
|
||||
established_time: intf::get_timestamp(),
|
||||
stats: Arc::new(Mutex::new(NetworkConnectionStats {
|
||||
last_message_sent_time: None,
|
||||
last_message_recv_time: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
fn new_arc(
|
||||
descriptor: ConnectionDescriptor,
|
||||
protocol_connection: ProtocolNetworkConnection,
|
||||
) -> NetworkConnectionArc {
|
||||
NetworkConnectionArc {
|
||||
descriptor,
|
||||
protocol_connection,
|
||||
established_time: intf::get_timestamp(),
|
||||
inner: Mutex::new(Self::new_inner()),
|
||||
})),
|
||||
sender,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dummy(descriptor: ConnectionDescriptor) -> Self {
|
||||
NetworkConnection::from_protocol(
|
||||
descriptor,
|
||||
ProtocolNetworkConnection::Dummy(DummyNetworkConnection {}),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn from_protocol(
|
||||
descriptor: ConnectionDescriptor,
|
||||
pub(super) fn from_protocol(
|
||||
connection_manager: ConnectionManager,
|
||||
protocol_connection: ProtocolNetworkConnection,
|
||||
) -> Self {
|
||||
Self {
|
||||
arc: Arc::new(Self::new_arc(descriptor, protocol_connection)),
|
||||
}
|
||||
}
|
||||
// Get timeout
|
||||
let network_manager = connection_manager.network_manager();
|
||||
let inactivity_timeout = network_manager
|
||||
.config()
|
||||
.get()
|
||||
.network
|
||||
.connection_inactivity_timeout_ms;
|
||||
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
ProtocolNetworkConnection::connect(local_address, dial_info).await
|
||||
// Get descriptor
|
||||
let descriptor = protocol_connection.descriptor();
|
||||
|
||||
// Create handle for sending
|
||||
let (sender, receiver) = flume::bounded(intf::get_concurrency() as usize);
|
||||
|
||||
// Create stats
|
||||
let stats = Arc::new(Mutex::new(NetworkConnectionStats {
|
||||
last_message_sent_time: None,
|
||||
last_message_recv_time: None,
|
||||
}));
|
||||
|
||||
// Spawn connection processor and pass in protocol connection
|
||||
let processor = intf::spawn_local(Self::process_connection(
|
||||
connection_manager,
|
||||
descriptor.clone(),
|
||||
receiver,
|
||||
protocol_connection,
|
||||
inactivity_timeout,
|
||||
stats.clone(),
|
||||
));
|
||||
|
||||
// Return the connection
|
||||
Self {
|
||||
descriptor,
|
||||
_processor: Some(processor),
|
||||
established_time: intf::get_timestamp(),
|
||||
stats,
|
||||
sender,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn connection_descriptor(&self) -> ConnectionDescriptor {
|
||||
self.arc.descriptor
|
||||
self.descriptor.clone()
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
self.arc.protocol_connection.close().await
|
||||
pub fn get_handle(&self) -> ConnectionHandle {
|
||||
ConnectionHandle::new(self.descriptor.clone(), self.sender.clone())
|
||||
}
|
||||
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
async fn send_internal(
|
||||
protocol_connection: &ProtocolNetworkConnection,
|
||||
stats: Arc<Mutex<NetworkConnectionStats>>,
|
||||
message: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
let ts = intf::get_timestamp();
|
||||
let out = self.arc.protocol_connection.send(message).await;
|
||||
let out = protocol_connection.send(message).await;
|
||||
if out.is_ok() {
|
||||
let mut inner = self.arc.inner.lock();
|
||||
inner.stats.last_message_sent_time.max_assign(Some(ts));
|
||||
let mut stats = stats.lock();
|
||||
stats.last_message_sent_time.max_assign(Some(ts));
|
||||
}
|
||||
out
|
||||
}
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
async fn recv_internal(
|
||||
protocol_connection: &ProtocolNetworkConnection,
|
||||
stats: Arc<Mutex<NetworkConnectionStats>>,
|
||||
) -> Result<Vec<u8>, String> {
|
||||
let ts = intf::get_timestamp();
|
||||
let out = self.arc.protocol_connection.recv().await;
|
||||
let out = protocol_connection.recv().await;
|
||||
if out.is_ok() {
|
||||
let mut inner = self.arc.inner.lock();
|
||||
inner.stats.last_message_recv_time.max_assign(Some(ts));
|
||||
let mut stats = stats.lock();
|
||||
stats.last_message_recv_time.max_assign(Some(ts));
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
pub fn stats(&self) -> NetworkConnectionStats {
|
||||
let inner = self.arc.inner.lock();
|
||||
inner.stats.clone()
|
||||
let stats = self.stats.lock();
|
||||
stats.clone()
|
||||
}
|
||||
|
||||
pub fn established_time(&self) -> u64 {
|
||||
self.arc.established_time
|
||||
self.established_time
|
||||
}
|
||||
|
||||
// Connection receiver loop
|
||||
fn process_connection(
|
||||
connection_manager: ConnectionManager,
|
||||
descriptor: ConnectionDescriptor,
|
||||
receiver: flume::Receiver<Vec<u8>>,
|
||||
protocol_connection: ProtocolNetworkConnection,
|
||||
connection_inactivity_timeout_ms: u32,
|
||||
stats: Arc<Mutex<NetworkConnectionStats>>,
|
||||
) -> SystemPinBoxFuture<()> {
|
||||
Box::pin(async move {
|
||||
log_net!(
|
||||
"Starting process_connection loop for {:?}",
|
||||
descriptor.green()
|
||||
);
|
||||
|
||||
let network_manager = connection_manager.network_manager();
|
||||
let mut unord = FuturesUnordered::new();
|
||||
let mut need_receiver = true;
|
||||
let mut need_sender = true;
|
||||
|
||||
// Push mutable timer so we can reset it
|
||||
// Normally we would use an io::timeout here, but WASM won't support that, so we use a mutable sleep future
|
||||
let new_timer = || {
|
||||
intf::sleep(connection_inactivity_timeout_ms).then(|_| async {
|
||||
// timeout
|
||||
log_net!("connection timeout on {:?}", descriptor.green());
|
||||
RecvLoopAction::Timeout
|
||||
})
|
||||
};
|
||||
let timer = MutableFuture::new(new_timer());
|
||||
unord.push(timer.clone().boxed());
|
||||
|
||||
loop {
|
||||
// Add another message sender future if necessary
|
||||
if need_sender {
|
||||
need_sender = false;
|
||||
unord.push(
|
||||
receiver
|
||||
.recv_async()
|
||||
.then(|res| async {
|
||||
match res {
|
||||
Ok(message) => {
|
||||
// send the packet
|
||||
if let Err(e) = Self::send_internal(
|
||||
&protocol_connection,
|
||||
stats.clone(),
|
||||
message,
|
||||
)
|
||||
.await
|
||||
{
|
||||
// Sending the packet along can fail, if so, this connection is dead
|
||||
log_net!(debug e);
|
||||
RecvLoopAction::Finish
|
||||
} else {
|
||||
RecvLoopAction::Send
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// All senders gone, shouldn't happen since we store one alongside the join handle
|
||||
log_net!(warn e);
|
||||
RecvLoopAction::Finish
|
||||
}
|
||||
}
|
||||
})
|
||||
.boxed(),
|
||||
);
|
||||
}
|
||||
|
||||
// Add another message receiver future if necessary
|
||||
if need_receiver {
|
||||
need_sender = false;
|
||||
unord.push(
|
||||
Self::recv_internal(&protocol_connection, stats.clone())
|
||||
.then(|res| async {
|
||||
match res {
|
||||
Ok(message) => {
|
||||
// Pass received messages up to the network manager for processing
|
||||
if let Err(e) = network_manager
|
||||
.on_recv_envelope(message.as_slice(), descriptor)
|
||||
.await
|
||||
{
|
||||
log_net!(error e);
|
||||
RecvLoopAction::Finish
|
||||
} else {
|
||||
RecvLoopAction::Recv
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// Connection unable to receive, closed
|
||||
log_net!(warn e);
|
||||
RecvLoopAction::Finish
|
||||
}
|
||||
}
|
||||
})
|
||||
.boxed(),
|
||||
);
|
||||
}
|
||||
|
||||
// Process futures
|
||||
match unord.next().await {
|
||||
Some(RecvLoopAction::Send) => {
|
||||
// Don't reset inactivity timer if we're only sending
|
||||
|
||||
need_sender = true;
|
||||
}
|
||||
Some(RecvLoopAction::Recv) => {
|
||||
// Reset inactivity timer since we got something from this connection
|
||||
timer.set(new_timer());
|
||||
|
||||
need_receiver = true;
|
||||
}
|
||||
Some(RecvLoopAction::Finish) | Some(RecvLoopAction::Timeout) => {
|
||||
break;
|
||||
}
|
||||
|
||||
None => {
|
||||
// Should not happen
|
||||
unreachable!();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log_net!(
|
||||
"== Connection loop finished local_addr={:?} remote={:?}",
|
||||
descriptor.local.green(),
|
||||
descriptor.remote.green()
|
||||
);
|
||||
|
||||
connection_manager
|
||||
.report_connection_finished(descriptor)
|
||||
.await
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -52,10 +52,15 @@ pub async fn test_add_get_remove() {
|
||||
);
|
||||
|
||||
let c1 = NetworkConnection::dummy(a1);
|
||||
let c1h = c1.get_handle();
|
||||
let c2 = NetworkConnection::dummy(a2);
|
||||
//let c2h = c2.get_handle();
|
||||
let c3 = NetworkConnection::dummy(a3);
|
||||
//let c3h = c3.get_handle();
|
||||
let c4 = NetworkConnection::dummy(a4);
|
||||
//let c4h = c4.get_handle();
|
||||
let c5 = NetworkConnection::dummy(a5);
|
||||
//let c5h = c5.get_handle();
|
||||
|
||||
assert_eq!(a1, c2.connection_descriptor());
|
||||
assert_ne!(a3, c4.connection_descriptor());
|
||||
@ -63,36 +68,39 @@ pub async fn test_add_get_remove() {
|
||||
|
||||
assert_eq!(table.connection_count(), 0);
|
||||
assert_eq!(table.get_connection(a1), None);
|
||||
table.add_connection(c1.clone()).unwrap();
|
||||
table.add_connection(c1).unwrap();
|
||||
|
||||
assert_eq!(table.connection_count(), 1);
|
||||
assert_err!(table.remove_connection(a3));
|
||||
assert_err!(table.remove_connection(a4));
|
||||
assert_eq!(table.connection_count(), 1);
|
||||
assert_eq!(table.get_connection(a1), Some(c1.clone()));
|
||||
assert_eq!(table.get_connection(a1), Some(c1.clone()));
|
||||
assert_eq!(table.get_connection(a1), Some(c1h.clone()));
|
||||
assert_eq!(table.get_connection(a1), Some(c1h.clone()));
|
||||
assert_eq!(table.connection_count(), 1);
|
||||
assert_err!(table.add_connection(c1.clone()));
|
||||
assert_err!(table.add_connection(c2.clone()));
|
||||
assert_err!(table.add_connection(c2));
|
||||
assert_eq!(table.connection_count(), 1);
|
||||
assert_eq!(table.get_connection(a1), Some(c1.clone()));
|
||||
assert_eq!(table.get_connection(a1), Some(c1.clone()));
|
||||
assert_eq!(table.get_connection(a1), Some(c1h.clone()));
|
||||
assert_eq!(table.get_connection(a1), Some(c1h.clone()));
|
||||
assert_eq!(table.connection_count(), 1);
|
||||
assert_eq!(table.remove_connection(a2), Ok(c1.clone()));
|
||||
assert_eq!(table.remove_connection(a2), Ok(()));
|
||||
assert_eq!(table.connection_count(), 0);
|
||||
assert_err!(table.remove_connection(a2));
|
||||
assert_eq!(table.connection_count(), 0);
|
||||
assert_eq!(table.get_connection(a2), None);
|
||||
assert_eq!(table.get_connection(a1), None);
|
||||
assert_eq!(table.connection_count(), 0);
|
||||
table.add_connection(c1.clone()).unwrap();
|
||||
let c1 = NetworkConnection::dummy(a1);
|
||||
//let c1h = c1.get_handle();
|
||||
table.add_connection(c1).unwrap();
|
||||
let c2 = NetworkConnection::dummy(a2);
|
||||
//let c2h = c2.get_handle();
|
||||
assert_err!(table.add_connection(c2));
|
||||
table.add_connection(c3.clone()).unwrap();
|
||||
table.add_connection(c4.clone()).unwrap();
|
||||
table.add_connection(c3).unwrap();
|
||||
table.add_connection(c4).unwrap();
|
||||
assert_eq!(table.connection_count(), 3);
|
||||
assert_eq!(table.remove_connection(a2), Ok(c1));
|
||||
assert_eq!(table.remove_connection(a3), Ok(c3));
|
||||
assert_eq!(table.remove_connection(a4), Ok(c4));
|
||||
assert_eq!(table.remove_connection(a2), Ok(()));
|
||||
assert_eq!(table.remove_connection(a3), Ok(()));
|
||||
assert_eq!(table.remove_connection(a4), Ok(()));
|
||||
assert_eq!(table.connection_count(), 0);
|
||||
}
|
||||
|
||||
|
@ -13,6 +13,7 @@ struct WebsocketNetworkConnectionInner {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WebsocketNetworkConnection {
|
||||
descriptor: ConnectionDescriptor,
|
||||
inner: Arc<WebsocketNetworkConnectionInner>,
|
||||
}
|
||||
|
||||
@ -23,8 +24,11 @@ impl fmt::Debug for WebsocketNetworkConnection {
|
||||
}
|
||||
|
||||
impl WebsocketNetworkConnection {
|
||||
pub fn new(ws_meta: WsMeta, ws_stream: WsStream) -> Self {
|
||||
pub fn new(
|
||||
descriptor: ConnectionDescriptor,
|
||||
ws_meta: WsMeta, ws_stream: WsStream) -> Self {
|
||||
Self {
|
||||
descriptor,
|
||||
inner: Arc::new(WebsocketNetworkConnectionInner {
|
||||
ws_meta,
|
||||
ws_stream: CloneStream::new(ws_stream),
|
||||
@ -32,6 +36,10 @@ impl WebsocketNetworkConnection {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn descriptor(&self) -> ConnectionDescriptor {
|
||||
self.descriptor.clone()
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
self.inner.ws_meta.close().await.map_err(map_to_string).map(drop)
|
||||
}
|
||||
@ -73,7 +81,7 @@ impl WebsocketProtocolHandler {
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
) -> Result<ProtocolNetworkConnection, String> {
|
||||
|
||||
assert!(local_address.is_none());
|
||||
|
||||
@ -96,10 +104,10 @@ impl WebsocketProtocolHandler {
|
||||
|
||||
// Make our connection descriptor
|
||||
|
||||
Ok(NetworkConnection::from_protocol(ConnectionDescriptor {
|
||||
Ok(ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new(ConnectionDescriptor {
|
||||
local: None,
|
||||
remote: dial_info.to_peer_address(),
|
||||
},ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new(wsmeta, wsio))))
|
||||
}, wsmeta, wsio)))
|
||||
}
|
||||
|
||||
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
|
||||
|
@ -8,6 +8,7 @@ mod eventual_value_clone;
|
||||
mod ip_addr_port;
|
||||
mod ip_extra;
|
||||
mod log_thru;
|
||||
mod mutable_future;
|
||||
mod single_future;
|
||||
mod single_shot_eventual;
|
||||
mod split_url;
|
||||
@ -104,6 +105,7 @@ pub use eventual_value::*;
|
||||
pub use eventual_value_clone::*;
|
||||
pub use ip_addr_port::*;
|
||||
pub use ip_extra::*;
|
||||
pub use mutable_future::*;
|
||||
pub use single_future::*;
|
||||
pub use single_shot_eventual::*;
|
||||
pub use tick_task::*;
|
||||
|
33
veilid-core/src/xx/mutable_future.rs
Normal file
33
veilid-core/src/xx/mutable_future.rs
Normal file
@ -0,0 +1,33 @@
|
||||
use super::*;
|
||||
|
||||
pub struct MutableFuture<O, T: Future<Output = O>> {
|
||||
inner: Arc<Mutex<Pin<Box<T>>>>,
|
||||
}
|
||||
|
||||
impl<O, T: Future<Output = O>> MutableFuture<O, T> {
|
||||
pub fn new(inner: T) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(Mutex::new(Box::pin(inner))),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set(&self, inner: T) {
|
||||
*self.inner.lock() = Box::pin(inner);
|
||||
}
|
||||
}
|
||||
|
||||
impl<O, T: Future<Output = O>> Clone for MutableFuture<O, T> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
inner: self.inner.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<O, T: Future<Output = O>> Future for MutableFuture<O, T> {
|
||||
type Output = O;
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
|
||||
let mut inner = self.inner.lock();
|
||||
T::poll(inner.as_mut(), cx)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user