refactor network manager

This commit is contained in:
John Smith 2022-05-31 19:54:52 -04:00
parent ad4b6328ac
commit 8148c37708
51 changed files with 500 additions and 389 deletions

9
Cargo.lock generated
View File

@ -2815,6 +2815,12 @@ dependencies = [
"stable_deref_trait", "stable_deref_trait",
] ]
[[package]]
name = "owo-colors"
version = "3.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "decf7381921fea4dcb2549c5667eda59b3ec297ab7e2b5fc33eac69d2e7da87b"
[[package]] [[package]]
name = "parity-scale-codec" name = "parity-scale-codec"
version = "3.0.0" version = "3.0.0"
@ -4394,6 +4400,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"android_logger", "android_logger",
"anyhow", "anyhow",
"async-io",
"async-lock", "async-lock",
"async-std", "async-std",
"async-std-resolver", "async-std-resolver",
@ -4441,8 +4448,8 @@ dependencies = [
"ndk-glue", "ndk-glue",
"nix 0.23.1", "nix 0.23.1",
"no-std-net", "no-std-net",
"num_cpus",
"once_cell", "once_cell",
"owo-colors",
"parking_lot 0.12.0", "parking_lot 0.12.0",
"rand 0.7.3", "rand 0.7.3",
"rtnetlink", "rtnetlink",

View File

@ -38,6 +38,7 @@ json = "^0"
flume = { version = "^0", features = ["async"] } flume = { version = "^0", features = ["async"] }
enumset = { version= "^1", features = ["serde"] } enumset = { version= "^1", features = ["serde"] }
backtrace = { version = "^0", optional = true } backtrace = { version = "^0", optional = true }
owo-colors = "^3"
ed25519-dalek = { version = "^1", default_features = false, features = ["alloc", "u64_backend"] } ed25519-dalek = { version = "^1", default_features = false, features = ["alloc", "u64_backend"] }
x25519-dalek = { package = "x25519-dalek-ng", version = "^1", default_features = false, features = ["u64_backend"] } x25519-dalek = { package = "x25519-dalek-ng", version = "^1", default_features = false, features = ["u64_backend"] }
@ -52,6 +53,7 @@ digest = "0.9.0"
# Linux, Windows, Mac, iOS, Android # Linux, Windows, Mac, iOS, Android
[target.'cfg(not(target_arch = "wasm32"))'.dependencies] [target.'cfg(not(target_arch = "wasm32"))'.dependencies]
async-std = { version = "^1", features = ["unstable"] } async-std = { version = "^1", features = ["unstable"] }
async-io = { version = "^1" }
async-tungstenite = { version = "^0", features = ["async-std-runtime", "async-tls"] } async-tungstenite = { version = "^0", features = ["async-std-runtime", "async-tls"] }
async-std-resolver = { version = "^0" } async-std-resolver = { version = "^0" }
maplit = "^1" maplit = "^1"
@ -63,7 +65,6 @@ webpki = "^0"
webpki-roots = "^0" webpki-roots = "^0"
rustls = "^0.19" rustls = "^0.19"
rustls-pemfile = "^0.2" rustls-pemfile = "^0.2"
num_cpus = "^1"
futures-util = { version = "^0", default-features = false, features = ["async-await", "sink", "std", "io"] } futures-util = { version = "^0", default-features = false, features = ["async-await", "sink", "std", "io"] }
keyvaluedb-sqlite = { path = "../external/keyvaluedb/keyvaluedb-sqlite" } keyvaluedb-sqlite = { path = "../external/keyvaluedb/keyvaluedb-sqlite" }
data-encoding = { version = "^2" } data-encoding = { version = "^2" }

View File

@ -1,5 +1,5 @@
use crate::callback_state_machine::*; use crate::callback_state_machine::*;
use crate::dht::crypto::Crypto; use crate::dht::Crypto;
use crate::intf::*; use crate::intf::*;
use crate::network_manager::*; use crate::network_manager::*;
use crate::routing_table::*; use crate::routing_table::*;

View File

@ -1,6 +1,6 @@
use crate::api_logger::*; use crate::api_logger::*;
use crate::attachment_manager::*; use crate::attachment_manager::*;
use crate::dht::crypto::Crypto; use crate::dht::Crypto;
use crate::intf::*; use crate::intf::*;
use crate::veilid_api::*; use crate::veilid_api::*;
use crate::veilid_config::*; use crate::veilid_config::*;

View File

@ -1,8 +1,10 @@
pub mod crypto; mod crypto;
pub mod envelope; mod envelope;
pub mod key; mod key;
pub mod receipt; mod receipt;
pub mod value; mod value;
pub mod tests;
pub use crypto::*; pub use crypto::*;
pub use envelope::*; pub use envelope::*;

View File

@ -0,0 +1,5 @@
pub mod test_crypto;
pub mod test_dht_key;
pub mod test_envelope_receipt;
use super::*;

View File

@ -1,6 +1,5 @@
use super::test_veilid_config::*; use super::*;
use crate::dht::crypto::*; use crate::tests::common::test_veilid_config::*;
use crate::dht::key;
use crate::xx::*; use crate::xx::*;
use crate::*; use crate::*;

View File

@ -1,6 +1,6 @@
#![allow(clippy::bool_assert_comparison)] #![allow(clippy::bool_assert_comparison)]
use crate::dht::key; use super::*;
use crate::xx::*; use crate::xx::*;
use core::convert::TryFrom; use core::convert::TryFrom;

View File

@ -1,8 +1,5 @@
use super::test_veilid_config::*; use super::*;
use crate::dht::crypto::*; use crate::tests::common::test_veilid_config::*;
use crate::dht::envelope::*;
use crate::dht::key::*;
use crate::dht::receipt::*;
use crate::xx::*; use crate::xx::*;
use crate::*; use crate::*;

View File

@ -1,12 +1,10 @@
mod block_store; mod block_store;
mod network;
mod protected_store; mod protected_store;
mod system; mod system;
mod table_store; mod table_store;
pub mod utils; pub mod utils;
pub use block_store::*; pub use block_store::*;
pub use network::*;
pub use protected_store::*; pub use protected_store::*;
pub use system::*; pub use system::*;
pub use table_store::*; pub use table_store::*;

View File

@ -94,7 +94,12 @@ where
} }
pub fn get_concurrency() -> u32 { pub fn get_concurrency() -> u32 {
num_cpus::get() as u32 std::thread::available_parallelism()
.map(|x| x.get())
.unwrap_or_else(|e| {
warn!("unable to get concurrency defaulting to single core: {}", e);
1
}) as u32
} }
pub async fn get_outbound_relay_peer() -> Option<crate::veilid_api::PeerInfo> { pub async fn get_outbound_relay_peer() -> Option<crate::veilid_api::PeerInfo> {

View File

@ -1,12 +1,11 @@
mod block_store; mod block_store;
mod network;
mod protected_store; mod protected_store;
mod system; mod system;
mod table_store; mod table_store;
pub mod utils; pub mod utils;
pub use block_store::*; pub use block_store::*;
pub use network::*;
pub use protected_store::*; pub use protected_store::*;
pub use system::*; pub use system::*;
pub use table_store::*; pub use table_store::*;

View File

@ -7,13 +7,9 @@ extern crate alloc;
mod api_logger; mod api_logger;
mod attachment_manager; mod attachment_manager;
mod callback_state_machine; mod callback_state_machine;
mod connection_limits;
mod connection_manager;
mod connection_table;
mod core_context; mod core_context;
mod dht; mod dht;
mod intf; mod intf;
mod network_connection;
mod network_manager; mod network_manager;
mod receipt_manager; mod receipt_manager;
mod routing_table; mod routing_table;

View File

@ -130,6 +130,7 @@ impl ConnectionLimits {
let cnt = &mut *self.conn_count_by_ip4.entry(v4).or_default(); let cnt = &mut *self.conn_count_by_ip4.entry(v4).or_default();
assert!(*cnt <= self.max_connections_per_ip4); assert!(*cnt <= self.max_connections_per_ip4);
if *cnt == self.max_connections_per_ip4 { if *cnt == self.max_connections_per_ip4 {
warn!("address filter count exceeded: {:?}", v4);
return Err(AddressFilterError::CountExceeded); return Err(AddressFilterError::CountExceeded);
} }
// See if this ip block has connected too frequently // See if this ip block has connected too frequently
@ -140,6 +141,7 @@ impl ConnectionLimits {
}); });
assert!(tstamps.len() <= self.max_connection_frequency_per_min); assert!(tstamps.len() <= self.max_connection_frequency_per_min);
if tstamps.len() == self.max_connection_frequency_per_min { if tstamps.len() == self.max_connection_frequency_per_min {
warn!("address filter rate exceeded: {:?}", v4);
return Err(AddressFilterError::RateExceeded); return Err(AddressFilterError::RateExceeded);
} }
@ -152,12 +154,14 @@ impl ConnectionLimits {
let cnt = &mut *self.conn_count_by_ip6_prefix.entry(v6).or_default(); let cnt = &mut *self.conn_count_by_ip6_prefix.entry(v6).or_default();
assert!(*cnt <= self.max_connections_per_ip6_prefix); assert!(*cnt <= self.max_connections_per_ip6_prefix);
if *cnt == self.max_connections_per_ip6_prefix { if *cnt == self.max_connections_per_ip6_prefix {
warn!("address filter count exceeded: {:?}", v6);
return Err(AddressFilterError::CountExceeded); return Err(AddressFilterError::CountExceeded);
} }
// See if this ip block has connected too frequently // See if this ip block has connected too frequently
let tstamps = &mut self.conn_timestamps_by_ip6_prefix.entry(v6).or_default(); let tstamps = &mut self.conn_timestamps_by_ip6_prefix.entry(v6).or_default();
assert!(tstamps.len() <= self.max_connection_frequency_per_min); assert!(tstamps.len() <= self.max_connection_frequency_per_min);
if tstamps.len() == self.max_connection_frequency_per_min { if tstamps.len() == self.max_connection_frequency_per_min {
warn!("address filter rate exceeded: {:?}", v6);
return Err(AddressFilterError::RateExceeded); return Err(AddressFilterError::RateExceeded);
} }

View File

@ -1,29 +1,16 @@
use crate::connection_table::*; use super::*;
use crate::intf::*;
use crate::network_connection::*;
use crate::network_manager::*;
use crate::xx::*; use crate::xx::*;
use crate::*; use connection_table::*;
use futures_util::stream::{FuturesUnordered, StreamExt}; use network_connection::*;
use futures_util::{select, FutureExt};
const CONNECTION_PROCESSOR_CHANNEL_SIZE: usize = 128usize; const CONNECTION_PROCESSOR_CHANNEL_SIZE: usize = 128usize;
/////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////
// Connection manager // Connection manager
#[derive(Debug)]
struct ConnectionManagerInner { struct ConnectionManagerInner {
connection_table: ConnectionTable, connection_table: ConnectionTable,
connection_processor_jh: Option<JoinHandle<()>>,
connection_add_channel_tx: Option<flume::Sender<SystemPinBoxFuture<()>>>,
}
impl core::fmt::Debug for ConnectionManagerInner {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("ConnectionManagerInner")
.field("connection_table", &self.connection_table)
.finish()
}
} }
struct ConnectionManagerArc { struct ConnectionManagerArc {
@ -47,8 +34,6 @@ impl ConnectionManager {
fn new_inner(config: VeilidConfig) -> ConnectionManagerInner { fn new_inner(config: VeilidConfig) -> ConnectionManagerInner {
ConnectionManagerInner { ConnectionManagerInner {
connection_table: ConnectionTable::new(config), connection_table: ConnectionTable::new(config),
connection_processor_jh: None,
connection_add_channel_tx: None,
} }
} }
fn new_arc(network_manager: NetworkManager) -> ConnectionManagerArc { fn new_arc(network_manager: NetworkManager) -> ConnectionManagerArc {
@ -70,15 +55,12 @@ impl ConnectionManager {
pub async fn startup(&self) { pub async fn startup(&self) {
trace!("startup connection manager"); trace!("startup connection manager");
let mut inner = self.arc.inner.lock().await; //let mut inner = self.arc.inner.lock().await;
let cac = flume::bounded(CONNECTION_PROCESSOR_CHANNEL_SIZE);
inner.connection_add_channel_tx = Some(cac.0);
let rx = cac.1.clone();
let this = self.clone();
inner.connection_processor_jh = Some(spawn(this.connection_processor(rx)));
} }
pub async fn shutdown(&self) { pub async fn shutdown(&self) {
// xxx close all connections in the connection table
*self.arc.inner.lock().await = Self::new_inner(self.arc.network_manager.config()); *self.arc.inner.lock().await = Self::new_inner(self.arc.network_manager.config());
} }
@ -127,6 +109,12 @@ impl ConnectionManager {
local_addr: Option<SocketAddr>, local_addr: Option<SocketAddr>,
dial_info: DialInfo, dial_info: DialInfo,
) -> Result<NetworkConnection, String> { ) -> Result<NetworkConnection, String> {
log_net!(
"== get_or_create_connection local_addr={:?} dial_info={:?}",
local_addr.green(),
dial_info.green()
);
let peer_address = dial_info.to_peer_address(); let peer_address = dial_info.to_peer_address();
let descriptor = match local_addr { let descriptor = match local_addr {
Some(la) => { Some(la) => {
@ -143,10 +131,46 @@ impl ConnectionManager {
.connection_table .connection_table
.get_last_connection_by_remote(descriptor.remote) .get_last_connection_by_remote(descriptor.remote)
{ {
log_net!(
"== Returning existing connection local_addr={:?} peer_address={:?}",
local_addr.green(),
peer_address.green()
);
return Ok(conn); return Ok(conn);
} }
// If not, attempt new connection // Drop any other protocols connections that have the same local addr
// otherwise this connection won't succeed due to binding
if let Some(local_addr) = local_addr {
if local_addr.port() != 0 {
for pt in [ProtocolType::TCP, ProtocolType::WS, ProtocolType::WSS] {
let pa = PeerAddress::new(descriptor.remote.socket_address, pt);
for conn in inner.connection_table.get_connections_by_remote(pa) {
let desc = conn.connection_descriptor();
let mut kill = false;
if let Some(conn_local) = desc.local {
if (local_addr.ip().is_unspecified()
|| (local_addr.ip() == conn_local.to_ip_addr()))
&& conn_local.port() == local_addr.port()
{
kill = true;
}
}
if kill {
log_net!(debug
">< Terminating connection local_addr={:?} peer_address={:?}",
local_addr.green(),
pa.green()
);
conn.close().await?;
}
}
}
}
}
// Attempt new connection
let conn = NetworkConnection::connect(local_addr, dial_info).await?; let conn = NetworkConnection::connect(local_addr, dial_info).await?;
self.on_new_connection_internal(&mut *inner, conn.clone())?; self.on_new_connection_internal(&mut *inner, conn.clone())?;
@ -159,7 +183,7 @@ impl ConnectionManager {
this: ConnectionManager, this: ConnectionManager,
conn: NetworkConnection, conn: NetworkConnection,
) -> SystemPinBoxFuture<()> { ) -> SystemPinBoxFuture<()> {
log_net!("Starting process_connection loop for {:?}", conn); log_net!("Starting process_connection loop for {:?}", conn.green());
let network_manager = this.network_manager(); let network_manager = this.network_manager();
Box::pin(async move { Box::pin(async move {
// //
@ -185,7 +209,7 @@ impl ConnectionManager {
} }
_ = intf::sleep(inactivity_timeout).fuse()=> { _ = intf::sleep(inactivity_timeout).fuse()=> {
// timeout // timeout
log_net!("connection timeout on {:?}", descriptor); log_net!("connection timeout on {:?}", descriptor.green());
break; break;
} }
}; };
@ -198,6 +222,12 @@ impl ConnectionManager {
} }
} }
log_net!(
"== Connection loop finished local_addr={:?} remote={:?}",
descriptor.local.green(),
descriptor.remote.green()
);
if let Err(e) = this if let Err(e) = this
.arc .arc
.inner .inner
@ -210,49 +240,4 @@ impl ConnectionManager {
} }
}) })
} }
// Process connection oriented sockets in the background
// This never terminates and must have its task cancelled once started
// Task cancellation is performed by shutdown() by dropping the join handle
async fn connection_processor(self, rx: flume::Receiver<SystemPinBoxFuture<()>>) {
let mut connection_futures: FuturesUnordered<SystemPinBoxFuture<()>> =
FuturesUnordered::new();
loop {
// Either process an existing connection, or receive a new one to add to our list
select! {
x = connection_futures.next().fuse() => {
// Processed some connection to completion, or there are none left
match x {
Some(()) => {
// Processed some connection to completion
}
None => {
// No connections to process, wait for one
match rx.recv_async().await {
Ok(v) => {
connection_futures.push(v);
}
Err(e) => {
log_net!(error "connection processor error: {:?}", e);
// xxx: do something here?? should the network be restarted if this happens?
}
};
}
}
}
x = rx.recv_async().fuse() => {
// Got a new connection future
match x {
Ok(v) => {
connection_futures.push(v);
}
Err(e) => {
log_net!(error "connection processor error: {:?}", e);
// xxx: do something here?? should the network be restarted if this happens?
}
};
}
}
}
}
} }

View File

@ -1,5 +1,5 @@
use crate::connection_limits::*; use super::connection_limits::*;
use crate::network_connection::*; use super::network_connection::*;
use crate::xx::*; use crate::xx::*;
use crate::*; use crate::*;
use alloc::collections::btree_map::Entry; use alloc::collections::btree_map::Entry;
@ -67,6 +67,7 @@ impl ConnectionTable {
// then drop the least recently used connection // then drop the least recently used connection
if self.conn_by_descriptor[index].len() > self.max_connections[index] { if self.conn_by_descriptor[index].len() > self.max_connections[index] {
if let Some((lruk, _)) = self.conn_by_descriptor[index].remove_lru() { if let Some((lruk, _)) = self.conn_by_descriptor[index].remove_lru() {
warn!("XX: connection lru out: {:?}", lruk);
self.remove_connection_records(lruk); self.remove_connection_records(lruk);
} }
} }
@ -74,7 +75,7 @@ impl ConnectionTable {
// add connection records // add connection records
let conns = self.conns_by_remote.entry(descriptor.remote).or_default(); let conns = self.conns_by_remote.entry(descriptor.remote).or_default();
//warn!("add_connection: {:?}", conn); warn!("add_connection: {:?}", conn);
conns.push(conn); conns.push(conn);
Ok(()) Ok(())
@ -86,7 +87,7 @@ impl ConnectionTable {
) -> Option<NetworkConnection> { ) -> Option<NetworkConnection> {
let index = protocol_to_index(descriptor.protocol_type()); let index = protocol_to_index(descriptor.protocol_type());
let out = self.conn_by_descriptor[index].get(&descriptor).cloned(); let out = self.conn_by_descriptor[index].get(&descriptor).cloned();
//warn!("get_connection: {:?} -> {:?}", descriptor, out); warn!("get_connection: {:?} -> {:?}", descriptor, out);
out out
} }
@ -98,7 +99,7 @@ impl ConnectionTable {
.conns_by_remote .conns_by_remote
.get(&remote) .get(&remote)
.map(|v| v[(v.len() - 1)].clone()); .map(|v| v[(v.len() - 1)].clone());
//warn!("get_last_connection_by_remote: {:?} -> {:?}", remote, out); warn!("get_last_connection_by_remote: {:?} -> {:?}", remote, out);
if let Some(connection) = &out { if let Some(connection) = &out {
// lru bump // lru bump
let index = protocol_to_index(connection.connection_descriptor().protocol_type()); let index = protocol_to_index(connection.connection_descriptor().protocol_type());
@ -107,6 +108,16 @@ impl ConnectionTable {
out out
} }
pub fn get_connections_by_remote(&mut self, remote: PeerAddress) -> Vec<NetworkConnection> {
let out = self
.conns_by_remote
.get(&remote)
.cloned()
.unwrap_or_default();
warn!("get_connections_by_remote: {:?} -> {:?}", remote, out);
out
}
pub fn connection_count(&self) -> usize { pub fn connection_count(&self) -> usize {
self.conn_by_descriptor.iter().fold(0, |b, c| b + c.len()) self.conn_by_descriptor.iter().fold(0, |b, c| b + c.len())
} }
@ -144,7 +155,7 @@ impl ConnectionTable {
&mut self, &mut self,
descriptor: ConnectionDescriptor, descriptor: ConnectionDescriptor,
) -> Result<NetworkConnection, String> { ) -> Result<NetworkConnection, String> {
//warn!("remove_connection: {:?}", descriptor); warn!("remove_connection: {:?}", descriptor);
let index = protocol_to_index(descriptor.protocol_type()); let index = protocol_to_index(descriptor.protocol_type());
let out = self.conn_by_descriptor[index] let out = self.conn_by_descriptor[index]
.remove(&descriptor) .remove(&descriptor)

View File

@ -1,11 +1,34 @@
use crate::*; use crate::*;
#[cfg(not(target_arch = "wasm32"))]
mod native;
#[cfg(target_arch = "wasm32")]
mod wasm;
mod connection_limits;
mod connection_manager;
mod connection_table;
mod network_connection;
pub mod tests;
////////////////////////////////////////////////////////////////////////////////////////
pub use network_connection::*;
////////////////////////////////////////////////////////////////////////////////////////
use connection_manager::*; use connection_manager::*;
use dht::*; use dht::*;
use hashlink::LruCache; use hashlink::LruCache;
use intf::*; use intf::*;
#[cfg(not(target_arch = "wasm32"))]
use native::*;
use receipt_manager::*; use receipt_manager::*;
use routing_table::*; use routing_table::*;
use rpc_processor::*; use rpc_processor::*;
#[cfg(target_arch = "wasm32")]
use wasm::*;
use xx::*; use xx::*;
//////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////
@ -93,9 +116,9 @@ struct NetworkManagerInner {
components: Option<NetworkComponents>, components: Option<NetworkComponents>,
update_callback: Option<UpdateCallback>, update_callback: Option<UpdateCallback>,
stats: NetworkManagerStats, stats: NetworkManagerStats,
client_whitelist: LruCache<key::DHTKey, ClientWhitelistEntry>, client_whitelist: LruCache<DHTKey, ClientWhitelistEntry>,
relay_node: Option<NodeRef>, relay_node: Option<NodeRef>,
public_address_check_cache: LruCache<key::DHTKey, SocketAddress>, public_address_check_cache: LruCache<DHTKey, SocketAddress>,
} }
struct NetworkManagerUnlockedInner { struct NetworkManagerUnlockedInner {
@ -300,7 +323,7 @@ impl NetworkManager {
trace!("NetworkManager::shutdown end"); trace!("NetworkManager::shutdown end");
} }
pub fn update_client_whitelist(&self, client: key::DHTKey) { pub fn update_client_whitelist(&self, client: DHTKey) {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
match inner.client_whitelist.entry(client) { match inner.client_whitelist.entry(client) {
hashlink::lru_cache::Entry::Occupied(mut entry) => { hashlink::lru_cache::Entry::Occupied(mut entry) => {
@ -314,7 +337,7 @@ impl NetworkManager {
} }
} }
pub fn check_client_whitelist(&self, client: key::DHTKey) -> bool { pub fn check_client_whitelist(&self, client: DHTKey) -> bool {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
match inner.client_whitelist.entry(client) { match inner.client_whitelist.entry(client) {
@ -565,7 +588,7 @@ impl NetworkManager {
// Builds an envelope for sending over the network // Builds an envelope for sending over the network
fn build_envelope<B: AsRef<[u8]>>( fn build_envelope<B: AsRef<[u8]>>(
&self, &self,
dest_node_id: key::DHTKey, dest_node_id: DHTKey,
version: u8, version: u8,
body: B, body: B,
) -> Result<Vec<u8>, String> { ) -> Result<Vec<u8>, String> {

View File

@ -4,11 +4,10 @@ mod network_udp;
mod protocol; mod protocol;
mod start_protocols; mod start_protocols;
use crate::connection_manager::*;
use crate::intf::*; use crate::intf::*;
use crate::network_manager::*; use crate::network_manager::*;
use crate::routing_table::*; use crate::routing_table::*;
use crate::*; use connection_manager::*;
use network_tcp::*; use network_tcp::*;
use protocol::tcp::RawTcpProtocolHandler; use protocol::tcp::RawTcpProtocolHandler;
use protocol::udp::RawUdpProtocolHandler; use protocol::udp::RawUdpProtocolHandler;

View File

@ -1,9 +1,4 @@
use super::*; use super::*;
use crate::intf::*;
use crate::routing_table::*;
use crate::*;
use futures_util::stream::FuturesUnordered; use futures_util::stream::FuturesUnordered;
use futures_util::FutureExt; use futures_util::FutureExt;
@ -215,7 +210,7 @@ impl DiscoveryContext {
{ {
None => { None => {
// If we can't get an external address, exit but don't throw an error so we can try again later // If we can't get an external address, exit but don't throw an error so we can try again later
log_net!(debug "couldn't get external address 1"); log_net!(debug "couldn't get external address 1 for {:?} {:?}", protocol_type, address_type);
return false; return false;
} }
Some(v) => v, Some(v) => v,
@ -463,6 +458,7 @@ impl Network {
let protocol_config = self.inner.lock().protocol_config.unwrap_or_default(); let protocol_config = self.inner.lock().protocol_config.unwrap_or_default();
let mut unord = FuturesUnordered::new(); let mut unord = FuturesUnordered::new();
// Do UDPv4+v6 at the same time as everything else
if protocol_config.inbound.contains(ProtocolType::UDP) { if protocol_config.inbound.contains(ProtocolType::UDP) {
// UDPv4 // UDPv4
unord.push( unord.push(
@ -475,101 +471,102 @@ impl Network {
log_net!(debug "Failed UDPv4 dialinfo discovery: {}", e); log_net!(debug "Failed UDPv4 dialinfo discovery: {}", e);
return None; return None;
} }
Some(udpv4_context) Some(vec![udpv4_context])
} }
.boxed(), .boxed(),
); );
// // UDPv6 // UDPv6
// unord.push( unord.push(
// async { async {
// let udpv6_context = DiscoveryContext::new(self.routing_table(), self.clone()); let udpv6_context = DiscoveryContext::new(self.routing_table(), self.clone());
// if let Err(e) = self if let Err(e) = self
// .update_ipv6_protocol_dialinfo(&udpv6_context, ProtocolType::UDP) .update_ipv6_protocol_dialinfo(&udpv6_context, ProtocolType::UDP)
// .await .await
// { {
// log_net!(debug "Failed UDPv6 dialinfo discovery: {}", e); log_net!(debug "Failed UDPv6 dialinfo discovery: {}", e);
// return None; return None;
// } }
// Some(udpv6_context) Some(vec![udpv6_context])
// } }
// .boxed(), .boxed(),
// ); );
} }
// if protocol_config.inbound.contains(ProtocolType::TCP) { // Do TCPv4 + WSv4 in series because they may use the same connection 5-tuple
// // TCPv4 unord.push(
// unord.push( async {
// async { // TCPv4
// let tcpv4_context = DiscoveryContext::new(self.routing_table(), self.clone()); let mut out = Vec::<DiscoveryContext>::new();
// if let Err(e) = self if protocol_config.inbound.contains(ProtocolType::TCP) {
// .update_ipv4_protocol_dialinfo(&tcpv4_context, ProtocolType::TCP) let tcpv4_context = DiscoveryContext::new(self.routing_table(), self.clone());
// .await if let Err(e) = self
// { .update_ipv4_protocol_dialinfo(&tcpv4_context, ProtocolType::TCP)
// log_net!(debug "Failed TCPv4 dialinfo discovery: {}", e); .await
// return None; {
// } log_net!(debug "Failed TCPv4 dialinfo discovery: {}", e);
// Some(tcpv4_context) return None;
// } }
// .boxed(), out.push(tcpv4_context);
// ); }
// // TCPv6 // WSv4
// unord.push( if protocol_config.inbound.contains(ProtocolType::WS) {
// async { let wsv4_context = DiscoveryContext::new(self.routing_table(), self.clone());
// let tcpv6_context = DiscoveryContext::new(self.routing_table(), self.clone()); if let Err(e) = self
// if let Err(e) = self .update_ipv4_protocol_dialinfo(&wsv4_context, ProtocolType::WS)
// .update_ipv6_protocol_dialinfo(&tcpv6_context, ProtocolType::TCP) .await
// .await {
// { log_net!(debug "Failed WSv4 dialinfo discovery: {}", e);
// log_net!(debug "Failed TCPv6 dialinfo discovery: {}", e); return None;
// return None; }
// } out.push(wsv4_context);
// Some(tcpv6_context) }
// } Some(out)
// .boxed(), }
// ); .boxed(),
// } );
// if protocol_config.inbound.contains(ProtocolType::WS) { // Do TCPv6 + WSv6 in series because they may use the same connection 5-tuple
// // WS4 unord.push(
// unord.push( async {
// async { // TCPv6
// let wsv4_context = DiscoveryContext::new(self.routing_table(), self.clone()); let mut out = Vec::<DiscoveryContext>::new();
// if let Err(e) = self if protocol_config.inbound.contains(ProtocolType::TCP) {
// .update_ipv4_protocol_dialinfo(&wsv4_context, ProtocolType::WS) let tcpv6_context = DiscoveryContext::new(self.routing_table(), self.clone());
// .await if let Err(e) = self
// { .update_ipv6_protocol_dialinfo(&tcpv6_context, ProtocolType::TCP)
// log_net!(debug "Failed WSv4 dialinfo discovery: {}", e); .await
// return None; {
// } log_net!(debug "Failed TCPv6 dialinfo discovery: {}", e);
// Some(wsv4_context) return None;
// } }
// .boxed(), out.push(tcpv6_context);
// ); }
// // WSv6 // WSv6
// unord.push( if protocol_config.inbound.contains(ProtocolType::WS) {
// async { let wsv6_context = DiscoveryContext::new(self.routing_table(), self.clone());
// let wsv6_context = DiscoveryContext::new(self.routing_table(), self.clone()); if let Err(e) = self
// if let Err(e) = self .update_ipv6_protocol_dialinfo(&wsv6_context, ProtocolType::WS)
// .update_ipv6_protocol_dialinfo(&wsv6_context, ProtocolType::TCP) .await
// .await {
// { log_net!(debug "Failed WSv6 dialinfo discovery: {}", e);
// log_net!(debug "Failed WSv6 dialinfo discovery: {}", e); return None;
// return None; }
// } out.push(wsv6_context);
// Some(wsv6_context) }
// } Some(out)
// .boxed(), }
// ); .boxed(),
// } );
// Wait for all discovery futures to complete and collect contexts // Wait for all discovery futures to complete and collect contexts
let mut contexts = Vec::<DiscoveryContext>::new(); let mut contexts = Vec::<DiscoveryContext>::new();
let mut network_class = Option::<NetworkClass>::None; let mut network_class = Option::<NetworkClass>::None;
while let Some(ctx) = unord.next().await { while let Some(ctxvec) = unord.next().await {
if let Some(ctx) = ctx { if let Some(ctxvec) = ctxvec {
for ctx in ctxvec {
if let Some(nc) = ctx.inner.lock().detected_network_class { if let Some(nc) = ctx.inner.lock().detected_network_class {
if let Some(last_nc) = network_class { if let Some(last_nc) = network_class {
if nc < last_nc { if nc < last_nc {
@ -583,6 +580,7 @@ impl Network {
contexts.push(ctx); contexts.push(ctx);
} }
} }
}
// Get best network class // Get best network class
if network_class.is_some() { if network_class.is_some() {

View File

@ -1,8 +1,7 @@
use super::sockets::*;
use super::*; use super::*;
use crate::intf::*; use crate::intf::*;
use crate::network_connection::*;
use async_tls::TlsAcceptor; use async_tls::TlsAcceptor;
use sockets::*;
///////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////
@ -43,6 +42,7 @@ impl Network {
&self, &self,
tls_acceptor: &TlsAcceptor, tls_acceptor: &TlsAcceptor,
stream: AsyncPeekStream, stream: AsyncPeekStream,
tcp_stream: TcpStream,
addr: SocketAddr, addr: SocketAddr,
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>], protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
tls_connection_initial_timeout: u64, tls_connection_initial_timeout: u64,
@ -67,18 +67,20 @@ impl Network {
.map_err(map_to_string) .map_err(map_to_string)
.map_err(logthru_net!())?; .map_err(logthru_net!())?;
self.try_handlers(ps, addr, protocol_handlers).await self.try_handlers(ps, tcp_stream, addr, protocol_handlers)
.await
} }
async fn try_handlers( async fn try_handlers(
&self, &self,
stream: AsyncPeekStream, stream: AsyncPeekStream,
tcp_stream: TcpStream,
addr: SocketAddr, addr: SocketAddr,
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>], protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
) -> Result<Option<NetworkConnection>, String> { ) -> Result<Option<NetworkConnection>, String> {
for ah in protocol_handlers.iter() { for ah in protocol_handlers.iter() {
if let Some(nc) = ah if let Some(nc) = ah
.on_accept(stream.clone(), addr) .on_accept(stream.clone(), tcp_stream.clone(), addr)
.await .await
.map_err(logthru_net!())? .map_err(logthru_net!())?
{ {
@ -148,7 +150,7 @@ impl Network {
log_net!("TCP connection from: {}", addr); log_net!("TCP connection from: {}", addr);
// Create a stream we can peek on // Create a stream we can peek on
let ps = AsyncPeekStream::new(tcp_stream); let ps = AsyncPeekStream::new(tcp_stream.clone());
///////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////
let mut first_packet = [0u8; PEEK_DETECT_LEN]; let mut first_packet = [0u8; PEEK_DETECT_LEN];
@ -176,13 +178,15 @@ impl Network {
this.try_tls_handlers( this.try_tls_handlers(
ls.tls_acceptor.as_ref().unwrap(), ls.tls_acceptor.as_ref().unwrap(),
ps, ps,
tcp_stream,
addr, addr,
&ls.tls_protocol_handlers, &ls.tls_protocol_handlers,
tls_connection_initial_timeout, tls_connection_initial_timeout,
) )
.await .await
} else { } else {
this.try_handlers(ps, addr, &ls.protocol_handlers).await this.try_handlers(ps, tcp_stream, addr, &ls.protocol_handlers)
.await
}; };
let conn = match conn { let conn = match conn {

View File

@ -1,6 +1,5 @@
use super::sockets::*;
use super::*; use super::*;
use futures_util::stream; use sockets::*;
impl Network { impl Network {
pub(super) async fn create_udp_listener_tasks(&self) -> Result<(), String> { pub(super) async fn create_udp_listener_tasks(&self) -> Result<(), String> {
@ -42,7 +41,7 @@ impl Network {
} }
// Spawn a local async task for each socket // Spawn a local async task for each socket
let mut protocol_handlers_unordered = stream::FuturesUnordered::new(); let mut protocol_handlers_unordered = FuturesUnordered::new();
let network_manager = this.network_manager(); let network_manager = this.network_manager();
for ph in protocol_handlers { for ph in protocol_handlers {

View File

@ -4,9 +4,8 @@ pub mod udp;
pub mod wrtc; pub mod wrtc;
pub mod ws; pub mod ws;
use crate::network_connection::*; use super::*;
use crate::xx::*; use crate::xx::*;
use crate::*;
#[derive(Debug)] #[derive(Debug)]
pub enum ProtocolNetworkConnection { pub enum ProtocolNetworkConnection {

View File

@ -1,6 +1,8 @@
use crate::xx::*; use crate::xx::*;
use crate::*; use crate::*;
use socket2::{Domain, Protocol, Socket, Type}; use async_io::Async;
use async_std::net::TcpStream;
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
cfg_if! { cfg_if! {
if #[cfg(windows)] { if #[cfg(windows)] {
@ -44,7 +46,7 @@ pub fn new_unbound_shared_udp_socket(domain: Domain) -> Result<Socket, String> {
pub fn new_bound_shared_udp_socket(local_address: SocketAddr) -> Result<Socket, String> { pub fn new_bound_shared_udp_socket(local_address: SocketAddr) -> Result<Socket, String> {
let domain = Domain::for_address(local_address); let domain = Domain::for_address(local_address);
let socket = new_unbound_shared_udp_socket(domain)?; let socket = new_unbound_shared_udp_socket(domain)?;
let socket2_addr = socket2::SockAddr::from(local_address); let socket2_addr = SockAddr::from(local_address);
socket.bind(&socket2_addr).map_err(|e| { socket.bind(&socket2_addr).map_err(|e| {
format!( format!(
"failed to bind UDP socket to '{}' in domain '{:?}': {} ", "failed to bind UDP socket to '{}' in domain '{:?}': {} ",
@ -68,7 +70,7 @@ pub fn new_bound_first_udp_socket(local_address: SocketAddr) -> Result<Socket, S
} }
// Bind the socket -first- before turning on 'reuse address' this way it will // Bind the socket -first- before turning on 'reuse address' this way it will
// fail if the port is already taken // fail if the port is already taken
let socket2_addr = socket2::SockAddr::from(local_address); let socket2_addr = SockAddr::from(local_address);
// On windows, do SO_EXCLUSIVEADDRUSE before the bind to ensure the port is fully available // On windows, do SO_EXCLUSIVEADDRUSE before the bind to ensure the port is fully available
cfg_if! { cfg_if! {
@ -128,7 +130,7 @@ pub fn new_bound_shared_tcp_socket(local_address: SocketAddr) -> Result<Socket,
let socket = new_unbound_shared_tcp_socket(domain)?; let socket = new_unbound_shared_tcp_socket(domain)?;
let socket2_addr = socket2::SockAddr::from(local_address); let socket2_addr = SockAddr::from(local_address);
socket socket
.bind(&socket2_addr) .bind(&socket2_addr)
.map_err(|e| format!("failed to bind TCP socket: {}", e))?; .map_err(|e| format!("failed to bind TCP socket: {}", e))?;
@ -165,7 +167,7 @@ pub fn new_bound_first_tcp_socket(local_address: SocketAddr) -> Result<Socket, S
// Bind the socket -first- before turning on 'reuse address' this way it will // Bind the socket -first- before turning on 'reuse address' this way it will
// fail if the port is already taken // fail if the port is already taken
let socket2_addr = socket2::SockAddr::from(local_address); let socket2_addr = SockAddr::from(local_address);
socket socket
.bind(&socket2_addr) .bind(&socket2_addr)
.map_err(|e| format!("failed to bind TCP socket: {}", e))?; .map_err(|e| format!("failed to bind TCP socket: {}", e))?;
@ -184,3 +186,37 @@ pub fn new_bound_first_tcp_socket(local_address: SocketAddr) -> Result<Socket, S
Ok(socket) Ok(socket)
} }
// Non-blocking connect is tricky when you want to start with a prepared socket
pub async fn nonblocking_connect(socket: Socket, addr: SocketAddr) -> std::io::Result<TcpStream> {
// Set for non blocking connect
socket.set_nonblocking(true)?;
// Make socket2 SockAddr
let socket2_addr = socket2::SockAddr::from(addr);
// Connect to the remote address
match socket.connect(&socket2_addr) {
Ok(()) => Ok(()),
#[cfg(unix)]
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(()),
Err(e) => Err(e),
}?;
let async_stream = Async::new(std::net::TcpStream::from(socket))?;
// The stream becomes writable when connected
intf::timeout(2000, async_stream.writable())
.await
.map_err(|e| std::io::Error::new(std::io::ErrorKind::TimedOut, e))??;
// Check low level error
let async_stream = match async_stream.get_ref().take_error()? {
None => Ok(async_stream),
Some(err) => Err(err),
}?;
// Convert back to inner and then return async version
Ok(TcpStream::from(async_stream.into_inner()?))
}

View File

@ -1,14 +1,10 @@
use super::sockets::*;
use super::*; use super::*;
use crate::intf::*; use futures_util::{AsyncReadExt, AsyncWriteExt};
use crate::network_manager::MAX_MESSAGE_SIZE; use sockets::*;
use crate::*;
use async_std::net::TcpStream;
use core::fmt;
use futures_util::io::{AsyncReadExt, AsyncWriteExt};
pub struct RawTcpNetworkConnection { pub struct RawTcpNetworkConnection {
stream: AsyncPeekStream, stream: AsyncPeekStream,
tcp_stream: TcpStream,
} }
impl fmt::Debug for RawTcpNetworkConnection { impl fmt::Debug for RawTcpNetworkConnection {
@ -18,16 +14,22 @@ impl fmt::Debug for RawTcpNetworkConnection {
} }
impl RawTcpNetworkConnection { impl RawTcpNetworkConnection {
pub fn new(stream: AsyncPeekStream) -> Self { pub fn new(stream: AsyncPeekStream, tcp_stream: TcpStream) -> Self {
Self { stream } Self { stream, tcp_stream }
} }
pub async fn close(&self) -> Result<(), String> { pub async fn close(&self) -> Result<(), String> {
// Make an attempt to flush the stream
self.stream self.stream
.clone() .clone()
.close() .close()
.await .await
.map_err(map_to_string) .map_err(map_to_string)
.map_err(logthru_net!())?;
// Then forcibly close the socket
self.tcp_stream
.shutdown(Shutdown::Both)
.map_err(map_to_string)
.map_err(logthru_net!()) .map_err(logthru_net!())
} }
@ -40,7 +42,6 @@ impl RawTcpNetworkConnection {
let header = [b'V', b'L', len as u8, (len >> 8) as u8]; let header = [b'V', b'L', len as u8, (len >> 8) as u8];
let mut stream = self.stream.clone(); let mut stream = self.stream.clone();
stream stream
.write_all(&header) .write_all(&header)
.await .await
@ -105,6 +106,7 @@ impl RawTcpProtocolHandler {
async fn on_accept_async( async fn on_accept_async(
self, self,
stream: AsyncPeekStream, stream: AsyncPeekStream,
tcp_stream: TcpStream,
socket_addr: SocketAddr, socket_addr: SocketAddr,
) -> Result<Option<NetworkConnection>, String> { ) -> Result<Option<NetworkConnection>, String> {
log_net!("TCP: on_accept_async: enter"); log_net!("TCP: on_accept_async: enter");
@ -123,7 +125,7 @@ impl RawTcpProtocolHandler {
let local_address = self.inner.lock().local_address; let local_address = self.inner.lock().local_address;
let conn = NetworkConnection::from_protocol( let conn = NetworkConnection::from_protocol(
ConnectionDescriptor::new(peer_addr, SocketAddress::from_socket_addr(local_address)), ConnectionDescriptor::new(peer_addr, SocketAddress::from_socket_addr(local_address)),
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(stream)), ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(stream, tcp_stream)),
); );
log_net!(debug "TCP: on_accept_async from: {}", socket_addr); log_net!(debug "TCP: on_accept_async from: {}", socket_addr);
@ -146,22 +148,17 @@ impl RawTcpProtocolHandler {
} }
}; };
// Connect to the remote address // Non-blocking connect to remote address
let remote_socket2_addr = socket2::SockAddr::from(remote_socket_addr); let ts = nonblocking_connect(socket, remote_socket_addr).await
socket
.connect(&remote_socket2_addr)
.map_err(map_to_string) .map_err(map_to_string)
.map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?; .map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?;
let std_stream: std::net::TcpStream = socket.into();
let ts = TcpStream::from(std_stream);
// See what local address we ended up with and turn this into a stream // See what local address we ended up with and turn this into a stream
let actual_local_address = ts let actual_local_address = ts
.local_addr() .local_addr()
.map_err(map_to_string) .map_err(map_to_string)
.map_err(logthru_net!("could not get local address from TCP stream"))?; .map_err(logthru_net!("could not get local address from TCP stream"))?;
let ps = AsyncPeekStream::new(ts); let ps = AsyncPeekStream::new(ts.clone());
// Wrap the stream in a network connection and return it // Wrap the stream in a network connection and return it
let conn = NetworkConnection::from_protocol( let conn = NetworkConnection::from_protocol(
@ -169,7 +166,7 @@ impl RawTcpProtocolHandler {
local: Some(SocketAddress::from_socket_addr(actual_local_address)), local: Some(SocketAddress::from_socket_addr(actual_local_address)),
remote: dial_info.to_peer_address(), remote: dial_info.to_peer_address(),
}, },
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps)), ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps, ts)),
); );
Ok(conn) Ok(conn)
} }
@ -187,10 +184,34 @@ impl RawTcpProtocolHandler {
socket_addr socket_addr
); );
let mut stream = TcpStream::connect(socket_addr) // Make a shared socket
let socket = new_unbound_shared_tcp_socket(socket2::Domain::for_address(socket_addr))?;
// Non-blocking connect to remote address
let ts = nonblocking_connect(socket, socket_addr)
.await .await
.map_err(|e| format!("failed to connect TCP for unbound message: {}", e))?; .map_err(map_to_string)
stream.write_all(&data).await.map_err(|e| format!("{}", e)) .map_err(logthru_net!(error "remote_addr={}", socket_addr))?;
// See what local address we ended up with and turn this into a stream
let actual_local_address = ts
.local_addr()
.map_err(map_to_string)
.map_err(logthru_net!("could not get local address from TCP stream"))?;
let ps = AsyncPeekStream::new(ts.clone());
// Wrap the stream in a network connection and return it
let conn = NetworkConnection::from_protocol(
ConnectionDescriptor {
local: Some(SocketAddress::from_socket_addr(actual_local_address)),
remote: PeerAddress::new(
SocketAddress::from_socket_addr(socket_addr),
ProtocolType::TCP,
),
},
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps, ts)),
);
conn.send(data).await
} }
} }
@ -198,8 +219,9 @@ impl ProtocolAcceptHandler for RawTcpProtocolHandler {
fn on_accept( fn on_accept(
&self, &self,
stream: AsyncPeekStream, stream: AsyncPeekStream,
tcp_stream: TcpStream,
peer_addr: SocketAddr, peer_addr: SocketAddr,
) -> SystemPinBoxFuture<core::result::Result<Option<NetworkConnection>, String>> { ) -> SystemPinBoxFuture<core::result::Result<Option<NetworkConnection>, String>> {
Box::pin(self.clone().on_accept_async(stream, peer_addr)) Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr))
} }
} }

View File

@ -1,7 +1,4 @@
use crate::intf::*; use super::*;
use crate::network_manager::MAX_MESSAGE_SIZE;
use crate::*;
use async_std::net::*;
#[derive(Clone)] #[derive(Clone)]
pub struct RawUdpProtocolHandler { pub struct RawUdpProtocolHandler {

View File

@ -1,29 +1,22 @@
use super::sockets::*;
use super::*; use super::*;
use crate::intf::*;
use crate::network_manager::MAX_MESSAGE_SIZE;
use crate::*;
use alloc::sync::Arc;
use async_std::io; use async_std::io;
use async_std::net::*;
use async_tls::TlsConnector; use async_tls::TlsConnector;
use async_tungstenite::tungstenite::protocol::Message; use async_tungstenite::tungstenite::protocol::Message;
use async_tungstenite::{accept_async, client_async, WebSocketStream}; use async_tungstenite::{accept_async, client_async, WebSocketStream};
use core::fmt; use futures_util::SinkExt;
use core::time::Duration; use sockets::*;
use futures_util::sink::SinkExt;
use futures_util::stream::StreamExt;
pub type WebSocketNetworkConnectionAccepted = WebsocketNetworkConnection<AsyncPeekStream>; pub type WebSocketNetworkConnectionAccepted = WebsocketNetworkConnection<AsyncPeekStream>;
pub type WebsocketNetworkConnectionWSS = pub type WebsocketNetworkConnectionWSS =
WebsocketNetworkConnection<async_tls::client::TlsStream<async_std::net::TcpStream>>; WebsocketNetworkConnection<async_tls::client::TlsStream<TcpStream>>;
pub type WebsocketNetworkConnectionWS = WebsocketNetworkConnection<async_std::net::TcpStream>; pub type WebsocketNetworkConnectionWS = WebsocketNetworkConnection<TcpStream>;
pub struct WebsocketNetworkConnection<T> pub struct WebsocketNetworkConnection<T>
where where
T: io::Read + io::Write + Send + Unpin + 'static, T: io::Read + io::Write + Send + Unpin + 'static,
{ {
ws_stream: CloneStream<WebSocketStream<T>>, stream: CloneStream<WebSocketStream<T>>,
tcp_stream: TcpStream,
} }
impl<T> fmt::Debug for WebsocketNetworkConnection<T> impl<T> fmt::Debug for WebsocketNetworkConnection<T>
@ -39,21 +32,33 @@ impl<T> WebsocketNetworkConnection<T>
where where
T: io::Read + io::Write + Send + Unpin + 'static, T: io::Read + io::Write + Send + Unpin + 'static,
{ {
pub fn new(ws_stream: WebSocketStream<T>) -> Self { pub fn new(stream: WebSocketStream<T>, tcp_stream: TcpStream) -> Self {
Self { Self {
ws_stream: CloneStream::new(ws_stream), stream: CloneStream::new(stream),
tcp_stream,
} }
} }
pub async fn close(&self) -> Result<(), String> { pub async fn close(&self) -> Result<(), String> {
self.ws_stream.clone().close().await.map_err(map_to_string) // Make an attempt to flush the stream
self.stream
.clone()
.close()
.await
.map_err(map_to_string)
.map_err(logthru_net!())?;
// Then forcibly close the socket
self.tcp_stream
.shutdown(Shutdown::Both)
.map_err(map_to_string)
.map_err(logthru_net!())
} }
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> { pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
if message.len() > MAX_MESSAGE_SIZE { if message.len() > MAX_MESSAGE_SIZE {
return Err("received too large WS message".to_owned()); return Err("received too large WS message".to_owned());
} }
self.ws_stream self.stream
.clone() .clone()
.send(Message::binary(message)) .send(Message::binary(message))
.await .await
@ -62,7 +67,7 @@ where
} }
pub async fn recv(&self) -> Result<Vec<u8>, String> { pub async fn recv(&self) -> Result<Vec<u8>, String> {
let out = match self.ws_stream.clone().next().await { let out = match self.stream.clone().next().await {
Some(Ok(Message::Binary(v))) => v, Some(Ok(Message::Binary(v))) => v,
Some(Ok(_)) => { Some(Ok(_)) => {
return Err("Unexpected WS message type".to_owned()).map_err(logthru_net!(error)); return Err("Unexpected WS message type".to_owned()).map_err(logthru_net!(error));
@ -125,6 +130,7 @@ impl WebsocketProtocolHandler {
pub async fn on_accept_async( pub async fn on_accept_async(
self, self,
ps: AsyncPeekStream, ps: AsyncPeekStream,
tcp_stream: TcpStream,
socket_addr: SocketAddr, socket_addr: SocketAddr,
) -> Result<Option<NetworkConnection>, String> { ) -> Result<Option<NetworkConnection>, String> {
log_net!("WS: on_accept_async: enter"); log_net!("WS: on_accept_async: enter");
@ -178,7 +184,9 @@ impl WebsocketProtocolHandler {
peer_addr, peer_addr,
SocketAddress::from_socket_addr(self.arc.local_address), SocketAddress::from_socket_addr(self.arc.local_address),
), ),
ProtocolNetworkConnection::WsAccepted(WebsocketNetworkConnection::new(ws_stream)), ProtocolNetworkConnection::WsAccepted(WebsocketNetworkConnection::new(
ws_stream, tcp_stream,
)),
); );
log_net!(debug "{}: on_accept_async from: {}", if self.arc.tls { "WSS" } else { "WS" }, socket_addr); log_net!(debug "{}: on_accept_async from: {}", if self.arc.tls { "WSS" } else { "WS" }, socket_addr);
@ -214,14 +222,10 @@ impl WebsocketProtocolHandler {
} }
}; };
// Connect to the remote address // Non-blocking connect to remote address
let remote_socket2_addr = socket2::SockAddr::from(remote_socket_addr); let tcp_stream = nonblocking_connect(socket, remote_socket_addr).await
socket
.connect(&remote_socket2_addr)
.map_err(map_to_string) .map_err(map_to_string)
.map_err(logthru_net!(error "local_address={:?} remote_socket_addr={}", local_address, remote_socket_addr))?; .map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?;
let std_stream: std::net::TcpStream = socket.into();
let tcp_stream = TcpStream::from(std_stream);
// See what local address we ended up with // See what local address we ended up with
let actual_local_addr = tcp_stream let actual_local_addr = tcp_stream
@ -238,7 +242,7 @@ impl WebsocketProtocolHandler {
if tls { if tls {
let connector = TlsConnector::default(); let connector = TlsConnector::default();
let tls_stream = connector let tls_stream = connector
.connect(domain.to_string(), tcp_stream) .connect(domain.to_string(), tcp_stream.clone())
.await .await
.map_err(map_to_string) .map_err(map_to_string)
.map_err(logthru_net!(error))?; .map_err(logthru_net!(error))?;
@ -249,16 +253,20 @@ impl WebsocketProtocolHandler {
Ok(NetworkConnection::from_protocol( Ok(NetworkConnection::from_protocol(
descriptor, descriptor,
ProtocolNetworkConnection::Wss(WebsocketNetworkConnection::new(ws_stream)), ProtocolNetworkConnection::Wss(WebsocketNetworkConnection::new(
ws_stream, tcp_stream,
)),
)) ))
} else { } else {
let (ws_stream, _response) = client_async(request, tcp_stream) let (ws_stream, _response) = client_async(request, tcp_stream.clone())
.await .await
.map_err(map_to_string) .map_err(map_to_string)
.map_err(logthru_net!(error))?; .map_err(logthru_net!(error))?;
Ok(NetworkConnection::from_protocol( Ok(NetworkConnection::from_protocol(
descriptor, descriptor,
ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new(ws_stream)), ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new(
ws_stream, tcp_stream,
)),
)) ))
} }
} }
@ -285,8 +293,9 @@ impl ProtocolAcceptHandler for WebsocketProtocolHandler {
fn on_accept( fn on_accept(
&self, &self,
stream: AsyncPeekStream, stream: AsyncPeekStream,
tcp_stream: TcpStream,
peer_addr: SocketAddr, peer_addr: SocketAddr,
) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>> { ) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>> {
Box::pin(self.clone().on_accept_async(stream, peer_addr)) Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr))
} }
} }

View File

@ -1,6 +1,11 @@
use crate::intf::*; use super::*;
use crate::xx::*; use crate::xx::*;
use crate::*;
cfg_if::cfg_if! {
if #[cfg(target_arch = "wasm32")] {
// No accept support for WASM
} else {
use async_std::net::*;
/////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////
// Accept // Accept
@ -9,6 +14,7 @@ pub trait ProtocolAcceptHandler: ProtocolAcceptHandlerClone + Send + Sync {
fn on_accept( fn on_accept(
&self, &self,
stream: AsyncPeekStream, stream: AsyncPeekStream,
tcp_stream: TcpStream,
peer_addr: SocketAddr, peer_addr: SocketAddr,
) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>>; ) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>>;
} }
@ -33,7 +39,8 @@ impl Clone for Box<dyn ProtocolAcceptHandler> {
pub type NewProtocolAcceptHandler = pub type NewProtocolAcceptHandler =
dyn Fn(VeilidConfig, bool, SocketAddr) -> Box<dyn ProtocolAcceptHandler> + Send; dyn Fn(VeilidConfig, bool, SocketAddr) -> Box<dyn ProtocolAcceptHandler> + Send;
}
}
/////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////
// Dummy protocol network connection for testing // Dummy protocol network connection for testing

View File

@ -0,0 +1,2 @@
pub mod test_connection_table;
use super::*;

View File

@ -1,6 +1,6 @@
use super::test_veilid_config::*; use super::connection_table::*;
use crate::connection_table::*; use super::network_connection::*;
use crate::network_connection::*; use crate::tests::common::test_veilid_config::*;
use crate::xx::*; use crate::xx::*;
use crate::*; use crate::*;

View File

@ -1,6 +1,6 @@
use crate::*; use crate::*;
use core::fmt; use core::fmt;
use dht::receipt::*; use dht::*;
use futures_util::stream::{FuturesUnordered, StreamExt}; use futures_util::stream::{FuturesUnordered, StreamExt};
use network_manager::*; use network_manager::*;
use routing_table::*; use routing_table::*;

View File

@ -1040,7 +1040,10 @@ impl RoutingTable {
// Run all bootstrap operations concurrently // Run all bootstrap operations concurrently
let mut unord = FuturesUnordered::new(); let mut unord = FuturesUnordered::new();
for (k, v) in bsmap { for (k, mut v) in bsmap {
// Sort dial info so we get the preferred order correct
v.dial_info_details.sort();
log_rtab!("--- bootstrapping {} with {:?}", k.encode(), &v); log_rtab!("--- bootstrapping {} with {:?}", k.encode(), &v);
// Make invalid signed node info (no signature) // Make invalid signed node info (no signature)

View File

@ -35,3 +35,5 @@ pub use signal_info::*;
pub use signature::*; pub use signature::*;
pub use signed_node_info::*; pub use signed_node_info::*;
pub use socket_address::*; pub use socket_address::*;
use super::*;

View File

@ -1,7 +1,4 @@
use crate::xx::*; use super::*;
use crate::*;
use core::convert::TryInto;
use rpc_processor::*;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -3,7 +3,7 @@ use crate::*;
use core::convert::TryInto; use core::convert::TryInto;
use rpc_processor::*; use rpc_processor::*;
pub fn decode_public_key(public_key: &veilid_capnp::curve25519_public_key::Reader) -> key::DHTKey { pub fn decode_public_key(public_key: &veilid_capnp::curve25519_public_key::Reader) -> DHTKey {
let u0 = public_key.get_u0().to_be_bytes(); let u0 = public_key.get_u0().to_be_bytes();
let u1 = public_key.get_u1().to_be_bytes(); let u1 = public_key.get_u1().to_be_bytes();
let u2 = public_key.get_u2().to_be_bytes(); let u2 = public_key.get_u2().to_be_bytes();
@ -15,11 +15,11 @@ pub fn decode_public_key(public_key: &veilid_capnp::curve25519_public_key::Reade
x[16..24].copy_from_slice(&u2); x[16..24].copy_from_slice(&u2);
x[24..32].copy_from_slice(&u3); x[24..32].copy_from_slice(&u3);
key::DHTKey::new(x) DHTKey::new(x)
} }
pub fn encode_public_key( pub fn encode_public_key(
key: &key::DHTKey, key: &DHTKey,
builder: &mut veilid_capnp::curve25519_public_key::Builder, builder: &mut veilid_capnp::curve25519_public_key::Builder,
) -> Result<(), RPCError> { ) -> Result<(), RPCError> {
if !key.valid { if !key.valid {

View File

@ -8,14 +8,12 @@ pub use private_route::*;
use crate::dht::*; use crate::dht::*;
use crate::intf::*; use crate::intf::*;
use crate::xx::*; use crate::xx::*;
use crate::*;
use capnp::message::ReaderSegments; use capnp::message::ReaderSegments;
use coders::*; use coders::*;
use core::convert::{TryFrom, TryInto};
use core::fmt;
use network_manager::*; use network_manager::*;
use receipt_manager::*; use receipt_manager::*;
use routing_table::*; use routing_table::*;
use super::*;
///////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////
@ -79,7 +77,7 @@ impl RespondTo {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct RPCMessageHeader { struct RPCMessageHeader {
timestamp: u64, // time the message was received, not sent timestamp: u64, // time the message was received, not sent
envelope: envelope::Envelope, envelope: Envelope,
body_len: u64, body_len: u64,
peer_noderef: NodeRef, // ensures node doesn't get evicted from routing table until we're done with it peer_noderef: NodeRef, // ensures node doesn't get evicted from routing table until we're done with it
} }
@ -163,8 +161,8 @@ pub struct FindNodeAnswer {
pub struct RPCProcessorInner { pub struct RPCProcessorInner {
network_manager: NetworkManager, network_manager: NetworkManager,
routing_table: RoutingTable, routing_table: RoutingTable,
node_id: key::DHTKey, node_id: DHTKey,
node_id_secret: key::DHTKeySecret, node_id_secret: DHTKeySecret,
send_channel: Option<flume::Sender<RPCMessage>>, send_channel: Option<flume::Sender<RPCMessage>>,
timeout: u64, timeout: u64,
max_route_hop_count: usize, max_route_hop_count: usize,
@ -185,8 +183,8 @@ impl RPCProcessor {
RPCProcessorInner { RPCProcessorInner {
network_manager: network_manager.clone(), network_manager: network_manager.clone(),
routing_table: network_manager.routing_table(), routing_table: network_manager.routing_table(),
node_id: key::DHTKey::default(), node_id: DHTKey::default(),
node_id_secret: key::DHTKeySecret::default(), node_id_secret: DHTKeySecret::default(),
send_channel: None, send_channel: None,
timeout: 10000000, timeout: 10000000,
max_route_hop_count: 7, max_route_hop_count: 7,
@ -215,11 +213,11 @@ impl RPCProcessor {
self.inner.lock().routing_table.clone() self.inner.lock().routing_table.clone()
} }
pub fn node_id(&self) -> key::DHTKey { pub fn node_id(&self) -> DHTKey {
self.inner.lock().node_id self.inner.lock().node_id
} }
pub fn node_id_secret(&self) -> key::DHTKeySecret { pub fn node_id_secret(&self) -> DHTKeySecret {
self.inner.lock().node_id_secret self.inner.lock().node_id_secret
} }
@ -258,7 +256,7 @@ impl RPCProcessor {
// Search the DHT for a single node closest to a key and add it to the routing table and return the node reference // Search the DHT for a single node closest to a key and add it to the routing table and return the node reference
pub async fn search_dht_single_key( pub async fn search_dht_single_key(
&self, &self,
node_id: key::DHTKey, node_id: DHTKey,
_count: u32, _count: u32,
_fanout: u32, _fanout: u32,
_timeout: Option<u64>, _timeout: Option<u64>,
@ -273,7 +271,7 @@ impl RPCProcessor {
// Search the DHT for the 'count' closest nodes to a key, adding them all to the routing table if they are not there and returning their node references // Search the DHT for the 'count' closest nodes to a key, adding them all to the routing table if they are not there and returning their node references
pub async fn search_dht_multi_key( pub async fn search_dht_multi_key(
&self, &self,
_node_id: key::DHTKey, _node_id: DHTKey,
_count: u32, _count: u32,
_fanout: u32, _fanout: u32,
_timeout: Option<u64>, _timeout: Option<u64>,
@ -286,7 +284,7 @@ impl RPCProcessor {
// Note: This routine can possible be recursive, hence the SystemPinBoxFuture async form // Note: This routine can possible be recursive, hence the SystemPinBoxFuture async form
pub fn resolve_node( pub fn resolve_node(
&self, &self,
node_id: key::DHTKey, node_id: DHTKey,
) -> SystemPinBoxFuture<Result<NodeRef, RPCError>> { ) -> SystemPinBoxFuture<Result<NodeRef, RPCError>> {
let this = self.clone(); let this = self.clone();
Box::pin(async move { Box::pin(async move {
@ -557,7 +555,6 @@ impl RPCProcessor {
.network_manager() .network_manager()
.send_envelope(node_ref.clone(), Some(out_node_id), out) .send_envelope(node_ref.clone(), Some(out_node_id), out)
.await .await
.map_err(logthru_rpc!(error))
.map_err(RPCError::Internal) .map_err(RPCError::Internal)
{ {
Ok(v) => v, Ok(v) => v,
@ -1406,7 +1403,7 @@ impl RPCProcessor {
pub fn enqueue_message( pub fn enqueue_message(
&self, &self,
envelope: envelope::Envelope, envelope: Envelope,
body: Vec<u8>, body: Vec<u8>,
peer_noderef: NodeRef, peer_noderef: NodeRef,
) -> Result<(), String> { ) -> Result<(), String> {
@ -1626,7 +1623,7 @@ impl RPCProcessor {
pub async fn rpc_call_find_node( pub async fn rpc_call_find_node(
self, self,
dest: Destination, dest: Destination,
key: key::DHTKey, key: DHTKey,
safety_route: Option<&SafetyRouteSpec>, safety_route: Option<&SafetyRouteSpec>,
respond_to: RespondTo, respond_to: RespondTo,
) -> Result<FindNodeAnswer, RPCError> { ) -> Result<FindNodeAnswer, RPCError> {

View File

@ -4,7 +4,7 @@ impl RPCProcessor {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
pub(super) fn new_stub_private_route<'a, T>( pub(super) fn new_stub_private_route<'a, T>(
&self, &self,
dest_node_id: key::DHTKey, dest_node_id: DHTKey,
builder: &'a mut ::capnp::message::Builder<T>, builder: &'a mut ::capnp::message::Builder<T>,
) -> Result<veilid_capnp::private_route::Reader<'a>, RPCError> ) -> Result<veilid_capnp::private_route::Reader<'a>, RPCError>
where where

View File

@ -1,7 +1,3 @@
pub mod test_connection_table;
pub mod test_crypto;
pub mod test_dht_key;
pub mod test_envelope_receipt;
pub mod test_host_interface; pub mod test_host_interface;
pub mod test_protected_store; pub mod test_protected_store;
pub mod test_table_store; pub mod test_table_store;

View File

@ -1,5 +1,4 @@
use super::test_veilid_config::*; use super::test_veilid_config::*;
use crate::dht::key;
use crate::intf::*; use crate::intf::*;
use crate::xx::*; use crate::xx::*;
use crate::*; use crate::*;
@ -131,13 +130,13 @@ pub async fn test_cbor(ts: TableStore) {
let _ = ts.delete("test"); let _ = ts.delete("test");
let db = ts.open("test", 3).await.expect("should have opened"); let db = ts.open("test", 3).await.expect("should have opened");
let (dht_key, _) = key::generate_secret(); let (dht_key, _) = generate_secret();
assert!(db.store_cbor(0, b"asdf", &dht_key).await.is_ok()); assert!(db.store_cbor(0, b"asdf", &dht_key).await.is_ok());
assert_eq!(db.load_cbor::<key::DHTKey>(0, b"qwer").await, Ok(None)); assert_eq!(db.load_cbor::<DHTKey>(0, b"qwer").await, Ok(None));
let d = match db.load_cbor::<key::DHTKey>(0, b"asdf").await { let d = match db.load_cbor::<DHTKey>(0, b"asdf").await {
Ok(x) => x, Ok(x) => x,
Err(e) => { Err(e) => {
panic!("couldn't decode cbor: {}", e); panic!("couldn't decode cbor: {}", e);
@ -151,7 +150,7 @@ pub async fn test_cbor(ts: TableStore) {
); );
assert!( assert!(
db.load_cbor::<key::DHTKey>(1, b"foo").await.is_err(), db.load_cbor::<DHTKey>(1, b"foo").await.is_err(),
"should fail to load cbor" "should fail to load cbor"
); );
} }

View File

@ -194,8 +194,8 @@ fn config_callback(key: String) -> ConfigCallbackReturn {
"network.client_whitelist_timeout_ms" => Ok(Box::new(300_000u32)), "network.client_whitelist_timeout_ms" => Ok(Box::new(300_000u32)),
"network.reverse_connection_receipt_time_ms" => Ok(Box::new(5_000u32)), "network.reverse_connection_receipt_time_ms" => Ok(Box::new(5_000u32)),
"network.hole_punch_receipt_time_ms" => Ok(Box::new(5_000u32)), "network.hole_punch_receipt_time_ms" => Ok(Box::new(5_000u32)),
"network.node_id" => Ok(Box::new(dht::key::DHTKey::default())), "network.node_id" => Ok(Box::new(DHTKey::default())),
"network.node_id_secret" => Ok(Box::new(dht::key::DHTKeySecret::default())), "network.node_id_secret" => Ok(Box::new(DHTKeySecret::default())),
"network.bootstrap" => Ok(Box::new(Vec::<String>::new())), "network.bootstrap" => Ok(Box::new(Vec::<String>::new())),
"network.bootstrap_nodes" => Ok(Box::new(Vec::<String>::new())), "network.bootstrap_nodes" => Ok(Box::new(Vec::<String>::new())),
"network.routing_table.limit_over_attached" => Ok(Box::new(64u32)), "network.routing_table.limit_over_attached" => Ok(Box::new(64u32)),

View File

@ -3,6 +3,8 @@
mod test_async_peek_stream; mod test_async_peek_stream;
use crate::dht::tests::*;
use crate::network_manager::tests::*;
use crate::tests::common::*; use crate::tests::common::*;
use crate::xx::*; use crate::xx::*;

View File

@ -1,13 +1,12 @@
use crate::xx::*; use super::*;
use core::time::Duration; use async_std::net::{TcpListener, TcpStream};
use async_std::prelude::FutureExt;
use async_std::task;
use futures_util::{AsyncReadExt, AsyncWriteExt};
use std::io;
static MESSAGE: &[u8; 62] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; static MESSAGE: &[u8; 62] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
use async_std::io;
use async_std::net::{TcpListener, TcpStream};
use async_std::prelude::*;
use async_std::task;
async fn make_tcp_loopback() -> Result<(TcpStream, TcpStream), io::Error> { async fn make_tcp_loopback() -> Result<(TcpStream, TcpStream), io::Error> {
let listener = TcpListener::bind("127.0.0.1:0").await?; let listener = TcpListener::bind("127.0.0.1:0").await?;
let local_addr = listener.local_addr()?; let local_addr = listener.local_addr()?;

View File

@ -12,9 +12,8 @@ pub use crate::xx::{
pub use alloc::string::ToString; pub use alloc::string::ToString;
pub use attachment_manager::AttachmentManager; pub use attachment_manager::AttachmentManager;
pub use core::str::FromStr; pub use core::str::FromStr;
pub use dht::crypto::Crypto; pub use dht::Crypto;
pub use dht::key::{generate_secret, sign, verify, DHTKey, DHTKeySecret, DHTSignature}; pub use dht::{generate_secret, sign, verify, DHTKey, DHTKeySecret, DHTSignature};
pub use dht::receipt::ReceiptNonce;
pub use intf::BlockStore; pub use intf::BlockStore;
pub use intf::ProtectedStore; pub use intf::ProtectedStore;
pub use intf::TableStore; pub use intf::TableStore;

View File

@ -1,9 +1,10 @@
use crate::dht::key; use crate::dht::*;
use crate::intf; use crate::intf;
use crate::xx::*; use crate::xx::*;
use serde::*; use serde::*;
////////////////////////////////////////////////////////////////////////////////////////////////
cfg_if! { cfg_if! {
if #[cfg(target_arch = "wasm32")] { if #[cfg(target_arch = "wasm32")] {
pub type ConfigCallbackReturn = Result<Box<dyn core::any::Any>, String>; pub type ConfigCallbackReturn = Result<Box<dyn core::any::Any>, String>;
@ -136,8 +137,8 @@ pub struct VeilidConfigNetwork {
pub client_whitelist_timeout_ms: u32, pub client_whitelist_timeout_ms: u32,
pub reverse_connection_receipt_time_ms: u32, pub reverse_connection_receipt_time_ms: u32,
pub hole_punch_receipt_time_ms: u32, pub hole_punch_receipt_time_ms: u32,
pub node_id: key::DHTKey, pub node_id: DHTKey,
pub node_id_secret: key::DHTKeySecret, pub node_id_secret: DHTKeySecret,
pub bootstrap: Vec<String>, pub bootstrap: Vec<String>,
pub bootstrap_nodes: Vec<String>, pub bootstrap_nodes: Vec<String>,
pub routing_table: VeilidConfigRoutingTable, pub routing_table: VeilidConfigRoutingTable,
@ -528,7 +529,7 @@ impl VeilidConfig {
debug!("pulling node id from storage"); debug!("pulling node id from storage");
if let Some(s) = protected_store.load_user_secret_string("node_id").await? { if let Some(s) = protected_store.load_user_secret_string("node_id").await? {
debug!("node id found in storage"); debug!("node id found in storage");
node_id = key::DHTKey::try_decode(s.as_str())? node_id = DHTKey::try_decode(s.as_str())?
} else { } else {
debug!("node id not found in storage"); debug!("node id not found in storage");
} }
@ -542,7 +543,7 @@ impl VeilidConfig {
.await? .await?
{ {
debug!("node id secret found in storage"); debug!("node id secret found in storage");
node_id_secret = key::DHTKeySecret::try_decode(s.as_str())? node_id_secret = DHTKeySecret::try_decode(s.as_str())?
} else { } else {
debug!("node id secret not found in storage"); debug!("node id secret not found in storage");
} }
@ -551,7 +552,7 @@ impl VeilidConfig {
// If we have a node id from storage, check it // If we have a node id from storage, check it
if node_id.valid && node_id_secret.valid { if node_id.valid && node_id_secret.valid {
// Validate node id // Validate node id
if !key::validate_key(&node_id, &node_id_secret) { if !validate_key(&node_id, &node_id_secret) {
return Err("node id secret and node id key don't match".to_owned()); return Err("node id secret and node id key don't match".to_owned());
} }
} }
@ -559,7 +560,7 @@ impl VeilidConfig {
// If we still don't have a valid node id, generate one // If we still don't have a valid node id, generate one
if !node_id.valid || !node_id_secret.valid { if !node_id.valid || !node_id_secret.valid {
debug!("generating new node id"); debug!("generating new node id");
let (i, s) = key::generate_secret(); let (i, s) = generate_secret();
node_id = i; node_id = i;
node_id_secret = s; node_id_secret = s;
} }

View File

@ -1,19 +1,16 @@
use crate::xx::*; use super::*;
use core::pin::Pin; use futures_util::AsyncReadExt;
use core::task::{Context, Poll}; use std::io;
use futures_util::io::AsyncRead as Read; use task::{Context, Poll};
use futures_util::io::AsyncReadExt;
use futures_util::io::AsyncWrite as Write;
use std::io::Result;
//////// ////////
/// ///
trait SendStream: Read + Write + Send + Unpin { trait SendStream: AsyncRead + AsyncWrite + Send + Unpin {
fn clone_stream(&self) -> Box<dyn SendStream>; fn clone_stream(&self) -> Box<dyn SendStream>;
} }
impl<S> SendStream for S impl<S> SendStream for S
where where
S: Read + Write + Send + Clone + Unpin + 'static, S: AsyncRead + AsyncWrite + Send + Clone + Unpin + 'static,
{ {
fn clone_stream(&self) -> Box<dyn SendStream> { fn clone_stream(&self) -> Box<dyn SendStream> {
Box::new(self.clone()) Box::new(self.clone())
@ -121,7 +118,7 @@ struct AsyncPeekStreamInner {
#[derive(Clone)] #[derive(Clone)]
pub struct AsyncPeekStream pub struct AsyncPeekStream
where where
Self: Read + Write + Send + Unpin, Self: AsyncRead + AsyncWrite + Send + Unpin,
{ {
inner: Arc<Mutex<AsyncPeekStreamInner>>, inner: Arc<Mutex<AsyncPeekStreamInner>>,
} }
@ -129,7 +126,7 @@ where
impl AsyncPeekStream { impl AsyncPeekStream {
pub fn new<S>(stream: S) -> Self pub fn new<S>(stream: S) -> Self
where where
S: Read + Write + Send + Clone + Unpin + 'static, S: AsyncRead + AsyncWrite + Send + Clone + Unpin + 'static,
{ {
Self { Self {
inner: Arc::new(Mutex::new(AsyncPeekStreamInner { inner: Arc::new(Mutex::new(AsyncPeekStreamInner {
@ -155,16 +152,16 @@ impl AsyncPeekStream {
} }
} }
impl Read for AsyncPeekStream { impl AsyncRead for AsyncPeekStream {
fn poll_read( fn poll_read(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
buf: &mut [u8], buf: &mut [u8],
) -> Poll<Result<usize>> { ) -> Poll<io::Result<usize>> {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
// //
let buflen = buf.len(); let buflen = buf.len();
let bufcopylen = cmp::min(buflen, inner.peekbuf_len); let bufcopylen = core::cmp::min(buflen, inner.peekbuf_len);
let bufreadlen = if buflen > inner.peekbuf_len { let bufreadlen = if buflen > inner.peekbuf_len {
buflen - inner.peekbuf_len buflen - inner.peekbuf_len
} else { } else {
@ -204,16 +201,20 @@ impl Read for AsyncPeekStream {
} }
} }
impl Write for AsyncPeekStream { impl AsyncWrite for AsyncPeekStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> { fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
Pin::new(&mut inner.stream).poll_write(cx, buf) Pin::new(&mut inner.stream).poll_write(cx, buf)
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
Pin::new(&mut inner.stream).poll_flush(cx) Pin::new(&mut inner.stream).poll_flush(cx)
} }
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> { fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
Pin::new(&mut inner.stream).poll_close(cx) Pin::new(&mut inner.stream).poll_close(cx)
} }

View File

@ -15,8 +15,13 @@ mod tick_task;
mod tools; mod tools;
pub use cfg_if::*; pub use cfg_if::*;
pub use futures_util::future::{select, Either};
pub use futures_util::select;
pub use futures_util::stream::FuturesUnordered;
pub use futures_util::{AsyncRead, AsyncWrite};
pub use log::*; pub use log::*;
pub use log_thru::*; pub use log_thru::*;
pub use owo_colors::OwoColorize;
pub use parking_lot::*; pub use parking_lot::*;
pub use split_url::*; pub use split_url::*;
pub use static_assertions::*; pub use static_assertions::*;
@ -41,11 +46,14 @@ cfg_if! {
pub use alloc::borrow::{Cow, ToOwned}; pub use alloc::borrow::{Cow, ToOwned};
pub use wasm_bindgen::prelude::*; pub use wasm_bindgen::prelude::*;
pub use core::cmp; pub use core::cmp;
pub use core::convert::{TryFrom, TryInto};
pub use core::mem; pub use core::mem;
pub use core::fmt;
pub use alloc::rc::Rc; pub use alloc::rc::Rc;
pub use core::cell::RefCell; pub use core::cell::RefCell;
pub use core::task; pub use core::task;
pub use core::future::Future; pub use core::future::Future;
pub use core::time::Duration;
pub use core::pin::Pin; pub use core::pin::Pin;
pub use core::sync::atomic::{Ordering, AtomicBool}; pub use core::sync::atomic::{Ordering, AtomicBool};
pub use alloc::sync::{Arc, Weak}; pub use alloc::sync::{Arc, Weak};
@ -67,15 +75,18 @@ cfg_if! {
pub use std::boxed::Box; pub use std::boxed::Box;
pub use std::borrow::{Cow, ToOwned}; pub use std::borrow::{Cow, ToOwned};
pub use std::cmp; pub use std::cmp;
pub use std::convert::{TryFrom, TryInto};
pub use std::mem; pub use std::mem;
pub use std::fmt;
pub use std::sync::atomic::{Ordering, AtomicBool}; pub use std::sync::atomic::{Ordering, AtomicBool};
pub use std::sync::{Arc, Weak}; pub use std::sync::{Arc, Weak};
pub use std::rc::Rc; pub use std::rc::Rc;
pub use std::cell::RefCell; pub use std::cell::RefCell;
pub use std::task; pub use std::task;
pub use std::future::Future;
pub use std::time::Duration;
pub use std::pin::Pin;
pub use std::ops::{FnOnce, FnMut, Fn}; pub use std::ops::{FnOnce, FnMut, Fn};
pub use async_std::future::Future;
pub use async_std::pin::Pin;
pub use async_std::sync::Mutex as AsyncMutex; pub use async_std::sync::Mutex as AsyncMutex;
pub use async_std::sync::MutexGuard as AsyncMutexGuard; pub use async_std::sync::MutexGuard as AsyncMutexGuard;
pub use std::net::{ SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs, IpAddr, Ipv4Addr, Ipv6Addr }; pub use std::net::{ SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs, IpAddr, Ipv4Addr, Ipv6Addr };