diff --git a/Cargo.lock b/Cargo.lock index 889d2be7..0e23de54 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6299,7 +6299,6 @@ name = "veilid-core" version = "0.4.1" dependencies = [ "argon2", - "async-io 1.13.0", "async-std", "async-std-resolver", "async-tls", @@ -6366,7 +6365,6 @@ dependencies = [ "sha2 0.10.8", "shell-words", "simplelog", - "socket2 0.5.7", "static_assertions", "stop-token", "sysinfo", @@ -6518,6 +6516,7 @@ name = "veilid-tools" version = "0.4.1" dependencies = [ "android_logger 0.13.3", + "async-io 1.13.0", "async-lock 3.4.0", "async-std", "async_executors", @@ -6556,6 +6555,7 @@ dependencies = [ "serde", "serial_test 2.0.0", "simplelog", + "socket2 0.5.7", "static_assertions", "stop-token", "thiserror", diff --git a/veilid-cli/src/client_api_connection.rs b/veilid-cli/src/client_api_connection.rs index 18a3f277..9124764a 100644 --- a/veilid-cli/src/client_api_connection.rs +++ b/veilid-cli/src/client_api_connection.rs @@ -223,13 +223,12 @@ impl ClientApiConnection { trace!("ClientApiConnection::handle_tcp_connection"); // Connect the TCP socket - let stream = TcpStream::connect(connect_addr) + let stream = connect_async_tcp_stream(None, connect_addr, 10_000) .await + .map_err(map_to_string)? + .into_timeout_error() .map_err(map_to_string)?; - // If it succeed, disable nagle algorithm - stream.set_nodelay(true).map_err(map_to_string)?; - // State we connected let comproc = self.inner.lock().comproc.clone(); comproc.set_connection_state(ConnectionState::ConnectedTCP( @@ -239,16 +238,8 @@ impl ClientApiConnection { // Split into reader and writer halves // with line buffering on the reader - cfg_if! { - if #[cfg(feature="rt-async-std")] { - use futures::AsyncReadExt; - let (reader, writer) = stream.split(); - let reader = BufReader::new(reader); - } else { - let (reader, writer) = stream.into_split(); - let reader = BufReader::new(reader); - } - } + let (reader, writer) = split_async_tcp_stream(stream); + let reader = BufReader::new(reader); self.clone().run_json_api_processor(reader, writer).await } diff --git a/veilid-core/Cargo.toml b/veilid-core/Cargo.toml index 4106d079..397df411 100644 --- a/veilid-core/Cargo.toml +++ b/veilid-core/Cargo.toml @@ -56,6 +56,7 @@ veilid_core_ios_tests = ["dep:tracing-oslog"] debug-locks = ["veilid-tools/debug-locks"] unstable-blockstore = [] unstable-tunnels = [] +virtual-network = [] # GeoIP geolocation = ["maxminddb", "reqwest"] @@ -164,7 +165,6 @@ sysinfo = { version = "^0.30.13", default-features = false } tokio = { version = "1.38.1", features = ["full"], optional = true } tokio-util = { version = "0.7.11", features = ["compat"], optional = true } tokio-stream = { version = "0.1.15", features = ["net"], optional = true } -async-io = { version = "1.13.0" } futures-util = { version = "0.3.30", default-features = false, features = [ "async-await", "sink", @@ -184,7 +184,6 @@ webpki = "0.22.4" webpki-roots = "0.25.4" rustls = "0.21.12" rustls-pemfile = "1.0.4" -socket2 = { version = "0.5.7", features = ["all"] } # Dependencies for WASM builds only [target.'cfg(target_arch = "wasm32")'.dependencies] diff --git a/veilid-core/src/network_manager/native/network_tcp.rs b/veilid-core/src/network_manager/native/network_tcp.rs index e3820146..390a0f58 100644 --- a/veilid-core/src/network_manager/native/network_tcp.rs +++ b/veilid-core/src/network_manager/native/network_tcp.rs @@ -1,6 +1,5 @@ use super::*; use async_tls::TlsAcceptor; -use sockets::*; use stop_token::future::FutureExt; ///////////////////////////////////////////////////////////////// @@ -135,39 +134,12 @@ impl Network { } }; - #[cfg(all(feature = "rt-async-std", unix))] + if let Err(e) = set_tcp_stream_linger(&tcp_stream, Some(core::time::Duration::from_secs(0))) { - // async-std does not directly support linger on TcpStream yet - use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd}; - if let Err(e) = unsafe { - let s = socket2::Socket::from_raw_fd(tcp_stream.as_raw_fd()); - let res = s.set_linger(Some(core::time::Duration::from_secs(0))); - s.into_raw_fd(); - res - } { - log_net!(debug "Couldn't set TCP linger: {}", e); - return; - } - } - #[cfg(all(feature = "rt-async-std", windows))] - { - // async-std does not directly support linger on TcpStream yet - use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket}; - if let Err(e) = unsafe { - let s = socket2::Socket::from_raw_socket(tcp_stream.as_raw_socket()); - let res = s.set_linger(Some(core::time::Duration::from_secs(0))); - s.into_raw_socket(); - res - } { - log_net!(debug "Couldn't set TCP linger: {}", e); - return; - } - } - #[cfg(not(feature = "rt-async-std"))] - if let Err(e) = tcp_stream.set_linger(Some(core::time::Duration::from_secs(0))) { log_net!(debug "Couldn't set TCP linger: {}", e); return; } + if let Err(e) = tcp_stream.set_nodelay(true) { log_net!(debug "Couldn't set TCP nodelay: {}", e); return; @@ -257,41 +229,11 @@ impl Network { ) }; - // Create a socket and bind it - let Some(socket) = new_bound_default_tcp_socket(addr) - .wrap_err("failed to create default socket listener")? - else { - return Ok(false); - }; - - // Drop the socket - drop(socket); - // Create a shared socket and bind it once we have determined the port is free - let Some(socket) = new_bound_shared_tcp_socket(addr) - .wrap_err("failed to create shared socket listener")? - else { + let Some(listener) = bind_async_tcp_listener(addr)? else { return Ok(false); }; - // Listen on the socket - if socket.listen(128).is_err() { - return Ok(false); - } - - // Make an async tcplistener from the socket2 socket - let std_listener: std::net::TcpListener = socket.into(); - cfg_if! { - if #[cfg(feature="rt-async-std")] { - let listener = TcpListener::from(std_listener); - } else if #[cfg(feature="rt-tokio")] { - std_listener.set_nonblocking(true).expect("failed to set nonblocking"); - let listener = TcpListener::from_std(std_listener).wrap_err("failed to create tokio tcp listener")?; - } else { - compile_error!("needs executor implementation"); - } - } - log_net!(debug "spawn_socket_listener: binding successful to {}", addr); // Create protocol handler records @@ -311,15 +253,7 @@ impl Network { // moves listener object in and get incoming iterator // when this task exists, the listener will close the socket - cfg_if! { - if #[cfg(feature="rt-async-std")] { - let incoming_stream = listener.incoming(); - } else if #[cfg(feature="rt-tokio")] { - let incoming_stream = tokio_stream::wrappers::TcpListenerStream::new(listener); - } else { - compile_error!("needs executor implementation"); - } - } + let incoming_stream = async_tcp_listener_incoming(listener); let _ = incoming_stream .for_each_concurrent(None, |tcp_stream| { diff --git a/veilid-core/src/network_manager/native/network_udp.rs b/veilid-core/src/network_manager/native/network_udp.rs index 575f04e1..dbfaf08a 100644 --- a/veilid-core/src/network_manager/native/network_udp.rs +++ b/veilid-core/src/network_manager/native/network_udp.rs @@ -1,5 +1,4 @@ use super::*; -use sockets::*; use stop_token::future::FutureExt; impl Network { @@ -114,23 +113,10 @@ impl Network { async fn create_udp_protocol_handler(&self, addr: SocketAddr) -> EyreResult { log_net!(debug "create_udp_protocol_handler on {:?}", &addr); - // Create a reusable socket - let Some(socket) = new_bound_default_udp_socket(addr)? else { + // Create a single-address-family UDP socket with default options bound to an address + let Some(udp_socket) = bind_async_udp_socket(addr)? else { return Ok(false); }; - - // Make an async UdpSocket from the socket2 socket - let std_udp_socket: std::net::UdpSocket = socket.into(); - cfg_if! { - if #[cfg(feature="rt-async-std")] { - let udp_socket = UdpSocket::from(std_udp_socket); - } else if #[cfg(feature="rt-tokio")] { - std_udp_socket.set_nonblocking(true).expect("failed to set nonblocking"); - let udp_socket = UdpSocket::from_std(std_udp_socket).wrap_err("failed to make inbound tokio udpsocket")?; - } else { - compile_error!("needs executor implementation"); - } - } let socket_arc = Arc::new(udp_socket); // Create protocol handler diff --git a/veilid-core/src/network_manager/native/protocol/mod.rs b/veilid-core/src/network_manager/native/protocol/mod.rs index eaef2aa1..63107655 100644 --- a/veilid-core/src/network_manager/native/protocol/mod.rs +++ b/veilid-core/src/network_manager/native/protocol/mod.rs @@ -1,4 +1,3 @@ -pub mod sockets; pub mod tcp; pub mod udp; pub mod wrtc; diff --git a/veilid-core/src/network_manager/native/protocol/sockets.rs b/veilid-core/src/network_manager/native/protocol/sockets.rs deleted file mode 100644 index 755ec6d1..00000000 --- a/veilid-core/src/network_manager/native/protocol/sockets.rs +++ /dev/null @@ -1,191 +0,0 @@ -use crate::*; -use async_io::Async; -use std::io; - -cfg_if! { - if #[cfg(feature="rt-async-std")] { - pub use async_std::net::{TcpStream, TcpListener, UdpSocket}; - } else if #[cfg(feature="rt-tokio")] { - pub use tokio::net::{TcpStream, TcpListener, UdpSocket}; - pub use tokio_util::compat::*; - } else { - compile_error!("needs executor implementation"); - } -} - -use socket2::{Domain, Protocol, SockAddr, Socket, Type}; - -// cfg_if! { -// if #[cfg(windows)] { -// use winapi::shared::ws2def::{ SOL_SOCKET, SO_EXCLUSIVEADDRUSE}; -// use winapi::um::winsock2::{SOCKET_ERROR, setsockopt}; -// use winapi::ctypes::c_int; -// use std::os::windows::io::AsRawSocket; - -// fn set_exclusiveaddruse(socket: &Socket) -> io::Result<()> { -// unsafe { -// let optval:c_int = 1; -// if setsockopt(socket.as_raw_socket().try_into().unwrap(), SOL_SOCKET, SO_EXCLUSIVEADDRUSE, (&optval as *const c_int).cast(), -// std::mem::size_of::() as c_int) == SOCKET_ERROR { -// return Err(io::Error::last_os_error()); -// } -// Ok(()) -// } -// } -// } -// } - -#[instrument(level = "trace", ret)] -pub fn new_shared_udp_socket(domain: Domain) -> io::Result { - let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?; - if domain == Domain::IPV6 { - socket.set_only_v6(true)?; - } - socket.set_reuse_address(true)?; - - cfg_if! { - if #[cfg(unix)] { - socket.set_reuse_port(true)?; - } - } - Ok(socket) -} - -#[instrument(level = "trace", ret)] -pub fn new_default_udp_socket(domain: Domain) -> io::Result { - let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?; - if domain == Domain::IPV6 { - socket.set_only_v6(true)?; - } - - Ok(socket) -} - -#[instrument(level = "trace", ret)] -pub fn new_bound_default_udp_socket(local_address: SocketAddr) -> io::Result> { - let domain = Domain::for_address(local_address); - let socket = new_default_udp_socket(domain)?; - let socket2_addr = SockAddr::from(local_address); - - if socket.bind(&socket2_addr).is_err() { - return Ok(None); - } - - log_net!("created bound default udp socket on {:?}", &local_address); - - Ok(Some(socket)) -} - -#[instrument(level = "trace", ret)] -pub fn new_default_tcp_socket(domain: Domain) -> io::Result { - let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?; - if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) { - log_net!(error "Couldn't set TCP linger: {}", e); - } - if let Err(e) = socket.set_nodelay(true) { - log_net!(error "Couldn't set TCP nodelay: {}", e); - } - if domain == Domain::IPV6 { - socket.set_only_v6(true)?; - } - Ok(socket) -} - -#[instrument(level = "trace", ret)] -pub fn new_shared_tcp_socket(domain: Domain) -> io::Result { - let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?; - if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) { - log_net!(error "Couldn't set TCP linger: {}", e); - } - if let Err(e) = socket.set_nodelay(true) { - log_net!(error "Couldn't set TCP nodelay: {}", e); - } - if domain == Domain::IPV6 { - socket.set_only_v6(true)?; - } - socket.set_reuse_address(true)?; - cfg_if! { - if #[cfg(unix)] { - socket.set_reuse_port(true)?; - } - } - - Ok(socket) -} -#[instrument(level = "trace", ret)] -pub fn new_bound_default_tcp_socket(local_address: SocketAddr) -> io::Result> { - let domain = Domain::for_address(local_address); - let socket = new_default_tcp_socket(domain)?; - let socket2_addr = SockAddr::from(local_address); - if socket.bind(&socket2_addr).is_err() { - return Ok(None); - } - - log_net!("created bound default tcp socket on {:?}", &local_address); - - Ok(Some(socket)) -} - -#[instrument(level = "trace", ret)] -pub fn new_bound_shared_tcp_socket(local_address: SocketAddr) -> io::Result> { - let domain = Domain::for_address(local_address); - let socket = new_shared_tcp_socket(domain)?; - let socket2_addr = SockAddr::from(local_address); - if socket.bind(&socket2_addr).is_err() { - return Ok(None); - } - - log_net!("created bound shared tcp socket on {:?}", &local_address); - - Ok(Some(socket)) -} - -// Non-blocking connect is tricky when you want to start with a prepared socket -// Errors should not be logged as they are valid conditions for this function -#[instrument(level = "trace", ret)] -pub async fn nonblocking_connect( - socket: Socket, - addr: SocketAddr, - timeout_ms: u32, -) -> io::Result> { - // Set for non blocking connect - socket.set_nonblocking(true)?; - - // Make socket2 SockAddr - let socket2_addr = socket2::SockAddr::from(addr); - - // Connect to the remote address - match socket.connect(&socket2_addr) { - Ok(()) => Ok(()), - #[cfg(unix)] - Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()), - Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(()), - Err(e) => Err(e), - }?; - let async_stream = Async::new(std::net::TcpStream::from(socket))?; - - // The stream becomes writable when connected - timeout_or_try!( - timeout(timeout_ms, async_stream.writable().in_current_span()) - .await - .into_timeout_or() - .into_result()? - ); - - // Check low level error - let async_stream = match async_stream.get_ref().take_error()? { - None => Ok(async_stream), - Some(err) => Err(err), - }?; - - // Convert back to inner and then return async version - cfg_if! { - if #[cfg(feature="rt-async-std")] { - Ok(TimeoutOr::value(TcpStream::from(async_stream.into_inner()?))) - } else if #[cfg(feature="rt-tokio")] { - Ok(TimeoutOr::value(TcpStream::from_std(async_stream.into_inner()?)?)) - } else { - compile_error!("needs executor implementation"); - } - } -} diff --git a/veilid-core/src/network_manager/native/protocol/tcp.rs b/veilid-core/src/network_manager/native/protocol/tcp.rs index 4b741d85..4a032d49 100644 --- a/veilid-core/src/network_manager/native/protocol/tcp.rs +++ b/veilid-core/src/network_manager/native/protocol/tcp.rs @@ -1,6 +1,5 @@ use super::*; use futures_util::{AsyncReadExt, AsyncWriteExt}; -use sockets::*; pub struct RawTcpNetworkConnection { flow: Flow, @@ -157,32 +156,28 @@ impl RawTcpProtocolHandler { #[instrument(level = "trace", target = "protocol", err)] pub async fn connect( local_address: Option, - socket_addr: SocketAddr, + remote_address: SocketAddr, timeout_ms: u32, ) -> io::Result> { - // Make a shared socket - let socket = match local_address { - Some(a) => { - new_bound_shared_tcp_socket(a)?.ok_or(io::Error::from(io::ErrorKind::AddrInUse))? - } - None => new_default_tcp_socket(socket2::Domain::for_address(socket_addr))?, - }; - // Non-blocking connect to remote address - let ts = network_result_try!(nonblocking_connect(socket, socket_addr, timeout_ms) - .await - .folded()?); + let tcp_stream = network_result_try!(connect_async_tcp_stream( + local_address, + remote_address, + timeout_ms + ) + .await + .folded()?); // See what local address we ended up with and turn this into a stream - let actual_local_address = ts.local_addr()?; + let actual_local_address = tcp_stream.local_addr()?; #[cfg(feature = "rt-tokio")] - let ts = ts.compat(); - let ps = AsyncPeekStream::new(ts); + let tcp_stream = tcp_stream.compat(); + let ps = AsyncPeekStream::new(tcp_stream); // Wrap the stream in a network connection and return it let flow = Flow::new( PeerAddress::new( - SocketAddress::from_socket_addr(socket_addr), + SocketAddress::from_socket_addr(remote_address), ProtocolType::TCP, ), SocketAddress::from_socket_addr(actual_local_address), diff --git a/veilid-core/src/network_manager/native/protocol/udp.rs b/veilid-core/src/network_manager/native/protocol/udp.rs index 37d888ba..fb69bc93 100644 --- a/veilid-core/src/network_manager/native/protocol/udp.rs +++ b/veilid-core/src/network_manager/native/protocol/udp.rs @@ -1,5 +1,4 @@ use super::*; -use sockets::*; #[derive(Clone)] pub struct RawUdpProtocolHandler { @@ -141,7 +140,8 @@ impl RawUdpProtocolHandler { ) -> io::Result { // get local wildcard address for bind let local_socket_addr = compatible_unspecified_socket_addr(socket_addr); - let socket = UdpSocket::bind(local_socket_addr).await?; + let socket = bind_async_udp_socket(local_socket_addr)? + .ok_or(io::Error::from(io::ErrorKind::AddrInUse))?; Ok(RawUdpProtocolHandler::new(Arc::new(socket), None)) } } diff --git a/veilid-core/src/network_manager/native/protocol/ws.rs b/veilid-core/src/network_manager/native/protocol/ws.rs index 96a97b11..5178a5ef 100644 --- a/veilid-core/src/network_manager/native/protocol/ws.rs +++ b/veilid-core/src/network_manager/native/protocol/ws.rs @@ -10,7 +10,6 @@ use async_tungstenite::tungstenite::protocol::{frame::coding::CloseCode, CloseFr use async_tungstenite::tungstenite::Error; use async_tungstenite::{accept_hdr_async, client_async, WebSocketStream}; use futures_util::{AsyncRead, AsyncWrite, SinkExt}; -use sockets::*; // Maximum number of websocket request headers to permit const MAX_WS_HEADERS: usize = 24; @@ -316,21 +315,16 @@ impl WebsocketProtocolHandler { let domain = split_url.host.clone(); // Resolve remote address - let remote_socket_addr = dial_info.to_socket_addr(); - - // Make a shared socket - let socket = match local_address { - Some(a) => { - new_bound_shared_tcp_socket(a)?.ok_or(io::Error::from(io::ErrorKind::AddrInUse))? - } - None => new_default_tcp_socket(socket2::Domain::for_address(remote_socket_addr))?, - }; + let remote_address = dial_info.to_socket_addr(); // Non-blocking connect to remote address - let tcp_stream = - network_result_try!(nonblocking_connect(socket, remote_socket_addr, timeout_ms) - .await - .folded()?); + let tcp_stream = network_result_try!(connect_async_tcp_stream( + local_address, + remote_address, + timeout_ms + ) + .await + .folded()?); // See what local address we ended up with let actual_local_addr = tcp_stream.local_addr()?; diff --git a/veilid-core/src/routing_table/debug.rs b/veilid-core/src/routing_table/debug.rs index 2db4e8c6..b2fe525e 100644 --- a/veilid-core/src/routing_table/debug.rs +++ b/veilid-core/src/routing_table/debug.rs @@ -333,9 +333,10 @@ impl RoutingTable { relaying_count += 1; } + let best_node_id = node.best_node_id(); + out += " "; - out += &node - .operate(|_rti, e| Self::format_entry(cur_ts, node.best_node_id(), e, &relay_tag)); + out += &node.operate(|_rti, e| Self::format_entry(cur_ts, best_node_id, e, &relay_tag)); out += "\n"; } diff --git a/veilid-server/src/client_api.rs b/veilid-server/src/client_api.rs index 47dc0886..d8097a9a 100644 --- a/veilid-server/src/client_api.rs +++ b/veilid-server/src/client_api.rs @@ -165,17 +165,12 @@ impl ClientApi { } async fn handle_tcp_incoming(self, bind_addr: SocketAddr) -> std::io::Result<()> { - let listener = TcpListener::bind(bind_addr).await?; + let listener = bind_async_tcp_listener(bind_addr)? + .ok_or(std::io::Error::from(std::io::ErrorKind::AddrInUse))?; debug!(target: "client_api", "TCPClient API listening on: {:?}", bind_addr); // Process the incoming accept stream - cfg_if! { - if #[cfg(feature="rt-async-std")] { - let mut incoming_stream = listener.incoming(); - } else { - let mut incoming_stream = tokio_stream::wrappers::TcpListenerStream::new(listener); - } - } + let mut incoming_stream = async_tcp_listener_incoming(listener); // Make wait group for all incoming connections let awg = AsyncWaitGroup::new(); diff --git a/veilid-server/src/tools.rs b/veilid-server/src/tools.rs index 6d5a4778..f1d073a3 100644 --- a/veilid-server/src/tools.rs +++ b/veilid-server/src/tools.rs @@ -7,8 +7,6 @@ pub use tracing::*; cfg_if! { if #[cfg(feature="rt-async-std")] { // pub use async_std::task::JoinHandle; - pub use async_std::net::TcpListener; - pub use async_std::net::TcpStream; pub use async_std::io::BufReader; //pub use async_std::future::TimeoutError; //pub fn spawn_detached + Send + 'static, T: Send + 'static>(f: F) -> JoinHandle { @@ -27,8 +25,6 @@ cfg_if! { } } else if #[cfg(feature="rt-tokio")] { //pub use tokio::task::JoinHandle; - pub use tokio::net::TcpListener; - pub use tokio::net::TcpStream; pub use tokio::io::BufReader; //pub use tokio_util::compat::*; //pub use tokio::time::error::Elapsed as TimeoutError; diff --git a/veilid-tools/Cargo.toml b/veilid-tools/Cargo.toml index 5b9a6b00..9a91209c 100644 --- a/veilid-tools/Cargo.toml +++ b/veilid-tools/Cargo.toml @@ -76,20 +76,21 @@ flume = { version = "0.11.0", features = ["async"] } # Dependencies for native builds only # Linux, Windows, Mac, iOS, Android [target.'cfg(not(target_arch = "wasm32"))'.dependencies] +async-io = { version = "1.13.0" } async-std = { version = "1.12.0", features = ["unstable"], optional = true } -tokio = { version = "1.38.1", features = ["full"], optional = true } -tokio-util = { version = "0.7.11", features = ["compat"], optional = true } -tokio-stream = { version = "0.1.15", features = ["net"], optional = true } +chrono = "0.4.38" futures-util = { version = "0.3.30", default-features = false, features = [ "async-await", "sink", "std", "io", ] } -chrono = "0.4.38" - libc = "0.2.155" nix = { version = "0.27.1", features = ["user"] } +socket2 = { version = "0.5.7", features = ["all"] } +tokio = { version = "1.38.1", features = ["full"], optional = true } +tokio-util = { version = "0.7.11", features = ["compat"], optional = true } +tokio-stream = { version = "0.1.15", features = ["net"], optional = true } # Dependencies for WASM builds only [target.'cfg(target_arch = "wasm32")'.dependencies] diff --git a/veilid-tools/src/bump_port.rs b/veilid-tools/src/bump_port.rs deleted file mode 100644 index ebdd2628..00000000 --- a/veilid-tools/src/bump_port.rs +++ /dev/null @@ -1,108 +0,0 @@ -use super::*; - -cfg_if! { - if #[cfg(target_arch = "wasm32")] { - - } else { - use std::net::{TcpListener, UdpSocket}; - } -} - -#[derive(ThisError, Debug, Clone, PartialEq, Eq)] -pub enum BumpPortError { - #[error("Unsupported architecture")] - Unsupported, - #[error("Failure: {0}")] - Failed(String), -} - -pub enum BumpPortType { - UDP, - TCP, -} - -pub fn tcp_port_available(addr: &SocketAddr) -> bool { - cfg_if! { - if #[cfg(target_arch = "wasm32")] { - true - } else { - match TcpListener::bind(addr) { - Ok(_) => true, - Err(_) => false, - } - } - } -} - -pub fn udp_port_available(addr: &SocketAddr) -> bool { - cfg_if! { - if #[cfg(target_arch = "wasm32")] { - true - } else { - match UdpSocket::bind(addr) { - Ok(_) => true, - Err(_) => false, - } - } - } -} - -pub fn bump_port(addr: &mut SocketAddr, bpt: BumpPortType) -> Result { - cfg_if! { - if #[cfg(target_arch = "wasm32")] { - Err(BumpPortError::Unsupported) - } - else - { - let mut bumped = false; - let mut port = addr.port(); - let mut addr_bump = addr.clone(); - loop { - - if match bpt { - BumpPortType::TCP => tcp_port_available(&addr_bump), - BumpPortType::UDP => udp_port_available(&addr_bump), - } { - *addr = addr_bump; - return Ok(bumped); - } - if port == u16::MAX { - break; - } - port += 1; - addr_bump.set_port(port); - bumped = true; - } - - Err(BumpPortError::Failure("no ports remaining".to_owned())) - } - } -} - -pub fn bump_port_string(addr: &mut String, bpt: BumpPortType) -> Result { - cfg_if! { - if #[cfg(target_arch = "wasm32")] { - return Err(BumpPortError::Unsupported); - } - else - { - let savec: Vec = addr - .to_socket_addrs() - .map_err(|x| BumpPortError::Failure(format!("failed to resolve socket address: {}", x)))? - .collect(); - - if savec.len() == 0 { - return Err(BumpPortError::Failure("No socket addresses resolved".to_owned())); - } - let mut sa = savec.first().unwrap().clone(); - - if !bump_port(&mut sa, bpt)? { - return Ok(false); - } - - *addr = sa.to_string(); - - Ok(true) - } - } -} diff --git a/veilid-tools/src/lib.rs b/veilid-tools/src/lib.rs index 2c3dd20f..f97b6ea9 100644 --- a/veilid-tools/src/lib.rs +++ b/veilid-tools/src/lib.rs @@ -24,7 +24,6 @@ #![allow(clippy::comparison_chain, clippy::upper_case_acronyms)] #![deny(unused_must_use)] -// pub mod bump_port; pub mod assembly_buffer; pub mod async_peek_stream; pub mod async_tag_lock; @@ -49,6 +48,8 @@ pub mod network_result; pub mod random; pub mod single_shot_eventual; pub mod sleep; +#[cfg(not(target_arch = "wasm32"))] +pub mod socket_tools; pub mod spawn; pub mod split_url; pub mod startup_lock; @@ -184,7 +185,6 @@ cfg_if! { } } -// pub use bump_port::*; #[doc(inline)] pub use assembly_buffer::*; #[doc(inline)] @@ -233,6 +233,9 @@ pub use single_shot_eventual::*; #[doc(inline)] pub use sleep::*; #[doc(inline)] +#[cfg(not(target_arch = "wasm32"))] +pub use socket_tools::*; +#[doc(inline)] pub use spawn::*; #[doc(inline)] pub use split_url::*; diff --git a/veilid-tools/src/socket_tools.rs b/veilid-tools/src/socket_tools.rs new file mode 100644 index 00000000..a2a567d5 --- /dev/null +++ b/veilid-tools/src/socket_tools.rs @@ -0,0 +1,284 @@ +use super::*; +use async_io::Async; +use std::io; + +cfg_if! { + if #[cfg(feature="rt-async-std")] { + pub use async_std::net::{TcpStream, TcpListener, UdpSocket}; + } else if #[cfg(feature="rt-tokio")] { + pub use tokio::net::{TcpStream, TcpListener, UdpSocket}; + pub use tokio_util::compat::*; + } else { + compile_error!("needs executor implementation"); + } +} + +use socket2::{Domain, Protocol, SockAddr, Socket, Type}; + +////////////////////////////////////////////////////////////////////////////////////////// + +pub fn bind_async_udp_socket(local_address: SocketAddr) -> io::Result> { + let Some(socket) = new_bound_default_socket2_udp(local_address)? else { + return Ok(None); + }; + + // Make an async UdpSocket from the socket2 socket + let std_udp_socket: std::net::UdpSocket = socket.into(); + cfg_if! { + if #[cfg(feature="rt-async-std")] { + let udp_socket = UdpSocket::from(std_udp_socket); + } else if #[cfg(feature="rt-tokio")] { + std_udp_socket.set_nonblocking(true)?; + let udp_socket = UdpSocket::from_std(std_udp_socket)?; + } else { + compile_error!("needs executor implementation"); + } + } + Ok(Some(udp_socket)) +} + +pub fn bind_async_tcp_listener(local_address: SocketAddr) -> io::Result> { + // Create a default non-shared socket and bind it + let Some(socket) = new_bound_default_socket2_tcp(local_address)? else { + return Ok(None); + }; + + // Drop the socket so we can make another shared socket in its place + drop(socket); + + // Create a shared socket and bind it now we have determined the port is free + let Some(socket) = new_bound_shared_socket2_tcp(local_address)? else { + return Ok(None); + }; + + // Listen on the socket + if socket.listen(128).is_err() { + return Ok(None); + } + + // Make an async tcplistener from the socket2 socket + let std_listener: std::net::TcpListener = socket.into(); + cfg_if! { + if #[cfg(feature="rt-async-std")] { + let listener = TcpListener::from(std_listener); + } else if #[cfg(feature="rt-tokio")] { + std_listener.set_nonblocking(true)?; + let listener = TcpListener::from_std(std_listener)?; + } else { + compile_error!("needs executor implementation"); + } + } + Ok(Some(listener)) +} + +pub async fn connect_async_tcp_stream( + local_address: Option, + remote_address: SocketAddr, + timeout_ms: u32, +) -> io::Result> { + let socket = match local_address { + Some(a) => { + new_bound_shared_socket2_tcp(a)?.ok_or(io::Error::from(io::ErrorKind::AddrInUse))? + } + None => new_default_socket2_tcp(domain_for_address(remote_address))?, + }; + + // Non-blocking connect to remote address + nonblocking_connect(socket, remote_address, timeout_ms).await +} + +pub fn set_tcp_stream_linger( + tcp_stream: &TcpStream, + linger: Option, +) -> io::Result<()> { + #[cfg(all(feature = "rt-async-std", unix))] + { + // async-std does not directly support linger on TcpStream yet + use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd}; + unsafe { + let s = socket2::Socket::from_raw_fd(tcp_stream.as_raw_fd()); + let res = s.set_linger(linger); + s.into_raw_fd(); + res + } + } + #[cfg(all(feature = "rt-async-std", windows))] + { + // async-std does not directly support linger on TcpStream yet + use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket}; + unsafe { + let s = socket2::Socket::from_raw_socket(tcp_stream.as_raw_socket()); + let res = s.set_linger(linger); + s.into_raw_socket(); + res + } + } + #[cfg(not(feature = "rt-async-std"))] + tcp_stream.set_linger(linger) +} + +cfg_if! { + if #[cfg(feature="rt-async-std")] { + pub type IncomingStream = Incoming; + pub type ReadHalf = futures_util::io::ReadHalf, + pub type WriteHalf = futures_util::io::WriteHalf, + } else if #[cfg(feature="rt-tokio")] { + pub type IncomingStream = tokio_stream::wrappers::TcpListenerStream; + pub type ReadHalf = tokio::net::tcp::OwnedReadHalf; + pub type WriteHalf = tokio::net::tcp::OwnedWriteHalf; + } else { + compile_error!("needs executor implementation"); + } +} + +pub fn async_tcp_listener_incoming(tcp_listener: TcpListener) -> IncomingStream { + cfg_if! { + if #[cfg(feature="rt-async-std")] { + tcp_listener.incoming() + } else if #[cfg(feature="rt-tokio")] { + tokio_stream::wrappers::TcpListenerStream::new(tcp_listener) + } else { + compile_error!("needs executor implementation"); + } + } +} + +pub fn split_async_tcp_stream(tcp_stream: TcpStream) -> (ReadHalf, WriteHalf) { + cfg_if! { + if #[cfg(feature="rt-async-std")] { + use futures_util::AsyncReadExt; + tcp_stream.split() + } else if #[cfg(feature="rt-tokio")] { + tcp_stream.into_split() + } else { + compile_error!("needs executor implementation"); + } + } +} + +////////////////////////////////////////////////////////////////////////////////////////// + +fn new_default_udp_socket(domain: core::ffi::c_int) -> io::Result { + let domain = Domain::from(domain); + let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?; + if domain == Domain::IPV6 { + socket.set_only_v6(true)?; + } + + Ok(socket) +} + +fn new_bound_default_socket2_udp(local_address: SocketAddr) -> io::Result> { + let domain = domain_for_address(local_address); + let socket = new_default_udp_socket(domain)?; + let socket2_addr = SockAddr::from(local_address); + + if socket.bind(&socket2_addr).is_err() { + return Ok(None); + } + + Ok(Some(socket)) +} + +pub fn new_default_socket2_tcp(domain: core::ffi::c_int) -> io::Result { + let domain = Domain::from(domain); + let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?; + socket.set_linger(Some(core::time::Duration::from_secs(0)))?; + socket.set_nodelay(true)?; + if domain == Domain::IPV6 { + socket.set_only_v6(true)?; + } + Ok(socket) +} + +fn new_shared_socket2_tcp(domain: core::ffi::c_int) -> io::Result { + let domain = Domain::from(domain); + let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?; + socket.set_linger(Some(core::time::Duration::from_secs(0)))?; + socket.set_nodelay(true)?; + if domain == Domain::IPV6 { + socket.set_only_v6(true)?; + } + socket.set_reuse_address(true)?; + cfg_if! { + if #[cfg(unix)] { + socket.set_reuse_port(true)?; + } + } + + Ok(socket) +} + +fn new_bound_default_socket2_tcp(local_address: SocketAddr) -> io::Result> { + let domain = domain_for_address(local_address); + let socket = new_default_socket2_tcp(domain)?; + let socket2_addr = SockAddr::from(local_address); + if socket.bind(&socket2_addr).is_err() { + return Ok(None); + } + + Ok(Some(socket)) +} + +fn new_bound_shared_socket2_tcp(local_address: SocketAddr) -> io::Result> { + // Create the reuseaddr/reuseport socket now that we've asserted the port is free + let domain = domain_for_address(local_address); + let socket = new_shared_socket2_tcp(domain)?; + let socket2_addr = SockAddr::from(local_address); + if socket.bind(&socket2_addr).is_err() { + return Ok(None); + } + + Ok(Some(socket)) +} + +// Non-blocking connect is tricky when you want to start with a prepared socket +// Errors should not be logged as they are valid conditions for this function +async fn nonblocking_connect( + socket: Socket, + addr: SocketAddr, + timeout_ms: u32, +) -> io::Result> { + // Set for non blocking connect + socket.set_nonblocking(true)?; + + // Make socket2 SockAddr + let socket2_addr = socket2::SockAddr::from(addr); + + // Connect to the remote address + match socket.connect(&socket2_addr) { + Ok(()) => Ok(()), + #[cfg(unix)] + Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()), + Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(()), + Err(e) => Err(e), + }?; + let async_stream = Async::new(std::net::TcpStream::from(socket))?; + + // The stream becomes writable when connected + timeout_or_try!(timeout(timeout_ms, async_stream.writable()) + .await + .into_timeout_or() + .into_result()?); + + // Check low level error + let async_stream = match async_stream.get_ref().take_error()? { + None => Ok(async_stream), + Some(err) => Err(err), + }?; + + // Convert back to inner and then return async version + cfg_if! { + if #[cfg(feature="rt-async-std")] { + Ok(TimeoutOr::value(TcpStream::from(async_stream.into_inner()?))) + } else if #[cfg(feature="rt-tokio")] { + Ok(TimeoutOr::value(TcpStream::from_std(async_stream.into_inner()?)?)) + } else { + compile_error!("needs executor implementation"); + } + } +} + +pub fn domain_for_address(address: SocketAddr) -> core::ffi::c_int { + socket2::Domain::for_address(address).into() +} diff --git a/veilid-tools/src/tests/native/test_async_peek_stream.rs b/veilid-tools/src/tests/native/test_async_peek_stream.rs index 1fde13f5..64028ccf 100644 --- a/veilid-tools/src/tests/native/test_async_peek_stream.rs +++ b/veilid-tools/src/tests/native/test_async_peek_stream.rs @@ -8,7 +8,6 @@ cfg_if! { } else if #[cfg(feature="rt-tokio")] { use tokio::net::{TcpListener, TcpStream}; use tokio::time::sleep; - use tokio_util::compat::*; } } diff --git a/veilid-tools/src/virtual_network/machine.rs b/veilid-tools/src/virtual_network/machine.rs index ea4fb251..d8e727e4 100644 --- a/veilid-tools/src/virtual_network/machine.rs +++ b/veilid-tools/src/virtual_network/machine.rs @@ -2,7 +2,7 @@ use super::*; pub type MachineId = u64; -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Machine { pub router_client: RouterClient, pub id: MachineId, diff --git a/veilid-tools/src/virtual_network/mod.rs b/veilid-tools/src/virtual_network/mod.rs index d1fe8f33..069caa9c 100644 --- a/veilid-tools/src/virtual_network/mod.rs +++ b/veilid-tools/src/virtual_network/mod.rs @@ -25,6 +25,7 @@ //! //! [VirtualTcpStream] //! [VirtualUdpSocket] +//! [VirtualTcpListener] //! [VirtualTcpListenerStream] //! [VirtualWsMeta] //! [VirtualWsStream] @@ -44,6 +45,8 @@ mod router_op_table; mod router_server; mod serde_io_error; mod virtual_network_error; +mod virtual_tcp_listener; +mod virtual_tcp_listener_stream; mod virtual_tcp_stream; mod virtual_udp_socket; @@ -53,5 +56,7 @@ pub use machine::*; pub use router_client::*; pub use router_server::*; pub use virtual_network_error::*; +pub use virtual_tcp_listener::*; +pub use virtual_tcp_listener_stream::*; pub use virtual_tcp_stream::*; pub use virtual_udp_socket::*; diff --git a/veilid-tools/src/virtual_network/router_client.rs b/veilid-tools/src/virtual_network/router_client.rs index c4b9aa0f..85a59a93 100644 --- a/veilid-tools/src/virtual_network/router_client.rs +++ b/veilid-tools/src/virtual_network/router_client.rs @@ -2,8 +2,7 @@ use super::*; use core::sync::atomic::AtomicU64; use futures_codec::{Bytes, BytesCodec, FramedRead, FramedWrite}; use futures_util::{ - io::BufReader, stream::FuturesUnordered, AsyncReadExt, AsyncWriteExt, FutureExt, SinkExt, - StreamExt, TryStreamExt, + stream::FuturesUnordered, AsyncReadExt, AsyncWriteExt, StreamExt, TryStreamExt, }; use postcard::{from_bytes, to_stdvec}; use router_op_table::*; @@ -16,11 +15,29 @@ struct RouterClientInner { stop_source: Option, } +impl fmt::Debug for RouterClientInner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RouterClientInner") + .field("jh_handler", &self.jh_handler) + .field("stop_source", &self.stop_source) + .finish() + } +} + struct RouterClientUnlockedInner { - receiver: flume::Receiver, sender: flume::Sender, next_message_id: AtomicU64, - router_op_waiter: RouterOpWaiter, + router_op_waiter: RouterOpWaiter, +} + +impl fmt::Debug for RouterClientUnlockedInner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RouterClientUnlockedInner") + .field("sender", &self.sender) + .field("next_message_id", &self.next_message_id) + .field("router_op_waiter", &self.router_op_waiter) + .finish() + } } pub type MessageId = u64; @@ -49,9 +66,9 @@ enum ServerProcessorRequest { }, TcpAccept { machine_id: MachineId, - socket_id: SocketId, + listen_socket_id: SocketId, }, - Close { + TcpShutdown { machine_id: MachineId, socket_id: SocketId, }, @@ -84,13 +101,22 @@ enum ServerProcessorRequest { } #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] -struct ServerProcessorRequestMessage { +struct ServerProcessorMessage { message_id: MessageId, request: ServerProcessorRequest, } #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] -enum ServerProcessorResponse { +enum ServerProcessorCommand { + Message(ServerProcessorMessage), + CloseSocket { + machine_id: MachineId, + socket_id: SocketId, + }, +} + +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] +enum ServerProcessorReplyValue { AllocateMachine { machine_id: MachineId, }, @@ -100,17 +126,21 @@ enum ServerProcessorResponse { }, TcpConnect { socket_id: SocketId, + local_address: SocketAddr, }, TcpBind { socket_id: SocketId, + local_address: SocketAddr, }, TcpAccept { - child_socket_id: SocketId, + socket_id: SocketId, + address: SocketAddr, }, + TcpShutdown, UdpBind { socket_id: SocketId, + local_address: SocketAddr, }, - Close, Send { len: u32, }, @@ -127,20 +157,29 @@ enum ServerProcessorResponse { } #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] -enum ServerProcessorResponseStatus { - Success(ServerProcessorResponse), +enum ServerProcessorReplyResult { + Value(ServerProcessorReplyValue), InvalidMachineId, InvalidSocketId, IoError(#[serde(with = "serde_io_error::SerdeIoErrorKindDef")] io::ErrorKind), } #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] -struct ServerProcessorResponseMessage { +struct ServerProcessorReply { message_id: MessageId, - status: ServerProcessorResponseStatus, + status: ServerProcessorReplyResult, } -#[derive(Clone)] +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] +enum ServerProcessorEvent { + Reply(ServerProcessorReply), + // DeadSocket { + // machine_id: MachineId, + // socket_id: SocketId, + // }, +} + +#[derive(Debug, Clone)] pub struct RouterClient { unlocked_inner: Arc, inner: Arc>, @@ -151,7 +190,7 @@ impl RouterClient { // Public interface #[cfg(not(target_arch = "wasm32"))] - pub async fn router_connect_tcp(host: H) -> ::std::io::Result { + pub async fn router_connect_tcp(host: H) -> io::Result { let addrs = host.to_socket_addrs()?.collect::>(); // Connect to RouterServer @@ -173,37 +212,37 @@ impl RouterClient { } } - let ts_buf_reader = BufReader::new(ts_reader); - // Create channels let (client_sender, server_receiver) = flume::unbounded::(); - let (server_sender, client_receiver) = flume::unbounded::(); // Create stopper let stop_source = StopSource::new(); + // Create router operation waiter + let router_op_waiter = RouterOpWaiter::new(); + // Spawn a client connection handler let jh_handler = spawn( "RouterClient server processor", Self::run_server_processor( - ts_buf_reader, + ts_reader, ts_writer, server_receiver, - server_sender, + router_op_waiter.clone(), stop_source.token(), ), ); Ok(Self::new( - client_receiver, client_sender, + router_op_waiter, jh_handler, stop_source, )) } #[cfg(target_arch = "wasm32")] - pub async fn router_connect_ws>(host: H) -> ::std::io::Result { + pub async fn router_connect_ws>(host: H) -> io::Result { let host = host.as_ref(); Ok(RouterClient {}) @@ -219,7 +258,7 @@ impl RouterClient { pub async fn allocate_machine(self) -> VirtualNetworkResult { let request = ServerProcessorRequest::AllocateMachine; - let ServerProcessorResponse::AllocateMachine { machine_id } = + let ServerProcessorReplyValue::AllocateMachine { machine_id } = self.perform_request(request).await? else { return Err(VirtualNetworkError::ResponseMismatch); @@ -229,7 +268,7 @@ impl RouterClient { pub async fn release_machine(self, machine_id: MachineId) -> VirtualNetworkResult<()> { let request = ServerProcessorRequest::ReleaseMachine { machine_id }; - let ServerProcessorResponse::ReleaseMachine = self.perform_request(request).await? else { + let ServerProcessorReplyValue::ReleaseMachine = self.perform_request(request).await? else { return Err(VirtualNetworkError::ResponseMismatch); }; Ok(()) @@ -240,7 +279,7 @@ impl RouterClient { machine_id: MachineId, ) -> VirtualNetworkResult> { let request = ServerProcessorRequest::GetInterfaces { machine_id }; - let ServerProcessorResponse::GetInterfaces { interfaces } = + let ServerProcessorReplyValue::GetInterfaces { interfaces } = self.perform_request(request).await? else { return Err(VirtualNetworkError::ResponseMismatch); @@ -252,91 +291,99 @@ impl RouterClient { self, machine_id: MachineId, remote_address: SocketAddr, - local_address: Option, + opt_local_address: Option, timeout_ms: u32, options: VirtualTcpOptions, - ) -> VirtualNetworkResult { + ) -> VirtualNetworkResult<(SocketId, SocketAddr)> { let request = ServerProcessorRequest::TcpConnect { machine_id, - local_address, + local_address: opt_local_address, remote_address, timeout_ms, options, }; - let ServerProcessorResponse::TcpConnect { socket_id } = - self.perform_request(request).await? + let ServerProcessorReplyValue::TcpConnect { + socket_id, + local_address, + } = self.perform_request(request).await? else { return Err(VirtualNetworkError::ResponseMismatch); }; - Ok(socket_id) + Ok((socket_id, local_address)) } pub async fn tcp_bind( self, machine_id: MachineId, - local_address: Option, + opt_local_address: Option, options: VirtualTcpOptions, - ) -> VirtualNetworkResult { + ) -> VirtualNetworkResult<(SocketId, SocketAddr)> { let request = ServerProcessorRequest::TcpBind { machine_id, - local_address, + local_address: opt_local_address, options, }; - let ServerProcessorResponse::TcpBind { socket_id } = self.perform_request(request).await? + let ServerProcessorReplyValue::TcpBind { + socket_id, + local_address, + } = self.perform_request(request).await? else { return Err(VirtualNetworkError::ResponseMismatch); }; - Ok(socket_id) + Ok((socket_id, local_address)) } pub async fn tcp_accept( self, machine_id: MachineId, - socket_id: SocketId, - ) -> VirtualNetworkResult { + listen_socket_id: SocketId, + ) -> VirtualNetworkResult<(SocketId, SocketAddr)> { let request = ServerProcessorRequest::TcpAccept { machine_id, - socket_id, + listen_socket_id, }; - let ServerProcessorResponse::TcpAccept { child_socket_id } = + let ServerProcessorReplyValue::TcpAccept { socket_id, address } = self.perform_request(request).await? else { return Err(VirtualNetworkError::ResponseMismatch); }; - Ok(child_socket_id) + Ok((socket_id, address)) + } + + pub async fn tcp_shutdown( + self, + machine_id: MachineId, + socket_id: SocketId, + ) -> VirtualNetworkResult<()> { + let request = ServerProcessorRequest::TcpShutdown { + machine_id, + socket_id, + }; + let ServerProcessorReplyValue::TcpShutdown = self.perform_request(request).await? else { + return Err(VirtualNetworkError::ResponseMismatch); + }; + Ok(()) } pub async fn udp_bind( self, machine_id: MachineId, - local_address: Option, + opt_local_address: Option, options: VirtualUdpOptions, - ) -> VirtualNetworkResult { + ) -> VirtualNetworkResult<(SocketId, SocketAddr)> { let request = ServerProcessorRequest::UdpBind { machine_id, - local_address, + local_address: opt_local_address, options, }; - let ServerProcessorResponse::UdpBind { socket_id } = self.perform_request(request).await? + let ServerProcessorReplyValue::UdpBind { + socket_id, + local_address, + } = self.perform_request(request).await? else { return Err(VirtualNetworkError::ResponseMismatch); }; - Ok(socket_id) - } - - pub async fn close( - self, - machine_id: MachineId, - socket_id: SocketId, - ) -> VirtualNetworkResult<()> { - let request = ServerProcessorRequest::Close { - machine_id, - socket_id, - }; - let ServerProcessorResponse::Close = self.perform_request(request).await? else { - return Err(VirtualNetworkError::ResponseMismatch); - }; - Ok(()) + Ok((socket_id, local_address)) } pub async fn send( @@ -350,7 +397,7 @@ impl RouterClient { socket_id, data, }; - let ServerProcessorResponse::Send { len } = self.perform_request(request).await? else { + let ServerProcessorReplyValue::Send { len } = self.perform_request(request).await? else { return Err(VirtualNetworkError::ResponseMismatch); }; Ok(len as usize) @@ -369,7 +416,7 @@ impl RouterClient { data, remote_address, }; - let ServerProcessorResponse::SendTo { len } = self.perform_request(request).await? else { + let ServerProcessorReplyValue::SendTo { len } = self.perform_request(request).await? else { return Err(VirtualNetworkError::ResponseMismatch); }; Ok(len as usize) @@ -379,14 +426,14 @@ impl RouterClient { self, machine_id: MachineId, socket_id: u64, - len: u32, + len: usize, ) -> VirtualNetworkResult> { let request = ServerProcessorRequest::Recv { machine_id, socket_id, - len, + len: len as u32, }; - let ServerProcessorResponse::Recv { data } = self.perform_request(request).await? else { + let ServerProcessorReplyValue::Recv { data } = self.perform_request(request).await? else { return Err(VirtualNetworkError::ResponseMismatch); }; Ok(data) @@ -396,14 +443,14 @@ impl RouterClient { self, machine_id: MachineId, socket_id: u64, - len: u32, + len: usize, ) -> VirtualNetworkResult<(Vec, SocketAddr)> { let request = ServerProcessorRequest::RecvFrom { machine_id, socket_id, - len, + len: len as u32, }; - let ServerProcessorResponse::RecvFrom { + let ServerProcessorReplyValue::RecvFrom { data, remote_address, } = self.perform_request(request).await? @@ -417,17 +464,16 @@ impl RouterClient { // Private implementation fn new( - receiver: flume::Receiver, sender: flume::Sender, + router_op_waiter: RouterOpWaiter, jh_handler: MustJoinHandle<()>, stop_source: StopSource, ) -> RouterClient { RouterClient { unlocked_inner: Arc::new(RouterClientUnlockedInner { - receiver, sender, next_message_id: AtomicU64::new(0), - router_op_waiter: RouterOpWaiter::new(), + router_op_waiter, }), inner: Arc::new(Mutex::new(RouterClientInner { jh_handler: Some(jh_handler), @@ -436,24 +482,60 @@ impl RouterClient { } } + fn report_closed_socket(&self, machine_id: MachineId, socket_id: SocketId) { + let command = ServerProcessorCommand::CloseSocket { + machine_id, + socket_id, + }; + let command_vec = match to_stdvec(&command).map_err(VirtualNetworkError::SerializationError) + { + Ok(v) => Bytes::from(v), + Err(e) => { + error!("{}", e); + return; + } + }; + + if let Err(e) = self + .unlocked_inner + .sender + .send(command_vec) + .map_err(|_| VirtualNetworkError::IoError(io::ErrorKind::BrokenPipe)) + { + error!("{}", e); + } + } + + pub(super) fn drop_tcp_stream(&self, machine_id: MachineId, socket_id: SocketId) { + self.report_closed_socket(machine_id, socket_id); + } + + pub(super) fn drop_tcp_listener(&self, machine_id: MachineId, socket_id: SocketId) { + self.report_closed_socket(machine_id, socket_id); + } + + pub(super) fn drop_udp_socket(&self, machine_id: MachineId, socket_id: SocketId) { + self.report_closed_socket(machine_id, socket_id); + } + async fn perform_request( &self, request: ServerProcessorRequest, - ) -> VirtualNetworkResult { + ) -> VirtualNetworkResult { let message_id = self .unlocked_inner .next_message_id .fetch_add(1, Ordering::AcqRel); - let msg = ServerProcessorRequestMessage { + let command = ServerProcessorCommand::Message(ServerProcessorMessage { message_id, request, - }; - let msg_vec = - Bytes::from(to_stdvec(&msg).map_err(VirtualNetworkError::SerializationError)?); + }); + let command_vec = + Bytes::from(to_stdvec(&command).map_err(VirtualNetworkError::SerializationError)?); self.unlocked_inner .sender - .send_async(msg_vec) + .send_async(command_vec) .await .map_err(|_| VirtualNetworkError::IoError(io::ErrorKind::BrokenPipe))?; let handle = self @@ -469,16 +551,16 @@ impl RouterClient { .map_err(|_| VirtualNetworkError::WaitError)?; match status { - ServerProcessorResponseStatus::Success(server_processor_response) => { + ServerProcessorReplyResult::Value(server_processor_response) => { Ok(server_processor_response) } - ServerProcessorResponseStatus::InvalidMachineId => { + ServerProcessorReplyResult::InvalidMachineId => { Err(VirtualNetworkError::InvalidMachineId) } - ServerProcessorResponseStatus::InvalidSocketId => { + ServerProcessorReplyResult::InvalidSocketId => { Err(VirtualNetworkError::InvalidSocketId) } - ServerProcessorResponseStatus::IoError(k) => Err(VirtualNetworkError::IoError(k)), + ServerProcessorReplyResult::IoError(k) => Err(VirtualNetworkError::IoError(k)), } } @@ -486,7 +568,7 @@ impl RouterClient { reader: R, writer: W, receiver: flume::Receiver, - sender: flume::Sender, + router_op_waiter: RouterOpWaiter, stop_token: StopToken, ) where R: AsyncReadExt + Unpin + Send, @@ -503,10 +585,27 @@ impl RouterClient { } }); let framed_reader_fut = system_boxed(async move { - if let Err(e) = framed_reader - .forward(sender.into_sink().sink_map_err(::std::io::Error::other)) - .await - { + let fut = framed_reader.try_for_each(|x| async { + let x = x; + let evt = from_bytes::(&x) + .map_err(VirtualNetworkError::SerializationError)?; + + match evt { + ServerProcessorEvent::Reply(reply) => { + router_op_waiter + .complete_op_waiter(reply.message_id, reply.status) + .map_err(io::Error::other)?; + } // ServerProcessorEvent::DeadSocket { + // machine_id, + // socket_id, + // } => { + // // + // } + } + + Ok(()) + }); + if let Err(e) = fut.await { error!("{}", e); } }); diff --git a/veilid-tools/src/virtual_network/router_op_table.rs b/veilid-tools/src/virtual_network/router_op_table.rs index 122b9fa8..b403700d 100644 --- a/veilid-tools/src/virtual_network/router_op_table.rs +++ b/veilid-tools/src/virtual_network/router_op_table.rs @@ -25,16 +25,6 @@ where result_receiver: Option>, } -impl RouterOpWaitHandle -where - T: Unpin, - C: Unpin + Clone, -{ - pub fn id(&self) -> RouterOpId { - self.op_id - } -} - impl Drop for RouterOpWaitHandle where T: Unpin, @@ -123,6 +113,7 @@ where } /// Get operation context + #[expect(dead_code)] pub fn get_op_context(&self, op_id: RouterOpId) -> Result> { let inner = self.inner.lock(); let Some(waiting_op) = inner.waiting_op_table.get(&op_id) else { @@ -132,14 +123,12 @@ where } /// Remove wait for op - #[instrument(level = "trace", target = "rpc", skip_all)] fn cancel_op_waiter(&self, op_id: RouterOpId) { let mut inner = self.inner.lock(); inner.waiting_op_table.remove(&op_id); } /// Complete the waiting op - #[instrument(level = "trace", target = "rpc", skip_all)] pub fn complete_op_waiter( &self, op_id: RouterOpId, @@ -159,7 +148,6 @@ where } /// Wait for operation to complete - #[instrument(level = "trace", target = "rpc", skip_all)] pub async fn wait_for_op( &self, mut handle: RouterOpWaitHandle, diff --git a/veilid-tools/src/virtual_network/virtual_tcp_listener.rs b/veilid-tools/src/virtual_network/virtual_tcp_listener.rs new file mode 100644 index 00000000..19555cd5 --- /dev/null +++ b/veilid-tools/src/virtual_network/virtual_tcp_listener.rs @@ -0,0 +1,67 @@ +use super::*; + +#[derive(Debug)] +pub struct VirtualTcpListener { + pub(super) machine: Machine, + pub(super) socket_id: SocketId, + pub(super) local_address: SocketAddr, +} + +impl VirtualTcpListener { + ///////////////////////////////////////////////////////////// + // Public Interface + + pub async fn bind( + opt_local_address: Option, + options: VirtualTcpOptions, + ) -> VirtualNetworkResult { + let machine = default_machine().unwrap(); + Self::bind_with_machine(machine, opt_local_address, options).await + } + + pub async fn bind_with_machine( + machine: Machine, + opt_local_address: Option, + options: VirtualTcpOptions, + ) -> VirtualNetworkResult { + machine + .router_client + .clone() + .tcp_bind(machine.id, opt_local_address, options) + .await + .map(|(socket_id, local_address)| Self::new(machine, socket_id, local_address)) + } + + pub async fn accept(&self) -> VirtualNetworkResult<(VirtualTcpStream, SocketAddr)> { + self.machine + .router_client + .clone() + .tcp_accept(self.machine.id, self.socket_id) + .await + .map(|v| { + ( + VirtualTcpStream::new(self.machine.clone(), v.0, self.local_address, v.1), + v.1, + ) + }) + } + + ///////////////////////////////////////////////////////////// + // Private Implementation + + fn new(machine: Machine, socket_id: SocketId, local_address: SocketAddr) -> Self { + Self { + machine, + socket_id, + local_address, + } + } +} + +impl Drop for VirtualTcpListener { + fn drop(&mut self) { + self.machine + .router_client + .drop_tcp_listener(self.machine.id, self.socket_id); + } +} diff --git a/veilid-tools/src/virtual_network/virtual_tcp_listener_stream.rs b/veilid-tools/src/virtual_network/virtual_tcp_listener_stream.rs new file mode 100644 index 00000000..31d79315 --- /dev/null +++ b/veilid-tools/src/virtual_network/virtual_tcp_listener_stream.rs @@ -0,0 +1,86 @@ +use super::*; + +use core::pin::Pin; +use core::task::{Context, Poll}; +use futures_util::{stream::Stream, FutureExt}; +use std::io; + +/// A wrapper around [`VirtualTcpListener`] that implements [`Stream`]. +/// +/// [`VirtualTcpListener`]: struct@crate::VirtualTcpListener +/// [`Stream`]: trait@futures_util::stream::Stream +pub struct VirtualTcpListenerStream { + inner: VirtualTcpListener, + current_accept_fut: Option>>, +} + +impl fmt::Debug for VirtualTcpListenerStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("VirtualTcpListenerStream") + .field("inner", &self.inner) + .field( + "current_accept_fut", + if self.current_accept_fut.is_some() { + &"Some(...)" + } else { + &"None" + }, + ) + .finish() + } +} + +impl VirtualTcpListenerStream { + /// Create a new `VirtualTcpListenerStream`. + pub fn new(listener: VirtualTcpListener) -> Self { + Self { + inner: listener, + current_accept_fut: None, + } + } + + /// Get back the inner `VirtualTcpListener`. + pub fn into_inner(self) -> VirtualTcpListener { + self.inner + } +} + +impl Stream for VirtualTcpListenerStream { + type Item = io::Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.current_accept_fut.is_none() { + let machine_id = self.inner.machine.id; + let router_client = self.inner.machine.router_client.clone(); + let socket_id = self.inner.socket_id; + + self.current_accept_fut = + Some(Box::pin(router_client.tcp_accept(machine_id, socket_id))); + } + let fut = self.current_accept_fut.as_mut().unwrap(); + fut.poll_unpin(cx).map(|v| match v { + Ok(v) => Some(Ok(VirtualTcpStream::new( + self.inner.machine.clone(), + v.0, + self.inner.local_address, + v.1, + ))), + Err(e) => Some(Err(e.into())), + }) + } +} + +impl AsRef for VirtualTcpListenerStream { + fn as_ref(&self) -> &VirtualTcpListener { + &self.inner + } +} + +impl AsMut for VirtualTcpListenerStream { + fn as_mut(&mut self) -> &mut VirtualTcpListener { + &mut self.inner + } +} diff --git a/veilid-tools/src/virtual_network/virtual_tcp_stream.rs b/veilid-tools/src/virtual_network/virtual_tcp_stream.rs index f9f88bbb..6cb63dc8 100644 --- a/veilid-tools/src/virtual_network/virtual_tcp_stream.rs +++ b/veilid-tools/src/virtual_network/virtual_tcp_stream.rs @@ -13,12 +13,52 @@ pub struct VirtualTcpOptions { pub struct VirtualTcpStream { machine: Machine, socket_id: SocketId, + local_address: SocketAddr, + remote_address: SocketAddr, current_recv_fut: Option, VirtualNetworkError>>>, current_send_fut: Option>>, - current_close_fut: Option>>, + current_tcp_shutdown_fut: Option>>, +} + +impl fmt::Debug for VirtualTcpStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("VirtualTcpStream") + .field("machine", &self.machine) + .field("socket_id", &self.socket_id) + .field("local_address", &self.local_address) + .field("remote_address", &self.remote_address) + .field( + "current_recv_fut", + if self.current_recv_fut.is_some() { + &"Some(...)" + } else { + &"None" + }, + ) + .field( + "current_send_fut", + if self.current_send_fut.is_some() { + &"Some(...)" + } else { + &"None" + }, + ) + .field( + "current_close_fut", + if self.current_tcp_shutdown_fut.is_some() { + &"Some(...)" + } else { + &"None" + }, + ) + .finish() + } } impl VirtualTcpStream { + ////////////////////////////////////////////////////////////////////////// + // Public Interface + pub async fn connect( remote_address: SocketAddr, local_address: Option, @@ -48,14 +88,46 @@ impl VirtualTcpStream { options, ) .await - .map(|socket_id| Self { - machine, - socket_id, - current_recv_fut: None, - current_send_fut: None, - current_close_fut: None, + .map(|(socket_id, local_address)| { + Self::new(machine, socket_id, local_address, remote_address) }) } + + pub fn local_addr(&self) -> VirtualNetworkResult { + Ok(self.local_address) + } + + pub fn peer_addr(&self) -> VirtualNetworkResult { + Ok(self.remote_address) + } + + ////////////////////////////////////////////////////////////////////////// + // Private Implementation + + pub(super) fn new( + machine: Machine, + socket_id: SocketId, + local_address: SocketAddr, + remote_address: SocketAddr, + ) -> Self { + Self { + machine, + socket_id, + local_address, + remote_address, + current_recv_fut: None, + current_send_fut: None, + current_tcp_shutdown_fut: None, + } + } +} + +impl Drop for VirtualTcpStream { + fn drop(&mut self) { + self.machine + .router_client + .drop_tcp_stream(self.machine.id, self.socket_id); + } } impl futures_util::AsyncRead for VirtualTcpStream { @@ -68,7 +140,7 @@ impl futures_util::AsyncRead for VirtualTcpStream { self.current_recv_fut = Some(Box::pin(self.machine.router_client.clone().recv( self.machine.id, self.socket_id, - buf.len() as u32, + buf.len(), ))); } let fut = self.current_recv_fut.as_mut().unwrap(); @@ -118,18 +190,18 @@ impl futures_util::AsyncWrite for VirtualTcpStream { mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> task::Poll> { - if self.current_close_fut.is_none() { - self.current_close_fut = Some(Box::pin( + if self.current_tcp_shutdown_fut.is_none() { + self.current_tcp_shutdown_fut = Some(Box::pin( self.machine .router_client .clone() - .close(self.machine.id, self.socket_id), + .tcp_shutdown(self.machine.id, self.socket_id), )); } - let fut = self.current_close_fut.as_mut().unwrap(); + let fut = self.current_tcp_shutdown_fut.as_mut().unwrap(); fut.poll_unpin(cx).map(|v| match v { Ok(v) => { - self.current_close_fut = None; + self.current_tcp_shutdown_fut = None; Ok(v) } Err(e) => Err(e.into()), diff --git a/veilid-tools/src/virtual_network/virtual_udp_socket.rs b/veilid-tools/src/virtual_network/virtual_udp_socket.rs index ed3d3f21..926196f4 100644 --- a/veilid-tools/src/virtual_network/virtual_udp_socket.rs +++ b/veilid-tools/src/virtual_network/virtual_udp_socket.rs @@ -7,50 +7,80 @@ pub struct VirtualUdpOptions { reuse_address_port: bool, } +#[derive(Debug)] pub struct VirtualUdpSocket { machine: Machine, socket_id: SocketId, + local_address: SocketAddr, } impl VirtualUdpSocket { - // pub async fn connect( - // remote_address: SocketAddr, - // local_address: Option, - // timeout_ms: u32, - // options: VirtualTcpOptions, - // ) -> VirtualNetworkResult { - // let machine = default_machine().unwrap(); - // Self::connect_with_machine(machine, remote_address, local_address, timeout_ms, options) - // .await - // } + ///////////////////////////////////////////////////////////// + // Public Interface - // pub async fn connect_with_machine( - // machine: Machine, - // remote_address: SocketAddr, - // local_address: Option, - // timeout_ms: u32, - // options: VirtualTcpOptions, - // ) -> VirtualNetworkResult { - // machine - // .router_client - // .tcp_connect( - // machine.id, - // remote_address, - // local_address, - // timeout_ms, - // options, - // ) - // .await - // .map(|socket_id| Self { machine, socket_id }) - // } + pub async fn bind( + opt_local_address: Option, + options: VirtualUdpOptions, + ) -> VirtualNetworkResult { + let machine = default_machine().unwrap(); + Self::bind_with_machine(machine, opt_local_address, options).await + } + + pub async fn bind_with_machine( + machine: Machine, + opt_local_address: Option, + options: VirtualUdpOptions, + ) -> VirtualNetworkResult { + machine + .router_client + .clone() + .udp_bind(machine.id, opt_local_address, options) + .await + .map(|(socket_id, local_address)| Self::new(machine, socket_id, local_address)) + } + + pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> VirtualNetworkResult { + self.machine + .router_client + .clone() + .send_to(self.machine.id, self.socket_id, target, buf.to_vec()) + .await + } + + pub async fn recv_from(&self, buf: &mut [u8]) -> VirtualNetworkResult<(usize, SocketAddr)> { + let (v, addr) = self + .machine + .router_client + .clone() + .recv_from(self.machine.id, self.socket_id, buf.len()) + .await?; + + let len = usize::min(buf.len(), v.len()); + buf[0..len].copy_from_slice(&v[0..len]); + + Ok((len, addr)) + } + + pub fn local_addr(&self) -> VirtualNetworkResult { + Ok(self.local_address) + } + + ///////////////////////////////////////////////////////////// + // Private Implementation + + fn new(machine: Machine, socket_id: SocketId, local_address: SocketAddr) -> Self { + Self { + machine, + socket_id, + local_address, + } + } } -// impl futures_util::AsyncRead for VirtualUdpSocket { -// fn poll_read( -// self: Pin<&mut Self>, -// cx: &mut task::Context<'_>, -// buf: &mut [u8], -// ) -> task::Poll> { -// todo!() -// } -// } +impl Drop for VirtualUdpSocket { + fn drop(&mut self) { + self.machine + .router_client + .drop_udp_socket(self.machine.id, self.socket_id); + } +} diff --git a/veilid-tools/src/wasm.rs b/veilid-tools/src/wasm.rs index d658a609..108e4f0c 100644 --- a/veilid-tools/src/wasm.rs +++ b/veilid-tools/src/wasm.rs @@ -66,13 +66,6 @@ pub fn is_ipv6_supported() -> bool { if let Some(supp) = *opt_supp { return supp; } - // let supp = match UdpSocket::bind(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 0, 0, 0)) { - // Ok(_) => true, - // Err(e) => !matches!( - // e.kind(), - // std::io::ErrorKind::AddrNotAvailable | std::io::ErrorKind::Unsupported - // ), - // }; // XXX: See issue #92 let supp = false;