network refactor for connection manager

This commit is contained in:
John Smith 2022-01-02 23:49:01 -05:00
parent c2c5e3c299
commit 94772094c5
19 changed files with 900 additions and 735 deletions

View File

@ -0,0 +1,246 @@
use crate::connection_table::*;
use crate::intf::*;
use crate::network_manager::*;
use crate::xx::*;
use crate::*;
use futures_util::future::{select, Either};
use futures_util::stream::{FuturesUnordered, StreamExt};
const CONNECTION_PROCESSOR_CHANNEL_SIZE: usize = 128usize;
type ProtocolConnectHandler = fn(Option<SocketAddr>, DialInfo) -> Result<NetworkConnection, String>;
type ProtocolConnectorMap = BTreeMap<ProtocolType, ProtocolConnectHandler>;
cfg_if! {
if #[cfg(not(target_arch = "wasm32"))] {
use async_std::net::*;
use utils::async_peek_stream::*;
pub trait ProtocolAcceptHandler: ProtocolAcceptHandlerClone + Send + Sync {
fn on_accept(
&self,
stream: AsyncPeekStream,
peer_addr: SocketAddr,
) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>>;
}
pub trait ProtocolAcceptHandlerClone {
fn clone_box(&self) -> Box<dyn ProtocolAcceptHandler>;
}
impl<T> ProtocolAcceptHandlerClone for T
where
T: 'static + ProtocolAcceptHandler + Clone,
{
fn clone_box(&self) -> Box<dyn ProtocolAcceptHandler> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn ProtocolAcceptHandler> {
fn clone(&self) -> Box<dyn ProtocolAcceptHandler> {
self.clone_box()
}
}
pub type NewProtocolAcceptHandler =
dyn Fn(ConnectionManager, bool, SocketAddr) -> Box<dyn ProtocolAcceptHandler> + Send;
}
}
pub struct ConnectionManagerInner {
network_manager: NetworkManager,
connection_table: ConnectionTable,
connection_processor_jh: Option<JoinHandle<()>>,
connection_add_channel_tx: Option<utils::channel::Sender<SystemPinBoxFuture<()>>>,
}
impl core::fmt::Debug for ConnectionManagerInner {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("ConnectionManagerInner")
.field("connection_table", &self.connection_table)
.finish()
}
}
#[derive(Clone)]
pub struct ConnectionManager {
inner: Arc<Mutex<ConnectionManagerInner>>,
}
impl core::fmt::Debug for ConnectionManager {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("ConnectionManager")
.field("inner", &*self.inner.lock())
.finish()
}
}
impl ConnectionManager {
fn new_inner(network_manager: NetworkManager) -> ConnectionManagerInner {
ConnectionManagerInner {
network_manager,
connection_table: ConnectionTable::new(),
connection_processor_jh: None,
connection_add_channel_tx: None,
}
}
pub fn new(network_manager: NetworkManager) -> Self {
Self {
inner: Arc::new(Mutex::new(Self::new_inner(network_manager))),
}
}
pub fn network_manager(&self) -> NetworkManager {
self.inner.lock().network_manager.clone()
}
pub fn config(&self) -> VeilidConfig {
self.network_manager().config()
}
pub async fn startup(&self) {
let cac = utils::channel::channel(CONNECTION_PROCESSOR_CHANNEL_SIZE); // xxx move to config
self.inner.lock().connection_add_channel_tx = Some(cac.0);
let rx = cac.1.clone();
let this = self.clone();
self.inner.lock().connection_processor_jh = Some(spawn(this.connection_processor(rx)));
}
pub async fn shutdown(&self) {
*self.inner.lock() = Self::new_inner(self.network_manager());
}
// Returns a network connection if one already is established
pub fn get_connection(&self, descriptor: &ConnectionDescriptor) -> Option<NetworkConnection> {
self.inner
.lock()
.connection_table
.get_connection(descriptor)
.map(|e| e.conn)
}
// Called by low-level network when any connection-oriented protocol connection appears
// either from incoming or outgoing connections. Registers connection in the connection table for later access
// and spawns a message processing loop for the connection
pub async fn on_new_connection(&self, conn: NetworkConnection) -> Result<(), String> {
let tx = self
.inner
.lock()
.connection_add_channel_tx
.as_ref()
.ok_or_else(fn_string!("connection channel isn't open yet"))?
.clone();
let receiver_loop_future = Self::process_connection(self.clone(), conn);
tx.try_send(receiver_loop_future)
.await
.map_err(map_to_string)
.map_err(logthru_net!(error "failed to start receiver loop"))
}
// Connection receiver loop
fn process_connection(
this: ConnectionManager,
conn: NetworkConnection,
) -> SystemPinBoxFuture<()> {
let network_manager = this.network_manager();
Box::pin(async move {
// Add new connections to the table
let entry = match this
.inner
.lock()
.connection_table
.add_connection(conn.clone())
{
Ok(e) => e,
Err(err) => {
error!(target: "net", "{}", err);
return;
}
};
//
let exit_value: Result<Vec<u8>, ()> = Err(());
let descriptor = conn.connection_descriptor();
loop {
let res = match select(
entry.stopper.clone().instance_clone(exit_value.clone()),
Box::pin(conn.clone().recv()),
)
.await
{
Either::Left((_x, _b)) => break,
Either::Right((y, _a)) => y,
};
let message = match res {
Ok(v) => v,
Err(_) => break,
};
match network_manager
.on_recv_envelope(message.as_slice(), &descriptor)
.await
{
Ok(_) => (),
Err(e) => {
error!("{}", e);
break;
}
};
}
if let Err(err) = this
.inner
.lock()
.connection_table
.remove_connection(&descriptor)
{
error!("{}", err);
}
})
}
// Process connection oriented sockets in the background
// This never terminates and must have its task cancelled once started
// Task cancellation is performed by shutdown() by dropping the join handle
async fn connection_processor(self, rx: utils::channel::Receiver<SystemPinBoxFuture<()>>) {
let mut connection_futures: FuturesUnordered<SystemPinBoxFuture<()>> =
FuturesUnordered::new();
loop {
// Either process an existing connection, or receive a new one to add to our list
match select(connection_futures.next(), Box::pin(rx.recv())).await {
Either::Left((x, _)) => {
// Processed some connection to completion, or there are none left
match x {
Some(()) => {
// Processed some connection to completion
}
None => {
// No connections to process, wait for one
match rx.recv().await {
Ok(v) => {
connection_futures.push(v);
}
Err(e) => {
log_net!(error "connection processor error: {:?}", e);
// xxx: do something here?? should the network be restarted if this happens?
}
};
}
}
}
Either::Right((x, _)) => {
// Got a new connection future
match x {
Ok(v) => {
connection_futures.push(v);
}
Err(e) => {
log_net!(error "connection processor error: {:?}", e);
// xxx: do something here?? should the network be restarted if this happens?
}
};
}
}
}
}
}

View File

@ -30,42 +30,22 @@ impl PartialEq for ConnectionTableEntry {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct ConnectionTableInner { pub struct ConnectionTable {
conn_by_addr: BTreeMap<ConnectionDescriptor, ConnectionTableEntry>, conn_by_addr: BTreeMap<ConnectionDescriptor, ConnectionTableEntry>,
} }
#[derive(Clone)]
pub struct ConnectionTable {
inner: Arc<Mutex<ConnectionTableInner>>,
}
impl core::fmt::Debug for ConnectionTable {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("ConnectionTable")
.field("inner", &*self.inner.lock())
.finish()
}
}
impl Default for ConnectionTable {
fn default() -> Self {
Self::new()
}
}
impl ConnectionTable { impl ConnectionTable {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
inner: Arc::new(Mutex::new(ConnectionTableInner { conn_by_addr: BTreeMap::new(),
conn_by_addr: BTreeMap::new(),
})),
} }
} }
pub fn add_connection( pub fn add_connection(
&self, &mut self,
descriptor: ConnectionDescriptor,
conn: NetworkConnection, conn: NetworkConnection,
) -> Result<ConnectionTableEntry, String> { ) -> Result<ConnectionTableEntry, String> {
trace!("descriptor: {:?}", descriptor); let descriptor = conn.connection_descriptor();
assert_ne!( assert_ne!(
descriptor.protocol_type(), descriptor.protocol_type(),
@ -73,8 +53,7 @@ impl ConnectionTable {
"Only connection oriented protocols go in the table!" "Only connection oriented protocols go in the table!"
); );
let mut inner = self.inner.lock(); if self.conn_by_addr.contains_key(&descriptor) {
if inner.conn_by_addr.contains_key(&descriptor) {
return Err(format!( return Err(format!(
"Connection already added to table: {:?}", "Connection already added to table: {:?}",
descriptor descriptor
@ -90,7 +69,7 @@ impl ConnectionTable {
last_message_recv_time: None, last_message_recv_time: None,
stopper: Eventual::new(), stopper: Eventual::new(),
}; };
let res = inner.conn_by_addr.insert(descriptor, entry.clone()); let res = self.conn_by_addr.insert(descriptor, entry.clone());
assert!(res.is_none()); assert!(res.is_none());
Ok(entry) Ok(entry)
} }
@ -99,27 +78,19 @@ impl ConnectionTable {
&self, &self,
descriptor: &ConnectionDescriptor, descriptor: &ConnectionDescriptor,
) -> Option<ConnectionTableEntry> { ) -> Option<ConnectionTableEntry> {
let inner = self.inner.lock(); self.conn_by_addr.get(descriptor).cloned()
inner.conn_by_addr.get(descriptor).cloned()
} }
pub fn connection_count(&self) -> usize { pub fn connection_count(&self) -> usize {
let inner = self.inner.lock(); self.conn_by_addr.len()
inner.conn_by_addr.len()
} }
pub fn remove_connection( pub fn remove_connection(
&self, &mut self,
descriptor: &ConnectionDescriptor, descriptor: &ConnectionDescriptor,
) -> Result<ConnectionTableEntry, String> { ) -> Result<ConnectionTableEntry, String> {
trace!("descriptor: {:?}", descriptor); self.conn_by_addr
.remove(descriptor)
let mut inner = self.inner.lock(); .ok_or_else(|| format!("Connection not in table: {:?}", descriptor))
let res = inner.conn_by_addr.remove(descriptor);
match res {
Some(v) => Ok(v),
None => Err(format!("Connection not in table: {:?}", descriptor)),
}
} }
} }

View File

@ -1,7 +1,7 @@
mod table_db; mod table_db;
mod user_secret;
use crate::xx::*; use crate::xx::*;
use data_encoding::BASE64URL_NOPAD; pub use user_secret::*;
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
mod wasm; mod wasm;
@ -11,44 +11,3 @@ pub use wasm::*;
mod native; mod native;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
pub use native::*; pub use native::*;
pub async fn save_user_secret(namespace: &str, key: &str, value: &[u8]) -> Result<bool, String> {
let mut s = BASE64URL_NOPAD.encode(value);
s.push('!');
save_user_secret_string(namespace, key, s.as_str()).await
}
pub async fn load_user_secret(namespace: &str, key: &str) -> Result<Option<Vec<u8>>, String> {
let mut s = match load_user_secret_string(namespace, key).await? {
Some(s) => s,
None => {
return Ok(None);
}
};
if s.pop() != Some('!') {
return Err("User secret is not a buffer".to_owned());
}
let mut bytes = Vec::<u8>::new();
let res = BASE64URL_NOPAD.decode_len(s.len());
match res {
Ok(l) => {
bytes.resize(l, 0u8);
}
Err(_) => {
return Err("Failed to decode".to_owned());
}
}
let res = BASE64URL_NOPAD.decode_mut(s.as_bytes(), &mut bytes);
match res {
Ok(_) => Ok(Some(bytes)),
Err(_) => Err("Failed to decode".to_owned()),
}
}
pub async fn remove_user_secret(namespace: &str, key: &str) -> Result<bool, String> {
remove_user_secret_string(namespace, key).await
}

View File

@ -1,54 +0,0 @@
use crate::intf::*;
use crate::network_manager::*;
use utils::async_peek_stream::*;
use async_std::net::*;
use async_tls::TlsAcceptor;
pub trait TcpProtocolHandler: TcpProtocolHandlerClone + Send + Sync {
fn on_accept(
&self,
stream: AsyncPeekStream,
peer_addr: SocketAddr,
) -> SendPinBoxFuture<Result<bool, String>>;
}
pub trait TcpProtocolHandlerClone {
fn clone_box(&self) -> Box<dyn TcpProtocolHandler>;
}
impl<T> TcpProtocolHandlerClone for T
where
T: 'static + TcpProtocolHandler + Clone,
{
fn clone_box(&self) -> Box<dyn TcpProtocolHandler> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn TcpProtocolHandler> {
fn clone(&self) -> Box<dyn TcpProtocolHandler> {
self.clone_box()
}
}
pub type NewTcpProtocolHandler =
dyn Fn(NetworkManager, bool, SocketAddr) -> Box<dyn TcpProtocolHandler> + Send;
/////////////////////////////////////////////////////////////////
#[derive(Clone)]
pub struct ListenerState {
pub protocol_handlers: Vec<Box<dyn TcpProtocolHandler + 'static>>,
pub tls_protocol_handlers: Vec<Box<dyn TcpProtocolHandler + 'static>>,
pub tls_acceptor: Option<TlsAcceptor>,
}
impl ListenerState {
pub fn new() -> Self {
Self {
protocol_handlers: Vec::new(),
tls_protocol_handlers: Vec::new(),
tls_acceptor: None,
}
}
}

View File

@ -1,15 +1,15 @@
mod listener_state;
mod network_tcp; mod network_tcp;
mod network_udp; mod network_udp;
mod protocol; mod protocol;
mod public_dialinfo_discovery; mod public_dialinfo_discovery;
mod start_protocols; mod start_protocols;
use crate::connection_manager::*;
use crate::intf::*; use crate::intf::*;
use crate::network_manager::*; use crate::network_manager::*;
use crate::routing_table::*; use crate::routing_table::*;
use crate::*; use crate::*;
use listener_state::*; use network_tcp::*;
use protocol::tcp::RawTcpProtocolHandler; use protocol::tcp::RawTcpProtocolHandler;
use protocol::udp::RawUdpProtocolHandler; use protocol::udp::RawUdpProtocolHandler;
use protocol::ws::WebsocketProtocolHandler; use protocol::ws::WebsocketProtocolHandler;
@ -136,10 +136,18 @@ impl Network {
this this
} }
fn network_manager(&self) -> NetworkManager {
self.inner.lock().network_manager.clone()
}
fn routing_table(&self) -> RoutingTable { fn routing_table(&self) -> RoutingTable {
self.inner.lock().routing_table.clone() self.inner.lock().routing_table.clone()
} }
fn connection_manager(&self) -> ConnectionManager {
self.inner.lock().network_manager.connection_manager()
}
fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> { fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> {
let cvec = certs(&mut BufReader::new(File::open(path)?)) let cvec = certs(&mut BufReader::new(File::open(path)?))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid TLS certificate"))?; .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid TLS certificate"))?;
@ -223,63 +231,28 @@ impl Network {
} }
} }
fn get_preferred_local_address( fn get_preferred_local_address(&self, dial_info: &DialInfo) -> SocketAddr {
&self, let inner = self.inner.lock();
local_port: u16,
peer_socket_addr: &SocketAddr, let local_port = match dial_info.protocol_type() {
) -> SocketAddr { ProtocolType::UDP => inner.udp_port,
match peer_socket_addr { ProtocolType::TCP => inner.tcp_port,
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), local_port), ProtocolType::WS => inner.ws_port,
SocketAddr::V6(_) => SocketAddr::new( ProtocolType::WSS => inner.wss_port,
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), };
local_port,
), match dial_info.address_type() {
AddressType::IPV4 => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), local_port),
AddressType::IPV6 => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), local_port),
} }
} }
//////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////
async fn send_data_to_existing_connection( // Send data to a dial info, unbound, using a new connection from a random port
&self, // This creates a short-lived connection in the case of connection-oriented protocols
descriptor: &ConnectionDescriptor, // for the purpose of sending this one message.
data: Vec<u8>, // This bypasses the connection table as it is not a 'node to node' connection.
) -> Result<Option<Vec<u8>>, String> {
match descriptor.protocol_type() {
ProtocolType::UDP => {
// send over the best udp socket we have bound since UDP is not connection oriented
let peer_socket_addr = descriptor.remote.to_socket_addr();
if let Some(ph) = self.find_best_udp_protocol_handler(
&peer_socket_addr,
&descriptor.local.map(|sa| sa.to_socket_addr()),
) {
ph.clone()
.send_message(data, peer_socket_addr)
.await
.map_err(logthru_net!())?;
// Data was consumed
return Ok(None);
}
}
ProtocolType::TCP | ProtocolType::WS | ProtocolType::WSS => {
// find an existing connection in the connection table if one exists
let network_manager = self.inner.lock().network_manager.clone();
if let Some(entry) = network_manager
.connection_table()
.get_connection(descriptor)
{
// connection exists, send over it
entry.conn.send(data).await.map_err(logthru_net!())?;
// Data was consumed
return Ok(None);
}
}
}
// connection or local socket didn't exist, we'll need to use dialinfo to create one
// Pass the data back out so we don't own it any more
Ok(Some(data))
}
pub async fn send_data_unbound_to_dial_info( pub async fn send_data_unbound_to_dial_info(
&self, &self,
dial_info: &DialInfo, dial_info: &DialInfo,
@ -305,61 +278,113 @@ impl Network {
} }
} }
// Initiate a new low-level protocol connection to a node
pub async fn connect_to_dial_info(
&self,
local_addr: Option<SocketAddr>,
dial_info: &DialInfo,
) -> Result<NetworkConnection, String> {
let connection_manager = self.connection_manager();
let peer_socket_addr = dial_info.to_socket_addr();
Ok(match &dial_info {
DialInfo::UDP(_) => {
panic!("Do not attempt to connect to UDP dial info")
}
DialInfo::TCP(_) => {
let local_addr =
self.get_preferred_local_address(self.inner.lock().tcp_port, &peer_socket_addr);
RawTcpProtocolHandler::connect(connection_manager, local_addr, dial_info)
.await
.map_err(logthru_net!())?
}
DialInfo::WS(_) => {
let local_addr =
self.get_preferred_local_address(self.inner.lock().ws_port, &peer_socket_addr);
WebsocketProtocolHandler::connect(connection_manager, local_addr, dial_info)
.await
.map_err(logthru_net!(error))?
}
DialInfo::WSS(_) => {
let local_addr =
self.get_preferred_local_address(self.inner.lock().wss_port, &peer_socket_addr);
WebsocketProtocolHandler::connect(connection_manager, local_addr, dial_info)
.await
.map_err(logthru_net!(error))?
}
})
}
async fn send_data_to_existing_connection(
&self,
descriptor: &ConnectionDescriptor,
data: Vec<u8>,
) -> Result<Option<Vec<u8>>, String> {
match descriptor.protocol_type() {
ProtocolType::UDP => {
// send over the best udp socket we have bound since UDP is not connection oriented
let peer_socket_addr = descriptor.remote.to_socket_addr();
if let Some(ph) = self.find_best_udp_protocol_handler(
&peer_socket_addr,
&descriptor.local.map(|sa| sa.to_socket_addr()),
) {
ph.clone()
.send_message(data, peer_socket_addr)
.await
.map_err(logthru_net!())?;
// Data was consumed
return Ok(None);
}
}
ProtocolType::TCP | ProtocolType::WS | ProtocolType::WSS => {
// find an existing connection in the connection table if one exists
if let Some(conn) = self.connection_manager().get_connection(descriptor) {
// connection exists, send over it
conn.send(data).await.map_err(logthru_net!())?;
// Data was consumed
return Ok(None);
}
}
}
// connection or local socket didn't exist, we'll need to use dialinfo to create one
// Pass the data back out so we don't own it any more
Ok(Some(data))
}
// Send data directly to a dial info, possibly without knowing which node it is going to
pub async fn send_data_to_dial_info( pub async fn send_data_to_dial_info(
&self, &self,
dial_info: &DialInfo, dial_info: &DialInfo,
data: Vec<u8>, data: Vec<u8>,
) -> Result<(), String> { ) -> Result<(), String> {
let network_manager = self.inner.lock().network_manager.clone(); // Handle connectionless protocol
if dial_info.protocol_type() == ProtocolType::UDP {
let peer_socket_addr = dial_info.to_socket_addr();
if let Some(ph) = self.find_best_udp_protocol_handler(&peer_socket_addr, &None) {
return ph
.send_message(data, peer_socket_addr)
.await
.map_err(logthru_net!());
}
return Err("no appropriate UDP protocol handler for dial_info".to_owned())
.map_err(logthru_net!(error));
}
let conn = match &dial_info { // Handle connection-oriented protocols
DialInfo::UDP(_) => { let conn = self.connect_to_dial_info(dial_info).await?;
let peer_socket_addr = dial_info.to_socket_addr();
if let Some(ph) = self.find_best_udp_protocol_handler(&peer_socket_addr, &None) {
return ph
.send_message(data, peer_socket_addr)
.await
.map_err(logthru_net!());
} else {
return Err("no appropriate UDP protocol handler for dial_info".to_owned())
.map_err(logthru_net!(error));
}
}
DialInfo::TCP(_) => {
let peer_socket_addr = dial_info.to_socket_addr();
let local_addr =
self.get_preferred_local_address(self.inner.lock().tcp_port, &peer_socket_addr);
RawTcpProtocolHandler::connect(network_manager, local_addr, peer_socket_addr)
.await
.map_err(logthru_net!())?
}
DialInfo::WS(_) => {
let peer_socket_addr = dial_info.to_socket_addr();
let local_addr =
self.get_preferred_local_address(self.inner.lock().ws_port, &peer_socket_addr);
WebsocketProtocolHandler::connect(network_manager, local_addr, dial_info)
.await
.map_err(logthru_net!(error))?
}
DialInfo::WSS(_) => {
let peer_socket_addr = dial_info.to_socket_addr();
let local_addr =
self.get_preferred_local_address(self.inner.lock().wss_port, &peer_socket_addr);
WebsocketProtocolHandler::connect(network_manager, local_addr, dial_info)
.await
.map_err(logthru_net!(error))?
}
};
conn.send(data).await.map_err(logthru_net!(error)) conn.send(data).await.map_err(logthru_net!(error))
} }
// Send data to node
// We may not have dial info for a node, but have an existing connection for it
// because an inbound connection happened first, and no FindNodeQ has happened to that
// node yet to discover its dial info. The existing connection should be tried first
// in this case.
pub async fn send_data(&self, node_ref: NodeRef, data: Vec<u8>) -> Result<(), String> { pub async fn send_data(&self, node_ref: NodeRef, data: Vec<u8>) -> Result<(), String> {
let dial_info = node_ref.best_dial_info();
let descriptor = node_ref.last_connection();
// First try to send data to the last socket we've seen this peer on // First try to send data to the last socket we've seen this peer on
let di_data = if let Some(descriptor) = descriptor { let data = if let Some(descriptor) = node_ref.last_connection() {
match self match self
.clone() .clone()
.send_data_to_existing_connection(&descriptor, data) .send_data_to_existing_connection(&descriptor, data)
@ -375,11 +400,30 @@ impl Network {
}; };
// If that fails, try to make a connection or reach out to the peer via its dial info // If that fails, try to make a connection or reach out to the peer via its dial info
if let Some(di) = dial_info { let dial_info = node_ref
self.clone().send_data_to_dial_info(&di, di_data).await .best_dial_info()
} else { .ok_or_else(|| "couldn't send data, no dial info or peer address".to_owned())?;
Err("couldn't send data, no dial info or peer address".to_owned())
// Handle connectionless protocol
if dial_info.protocol_type() == ProtocolType::UDP {
let peer_socket_addr = dial_info.to_socket_addr();
if let Some(ph) = self.find_best_udp_protocol_handler(&peer_socket_addr, &None) {
return ph
.send_message(data, peer_socket_addr)
.await
.map_err(logthru_net!());
}
return Err("no appropriate UDP protocol handler for dial_info".to_owned())
.map_err(logthru_net!(error));
} }
// Handle connection-oriented protocols
let local_addr = self.get_preferred_local_address(&dial_info);
let conn = self
.connection_manager()
.get_or_create_connection(dial_info, Some(local_addr)); xxx implement this and pass thru to NetworkConnection::connect
conn.send(data).await.map_err(logthru_net!(error))
} }
///////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////
@ -437,9 +481,10 @@ impl Network {
pub async fn shutdown(&self) { pub async fn shutdown(&self) {
info!("stopping network"); info!("stopping network");
let network_manager = self.network_manager();
let routing_table = self.routing_table();
// Reset state // Reset state
let network_manager = self.inner.lock().network_manager.clone();
let routing_table = network_manager.routing_table();
// Drop all dial info // Drop all dial info
routing_table.clear_dial_info_details(); routing_table.clear_dial_info_details();
@ -453,8 +498,6 @@ impl Network {
////////////////////////////////////////// //////////////////////////////////////////
pub fn get_network_class(&self) -> Option<NetworkClass> { pub fn get_network_class(&self) -> Option<NetworkClass> {
let inner = self.inner.lock(); let inner = self.inner.lock();
let routing_table = inner.routing_table.clone();
if !inner.network_started { if !inner.network_started {
return None; return None;
} }
@ -466,7 +509,7 @@ impl Network {
// Go through our global dialinfo and see what our best network class is // Go through our global dialinfo and see what our best network class is
let mut network_class = NetworkClass::Invalid; let mut network_class = NetworkClass::Invalid;
for did in routing_table.global_dial_info_details() { for did in inner.routing_table.global_dial_info_details() {
if let Some(nc) = did.network_class { if let Some(nc) = did.network_class {
if nc < network_class { if nc < network_class {
network_class = nc; network_class = nc;
@ -488,7 +531,7 @@ impl Network {
) = { ) = {
let inner = self.inner.lock(); let inner = self.inner.lock();
( (
inner.network_manager.routing_table(), inner.routing_table.clone(),
inner.protocol_config.unwrap_or_default(), inner.protocol_config.unwrap_or_default(),
inner.udp_static_public_dialinfo, inner.udp_static_public_dialinfo,
inner.tcp_static_public_dialinfo, inner.tcp_static_public_dialinfo,

View File

@ -1,6 +1,31 @@
use super::*; use super::*;
use crate::connection_manager::*;
use crate::intf::*;
use utils::clone_stream::*; use utils::clone_stream::*;
use async_tls::TlsAcceptor;
/////////////////////////////////////////////////////////////////
#[derive(Clone)]
pub struct ListenerState {
pub protocol_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
pub tls_protocol_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
pub tls_acceptor: Option<TlsAcceptor>,
}
impl ListenerState {
pub fn new() -> Self {
Self {
protocol_handlers: Vec::new(),
tls_protocol_handlers: Vec::new(),
tls_acceptor: None,
}
}
}
/////////////////////////////////////////////////////////////////
impl Network { impl Network {
fn get_or_create_tls_acceptor(&self) -> Result<TlsAcceptor, String> { fn get_or_create_tls_acceptor(&self) -> Result<TlsAcceptor, String> {
if let Some(ts) = self.inner.lock().tls_acceptor.as_ref() { if let Some(ts) = self.inner.lock().tls_acceptor.as_ref() {
@ -20,46 +45,44 @@ impl Network {
tls_acceptor: &TlsAcceptor, tls_acceptor: &TlsAcceptor,
stream: AsyncPeekStream, stream: AsyncPeekStream,
addr: SocketAddr, addr: SocketAddr,
protocol_handlers: &[Box<dyn TcpProtocolHandler>], protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
tls_connection_initial_timeout: u64, tls_connection_initial_timeout: u64,
) { ) -> Result<Option<NetworkConnection>, String> {
match tls_acceptor.accept(stream).await { let ts = tls_acceptor
Ok(ts) => { .accept(stream)
let ps = AsyncPeekStream::new(CloneStream::new(ts)); .await
let mut first_packet = [0u8; PEEK_DETECT_LEN]; .map_err(map_to_string)
.map_err(logthru_net!(debug "TLS stream failed handshake"))?;
let ps = AsyncPeekStream::new(CloneStream::new(ts));
let mut first_packet = [0u8; PEEK_DETECT_LEN];
// Try the handlers but first get a chunk of data for them to process // Try the handlers but first get a chunk of data for them to process
// Don't waste more than N seconds getting it though, in case someone // Don't waste more than N seconds getting it though, in case someone
// is trying to DoS us with a bunch of connections or something // is trying to DoS us with a bunch of connections or something
// read a chunk of the stream // read a chunk of the stream
match io::timeout( io::timeout(
Duration::from_micros(tls_connection_initial_timeout), Duration::from_micros(tls_connection_initial_timeout),
ps.peek_exact(&mut first_packet), ps.peek_exact(&mut first_packet),
) )
.await .await
{ .map_err(map_to_string)
Ok(()) => (), .map_err(logthru_net!())?;
Err(_) => return,
} self.try_handlers(ps, addr, protocol_handlers).await
self.clone().try_handlers(ps, addr, protocol_handlers).await;
}
Err(e) => {
debug!("TLS stream failed handshake: {}", e);
}
}
} }
async fn try_handlers( async fn try_handlers(
&self, &self,
stream: AsyncPeekStream, stream: AsyncPeekStream,
addr: SocketAddr, addr: SocketAddr,
protocol_handlers: &[Box<dyn TcpProtocolHandler>], protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
) { ) -> Result<Option<NetworkConnection>, String> {
for ah in protocol_handlers.iter() { for ah in protocol_handlers.iter() {
if ah.on_accept(stream.clone(), addr).await == Ok(true) { if let Some(nc) = ah.on_accept(stream.clone(), addr).await? {
return; return Ok(Some(nc));
} }
} }
Ok(None)
} }
async fn spawn_socket_listener(&self, addr: SocketAddr) -> Result<(), String> { async fn spawn_socket_listener(&self, addr: SocketAddr) -> Result<(), String> {
@ -73,7 +96,7 @@ impl Network {
}; };
// Create a reusable socket with no linger time, and no delay // Create a reusable socket with no linger time, and no delay
let socket = new_shared_tcp_socket(addr)?; let socket = new_bound_shared_tcp_socket(addr)?;
// Listen on the socket // Listen on the socket
socket socket
.listen(128) .listen(128)
@ -94,6 +117,7 @@ impl Network {
// Spawn the socket task // Spawn the socket task
let this = self.clone(); let this = self.clone();
let connection_manager = self.connection_manager();
//////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////
let jh = spawn(async move { let jh = spawn(async move {
@ -104,10 +128,7 @@ impl Network {
.for_each_concurrent(None, |tcp_stream| async { .for_each_concurrent(None, |tcp_stream| async {
let tcp_stream = tcp_stream.unwrap(); let tcp_stream = tcp_stream.unwrap();
let listener_state = listener_state.clone(); let listener_state = listener_state.clone();
// match tcp_stream.set_nodelay(true) { let connection_manager = connection_manager.clone();
// Ok(_) => (),
// _ => continue,
// };
// Limit the number of connections from the same IP address // Limit the number of connections from the same IP address
// and the number of total connections // and the number of total connections
@ -129,7 +150,6 @@ impl Network {
let mut first_packet = [0u8; PEEK_DETECT_LEN]; let mut first_packet = [0u8; PEEK_DETECT_LEN];
// read a chunk of the stream // read a chunk of the stream
trace!("reading chunk");
if io::timeout( if io::timeout(
Duration::from_micros(connection_initial_timeout), Duration::from_micros(connection_initial_timeout),
ps.peek_exact(&mut first_packet), ps.peek_exact(&mut first_packet),
@ -143,26 +163,35 @@ impl Network {
} }
// Run accept handlers on accepted stream // Run accept handlers on accepted stream
trace!("packet ready");
// Check is this could be TLS // Check is this could be TLS
let ls = listener_state.read().clone(); let ls = listener_state.read().clone();
if ls.tls_acceptor.is_some() && first_packet[0] == 0x16 { let conn = if ls.tls_acceptor.is_some() && first_packet[0] == 0x16 {
trace!("trying TLS"); this.try_tls_handlers(
this.clone() ls.tls_acceptor.as_ref().unwrap(),
.try_tls_handlers( ps,
ls.tls_acceptor.as_ref().unwrap(), addr,
ps, &ls.tls_protocol_handlers,
addr, tls_connection_initial_timeout,
&ls.tls_protocol_handlers, )
tls_connection_initial_timeout, .await
)
.await;
} else { } else {
trace!("not TLS"); this.try_handlers(ps, addr, &ls.protocol_handlers).await
this.clone() };
.try_handlers(ps, addr, &ls.protocol_handlers) let conn = match conn {
.await; Ok(Some(c)) => c,
} Ok(None) => {
// No protocol handlers matched? drop it.
return;
}
Err(_) => {
// Failed to negotiate connection? drop it.
return;
}
};
// Register the new connection in the connection manager
connection_manager.on_new_connection(conn).await;
}) })
.await; .await;
trace!("exited incoming loop for {}", addr); trace!("exited incoming loop for {}", addr);
@ -189,7 +218,7 @@ impl Network {
&self, &self,
address: String, address: String,
is_tls: bool, is_tls: bool,
new_tcp_protocol_handler: Box<NewTcpProtocolHandler>, new_protocol_accept_handler: Box<NewProtocolAcceptHandler>,
) -> Result<Vec<SocketAddress>, String> { ) -> Result<Vec<SocketAddress>, String> {
let mut out = Vec::<SocketAddress>::new(); let mut out = Vec::<SocketAddress>::new();
// convert to socketaddrs // convert to socketaddrs
@ -218,17 +247,19 @@ impl Network {
} }
ls.write() ls.write()
.tls_protocol_handlers .tls_protocol_handlers
.push(new_tcp_protocol_handler( .push(new_protocol_accept_handler(
self.inner.lock().network_manager.clone(), self.connection_manager(),
true, true,
addr, addr,
)); ));
} else { } else {
ls.write().protocol_handlers.push(new_tcp_protocol_handler( ls.write()
self.inner.lock().network_manager.clone(), .protocol_handlers
false, .push(new_protocol_accept_handler(
addr, self.connection_manager(),
)); false,
addr,
));
} }
// Return local dial infos we listen on // Return local dial infos we listen on

View File

@ -68,7 +68,7 @@ impl Network {
let mut port = inner.udp_port; let mut port = inner.udp_port;
// v4 // v4
let socket_addr_v4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port); let socket_addr_v4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
if let Ok(socket) = new_shared_udp_socket(socket_addr_v4) { if let Ok(socket) = new_bound_shared_udp_socket(socket_addr_v4) {
// Pull the port if we randomly bound, so v6 can be on the same port // Pull the port if we randomly bound, so v6 can be on the same port
port = socket port = socket
.local_addr() .local_addr()
@ -91,7 +91,7 @@ impl Network {
//v6 //v6
let socket_addr_v6 = let socket_addr_v6 =
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), port); SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), port);
if let Ok(socket) = new_shared_udp_socket(socket_addr_v6) { if let Ok(socket) = new_bound_shared_udp_socket(socket_addr_v6) {
// Make an async UdpSocket from the socket2 socket // Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into(); let std_udp_socket: std::net::UdpSocket = socket.into();
let udp_socket = UdpSocket::from(std_udp_socket); let udp_socket = UdpSocket::from(std_udp_socket);
@ -111,7 +111,7 @@ impl Network {
log_net!("create_udp_inbound_socket on {:?}", &addr); log_net!("create_udp_inbound_socket on {:?}", &addr);
// Create a reusable socket // Create a reusable socket
let socket = new_shared_udp_socket(addr)?; let socket = new_bound_shared_udp_socket(addr)?;
// Make an async UdpSocket from the socket2 socket // Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into(); let std_udp_socket: std::net::UdpSocket = socket.into();

View File

@ -3,7 +3,6 @@ pub mod udp;
pub mod wrtc; pub mod wrtc;
pub mod ws; pub mod ws;
use super::listener_state::*;
use crate::xx::*; use crate::xx::*;
use crate::*; use crate::*;
use socket2::{Domain, Protocol, Socket, Type}; use socket2::{Domain, Protocol, Socket, Type};
@ -12,11 +11,17 @@ use socket2::{Domain, Protocol, Socket, Type};
pub struct DummyNetworkConnection {} pub struct DummyNetworkConnection {}
impl DummyNetworkConnection { impl DummyNetworkConnection {
pub fn send(&self, _message: Vec<u8>) -> SystemPinBoxFuture<Result<(), String>> { pub fn connection_descriptor(&self) -> ConnectionDescriptor {
Box::pin(async { Ok(()) }) ConnectionDescriptor::new_no_local(PeerAddress::new(
SocketAddress::default(),
ProtocolType::UDP,
))
} }
pub fn recv(&self) -> SystemPinBoxFuture<Result<Vec<u8>, String>> { pub async fn send(&self, _message: Vec<u8>) -> Result<(), String> {
Box::pin(async { Ok(Vec::new()) }) Ok(())
}
pub async fn recv(&self) -> Result<Vec<u8>, String> {
Ok(Vec::new())
} }
} }
@ -31,28 +36,53 @@ pub enum NetworkConnection {
} }
impl NetworkConnection { impl NetworkConnection {
pub fn send(&self, message: Vec<u8>) -> SystemPinBoxFuture<Result<(), String>> { pub async fn connect(
match self { local_address: Option<SocketAddr>,
Self::Dummy(d) => d.send(message), dial_info: DialInfo,
Self::RawTcp(t) => t.send(message), ) -> Result<NetworkConnection, String> {
Self::WsAccepted(w) => w.send(message), match dial_info.protocol_type() {
Self::Ws(w) => w.send(message), ProtocolType::UDP => {
Self::Wss(w) => w.send(message), panic!("Should not connect to UDP dialinfo");
}
ProtocolType::TCP => {
tcp::RawTcpProtocolHandler::connect(local_address, dial_info).await
}
ProtocolType::WS | ProtocolType::WSS => {
ws::WebsocketProtocolHandler::connect(local_address, dial_info).await
}
} }
} }
pub fn recv(&self) -> SystemPinBoxFuture<Result<Vec<u8>, String>> {
pub fn connection_descriptor(&self) -> ConnectionDescriptor {
match self { match self {
Self::Dummy(d) => d.recv(), Self::Dummy(d) => d.connection_descriptor(),
Self::RawTcp(t) => t.recv(), Self::RawTcp(t) => t.connection_descriptor(),
Self::WsAccepted(w) => w.recv(), Self::WsAccepted(w) => w.connection_descriptor(),
Self::Ws(w) => w.recv(), Self::Ws(w) => w.connection_descriptor(),
Self::Wss(w) => w.recv(), Self::Wss(w) => w.connection_descriptor(),
}
}
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
match self {
Self::Dummy(d) => d.send(message).await,
Self::RawTcp(t) => t.send(message).await,
Self::WsAccepted(w) => w.send(message).await,
Self::Ws(w) => w.send(message).await,
Self::Wss(w) => w.send(message).await,
}
}
pub async fn recv(&self) -> Result<Vec<u8>, String> {
match self {
Self::Dummy(d) => d.recv().await,
Self::RawTcp(t) => t.recv().await,
Self::WsAccepted(w) => w.recv().await,
Self::Ws(w) => w.recv().await,
Self::Wss(w) => w.recv().await,
} }
} }
} }
pub fn new_shared_udp_socket(local_address: SocketAddr) -> Result<socket2::Socket, String> { pub fn new_unbound_shared_udp_socket(domain: Domain) -> Result<socket2::Socket, String> {
let domain = Domain::for_address(local_address);
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP)) let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))
.map_err(|e| format!("Couldn't create UDP socket: {}", e))?; .map_err(|e| format!("Couldn't create UDP socket: {}", e))?;
@ -66,7 +96,12 @@ pub fn new_shared_udp_socket(local_address: SocketAddr) -> Result<socket2::Socke
} }
} }
} }
Ok(socket)
}
pub fn new_bound_shared_udp_socket(local_address: SocketAddr) -> Result<socket2::Socket, String> {
let domain = Domain::for_address(local_address);
let socket = new_unbound_shared_udp_socket(domain)?;
let socket2_addr = socket2::SockAddr::from(local_address); let socket2_addr = socket2::SockAddr::from(local_address);
socket socket
.bind(&socket2_addr) .bind(&socket2_addr)
@ -77,8 +112,7 @@ pub fn new_shared_udp_socket(local_address: SocketAddr) -> Result<socket2::Socke
Ok(socket) Ok(socket)
} }
pub fn new_shared_tcp_socket(local_address: SocketAddr) -> Result<socket2::Socket, String> { pub fn new_unbound_shared_tcp_socket(domain: Domain) -> Result<socket2::Socket, String> {
let domain = Domain::for_address(local_address);
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)) let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
.map_err(map_to_string) .map_err(map_to_string)
.map_err(logthru_net!("failed to create TCP socket"))?; .map_err(logthru_net!("failed to create TCP socket"))?;
@ -98,13 +132,18 @@ pub fn new_shared_tcp_socket(local_address: SocketAddr) -> Result<socket2::Socke
} }
} }
} }
Ok(socket)
}
pub fn new_bound_shared_tcp_socket(local_address: SocketAddr) -> Result<socket2::Socket, String> {
let domain = Domain::for_address(local_address);
let socket = new_unbound_shared_tcp_socket(domain)?;
let socket2_addr = socket2::SockAddr::from(local_address); let socket2_addr = socket2::SockAddr::from(local_address);
socket socket
.bind(&socket2_addr) .bind(&socket2_addr)
.map_err(|e| format!("failed to bind TCP socket: {}", e))?; .map_err(|e| format!("failed to bind TCP socket: {}", e))?;
log_net!("created shared tcp socket on {:?}", &local_address);
Ok(socket) Ok(socket)
} }

View File

@ -1,7 +1,8 @@
use super::*; use super::*;
use crate::connection_manager::*;
use crate::intf::native::utils::async_peek_stream::*; use crate::intf::native::utils::async_peek_stream::*;
use crate::intf::*; use crate::intf::*;
use crate::network_manager::{NetworkManager, MAX_MESSAGE_SIZE}; use crate::network_manager::MAX_MESSAGE_SIZE;
use crate::*; use crate::*;
use async_std::net::*; use async_std::net::*;
use async_std::prelude::*; use async_std::prelude::*;
@ -15,11 +16,14 @@ struct RawTcpNetworkConnectionInner {
#[derive(Clone)] #[derive(Clone)]
pub struct RawTcpNetworkConnection { pub struct RawTcpNetworkConnection {
inner: Arc<AsyncMutex<RawTcpNetworkConnectionInner>>, inner: Arc<AsyncMutex<RawTcpNetworkConnectionInner>>,
connection_descriptor: ConnectionDescriptor,
} }
impl fmt::Debug for RawTcpNetworkConnection { impl fmt::Debug for RawTcpNetworkConnection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", std::any::type_name::<Self>()) f.debug_struct("RawTCPNetworkConnection")
.field("connection_descriptor", &self.connection_descriptor)
.finish()
} }
} }
@ -36,68 +40,63 @@ impl RawTcpNetworkConnection {
RawTcpNetworkConnectionInner { stream } RawTcpNetworkConnectionInner { stream }
} }
pub fn new(stream: AsyncPeekStream) -> Self { pub fn new(stream: AsyncPeekStream, connection_descriptor: ConnectionDescriptor) -> Self {
Self { Self {
inner: Arc::new(AsyncMutex::new(Self::new_inner(stream))), inner: Arc::new(AsyncMutex::new(Self::new_inner(stream))),
connection_descriptor,
} }
} }
}
impl RawTcpNetworkConnection { pub fn connection_descriptor(&self) -> ConnectionDescriptor {
pub fn send(&self, message: Vec<u8>) -> SystemPinBoxFuture<Result<(), String>> { self.connection_descriptor.clone()
let inner = self.inner.clone();
Box::pin(async move {
if message.len() > MAX_MESSAGE_SIZE {
return Err("sending too large TCP message".to_owned());
}
let len = message.len() as u16;
let header = [b'V', b'L', len as u8, (len >> 8) as u8];
let mut inner = inner.lock().await;
inner
.stream
.write_all(&header)
.await
.map_err(map_to_string)
.map_err(logthru_net!())?;
inner
.stream
.write_all(&message)
.await
.map_err(map_to_string)
.map_err(logthru_net!())
})
} }
pub fn recv(&self) -> SystemPinBoxFuture<Result<Vec<u8>, String>> { pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
let inner = self.inner.clone(); if message.len() > MAX_MESSAGE_SIZE {
return Err("sending too large TCP message".to_owned());
}
let len = message.len() as u16;
let header = [b'V', b'L', len as u8, (len >> 8) as u8];
Box::pin(async move { let mut inner = self.inner.lock().await;
let mut header = [0u8; 4]; inner
let mut inner = inner.lock().await; .stream
.write_all(&header)
.await
.map_err(map_to_string)
.map_err(logthru_net!())?;
inner
.stream
.write_all(&message)
.await
.map_err(map_to_string)
.map_err(logthru_net!())
}
inner pub async fn recv(&self) -> Result<Vec<u8>, String> {
.stream let mut header = [0u8; 4];
.read_exact(&mut header) let mut inner = self.inner.lock().await;
.await
.map_err(|e| format!("TCP recv error: {}", e))?;
if header[0] != b'V' || header[1] != b'L' {
return Err("received invalid TCP frame header".to_owned());
}
let len = ((header[3] as usize) << 8) | (header[2] as usize);
if len > MAX_MESSAGE_SIZE {
return Err("received too large TCP frame".to_owned());
}
let mut out: Vec<u8> = vec![0u8; len]; inner
inner .stream
.stream .read_exact(&mut header)
.read_exact(&mut out) .await
.await .map_err(|e| format!("TCP recv error: {}", e))?;
.map_err(map_to_string)?; if header[0] != b'V' || header[1] != b'L' {
Ok(out) return Err("received invalid TCP frame header".to_owned());
}) }
let len = ((header[3] as usize) << 8) | (header[2] as usize);
if len > MAX_MESSAGE_SIZE {
return Err("received too large TCP frame".to_owned());
}
let mut out: Vec<u8> = vec![0u8; len];
inner
.stream
.read_exact(&mut out)
.await
.map_err(map_to_string)?;
Ok(out)
} }
} }
@ -105,32 +104,35 @@ impl RawTcpNetworkConnection {
/// ///
struct RawTcpProtocolHandlerInner { struct RawTcpProtocolHandlerInner {
network_manager: NetworkManager, connection_manager: ConnectionManager,
local_address: SocketAddr, local_address: SocketAddr,
} }
#[derive(Clone)] #[derive(Clone)]
pub struct RawTcpProtocolHandler pub struct RawTcpProtocolHandler
where where
Self: TcpProtocolHandler, Self: ProtocolAcceptHandler,
{ {
inner: Arc<Mutex<RawTcpProtocolHandlerInner>>, inner: Arc<Mutex<RawTcpProtocolHandlerInner>>,
} }
impl RawTcpProtocolHandler { impl RawTcpProtocolHandler {
fn new_inner( fn new_inner(
network_manager: NetworkManager, connection_manager: ConnectionManager,
local_address: SocketAddr, local_address: SocketAddr,
) -> RawTcpProtocolHandlerInner { ) -> RawTcpProtocolHandlerInner {
RawTcpProtocolHandlerInner { RawTcpProtocolHandlerInner {
network_manager, connection_manager,
local_address, local_address,
} }
} }
pub fn new(network_manager: NetworkManager, local_address: SocketAddr) -> Self { pub fn new(connection_manager: ConnectionManager, local_address: SocketAddr) -> Self {
Self { Self {
inner: Arc::new(Mutex::new(Self::new_inner(network_manager, local_address))), inner: Arc::new(Mutex::new(Self::new_inner(
connection_manager,
local_address,
))),
} }
} }
@ -138,7 +140,7 @@ impl RawTcpProtocolHandler {
self, self,
stream: AsyncPeekStream, stream: AsyncPeekStream,
socket_addr: SocketAddr, socket_addr: SocketAddr,
) -> Result<bool, String> { ) -> Result<Option<NetworkConnection>, String> {
let mut peekbuf: [u8; PEEK_DETECT_LEN] = [0u8; PEEK_DETECT_LEN]; let mut peekbuf: [u8; PEEK_DETECT_LEN] = [0u8; PEEK_DETECT_LEN];
let peeklen = stream let peeklen = stream
.peek(&mut peekbuf) .peek(&mut peekbuf)
@ -147,51 +149,47 @@ impl RawTcpProtocolHandler {
.map_err(logthru_net!("could not peek tcp stream"))?; .map_err(logthru_net!("could not peek tcp stream"))?;
assert_eq!(peeklen, PEEK_DETECT_LEN); assert_eq!(peeklen, PEEK_DETECT_LEN);
let conn = NetworkConnection::RawTcp(RawTcpNetworkConnection::new(stream));
let peer_addr = PeerAddress::new( let peer_addr = PeerAddress::new(
SocketAddress::from_socket_addr(socket_addr), SocketAddress::from_socket_addr(socket_addr),
ProtocolType::TCP, ProtocolType::TCP,
); );
let (network_manager, local_address) = { let (network_manager, local_address) = {
let inner = self.inner.lock(); let inner = self.inner.lock();
(inner.network_manager.clone(), inner.local_address) (inner.connection_manager.clone(), inner.local_address)
}; };
network_manager let conn = NetworkConnection::RawTcp(RawTcpNetworkConnection::new(
.on_new_connection( stream,
ConnectionDescriptor::new( ConnectionDescriptor::new(peer_addr, SocketAddress::from_socket_addr(local_address)),
peer_addr, ));
SocketAddress::from_socket_addr(local_address),
), Ok(Some(conn))
conn,
)
.await?;
Ok(true)
} }
pub async fn connect( pub async fn connect(
network_manager: NetworkManager, local_address: Option<SocketAddr>,
local_address: SocketAddr, dial_info: DialInfo,
remote_socket_addr: SocketAddr,
) -> Result<NetworkConnection, String> { ) -> Result<NetworkConnection, String> {
// Get remote socket address to connect to
let remote_socket_addr = dial_info.to_socket_addr();
// Make a shared socket // Make a shared socket
let socket = new_shared_tcp_socket(local_address)?; let socket = match local_address {
Some(a) => new_bound_shared_tcp_socket(a)?,
None => new_unbound_shared_tcp_socket(Domain::for_address(remote_socket_addr))?,
};
// Connect to the remote address // Connect to the remote address
let remote_socket2_addr = socket2::SockAddr::from(remote_socket_addr); let remote_socket2_addr = socket2::SockAddr::from(remote_socket_addr);
socket socket
.connect(&remote_socket2_addr) .connect(&remote_socket2_addr)
.map_err(map_to_string) .map_err(map_to_string)
.map_err(logthru_net!(error "local_address={} remote_addr={}", local_address, remote_socket_addr))?; .map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?;
log_net!(
"tcp connect successful: local_address={} remote_addr={}",
local_address,
remote_socket_addr
);
let std_stream: std::net::TcpStream = socket.into(); let std_stream: std::net::TcpStream = socket.into();
let ts = TcpStream::from(std_stream); let ts = TcpStream::from(std_stream);
// See what local address we ended up with and turn this into a stream // See what local address we ended up with and turn this into a stream
let local_address = ts let actual_local_address = ts
.local_addr() .local_addr()
.map_err(map_to_string) .map_err(map_to_string)
.map_err(logthru_net!("could not get local address from TCP stream"))?; .map_err(logthru_net!("could not get local address from TCP stream"))?;
@ -202,16 +200,13 @@ impl RawTcpProtocolHandler {
); );
// Wrap the stream in a network connection and register it // Wrap the stream in a network connection and register it
let conn = NetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps)); let conn = NetworkConnection::RawTcp(RawTcpNetworkConnection::new(
network_manager ps,
.on_new_connection( ConnectionDescriptor {
ConnectionDescriptor::new( local: Some(SocketAddress::from_socket_addr(actual_local_address)),
peer_addr, remote: dial_info.to_peer_address(),
SocketAddress::from_socket_addr(local_address), },
), ));
conn.clone(),
)
.await?;
Ok(conn) Ok(conn)
} }
@ -235,12 +230,12 @@ impl RawTcpProtocolHandler {
} }
} }
impl TcpProtocolHandler for RawTcpProtocolHandler { impl ProtocolAcceptHandler for RawTcpProtocolHandler {
fn on_accept( fn on_accept(
&self, &self,
stream: AsyncPeekStream, stream: AsyncPeekStream,
peer_addr: SocketAddr, peer_addr: SocketAddr,
) -> SendPinBoxFuture<Result<bool, String>> { ) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>> {
Box::pin(self.clone().on_accept_async(stream, peer_addr)) Box::pin(self.clone().on_accept_async(stream, peer_addr))
} }
} }

View File

@ -1,7 +1,8 @@
use super::*; use super::*;
use crate::connection_manager::*;
use crate::intf::native::utils::async_peek_stream::*; use crate::intf::native::utils::async_peek_stream::*;
use crate::intf::*; use crate::intf::*;
use crate::network_manager::{NetworkManager, MAX_MESSAGE_SIZE}; use crate::network_manager::MAX_MESSAGE_SIZE;
use crate::*; use crate::*;
use async_std::io; use async_std::io;
use async_std::net::*; use async_std::net::*;
@ -32,6 +33,7 @@ where
T: io::Read + io::Write + Send + Unpin + 'static, T: io::Read + io::Write + Send + Unpin + 'static,
{ {
tls: bool, tls: bool,
connection_descriptor: ConnectionDescriptor,
inner: Arc<AsyncMutex<WebSocketNetworkConnectionInner<T>>>, inner: Arc<AsyncMutex<WebSocketNetworkConnectionInner<T>>>,
} }
@ -42,6 +44,7 @@ where
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
tls: self.tls, tls: self.tls,
connection_descriptor: self.connection_descriptor.clone(),
inner: self.inner.clone(), inner: self.inner.clone(),
} }
} }
@ -61,7 +64,9 @@ where
T: io::Read + io::Write + Send + Unpin + 'static, T: io::Read + io::Write + Send + Unpin + 'static,
{ {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.tls == other.tls && Arc::as_ptr(&self.inner) == Arc::as_ptr(&other.inner) self.tls == other.tls
&& self.connection_descriptor == other.connection_descriptor
&& Arc::as_ptr(&self.inner) == Arc::as_ptr(&other.inner)
} }
} }
@ -71,56 +76,56 @@ impl<T> WebsocketNetworkConnection<T>
where where
T: io::Read + io::Write + Send + Unpin + 'static, T: io::Read + io::Write + Send + Unpin + 'static,
{ {
pub fn new(tls: bool, ws_stream: WebSocketStream<T>) -> Self { pub fn new(
tls: bool,
connection_descriptor: ConnectionDescriptor,
ws_stream: WebSocketStream<T>,
) -> Self {
Self { Self {
tls, tls,
connection_descriptor,
inner: Arc::new(AsyncMutex::new(WebSocketNetworkConnectionInner { inner: Arc::new(AsyncMutex::new(WebSocketNetworkConnectionInner {
ws_stream, ws_stream,
})), })),
} }
} }
pub fn send(&self, message: Vec<u8>) -> SystemPinBoxFuture<Result<(), String>> { pub fn connection_descriptor(&self) -> ConnectionDescriptor {
let inner = self.inner.clone(); self.connection_descriptor.clone()
Box::pin(async move {
if message.len() > MAX_MESSAGE_SIZE {
return Err("received too large WS message".to_owned());
}
let mut inner = inner.lock().await;
inner
.ws_stream
.send(Message::binary(message))
.await
.map_err(map_to_string)
.map_err(logthru_net!(error "failed to send websocket message"))
})
} }
pub fn recv(&self) -> SystemPinBoxFuture<Result<Vec<u8>, String>> {
let inner = self.inner.clone();
Box::pin(async move { pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
let mut inner = inner.lock().await; if message.len() > MAX_MESSAGE_SIZE {
return Err("received too large WS message".to_owned());
}
let mut inner = self.inner.lock().await;
inner
.ws_stream
.send(Message::binary(message))
.await
.map_err(map_to_string)
.map_err(logthru_net!(error "failed to send websocket message"))
}
pub async fn recv(&self) -> Result<Vec<u8>, String> {
let mut inner = self.inner.lock().await;
let out = match inner.ws_stream.next().await { let out = match inner.ws_stream.next().await {
Some(Ok(Message::Binary(v))) => v, Some(Ok(Message::Binary(v))) => v,
Some(Ok(_)) => { Some(Ok(_)) => {
return Err("Unexpected WS message type".to_owned()) return Err("Unexpected WS message type".to_owned()).map_err(logthru_net!(error));
.map_err(logthru_net!(error));
}
Some(Err(e)) => {
return Err(e.to_string()).map_err(logthru_net!(error));
}
None => {
return Err("WS stream closed".to_owned()).map_err(logthru_net!());
}
};
if out.len() > MAX_MESSAGE_SIZE {
Err("sending too large WS message".to_owned()).map_err(logthru_net!(error))
} else {
Ok(out)
} }
}) Some(Err(e)) => {
return Err(e.to_string()).map_err(logthru_net!(error));
}
None => {
return Err("WS stream closed".to_owned()).map_err(logthru_net!());
}
};
if out.len() > MAX_MESSAGE_SIZE {
Err("sending too large WS message".to_owned()).map_err(logthru_net!(error))
} else {
Ok(out)
}
} }
} }
@ -128,7 +133,7 @@ where
/// ///
struct WebsocketProtocolHandlerInner { struct WebsocketProtocolHandlerInner {
tls: bool, tls: bool,
network_manager: NetworkManager, connection_manager: ConnectionManager,
local_address: SocketAddr, local_address: SocketAddr,
request_path: Vec<u8>, request_path: Vec<u8>,
connection_initial_timeout: u64, connection_initial_timeout: u64,
@ -137,13 +142,17 @@ struct WebsocketProtocolHandlerInner {
#[derive(Clone)] #[derive(Clone)]
pub struct WebsocketProtocolHandler pub struct WebsocketProtocolHandler
where where
Self: TcpProtocolHandler, Self: ProtocolAcceptHandler,
{ {
inner: Arc<WebsocketProtocolHandlerInner>, inner: Arc<WebsocketProtocolHandlerInner>,
} }
impl WebsocketProtocolHandler { impl WebsocketProtocolHandler {
pub fn new(network_manager: NetworkManager, tls: bool, local_address: SocketAddr) -> Self { pub fn new(
let config = network_manager.config(); connection_manager: ConnectionManager,
tls: bool,
local_address: SocketAddr,
) -> Self {
let config = connection_manager.config();
let c = config.get(); let c = config.get();
let path = if tls { let path = if tls {
format!("GET {}", c.network.protocol.ws.path.trim_end_matches('/')) format!("GET {}", c.network.protocol.ws.path.trim_end_matches('/'))
@ -158,7 +167,7 @@ impl WebsocketProtocolHandler {
let inner = WebsocketProtocolHandlerInner { let inner = WebsocketProtocolHandlerInner {
tls, tls,
network_manager, connection_manager,
local_address, local_address,
request_path: path.as_bytes().to_vec(), request_path: path.as_bytes().to_vec(),
connection_initial_timeout, connection_initial_timeout,
@ -172,7 +181,7 @@ impl WebsocketProtocolHandler {
self, self,
ps: AsyncPeekStream, ps: AsyncPeekStream,
socket_addr: SocketAddr, socket_addr: SocketAddr,
) -> Result<bool, String> { ) -> Result<Option<NetworkConnection>, String> {
let request_path_len = self.inner.request_path.len() + 2; let request_path_len = self.inner.request_path.len() + 2;
let mut peekbuf: Vec<u8> = vec![0u8; request_path_len]; let mut peekbuf: Vec<u8> = vec![0u8; request_path_len];
match io::timeout( match io::timeout(
@ -197,7 +206,7 @@ impl WebsocketProtocolHandler {
if !matches_path { if !matches_path {
log_net!("not websocket"); log_net!("not websocket");
return Ok(false); return Ok(None);
} }
log_net!("found websocket"); log_net!("found websocket");
@ -218,26 +227,19 @@ impl WebsocketProtocolHandler {
let conn = NetworkConnection::WsAccepted(WebsocketNetworkConnection::new( let conn = NetworkConnection::WsAccepted(WebsocketNetworkConnection::new(
self.inner.tls, self.inner.tls,
ConnectionDescriptor::new(
peer_addr,
SocketAddress::from_socket_addr(self.inner.local_address),
),
ws_stream, ws_stream,
)); ));
self.inner
.network_manager Ok(Some(conn))
.clone()
.on_new_connection(
ConnectionDescriptor::new(
peer_addr,
SocketAddress::from_socket_addr(self.inner.local_address),
),
conn,
)
.await?;
Ok(true)
} }
pub async fn connect( pub async fn connect(
network_manager: NetworkManager, local_address: Option<SocketAddr>,
local_address: SocketAddr, dial_info: DialInfo,
dial_info: &DialInfo,
) -> Result<NetworkConnection, String> { ) -> Result<NetworkConnection, String> {
// Split dial info up // Split dial info up
let (tls, scheme) = match &dial_info { let (tls, scheme) = match &dial_info {
@ -256,14 +258,17 @@ impl WebsocketProtocolHandler {
let remote_socket_addr = dial_info.to_socket_addr(); let remote_socket_addr = dial_info.to_socket_addr();
// Make a shared socket // Make a shared socket
let socket = new_shared_tcp_socket(local_address)?; let socket = match local_address {
Some(a) => new_bound_shared_tcp_socket(a)?,
None => new_unbound_shared_tcp_socket(Domain::for_address(remote_socket_addr))?,
};
// Connect to the remote address // Connect to the remote address
let remote_socket2_addr = socket2::SockAddr::from(remote_socket_addr); let remote_socket2_addr = socket2::SockAddr::from(remote_socket_addr);
socket socket
.connect(&remote_socket2_addr) .connect(&remote_socket2_addr)
.map_err(map_to_string) .map_err(map_to_string)
.map_err(logthru_net!(error "local_address={} remote_socket_addr={}", local_address, remote_socket_addr))?; .map_err(logthru_net!(error "local_address={:?} remote_socket_addr={}", local_address, remote_socket_addr))?;
let std_stream: std::net::TcpStream = socket.into(); let std_stream: std::net::TcpStream = socket.into();
let tcp_stream = TcpStream::from(std_stream); let tcp_stream = TcpStream::from(std_stream);
@ -273,6 +278,11 @@ impl WebsocketProtocolHandler {
.map_err(map_to_string) .map_err(map_to_string)
.map_err(logthru_net!())?; .map_err(logthru_net!())?;
// Make our connection descriptor
let connection_descriptor = ConnectionDescriptor {
local: Some(SocketAddress::from_socket_addr(actual_local_addr)),
remote: dial_info.to_peer_address(),
};
// Negotiate TLS if this is WSS // Negotiate TLS if this is WSS
if tls { if tls {
let connector = TlsConnector::default(); let connector = TlsConnector::default();
@ -285,59 +295,32 @@ impl WebsocketProtocolHandler {
.await .await
.map_err(map_to_string) .map_err(map_to_string)
.map_err(logthru_net!(error))?; .map_err(logthru_net!(error))?;
let conn = NetworkConnection::Wss(WebsocketNetworkConnection::new(tls, ws_stream));
// Make the connection descriptor peer address Ok(NetworkConnection::Wss(WebsocketNetworkConnection::new(
let peer_addr = PeerAddress::new( tls,
SocketAddress::from_socket_addr(remote_socket_addr), connection_descriptor,
ProtocolType::WSS, ws_stream,
); )))
// Register the WSS connection
network_manager
.on_new_connection(
ConnectionDescriptor::new(
peer_addr,
SocketAddress::from_socket_addr(actual_local_addr),
),
conn.clone(),
)
.await?;
Ok(conn)
} else { } else {
let (ws_stream, _response) = client_async(request, tcp_stream) let (ws_stream, _response) = client_async(request, tcp_stream)
.await .await
.map_err(map_to_string) .map_err(map_to_string)
.map_err(logthru_net!(error))?; .map_err(logthru_net!(error))?;
let conn = NetworkConnection::Ws(WebsocketNetworkConnection::new(tls, ws_stream)); Ok(NetworkConnection::Ws(WebsocketNetworkConnection::new(
tls,
// Make the connection descriptor peer address connection_descriptor,
let peer_addr = PeerAddress::new( ws_stream,
SocketAddress::from_socket_addr(remote_socket_addr), )))
ProtocolType::WS,
);
// Register the WS connection
network_manager
.on_new_connection(
ConnectionDescriptor::new(
peer_addr,
SocketAddress::from_socket_addr(actual_local_addr),
),
conn.clone(),
)
.await?;
Ok(conn)
} }
} }
} }
impl TcpProtocolHandler for WebsocketProtocolHandler { impl ProtocolAcceptHandler for WebsocketProtocolHandler {
fn on_accept( fn on_accept(
&self, &self,
stream: AsyncPeekStream, stream: AsyncPeekStream,
peer_addr: SocketAddr, peer_addr: SocketAddr,
) -> SystemPinBoxFuture<Result<bool, String>> { ) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>> {
Box::pin(self.clone().on_accept_async(stream, peer_addr)) Box::pin(self.clone().on_accept_async(stream, peer_addr))
} }
} }

View File

@ -78,7 +78,7 @@ impl Network {
.start_tcp_listener( .start_tcp_listener(
listen_address.clone(), listen_address.clone(),
false, false,
Box::new(|n, t, a| Box::new(WebsocketProtocolHandler::new(n, t, a))), Box::new(|c, t, a| Box::new(WebsocketProtocolHandler::new(c, t, a))),
) )
.await?; .await?;
trace!("WS: listener started"); trace!("WS: listener started");

View File

@ -1,7 +1,7 @@
use crate::xx::*; use crate::xx::*;
use async_std::io::{Read, ReadExt, Result, Write}; use async_std::io::{Read, ReadExt, Result, Write};
use core::pin::Pin;
use core::task::{Context, Poll}; use core::task::{Context, Poll};
use std::pin::Pin;
//////// ////////
/// ///

View File

@ -0,0 +1,43 @@
use super::*;
use data_encoding::BASE64URL_NOPAD;
pub async fn save_user_secret(namespace: &str, key: &str, value: &[u8]) -> Result<bool, String> {
let mut s = BASE64URL_NOPAD.encode(value);
s.push('!');
save_user_secret_string(namespace, key, s.as_str()).await
}
pub async fn load_user_secret(namespace: &str, key: &str) -> Result<Option<Vec<u8>>, String> {
let mut s = match load_user_secret_string(namespace, key).await? {
Some(s) => s,
None => {
return Ok(None);
}
};
if s.pop() != Some('!') {
return Err("User secret is not a buffer".to_owned());
}
let mut bytes = Vec::<u8>::new();
let res = BASE64URL_NOPAD.decode_len(s.len());
match res {
Ok(l) => {
bytes.resize(l, 0u8);
}
Err(_) => {
return Err("Failed to decode".to_owned());
}
}
let res = BASE64URL_NOPAD.decode_mut(s.as_bytes(), &mut bytes);
match res {
Ok(_) => Ok(Some(bytes)),
Err(_) => Err("Failed to decode".to_owned()),
}
}
pub async fn remove_user_secret(namespace: &str, key: &str) -> Result<bool, String> {
remove_user_secret_string(namespace, key).await
}

View File

@ -8,14 +8,17 @@ use crate::xx::*;
pub struct DummyNetworkConnection {} pub struct DummyNetworkConnection {}
impl DummyNetworkConnection { impl DummyNetworkConnection {
pub fn protocol_type(&self) -> ProtocolType { pub fn connection_descriptor(&self) -> ConnectionDescriptor {
ProtocolType::UDP ConnectionDescriptor::new_no_local(PeerAddress::new(
SocketAddress::default(),
ProtocolType::UDP,
))
} }
pub fn send(&self, _message: Vec<u8>) -> SystemPinBoxFuture<Result<(), String>> { pub async fn send(&self, _message: Vec<u8>) -> Result<(), String> {
Box::pin(async { Ok(()) }) Ok(())
} }
pub fn recv(&self) -> SystemPinBoxFuture<Result<Vec<u8>, String>> { pub async fn recv(&self) -> Result<Vec<u8>, String> {
Box::pin(async { Ok(Vec::new()) }) Ok(Vec::new())
} }
} }
@ -27,16 +30,33 @@ pub enum NetworkConnection {
} }
impl NetworkConnection { impl NetworkConnection {
pub fn send(&self, message: Vec<u8>) -> SystemPinBoxFuture<Result<(), String>> { pub async fn connect(
match self { local_address: Option<SocketAddr>,
Self::Dummy(d) => d.send(message), dial_info: DialInfo,
Self::WS(w) => w.send(message), ) -> Result<NetworkConnection, String> {
match dial_info.protocol_type() {
ProtocolType::UDP => {
panic!("Should not connect to UDP dialinfo");
}
ProtocolType::TCP => {
panic!("TCP dial info is not support on WASM targets");
}
ProtocolType::WS | ProtocolType::WSS => {
ws::WebsocketProtocolHandler::connect(local_address, dial_info).await
}
} }
} }
pub fn recv(&self) -> SystemPinBoxFuture<Result<Vec<u8>, String>> {
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
match self { match self {
Self::Dummy(d) => d.recv(), Self::Dummy(d) => d.send(message).await,
Self::WS(w) => w.recv(), Self::WS(w) => w.send(message).await,
}
}
pub async fn recv(&self) -> Result<Vec<u8>, String> {
match self {
Self::Dummy(d) => d.recv().await,
Self::WS(w) => w.recv().await,
} }
} }
} }

View File

@ -14,6 +14,7 @@ struct WebsocketNetworkConnectionInner {
#[derive(Clone)] #[derive(Clone)]
pub struct WebsocketNetworkConnection { pub struct WebsocketNetworkConnection {
tls: bool, tls: bool,
connection_descriptor: ConnectionDescriptor,
inner: Arc<Mutex<WebsocketNetworkConnectionInner>>, inner: Arc<Mutex<WebsocketNetworkConnectionInner>>,
} }
@ -32,52 +33,49 @@ impl PartialEq for WebsocketNetworkConnection {
impl Eq for WebsocketNetworkConnection {} impl Eq for WebsocketNetworkConnection {}
impl WebsocketNetworkConnection { impl WebsocketNetworkConnection {
pub fn new(tls: bool, ws_meta: WsMeta, ws_stream: WsStream) -> Self { pub fn new(tls: bool, connection_descriptor: ConnectionDescriptor, ws_stream: WsStream) -> Self {
let ws = ws_stream.wrapped().clone(); let ws = ws_stream.wrapped().clone();
Self { Self {
tls, tls,
connection_descriptor,
inner: Arc::new(Mutex::new(WebsocketNetworkConnectionInner { inner: Arc::new(Mutex::new(WebsocketNetworkConnectionInner {
ws_stream, ws_stream,
ws, ws,
})), })),
} }
} }
}
impl WebsocketNetworkConnection { pub fn connection_descriptor(&self) -> ConnectionDescriptor {
pub fn send(&self, message: Vec<u8>) -> SystemPinBoxFuture<Result<(), String>> { self.connection_descriptor.clone()
let inner = self.inner.clone();
Box::pin(async move {
if message.len() > MAX_MESSAGE_SIZE {
return Err("sending too large WS message".to_owned()).map_err(logthru_net!(error));
}
inner
.lock()
.ws
.send_with_u8_array(&message)
.map_err(|_| "failed to send to websocket".to_owned())
.map_err(logthru_net!(error))
})
} }
pub fn recv(&self) -> SystemPinBoxFuture<Result<Vec<u8>, String>> {
let inner = self.inner.clone(); pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
Box::pin(async move { if message.len() > MAX_MESSAGE_SIZE {
let out = match inner.lock().ws_stream.next().await { return Err("sending too large WS message".to_owned()).map_err(logthru_net!(error));
Some(WsMessage::Binary(v)) => v, }
Some(_) => { self.inner
return Err("Unexpected WS message type".to_owned()) .lock()
.map_err(logthru_net!(error)); .ws
} .send_with_u8_array(&message)
None => { .map_err(|_| "failed to send to websocket".to_owned())
return Err("WS stream closed".to_owned()).map_err(logthru_net!(error)); .map_err(logthru_net!(error))
} }
}; pub async fn recv(&self) -> Result<Vec<u8>, String> {
if out.len() > MAX_MESSAGE_SIZE { let out = match self.inner.lock().ws_stream.next().await {
Err("sending too large WS message".to_owned()).map_err(logthru_net!(error)) Some(WsMessage::Binary(v)) => v,
} else { Some(_) => {
Ok(out) return Err("Unexpected WS message type".to_owned())
.map_err(logthru_net!(error));
} }
}) None => {
return Err("WS stream closed".to_owned()).map_err(logthru_net!(error));
}
};
if out.len() > MAX_MESSAGE_SIZE {
Err("sending too large WS message".to_owned()).map_err(logthru_net!(error))
} else {
Ok(out)
}
} }
} }
@ -88,7 +86,7 @@ pub struct WebsocketProtocolHandler {}
impl WebsocketProtocolHandler { impl WebsocketProtocolHandler {
pub async fn connect( pub async fn connect(
network_manager: NetworkManager, local_address: Option<SocketAddr>,
dial_info: &DialInfo, dial_info: &DialInfo,
) -> Result<NetworkConnection, String> { ) -> Result<NetworkConnection, String> {
let url = dial_info let url = dial_info
@ -113,18 +111,18 @@ impl WebsocketProtocolHandler {
.map_err(logthru_net!(error)) .map_err(logthru_net!(error))
} }
}; };
let peer_addr = dial_info.to_peer_address();
let (ws, wsio) = WsMeta::connect(url, None) let (_, wsio) = WsMeta::connect(url, None)
.await .await
.map_err(map_to_string) .map_err(map_to_string)
.map_err(logthru_net!(error))?; .map_err(logthru_net!(error))?;
let conn = NetworkConnection::WS(WebsocketNetworkConnection::new(tls, ws, wsio)); // Make our connection descriptor
network_manager let connection_descriptor = ConnectionDescriptor {
.on_new_connection(ConnectionDescriptor::new_no_local(peer_addr), conn.clone()) local: None,
.await?; remote: dial_info.to_peer_address(),
};
Ok(conn) Ok(NetworkConnection::WS(WebsocketNetworkConnection::new(tls, connection_descriptor, wsio)))
} }
} }

View File

@ -7,6 +7,7 @@ extern crate alloc;
mod attachment_manager; mod attachment_manager;
mod callback_state_machine; mod callback_state_machine;
mod connection_manager;
mod connection_table; mod connection_table;
mod dht; mod dht;
mod intf; mod intf;

View File

@ -1,8 +1,6 @@
use crate::*; use crate::*;
use connection_table::*; use connection_manager::*;
use dht::*; use dht::*;
use futures_util::future::{select, Either};
use futures_util::stream::{FuturesUnordered, StreamExt};
use intf::*; use intf::*;
use lease_manager::*; use lease_manager::*;
use receipt_manager::*; use receipt_manager::*;
@ -12,8 +10,6 @@ use xx::*;
//////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////
const CONNECTION_PROCESSOR_CHANNEL_SIZE: usize = 128usize;
pub const MAX_MESSAGE_SIZE: usize = MAX_ENVELOPE_SIZE; pub const MAX_MESSAGE_SIZE: usize = MAX_ENVELOPE_SIZE;
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
@ -77,7 +73,7 @@ impl ProtocolConfig {
#[derive(Clone)] #[derive(Clone)]
struct NetworkComponents { struct NetworkComponents {
net: Network, net: Network,
connection_table: ConnectionTable, connection_manager: ConnectionManager,
rpc_processor: RPCProcessor, rpc_processor: RPCProcessor,
lease_manager: LeaseManager, lease_manager: LeaseManager,
receipt_manager: ReceiptManager, receipt_manager: ReceiptManager,
@ -88,8 +84,6 @@ pub struct NetworkManagerInner {
routing_table: Option<RoutingTable>, routing_table: Option<RoutingTable>,
components: Option<NetworkComponents>, components: Option<NetworkComponents>,
network_class: Option<NetworkClass>, network_class: Option<NetworkClass>,
connection_processor_jh: Option<JoinHandle<()>>,
connection_add_channel_tx: Option<utils::channel::Sender<SystemPinBoxFuture<()>>>,
} }
#[derive(Clone)] #[derive(Clone)]
@ -106,8 +100,6 @@ impl NetworkManager {
routing_table: None, routing_table: None,
components: None, components: None,
network_class: None, network_class: None,
connection_processor_jh: None,
connection_add_channel_tx: None,
} }
} }
@ -161,13 +153,13 @@ impl NetworkManager {
.receipt_manager .receipt_manager
.clone() .clone()
} }
pub fn connection_table(&self) -> ConnectionTable { pub fn connection_manager(&self) -> ConnectionManager {
self.inner self.inner
.lock() .lock()
.components .components
.as_ref() .as_ref()
.unwrap() .unwrap()
.connection_table .connection_manager
.clone() .clone()
} }
@ -194,13 +186,13 @@ impl NetworkManager {
// Create network components // Create network components
let net = Network::new(self.clone()); let net = Network::new(self.clone());
let connection_table = ConnectionTable::new(); let connection_manager = ConnectionManager::new(self.clone());
let rpc_processor = RPCProcessor::new(self.clone()); let rpc_processor = RPCProcessor::new(self.clone());
let lease_manager = LeaseManager::new(self.clone()); let lease_manager = LeaseManager::new(self.clone());
let receipt_manager = ReceiptManager::new(self.clone()); let receipt_manager = ReceiptManager::new(self.clone());
self.inner.lock().components = Some(NetworkComponents { self.inner.lock().components = Some(NetworkComponents {
net: net.clone(), net: net.clone(),
connection_table: connection_table.clone(), connection_manager: connection_manager.clone(),
rpc_processor: rpc_processor.clone(), rpc_processor: rpc_processor.clone(),
lease_manager: lease_manager.clone(), lease_manager: lease_manager.clone(),
receipt_manager: receipt_manager.clone(), receipt_manager: receipt_manager.clone(),
@ -211,13 +203,7 @@ impl NetworkManager {
lease_manager.startup().await?; lease_manager.startup().await?;
receipt_manager.startup().await?; receipt_manager.startup().await?;
net.startup().await?; net.startup().await?;
connection_manager.startup().await;
// Run connection processing task
let cac = utils::channel::channel(CONNECTION_PROCESSOR_CHANNEL_SIZE); // xxx move to config
self.inner.lock().connection_add_channel_tx = Some(cac.0);
let rx = cac.1.clone();
let this = self.clone();
self.inner.lock().connection_processor_jh = Some(spawn(this.connection_processor(rx)));
trace!("NetworkManager::internal_startup end"); trace!("NetworkManager::internal_startup end");
@ -234,16 +220,11 @@ impl NetworkManager {
pub async fn shutdown(&self) { pub async fn shutdown(&self) {
trace!("NetworkManager::shutdown begin"); trace!("NetworkManager::shutdown begin");
let components = {
let mut inner = self.inner.lock();
// Drop/cancel the connection processing task first
inner.connection_processor_jh = None;
inner.connection_add_channel_tx = None;
inner.components.clone()
};
// Shutdown network components if they started up // Shutdown network components if they started up
let components = self.inner.lock().components.clone();
if let Some(components) = components { if let Some(components) = components {
components.connection_manager.shutdown().await;
components.net.shutdown().await; components.net.shutdown().await;
components.receipt_manager.shutdown().await; components.receipt_manager.shutdown().await;
components.lease_manager.shutdown().await; components.lease_manager.shutdown().await;
@ -292,124 +273,6 @@ impl NetworkManager {
Ok(()) Ok(())
} }
// Called by low-level protocol handlers when any connection-oriented protocol connection appears
// either from incoming or outgoing connections. Registers connection in the connection table for later access
// and spawns a message processing loop for the connection
pub async fn on_new_connection(
&self,
descriptor: ConnectionDescriptor,
conn: NetworkConnection,
) -> Result<(), String> {
let tx = self
.inner
.lock()
.connection_add_channel_tx
.as_ref()
.ok_or_else(fn_string!("connection channel isn't open yet"))?
.clone();
let this = self.clone();
let receiver_loop_future = Self::process_connection(this, descriptor, conn);
tx.try_send(receiver_loop_future)
.await
.map_err(map_to_string)
.map_err(logthru_net!(error "failed to start receiver loop"))
}
// Connection receiver loop
fn process_connection(
this: NetworkManager,
descriptor: ConnectionDescriptor,
conn: NetworkConnection,
) -> SystemPinBoxFuture<()> {
Box::pin(async move {
// Add new connections to the table
let entry = match this
.connection_table()
.add_connection(descriptor.clone(), conn.clone())
{
Ok(e) => e,
Err(err) => {
error!(target: "net", "{}", err);
return;
}
};
//
let exit_value: Result<Vec<u8>, ()> = Err(());
loop {
let res = match select(
entry.stopper.clone().instance_clone(exit_value.clone()),
conn.clone().recv(),
)
.await
{
Either::Left((_x, _b)) => break,
Either::Right((y, _a)) => y,
};
let message = match res {
Ok(v) => v,
Err(_) => break,
};
match this.on_recv_envelope(message.as_slice(), &descriptor).await {
Ok(_) => (),
Err(e) => {
error!("{}", e);
break;
}
};
}
if let Err(err) = this.connection_table().remove_connection(&descriptor) {
error!("{}", err);
}
})
}
// Process connection oriented sockets in the background
// This never terminates and must have its task cancelled once started
async fn connection_processor(self, rx: utils::channel::Receiver<SystemPinBoxFuture<()>>) {
let mut connection_futures: FuturesUnordered<SystemPinBoxFuture<()>> =
FuturesUnordered::new();
loop {
// Either process an existing connection, or receive a new one to add to our list
match select(connection_futures.next(), Box::pin(rx.recv())).await {
Either::Left((x, _)) => {
// Processed some connection to completion, or there are none left
match x {
Some(()) => {
// Processed some connection to completion
}
None => {
// No connections to process, wait for one
match rx.recv().await {
Ok(v) => {
connection_futures.push(v);
}
Err(e) => {
error!("connection processor error: {:?}", e);
// xxx: do something here??
}
};
}
}
}
Either::Right((x, _)) => {
// Got a new connection future
match x {
Ok(v) => {
connection_futures.push(v);
}
Err(e) => {
error!("connection processor error: {:?}", e);
// xxx: do something here??
}
};
}
}
}
}
// Return what network class we are in // Return what network class we are in
pub fn get_network_class(&self) -> Option<NetworkClass> { pub fn get_network_class(&self) -> Option<NetworkClass> {
if let Some(components) = &self.inner.lock().components { if let Some(components) = &self.inner.lock().components {

View File

@ -4,7 +4,7 @@ use crate::xx::*;
use crate::*; use crate::*;
pub async fn test_add_get_remove() { pub async fn test_add_get_remove() {
let table = ConnectionTable::new(); let mut table = ConnectionTable::new();
let c1 = NetworkConnection::Dummy(DummyNetworkConnection {}); let c1 = NetworkConnection::Dummy(DummyNetworkConnection {});
let c2 = NetworkConnection::Dummy(DummyNetworkConnection {}); let c2 = NetworkConnection::Dummy(DummyNetworkConnection {});

View File

@ -153,6 +153,33 @@ macro_rules! logthru {
); );
e__ e__
}); });
// debug
(debug $target:literal) => (|e__| {
debug!(
target: $target,
"[{}]",
e__,
);
e__
});
(debug $target:literal, $text:literal) => (|e__| {
debug!(
target: $target,
"[{}] {}",
e__,
$text
);
e__
});
(debug $target:literal, $fmt:literal, $($arg:expr),+) => (|e__| {
debug!(
target: $target,
concat!("[{}] ", $fmt),
e__,
$($arg),+
);
e__
});
// trace // trace
($target:literal) => (|e__| { ($target:literal) => (|e__| {
trace!( trace!(