refactor network into veilid-tools

fix deadlock in debug command 'entries fastest'
This commit is contained in:
Christien Rioux 2024-11-06 15:44:43 -05:00
parent ad747d7831
commit 547427271c
27 changed files with 837 additions and 619 deletions

4
Cargo.lock generated
View File

@ -6299,7 +6299,6 @@ name = "veilid-core"
version = "0.4.1" version = "0.4.1"
dependencies = [ dependencies = [
"argon2", "argon2",
"async-io 1.13.0",
"async-std", "async-std",
"async-std-resolver", "async-std-resolver",
"async-tls", "async-tls",
@ -6366,7 +6365,6 @@ dependencies = [
"sha2 0.10.8", "sha2 0.10.8",
"shell-words", "shell-words",
"simplelog", "simplelog",
"socket2 0.5.7",
"static_assertions", "static_assertions",
"stop-token", "stop-token",
"sysinfo", "sysinfo",
@ -6518,6 +6516,7 @@ name = "veilid-tools"
version = "0.4.1" version = "0.4.1"
dependencies = [ dependencies = [
"android_logger 0.13.3", "android_logger 0.13.3",
"async-io 1.13.0",
"async-lock 3.4.0", "async-lock 3.4.0",
"async-std", "async-std",
"async_executors", "async_executors",
@ -6556,6 +6555,7 @@ dependencies = [
"serde", "serde",
"serial_test 2.0.0", "serial_test 2.0.0",
"simplelog", "simplelog",
"socket2 0.5.7",
"static_assertions", "static_assertions",
"stop-token", "stop-token",
"thiserror", "thiserror",

View File

@ -223,13 +223,12 @@ impl ClientApiConnection {
trace!("ClientApiConnection::handle_tcp_connection"); trace!("ClientApiConnection::handle_tcp_connection");
// Connect the TCP socket // Connect the TCP socket
let stream = TcpStream::connect(connect_addr) let stream = connect_async_tcp_stream(None, connect_addr, 10_000)
.await .await
.map_err(map_to_string)?
.into_timeout_error()
.map_err(map_to_string)?; .map_err(map_to_string)?;
// If it succeed, disable nagle algorithm
stream.set_nodelay(true).map_err(map_to_string)?;
// State we connected // State we connected
let comproc = self.inner.lock().comproc.clone(); let comproc = self.inner.lock().comproc.clone();
comproc.set_connection_state(ConnectionState::ConnectedTCP( comproc.set_connection_state(ConnectionState::ConnectedTCP(
@ -239,16 +238,8 @@ impl ClientApiConnection {
// Split into reader and writer halves // Split into reader and writer halves
// with line buffering on the reader // with line buffering on the reader
cfg_if! { let (reader, writer) = split_async_tcp_stream(stream);
if #[cfg(feature="rt-async-std")] {
use futures::AsyncReadExt;
let (reader, writer) = stream.split();
let reader = BufReader::new(reader); let reader = BufReader::new(reader);
} else {
let (reader, writer) = stream.into_split();
let reader = BufReader::new(reader);
}
}
self.clone().run_json_api_processor(reader, writer).await self.clone().run_json_api_processor(reader, writer).await
} }

View File

@ -56,6 +56,7 @@ veilid_core_ios_tests = ["dep:tracing-oslog"]
debug-locks = ["veilid-tools/debug-locks"] debug-locks = ["veilid-tools/debug-locks"]
unstable-blockstore = [] unstable-blockstore = []
unstable-tunnels = [] unstable-tunnels = []
virtual-network = []
# GeoIP # GeoIP
geolocation = ["maxminddb", "reqwest"] geolocation = ["maxminddb", "reqwest"]
@ -164,7 +165,6 @@ sysinfo = { version = "^0.30.13", default-features = false }
tokio = { version = "1.38.1", features = ["full"], optional = true } tokio = { version = "1.38.1", features = ["full"], optional = true }
tokio-util = { version = "0.7.11", features = ["compat"], optional = true } tokio-util = { version = "0.7.11", features = ["compat"], optional = true }
tokio-stream = { version = "0.1.15", features = ["net"], optional = true } tokio-stream = { version = "0.1.15", features = ["net"], optional = true }
async-io = { version = "1.13.0" }
futures-util = { version = "0.3.30", default-features = false, features = [ futures-util = { version = "0.3.30", default-features = false, features = [
"async-await", "async-await",
"sink", "sink",
@ -184,7 +184,6 @@ webpki = "0.22.4"
webpki-roots = "0.25.4" webpki-roots = "0.25.4"
rustls = "0.21.12" rustls = "0.21.12"
rustls-pemfile = "1.0.4" rustls-pemfile = "1.0.4"
socket2 = { version = "0.5.7", features = ["all"] }
# Dependencies for WASM builds only # Dependencies for WASM builds only
[target.'cfg(target_arch = "wasm32")'.dependencies] [target.'cfg(target_arch = "wasm32")'.dependencies]

View File

@ -1,6 +1,5 @@
use super::*; use super::*;
use async_tls::TlsAcceptor; use async_tls::TlsAcceptor;
use sockets::*;
use stop_token::future::FutureExt; use stop_token::future::FutureExt;
///////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////
@ -135,39 +134,12 @@ impl Network {
} }
}; };
#[cfg(all(feature = "rt-async-std", unix))] if let Err(e) = set_tcp_stream_linger(&tcp_stream, Some(core::time::Duration::from_secs(0)))
{ {
// async-std does not directly support linger on TcpStream yet
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd};
if let Err(e) = unsafe {
let s = socket2::Socket::from_raw_fd(tcp_stream.as_raw_fd());
let res = s.set_linger(Some(core::time::Duration::from_secs(0)));
s.into_raw_fd();
res
} {
log_net!(debug "Couldn't set TCP linger: {}", e);
return;
}
}
#[cfg(all(feature = "rt-async-std", windows))]
{
// async-std does not directly support linger on TcpStream yet
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket};
if let Err(e) = unsafe {
let s = socket2::Socket::from_raw_socket(tcp_stream.as_raw_socket());
let res = s.set_linger(Some(core::time::Duration::from_secs(0)));
s.into_raw_socket();
res
} {
log_net!(debug "Couldn't set TCP linger: {}", e);
return;
}
}
#[cfg(not(feature = "rt-async-std"))]
if let Err(e) = tcp_stream.set_linger(Some(core::time::Duration::from_secs(0))) {
log_net!(debug "Couldn't set TCP linger: {}", e); log_net!(debug "Couldn't set TCP linger: {}", e);
return; return;
} }
if let Err(e) = tcp_stream.set_nodelay(true) { if let Err(e) = tcp_stream.set_nodelay(true) {
log_net!(debug "Couldn't set TCP nodelay: {}", e); log_net!(debug "Couldn't set TCP nodelay: {}", e);
return; return;
@ -257,41 +229,11 @@ impl Network {
) )
}; };
// Create a socket and bind it
let Some(socket) = new_bound_default_tcp_socket(addr)
.wrap_err("failed to create default socket listener")?
else {
return Ok(false);
};
// Drop the socket
drop(socket);
// Create a shared socket and bind it once we have determined the port is free // Create a shared socket and bind it once we have determined the port is free
let Some(socket) = new_bound_shared_tcp_socket(addr) let Some(listener) = bind_async_tcp_listener(addr)? else {
.wrap_err("failed to create shared socket listener")?
else {
return Ok(false); return Ok(false);
}; };
// Listen on the socket
if socket.listen(128).is_err() {
return Ok(false);
}
// Make an async tcplistener from the socket2 socket
let std_listener: std::net::TcpListener = socket.into();
cfg_if! {
if #[cfg(feature="rt-async-std")] {
let listener = TcpListener::from(std_listener);
} else if #[cfg(feature="rt-tokio")] {
std_listener.set_nonblocking(true).expect("failed to set nonblocking");
let listener = TcpListener::from_std(std_listener).wrap_err("failed to create tokio tcp listener")?;
} else {
compile_error!("needs executor implementation");
}
}
log_net!(debug "spawn_socket_listener: binding successful to {}", addr); log_net!(debug "spawn_socket_listener: binding successful to {}", addr);
// Create protocol handler records // Create protocol handler records
@ -311,15 +253,7 @@ impl Network {
// moves listener object in and get incoming iterator // moves listener object in and get incoming iterator
// when this task exists, the listener will close the socket // when this task exists, the listener will close the socket
cfg_if! { let incoming_stream = async_tcp_listener_incoming(listener);
if #[cfg(feature="rt-async-std")] {
let incoming_stream = listener.incoming();
} else if #[cfg(feature="rt-tokio")] {
let incoming_stream = tokio_stream::wrappers::TcpListenerStream::new(listener);
} else {
compile_error!("needs executor implementation");
}
}
let _ = incoming_stream let _ = incoming_stream
.for_each_concurrent(None, |tcp_stream| { .for_each_concurrent(None, |tcp_stream| {

View File

@ -1,5 +1,4 @@
use super::*; use super::*;
use sockets::*;
use stop_token::future::FutureExt; use stop_token::future::FutureExt;
impl Network { impl Network {
@ -114,23 +113,10 @@ impl Network {
async fn create_udp_protocol_handler(&self, addr: SocketAddr) -> EyreResult<bool> { async fn create_udp_protocol_handler(&self, addr: SocketAddr) -> EyreResult<bool> {
log_net!(debug "create_udp_protocol_handler on {:?}", &addr); log_net!(debug "create_udp_protocol_handler on {:?}", &addr);
// Create a reusable socket // Create a single-address-family UDP socket with default options bound to an address
let Some(socket) = new_bound_default_udp_socket(addr)? else { let Some(udp_socket) = bind_async_udp_socket(addr)? else {
return Ok(false); return Ok(false);
}; };
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
cfg_if! {
if #[cfg(feature="rt-async-std")] {
let udp_socket = UdpSocket::from(std_udp_socket);
} else if #[cfg(feature="rt-tokio")] {
std_udp_socket.set_nonblocking(true).expect("failed to set nonblocking");
let udp_socket = UdpSocket::from_std(std_udp_socket).wrap_err("failed to make inbound tokio udpsocket")?;
} else {
compile_error!("needs executor implementation");
}
}
let socket_arc = Arc::new(udp_socket); let socket_arc = Arc::new(udp_socket);
// Create protocol handler // Create protocol handler

View File

@ -1,4 +1,3 @@
pub mod sockets;
pub mod tcp; pub mod tcp;
pub mod udp; pub mod udp;
pub mod wrtc; pub mod wrtc;

View File

@ -1,191 +0,0 @@
use crate::*;
use async_io::Async;
use std::io;
cfg_if! {
if #[cfg(feature="rt-async-std")] {
pub use async_std::net::{TcpStream, TcpListener, UdpSocket};
} else if #[cfg(feature="rt-tokio")] {
pub use tokio::net::{TcpStream, TcpListener, UdpSocket};
pub use tokio_util::compat::*;
} else {
compile_error!("needs executor implementation");
}
}
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
// cfg_if! {
// if #[cfg(windows)] {
// use winapi::shared::ws2def::{ SOL_SOCKET, SO_EXCLUSIVEADDRUSE};
// use winapi::um::winsock2::{SOCKET_ERROR, setsockopt};
// use winapi::ctypes::c_int;
// use std::os::windows::io::AsRawSocket;
// fn set_exclusiveaddruse(socket: &Socket) -> io::Result<()> {
// unsafe {
// let optval:c_int = 1;
// if setsockopt(socket.as_raw_socket().try_into().unwrap(), SOL_SOCKET, SO_EXCLUSIVEADDRUSE, (&optval as *const c_int).cast(),
// std::mem::size_of::<c_int>() as c_int) == SOCKET_ERROR {
// return Err(io::Error::last_os_error());
// }
// Ok(())
// }
// }
// }
// }
#[instrument(level = "trace", ret)]
pub fn new_shared_udp_socket(domain: Domain) -> io::Result<Socket> {
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
socket.set_reuse_address(true)?;
cfg_if! {
if #[cfg(unix)] {
socket.set_reuse_port(true)?;
}
}
Ok(socket)
}
#[instrument(level = "trace", ret)]
pub fn new_default_udp_socket(domain: Domain) -> io::Result<Socket> {
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
Ok(socket)
}
#[instrument(level = "trace", ret)]
pub fn new_bound_default_udp_socket(local_address: SocketAddr) -> io::Result<Option<Socket>> {
let domain = Domain::for_address(local_address);
let socket = new_default_udp_socket(domain)?;
let socket2_addr = SockAddr::from(local_address);
if socket.bind(&socket2_addr).is_err() {
return Ok(None);
}
log_net!("created bound default udp socket on {:?}", &local_address);
Ok(Some(socket))
}
#[instrument(level = "trace", ret)]
pub fn new_default_tcp_socket(domain: Domain) -> io::Result<Socket> {
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) {
log_net!(error "Couldn't set TCP linger: {}", e);
}
if let Err(e) = socket.set_nodelay(true) {
log_net!(error "Couldn't set TCP nodelay: {}", e);
}
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
Ok(socket)
}
#[instrument(level = "trace", ret)]
pub fn new_shared_tcp_socket(domain: Domain) -> io::Result<Socket> {
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) {
log_net!(error "Couldn't set TCP linger: {}", e);
}
if let Err(e) = socket.set_nodelay(true) {
log_net!(error "Couldn't set TCP nodelay: {}", e);
}
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
socket.set_reuse_address(true)?;
cfg_if! {
if #[cfg(unix)] {
socket.set_reuse_port(true)?;
}
}
Ok(socket)
}
#[instrument(level = "trace", ret)]
pub fn new_bound_default_tcp_socket(local_address: SocketAddr) -> io::Result<Option<Socket>> {
let domain = Domain::for_address(local_address);
let socket = new_default_tcp_socket(domain)?;
let socket2_addr = SockAddr::from(local_address);
if socket.bind(&socket2_addr).is_err() {
return Ok(None);
}
log_net!("created bound default tcp socket on {:?}", &local_address);
Ok(Some(socket))
}
#[instrument(level = "trace", ret)]
pub fn new_bound_shared_tcp_socket(local_address: SocketAddr) -> io::Result<Option<Socket>> {
let domain = Domain::for_address(local_address);
let socket = new_shared_tcp_socket(domain)?;
let socket2_addr = SockAddr::from(local_address);
if socket.bind(&socket2_addr).is_err() {
return Ok(None);
}
log_net!("created bound shared tcp socket on {:?}", &local_address);
Ok(Some(socket))
}
// Non-blocking connect is tricky when you want to start with a prepared socket
// Errors should not be logged as they are valid conditions for this function
#[instrument(level = "trace", ret)]
pub async fn nonblocking_connect(
socket: Socket,
addr: SocketAddr,
timeout_ms: u32,
) -> io::Result<TimeoutOr<TcpStream>> {
// Set for non blocking connect
socket.set_nonblocking(true)?;
// Make socket2 SockAddr
let socket2_addr = socket2::SockAddr::from(addr);
// Connect to the remote address
match socket.connect(&socket2_addr) {
Ok(()) => Ok(()),
#[cfg(unix)]
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(()),
Err(e) => Err(e),
}?;
let async_stream = Async::new(std::net::TcpStream::from(socket))?;
// The stream becomes writable when connected
timeout_or_try!(
timeout(timeout_ms, async_stream.writable().in_current_span())
.await
.into_timeout_or()
.into_result()?
);
// Check low level error
let async_stream = match async_stream.get_ref().take_error()? {
None => Ok(async_stream),
Some(err) => Err(err),
}?;
// Convert back to inner and then return async version
cfg_if! {
if #[cfg(feature="rt-async-std")] {
Ok(TimeoutOr::value(TcpStream::from(async_stream.into_inner()?)))
} else if #[cfg(feature="rt-tokio")] {
Ok(TimeoutOr::value(TcpStream::from_std(async_stream.into_inner()?)?))
} else {
compile_error!("needs executor implementation");
}
}
}

View File

@ -1,6 +1,5 @@
use super::*; use super::*;
use futures_util::{AsyncReadExt, AsyncWriteExt}; use futures_util::{AsyncReadExt, AsyncWriteExt};
use sockets::*;
pub struct RawTcpNetworkConnection { pub struct RawTcpNetworkConnection {
flow: Flow, flow: Flow,
@ -157,32 +156,28 @@ impl RawTcpProtocolHandler {
#[instrument(level = "trace", target = "protocol", err)] #[instrument(level = "trace", target = "protocol", err)]
pub async fn connect( pub async fn connect(
local_address: Option<SocketAddr>, local_address: Option<SocketAddr>,
socket_addr: SocketAddr, remote_address: SocketAddr,
timeout_ms: u32, timeout_ms: u32,
) -> io::Result<NetworkResult<ProtocolNetworkConnection>> { ) -> io::Result<NetworkResult<ProtocolNetworkConnection>> {
// Make a shared socket
let socket = match local_address {
Some(a) => {
new_bound_shared_tcp_socket(a)?.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?
}
None => new_default_tcp_socket(socket2::Domain::for_address(socket_addr))?,
};
// Non-blocking connect to remote address // Non-blocking connect to remote address
let ts = network_result_try!(nonblocking_connect(socket, socket_addr, timeout_ms) let tcp_stream = network_result_try!(connect_async_tcp_stream(
local_address,
remote_address,
timeout_ms
)
.await .await
.folded()?); .folded()?);
// See what local address we ended up with and turn this into a stream // See what local address we ended up with and turn this into a stream
let actual_local_address = ts.local_addr()?; let actual_local_address = tcp_stream.local_addr()?;
#[cfg(feature = "rt-tokio")] #[cfg(feature = "rt-tokio")]
let ts = ts.compat(); let tcp_stream = tcp_stream.compat();
let ps = AsyncPeekStream::new(ts); let ps = AsyncPeekStream::new(tcp_stream);
// Wrap the stream in a network connection and return it // Wrap the stream in a network connection and return it
let flow = Flow::new( let flow = Flow::new(
PeerAddress::new( PeerAddress::new(
SocketAddress::from_socket_addr(socket_addr), SocketAddress::from_socket_addr(remote_address),
ProtocolType::TCP, ProtocolType::TCP,
), ),
SocketAddress::from_socket_addr(actual_local_address), SocketAddress::from_socket_addr(actual_local_address),

View File

@ -1,5 +1,4 @@
use super::*; use super::*;
use sockets::*;
#[derive(Clone)] #[derive(Clone)]
pub struct RawUdpProtocolHandler { pub struct RawUdpProtocolHandler {
@ -141,7 +140,8 @@ impl RawUdpProtocolHandler {
) -> io::Result<RawUdpProtocolHandler> { ) -> io::Result<RawUdpProtocolHandler> {
// get local wildcard address for bind // get local wildcard address for bind
let local_socket_addr = compatible_unspecified_socket_addr(socket_addr); let local_socket_addr = compatible_unspecified_socket_addr(socket_addr);
let socket = UdpSocket::bind(local_socket_addr).await?; let socket = bind_async_udp_socket(local_socket_addr)?
.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?;
Ok(RawUdpProtocolHandler::new(Arc::new(socket), None)) Ok(RawUdpProtocolHandler::new(Arc::new(socket), None))
} }
} }

View File

@ -10,7 +10,6 @@ use async_tungstenite::tungstenite::protocol::{frame::coding::CloseCode, CloseFr
use async_tungstenite::tungstenite::Error; use async_tungstenite::tungstenite::Error;
use async_tungstenite::{accept_hdr_async, client_async, WebSocketStream}; use async_tungstenite::{accept_hdr_async, client_async, WebSocketStream};
use futures_util::{AsyncRead, AsyncWrite, SinkExt}; use futures_util::{AsyncRead, AsyncWrite, SinkExt};
use sockets::*;
// Maximum number of websocket request headers to permit // Maximum number of websocket request headers to permit
const MAX_WS_HEADERS: usize = 24; const MAX_WS_HEADERS: usize = 24;
@ -316,19 +315,14 @@ impl WebsocketProtocolHandler {
let domain = split_url.host.clone(); let domain = split_url.host.clone();
// Resolve remote address // Resolve remote address
let remote_socket_addr = dial_info.to_socket_addr(); let remote_address = dial_info.to_socket_addr();
// Make a shared socket
let socket = match local_address {
Some(a) => {
new_bound_shared_tcp_socket(a)?.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?
}
None => new_default_tcp_socket(socket2::Domain::for_address(remote_socket_addr))?,
};
// Non-blocking connect to remote address // Non-blocking connect to remote address
let tcp_stream = let tcp_stream = network_result_try!(connect_async_tcp_stream(
network_result_try!(nonblocking_connect(socket, remote_socket_addr, timeout_ms) local_address,
remote_address,
timeout_ms
)
.await .await
.folded()?); .folded()?);

View File

@ -333,9 +333,10 @@ impl RoutingTable {
relaying_count += 1; relaying_count += 1;
} }
let best_node_id = node.best_node_id();
out += " "; out += " ";
out += &node out += &node.operate(|_rti, e| Self::format_entry(cur_ts, best_node_id, e, &relay_tag));
.operate(|_rti, e| Self::format_entry(cur_ts, node.best_node_id(), e, &relay_tag));
out += "\n"; out += "\n";
} }

View File

@ -165,17 +165,12 @@ impl ClientApi {
} }
async fn handle_tcp_incoming(self, bind_addr: SocketAddr) -> std::io::Result<()> { async fn handle_tcp_incoming(self, bind_addr: SocketAddr) -> std::io::Result<()> {
let listener = TcpListener::bind(bind_addr).await?; let listener = bind_async_tcp_listener(bind_addr)?
.ok_or(std::io::Error::from(std::io::ErrorKind::AddrInUse))?;
debug!(target: "client_api", "TCPClient API listening on: {:?}", bind_addr); debug!(target: "client_api", "TCPClient API listening on: {:?}", bind_addr);
// Process the incoming accept stream // Process the incoming accept stream
cfg_if! { let mut incoming_stream = async_tcp_listener_incoming(listener);
if #[cfg(feature="rt-async-std")] {
let mut incoming_stream = listener.incoming();
} else {
let mut incoming_stream = tokio_stream::wrappers::TcpListenerStream::new(listener);
}
}
// Make wait group for all incoming connections // Make wait group for all incoming connections
let awg = AsyncWaitGroup::new(); let awg = AsyncWaitGroup::new();

View File

@ -7,8 +7,6 @@ pub use tracing::*;
cfg_if! { cfg_if! {
if #[cfg(feature="rt-async-std")] { if #[cfg(feature="rt-async-std")] {
// pub use async_std::task::JoinHandle; // pub use async_std::task::JoinHandle;
pub use async_std::net::TcpListener;
pub use async_std::net::TcpStream;
pub use async_std::io::BufReader; pub use async_std::io::BufReader;
//pub use async_std::future::TimeoutError; //pub use async_std::future::TimeoutError;
//pub fn spawn_detached<F: Future<Output = T> + Send + 'static, T: Send + 'static>(f: F) -> JoinHandle<T> { //pub fn spawn_detached<F: Future<Output = T> + Send + 'static, T: Send + 'static>(f: F) -> JoinHandle<T> {
@ -27,8 +25,6 @@ cfg_if! {
} }
} else if #[cfg(feature="rt-tokio")] { } else if #[cfg(feature="rt-tokio")] {
//pub use tokio::task::JoinHandle; //pub use tokio::task::JoinHandle;
pub use tokio::net::TcpListener;
pub use tokio::net::TcpStream;
pub use tokio::io::BufReader; pub use tokio::io::BufReader;
//pub use tokio_util::compat::*; //pub use tokio_util::compat::*;
//pub use tokio::time::error::Elapsed as TimeoutError; //pub use tokio::time::error::Elapsed as TimeoutError;

View File

@ -76,20 +76,21 @@ flume = { version = "0.11.0", features = ["async"] }
# Dependencies for native builds only # Dependencies for native builds only
# Linux, Windows, Mac, iOS, Android # Linux, Windows, Mac, iOS, Android
[target.'cfg(not(target_arch = "wasm32"))'.dependencies] [target.'cfg(not(target_arch = "wasm32"))'.dependencies]
async-io = { version = "1.13.0" }
async-std = { version = "1.12.0", features = ["unstable"], optional = true } async-std = { version = "1.12.0", features = ["unstable"], optional = true }
tokio = { version = "1.38.1", features = ["full"], optional = true } chrono = "0.4.38"
tokio-util = { version = "0.7.11", features = ["compat"], optional = true }
tokio-stream = { version = "0.1.15", features = ["net"], optional = true }
futures-util = { version = "0.3.30", default-features = false, features = [ futures-util = { version = "0.3.30", default-features = false, features = [
"async-await", "async-await",
"sink", "sink",
"std", "std",
"io", "io",
] } ] }
chrono = "0.4.38"
libc = "0.2.155" libc = "0.2.155"
nix = { version = "0.27.1", features = ["user"] } nix = { version = "0.27.1", features = ["user"] }
socket2 = { version = "0.5.7", features = ["all"] }
tokio = { version = "1.38.1", features = ["full"], optional = true }
tokio-util = { version = "0.7.11", features = ["compat"], optional = true }
tokio-stream = { version = "0.1.15", features = ["net"], optional = true }
# Dependencies for WASM builds only # Dependencies for WASM builds only
[target.'cfg(target_arch = "wasm32")'.dependencies] [target.'cfg(target_arch = "wasm32")'.dependencies]

View File

@ -1,108 +0,0 @@
use super::*;
cfg_if! {
if #[cfg(target_arch = "wasm32")] {
} else {
use std::net::{TcpListener, UdpSocket};
}
}
#[derive(ThisError, Debug, Clone, PartialEq, Eq)]
pub enum BumpPortError {
#[error("Unsupported architecture")]
Unsupported,
#[error("Failure: {0}")]
Failed(String),
}
pub enum BumpPortType {
UDP,
TCP,
}
pub fn tcp_port_available(addr: &SocketAddr) -> bool {
cfg_if! {
if #[cfg(target_arch = "wasm32")] {
true
} else {
match TcpListener::bind(addr) {
Ok(_) => true,
Err(_) => false,
}
}
}
}
pub fn udp_port_available(addr: &SocketAddr) -> bool {
cfg_if! {
if #[cfg(target_arch = "wasm32")] {
true
} else {
match UdpSocket::bind(addr) {
Ok(_) => true,
Err(_) => false,
}
}
}
}
pub fn bump_port(addr: &mut SocketAddr, bpt: BumpPortType) -> Result<bool, BumpPortError> {
cfg_if! {
if #[cfg(target_arch = "wasm32")] {
Err(BumpPortError::Unsupported)
}
else
{
let mut bumped = false;
let mut port = addr.port();
let mut addr_bump = addr.clone();
loop {
if match bpt {
BumpPortType::TCP => tcp_port_available(&addr_bump),
BumpPortType::UDP => udp_port_available(&addr_bump),
} {
*addr = addr_bump;
return Ok(bumped);
}
if port == u16::MAX {
break;
}
port += 1;
addr_bump.set_port(port);
bumped = true;
}
Err(BumpPortError::Failure("no ports remaining".to_owned()))
}
}
}
pub fn bump_port_string(addr: &mut String, bpt: BumpPortType) -> Result<bool, BumpPortError> {
cfg_if! {
if #[cfg(target_arch = "wasm32")] {
return Err(BumpPortError::Unsupported);
}
else
{
let savec: Vec<SocketAddr> = addr
.to_socket_addrs()
.map_err(|x| BumpPortError::Failure(format!("failed to resolve socket address: {}", x)))?
.collect();
if savec.len() == 0 {
return Err(BumpPortError::Failure("No socket addresses resolved".to_owned()));
}
let mut sa = savec.first().unwrap().clone();
if !bump_port(&mut sa, bpt)? {
return Ok(false);
}
*addr = sa.to_string();
Ok(true)
}
}
}

View File

@ -24,7 +24,6 @@
#![allow(clippy::comparison_chain, clippy::upper_case_acronyms)] #![allow(clippy::comparison_chain, clippy::upper_case_acronyms)]
#![deny(unused_must_use)] #![deny(unused_must_use)]
// pub mod bump_port;
pub mod assembly_buffer; pub mod assembly_buffer;
pub mod async_peek_stream; pub mod async_peek_stream;
pub mod async_tag_lock; pub mod async_tag_lock;
@ -49,6 +48,8 @@ pub mod network_result;
pub mod random; pub mod random;
pub mod single_shot_eventual; pub mod single_shot_eventual;
pub mod sleep; pub mod sleep;
#[cfg(not(target_arch = "wasm32"))]
pub mod socket_tools;
pub mod spawn; pub mod spawn;
pub mod split_url; pub mod split_url;
pub mod startup_lock; pub mod startup_lock;
@ -184,7 +185,6 @@ cfg_if! {
} }
} }
// pub use bump_port::*;
#[doc(inline)] #[doc(inline)]
pub use assembly_buffer::*; pub use assembly_buffer::*;
#[doc(inline)] #[doc(inline)]
@ -233,6 +233,9 @@ pub use single_shot_eventual::*;
#[doc(inline)] #[doc(inline)]
pub use sleep::*; pub use sleep::*;
#[doc(inline)] #[doc(inline)]
#[cfg(not(target_arch = "wasm32"))]
pub use socket_tools::*;
#[doc(inline)]
pub use spawn::*; pub use spawn::*;
#[doc(inline)] #[doc(inline)]
pub use split_url::*; pub use split_url::*;

View File

@ -0,0 +1,284 @@
use super::*;
use async_io::Async;
use std::io;
cfg_if! {
if #[cfg(feature="rt-async-std")] {
pub use async_std::net::{TcpStream, TcpListener, UdpSocket};
} else if #[cfg(feature="rt-tokio")] {
pub use tokio::net::{TcpStream, TcpListener, UdpSocket};
pub use tokio_util::compat::*;
} else {
compile_error!("needs executor implementation");
}
}
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
//////////////////////////////////////////////////////////////////////////////////////////
pub fn bind_async_udp_socket(local_address: SocketAddr) -> io::Result<Option<UdpSocket>> {
let Some(socket) = new_bound_default_socket2_udp(local_address)? else {
return Ok(None);
};
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
cfg_if! {
if #[cfg(feature="rt-async-std")] {
let udp_socket = UdpSocket::from(std_udp_socket);
} else if #[cfg(feature="rt-tokio")] {
std_udp_socket.set_nonblocking(true)?;
let udp_socket = UdpSocket::from_std(std_udp_socket)?;
} else {
compile_error!("needs executor implementation");
}
}
Ok(Some(udp_socket))
}
pub fn bind_async_tcp_listener(local_address: SocketAddr) -> io::Result<Option<TcpListener>> {
// Create a default non-shared socket and bind it
let Some(socket) = new_bound_default_socket2_tcp(local_address)? else {
return Ok(None);
};
// Drop the socket so we can make another shared socket in its place
drop(socket);
// Create a shared socket and bind it now we have determined the port is free
let Some(socket) = new_bound_shared_socket2_tcp(local_address)? else {
return Ok(None);
};
// Listen on the socket
if socket.listen(128).is_err() {
return Ok(None);
}
// Make an async tcplistener from the socket2 socket
let std_listener: std::net::TcpListener = socket.into();
cfg_if! {
if #[cfg(feature="rt-async-std")] {
let listener = TcpListener::from(std_listener);
} else if #[cfg(feature="rt-tokio")] {
std_listener.set_nonblocking(true)?;
let listener = TcpListener::from_std(std_listener)?;
} else {
compile_error!("needs executor implementation");
}
}
Ok(Some(listener))
}
pub async fn connect_async_tcp_stream(
local_address: Option<SocketAddr>,
remote_address: SocketAddr,
timeout_ms: u32,
) -> io::Result<TimeoutOr<TcpStream>> {
let socket = match local_address {
Some(a) => {
new_bound_shared_socket2_tcp(a)?.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?
}
None => new_default_socket2_tcp(domain_for_address(remote_address))?,
};
// Non-blocking connect to remote address
nonblocking_connect(socket, remote_address, timeout_ms).await
}
pub fn set_tcp_stream_linger(
tcp_stream: &TcpStream,
linger: Option<core::time::Duration>,
) -> io::Result<()> {
#[cfg(all(feature = "rt-async-std", unix))]
{
// async-std does not directly support linger on TcpStream yet
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd};
unsafe {
let s = socket2::Socket::from_raw_fd(tcp_stream.as_raw_fd());
let res = s.set_linger(linger);
s.into_raw_fd();
res
}
}
#[cfg(all(feature = "rt-async-std", windows))]
{
// async-std does not directly support linger on TcpStream yet
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket};
unsafe {
let s = socket2::Socket::from_raw_socket(tcp_stream.as_raw_socket());
let res = s.set_linger(linger);
s.into_raw_socket();
res
}
}
#[cfg(not(feature = "rt-async-std"))]
tcp_stream.set_linger(linger)
}
cfg_if! {
if #[cfg(feature="rt-async-std")] {
pub type IncomingStream = Incoming;
pub type ReadHalf = futures_util::io::ReadHalf<TcpStream>,
pub type WriteHalf = futures_util::io::WriteHalf<TcpStream>,
} else if #[cfg(feature="rt-tokio")] {
pub type IncomingStream = tokio_stream::wrappers::TcpListenerStream;
pub type ReadHalf = tokio::net::tcp::OwnedReadHalf;
pub type WriteHalf = tokio::net::tcp::OwnedWriteHalf;
} else {
compile_error!("needs executor implementation");
}
}
pub fn async_tcp_listener_incoming(tcp_listener: TcpListener) -> IncomingStream {
cfg_if! {
if #[cfg(feature="rt-async-std")] {
tcp_listener.incoming()
} else if #[cfg(feature="rt-tokio")] {
tokio_stream::wrappers::TcpListenerStream::new(tcp_listener)
} else {
compile_error!("needs executor implementation");
}
}
}
pub fn split_async_tcp_stream(tcp_stream: TcpStream) -> (ReadHalf, WriteHalf) {
cfg_if! {
if #[cfg(feature="rt-async-std")] {
use futures_util::AsyncReadExt;
tcp_stream.split()
} else if #[cfg(feature="rt-tokio")] {
tcp_stream.into_split()
} else {
compile_error!("needs executor implementation");
}
}
}
//////////////////////////////////////////////////////////////////////////////////////////
fn new_default_udp_socket(domain: core::ffi::c_int) -> io::Result<Socket> {
let domain = Domain::from(domain);
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
Ok(socket)
}
fn new_bound_default_socket2_udp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
let domain = domain_for_address(local_address);
let socket = new_default_udp_socket(domain)?;
let socket2_addr = SockAddr::from(local_address);
if socket.bind(&socket2_addr).is_err() {
return Ok(None);
}
Ok(Some(socket))
}
pub fn new_default_socket2_tcp(domain: core::ffi::c_int) -> io::Result<Socket> {
let domain = Domain::from(domain);
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
socket.set_linger(Some(core::time::Duration::from_secs(0)))?;
socket.set_nodelay(true)?;
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
Ok(socket)
}
fn new_shared_socket2_tcp(domain: core::ffi::c_int) -> io::Result<Socket> {
let domain = Domain::from(domain);
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
socket.set_linger(Some(core::time::Duration::from_secs(0)))?;
socket.set_nodelay(true)?;
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
socket.set_reuse_address(true)?;
cfg_if! {
if #[cfg(unix)] {
socket.set_reuse_port(true)?;
}
}
Ok(socket)
}
fn new_bound_default_socket2_tcp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
let domain = domain_for_address(local_address);
let socket = new_default_socket2_tcp(domain)?;
let socket2_addr = SockAddr::from(local_address);
if socket.bind(&socket2_addr).is_err() {
return Ok(None);
}
Ok(Some(socket))
}
fn new_bound_shared_socket2_tcp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
// Create the reuseaddr/reuseport socket now that we've asserted the port is free
let domain = domain_for_address(local_address);
let socket = new_shared_socket2_tcp(domain)?;
let socket2_addr = SockAddr::from(local_address);
if socket.bind(&socket2_addr).is_err() {
return Ok(None);
}
Ok(Some(socket))
}
// Non-blocking connect is tricky when you want to start with a prepared socket
// Errors should not be logged as they are valid conditions for this function
async fn nonblocking_connect(
socket: Socket,
addr: SocketAddr,
timeout_ms: u32,
) -> io::Result<TimeoutOr<TcpStream>> {
// Set for non blocking connect
socket.set_nonblocking(true)?;
// Make socket2 SockAddr
let socket2_addr = socket2::SockAddr::from(addr);
// Connect to the remote address
match socket.connect(&socket2_addr) {
Ok(()) => Ok(()),
#[cfg(unix)]
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(()),
Err(e) => Err(e),
}?;
let async_stream = Async::new(std::net::TcpStream::from(socket))?;
// The stream becomes writable when connected
timeout_or_try!(timeout(timeout_ms, async_stream.writable())
.await
.into_timeout_or()
.into_result()?);
// Check low level error
let async_stream = match async_stream.get_ref().take_error()? {
None => Ok(async_stream),
Some(err) => Err(err),
}?;
// Convert back to inner and then return async version
cfg_if! {
if #[cfg(feature="rt-async-std")] {
Ok(TimeoutOr::value(TcpStream::from(async_stream.into_inner()?)))
} else if #[cfg(feature="rt-tokio")] {
Ok(TimeoutOr::value(TcpStream::from_std(async_stream.into_inner()?)?))
} else {
compile_error!("needs executor implementation");
}
}
}
pub fn domain_for_address(address: SocketAddr) -> core::ffi::c_int {
socket2::Domain::for_address(address).into()
}

View File

@ -8,7 +8,6 @@ cfg_if! {
} else if #[cfg(feature="rt-tokio")] { } else if #[cfg(feature="rt-tokio")] {
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::time::sleep; use tokio::time::sleep;
use tokio_util::compat::*;
} }
} }

View File

@ -2,7 +2,7 @@ use super::*;
pub type MachineId = u64; pub type MachineId = u64;
#[derive(Clone)] #[derive(Debug, Clone)]
pub struct Machine { pub struct Machine {
pub router_client: RouterClient, pub router_client: RouterClient,
pub id: MachineId, pub id: MachineId,

View File

@ -25,6 +25,7 @@
//! //!
//! [VirtualTcpStream] //! [VirtualTcpStream]
//! [VirtualUdpSocket] //! [VirtualUdpSocket]
//! [VirtualTcpListener]
//! [VirtualTcpListenerStream] //! [VirtualTcpListenerStream]
//! [VirtualWsMeta] //! [VirtualWsMeta]
//! [VirtualWsStream] //! [VirtualWsStream]
@ -44,6 +45,8 @@ mod router_op_table;
mod router_server; mod router_server;
mod serde_io_error; mod serde_io_error;
mod virtual_network_error; mod virtual_network_error;
mod virtual_tcp_listener;
mod virtual_tcp_listener_stream;
mod virtual_tcp_stream; mod virtual_tcp_stream;
mod virtual_udp_socket; mod virtual_udp_socket;
@ -53,5 +56,7 @@ pub use machine::*;
pub use router_client::*; pub use router_client::*;
pub use router_server::*; pub use router_server::*;
pub use virtual_network_error::*; pub use virtual_network_error::*;
pub use virtual_tcp_listener::*;
pub use virtual_tcp_listener_stream::*;
pub use virtual_tcp_stream::*; pub use virtual_tcp_stream::*;
pub use virtual_udp_socket::*; pub use virtual_udp_socket::*;

View File

@ -2,8 +2,7 @@ use super::*;
use core::sync::atomic::AtomicU64; use core::sync::atomic::AtomicU64;
use futures_codec::{Bytes, BytesCodec, FramedRead, FramedWrite}; use futures_codec::{Bytes, BytesCodec, FramedRead, FramedWrite};
use futures_util::{ use futures_util::{
io::BufReader, stream::FuturesUnordered, AsyncReadExt, AsyncWriteExt, FutureExt, SinkExt, stream::FuturesUnordered, AsyncReadExt, AsyncWriteExt, StreamExt, TryStreamExt,
StreamExt, TryStreamExt,
}; };
use postcard::{from_bytes, to_stdvec}; use postcard::{from_bytes, to_stdvec};
use router_op_table::*; use router_op_table::*;
@ -16,11 +15,29 @@ struct RouterClientInner {
stop_source: Option<StopSource>, stop_source: Option<StopSource>,
} }
impl fmt::Debug for RouterClientInner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RouterClientInner")
.field("jh_handler", &self.jh_handler)
.field("stop_source", &self.stop_source)
.finish()
}
}
struct RouterClientUnlockedInner { struct RouterClientUnlockedInner {
receiver: flume::Receiver<Bytes>,
sender: flume::Sender<Bytes>, sender: flume::Sender<Bytes>,
next_message_id: AtomicU64, next_message_id: AtomicU64,
router_op_waiter: RouterOpWaiter<ServerProcessorResponseStatus, ()>, router_op_waiter: RouterOpWaiter<ServerProcessorReplyResult, ()>,
}
impl fmt::Debug for RouterClientUnlockedInner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RouterClientUnlockedInner")
.field("sender", &self.sender)
.field("next_message_id", &self.next_message_id)
.field("router_op_waiter", &self.router_op_waiter)
.finish()
}
} }
pub type MessageId = u64; pub type MessageId = u64;
@ -49,9 +66,9 @@ enum ServerProcessorRequest {
}, },
TcpAccept { TcpAccept {
machine_id: MachineId, machine_id: MachineId,
socket_id: SocketId, listen_socket_id: SocketId,
}, },
Close { TcpShutdown {
machine_id: MachineId, machine_id: MachineId,
socket_id: SocketId, socket_id: SocketId,
}, },
@ -84,13 +101,22 @@ enum ServerProcessorRequest {
} }
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
struct ServerProcessorRequestMessage { struct ServerProcessorMessage {
message_id: MessageId, message_id: MessageId,
request: ServerProcessorRequest, request: ServerProcessorRequest,
} }
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
enum ServerProcessorResponse { enum ServerProcessorCommand {
Message(ServerProcessorMessage),
CloseSocket {
machine_id: MachineId,
socket_id: SocketId,
},
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
enum ServerProcessorReplyValue {
AllocateMachine { AllocateMachine {
machine_id: MachineId, machine_id: MachineId,
}, },
@ -100,17 +126,21 @@ enum ServerProcessorResponse {
}, },
TcpConnect { TcpConnect {
socket_id: SocketId, socket_id: SocketId,
local_address: SocketAddr,
}, },
TcpBind { TcpBind {
socket_id: SocketId, socket_id: SocketId,
local_address: SocketAddr,
}, },
TcpAccept { TcpAccept {
child_socket_id: SocketId, socket_id: SocketId,
address: SocketAddr,
}, },
TcpShutdown,
UdpBind { UdpBind {
socket_id: SocketId, socket_id: SocketId,
local_address: SocketAddr,
}, },
Close,
Send { Send {
len: u32, len: u32,
}, },
@ -127,20 +157,29 @@ enum ServerProcessorResponse {
} }
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
enum ServerProcessorResponseStatus { enum ServerProcessorReplyResult {
Success(ServerProcessorResponse), Value(ServerProcessorReplyValue),
InvalidMachineId, InvalidMachineId,
InvalidSocketId, InvalidSocketId,
IoError(#[serde(with = "serde_io_error::SerdeIoErrorKindDef")] io::ErrorKind), IoError(#[serde(with = "serde_io_error::SerdeIoErrorKindDef")] io::ErrorKind),
} }
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
struct ServerProcessorResponseMessage { struct ServerProcessorReply {
message_id: MessageId, message_id: MessageId,
status: ServerProcessorResponseStatus, status: ServerProcessorReplyResult,
} }
#[derive(Clone)] #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
enum ServerProcessorEvent {
Reply(ServerProcessorReply),
// DeadSocket {
// machine_id: MachineId,
// socket_id: SocketId,
// },
}
#[derive(Debug, Clone)]
pub struct RouterClient { pub struct RouterClient {
unlocked_inner: Arc<RouterClientUnlockedInner>, unlocked_inner: Arc<RouterClientUnlockedInner>,
inner: Arc<Mutex<RouterClientInner>>, inner: Arc<Mutex<RouterClientInner>>,
@ -151,7 +190,7 @@ impl RouterClient {
// Public interface // Public interface
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
pub async fn router_connect_tcp<H: ToSocketAddrs>(host: H) -> ::std::io::Result<RouterClient> { pub async fn router_connect_tcp<H: ToSocketAddrs>(host: H) -> io::Result<RouterClient> {
let addrs = host.to_socket_addrs()?.collect::<Vec<_>>(); let addrs = host.to_socket_addrs()?.collect::<Vec<_>>();
// Connect to RouterServer // Connect to RouterServer
@ -173,37 +212,37 @@ impl RouterClient {
} }
} }
let ts_buf_reader = BufReader::new(ts_reader);
// Create channels // Create channels
let (client_sender, server_receiver) = flume::unbounded::<Bytes>(); let (client_sender, server_receiver) = flume::unbounded::<Bytes>();
let (server_sender, client_receiver) = flume::unbounded::<Bytes>();
// Create stopper // Create stopper
let stop_source = StopSource::new(); let stop_source = StopSource::new();
// Create router operation waiter
let router_op_waiter = RouterOpWaiter::new();
// Spawn a client connection handler // Spawn a client connection handler
let jh_handler = spawn( let jh_handler = spawn(
"RouterClient server processor", "RouterClient server processor",
Self::run_server_processor( Self::run_server_processor(
ts_buf_reader, ts_reader,
ts_writer, ts_writer,
server_receiver, server_receiver,
server_sender, router_op_waiter.clone(),
stop_source.token(), stop_source.token(),
), ),
); );
Ok(Self::new( Ok(Self::new(
client_receiver,
client_sender, client_sender,
router_op_waiter,
jh_handler, jh_handler,
stop_source, stop_source,
)) ))
} }
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
pub async fn router_connect_ws<H: AsRef<str>>(host: H) -> ::std::io::Result<RouterClient> { pub async fn router_connect_ws<H: AsRef<str>>(host: H) -> io::Result<RouterClient> {
let host = host.as_ref(); let host = host.as_ref();
Ok(RouterClient {}) Ok(RouterClient {})
@ -219,7 +258,7 @@ impl RouterClient {
pub async fn allocate_machine(self) -> VirtualNetworkResult<MachineId> { pub async fn allocate_machine(self) -> VirtualNetworkResult<MachineId> {
let request = ServerProcessorRequest::AllocateMachine; let request = ServerProcessorRequest::AllocateMachine;
let ServerProcessorResponse::AllocateMachine { machine_id } = let ServerProcessorReplyValue::AllocateMachine { machine_id } =
self.perform_request(request).await? self.perform_request(request).await?
else { else {
return Err(VirtualNetworkError::ResponseMismatch); return Err(VirtualNetworkError::ResponseMismatch);
@ -229,7 +268,7 @@ impl RouterClient {
pub async fn release_machine(self, machine_id: MachineId) -> VirtualNetworkResult<()> { pub async fn release_machine(self, machine_id: MachineId) -> VirtualNetworkResult<()> {
let request = ServerProcessorRequest::ReleaseMachine { machine_id }; let request = ServerProcessorRequest::ReleaseMachine { machine_id };
let ServerProcessorResponse::ReleaseMachine = self.perform_request(request).await? else { let ServerProcessorReplyValue::ReleaseMachine = self.perform_request(request).await? else {
return Err(VirtualNetworkError::ResponseMismatch); return Err(VirtualNetworkError::ResponseMismatch);
}; };
Ok(()) Ok(())
@ -240,7 +279,7 @@ impl RouterClient {
machine_id: MachineId, machine_id: MachineId,
) -> VirtualNetworkResult<BTreeMap<String, NetworkInterface>> { ) -> VirtualNetworkResult<BTreeMap<String, NetworkInterface>> {
let request = ServerProcessorRequest::GetInterfaces { machine_id }; let request = ServerProcessorRequest::GetInterfaces { machine_id };
let ServerProcessorResponse::GetInterfaces { interfaces } = let ServerProcessorReplyValue::GetInterfaces { interfaces } =
self.perform_request(request).await? self.perform_request(request).await?
else { else {
return Err(VirtualNetworkError::ResponseMismatch); return Err(VirtualNetworkError::ResponseMismatch);
@ -252,91 +291,99 @@ impl RouterClient {
self, self,
machine_id: MachineId, machine_id: MachineId,
remote_address: SocketAddr, remote_address: SocketAddr,
local_address: Option<SocketAddr>, opt_local_address: Option<SocketAddr>,
timeout_ms: u32, timeout_ms: u32,
options: VirtualTcpOptions, options: VirtualTcpOptions,
) -> VirtualNetworkResult<SocketId> { ) -> VirtualNetworkResult<(SocketId, SocketAddr)> {
let request = ServerProcessorRequest::TcpConnect { let request = ServerProcessorRequest::TcpConnect {
machine_id, machine_id,
local_address, local_address: opt_local_address,
remote_address, remote_address,
timeout_ms, timeout_ms,
options, options,
}; };
let ServerProcessorResponse::TcpConnect { socket_id } = let ServerProcessorReplyValue::TcpConnect {
self.perform_request(request).await? socket_id,
local_address,
} = self.perform_request(request).await?
else { else {
return Err(VirtualNetworkError::ResponseMismatch); return Err(VirtualNetworkError::ResponseMismatch);
}; };
Ok(socket_id) Ok((socket_id, local_address))
} }
pub async fn tcp_bind( pub async fn tcp_bind(
self, self,
machine_id: MachineId, machine_id: MachineId,
local_address: Option<SocketAddr>, opt_local_address: Option<SocketAddr>,
options: VirtualTcpOptions, options: VirtualTcpOptions,
) -> VirtualNetworkResult<SocketId> { ) -> VirtualNetworkResult<(SocketId, SocketAddr)> {
let request = ServerProcessorRequest::TcpBind { let request = ServerProcessorRequest::TcpBind {
machine_id, machine_id,
local_address, local_address: opt_local_address,
options, options,
}; };
let ServerProcessorResponse::TcpBind { socket_id } = self.perform_request(request).await? let ServerProcessorReplyValue::TcpBind {
socket_id,
local_address,
} = self.perform_request(request).await?
else { else {
return Err(VirtualNetworkError::ResponseMismatch); return Err(VirtualNetworkError::ResponseMismatch);
}; };
Ok(socket_id) Ok((socket_id, local_address))
} }
pub async fn tcp_accept( pub async fn tcp_accept(
self, self,
machine_id: MachineId, machine_id: MachineId,
socket_id: SocketId, listen_socket_id: SocketId,
) -> VirtualNetworkResult<SocketId> { ) -> VirtualNetworkResult<(SocketId, SocketAddr)> {
let request = ServerProcessorRequest::TcpAccept { let request = ServerProcessorRequest::TcpAccept {
machine_id, machine_id,
socket_id, listen_socket_id,
}; };
let ServerProcessorResponse::TcpAccept { child_socket_id } = let ServerProcessorReplyValue::TcpAccept { socket_id, address } =
self.perform_request(request).await? self.perform_request(request).await?
else { else {
return Err(VirtualNetworkError::ResponseMismatch); return Err(VirtualNetworkError::ResponseMismatch);
}; };
Ok(child_socket_id) Ok((socket_id, address))
}
pub async fn tcp_shutdown(
self,
machine_id: MachineId,
socket_id: SocketId,
) -> VirtualNetworkResult<()> {
let request = ServerProcessorRequest::TcpShutdown {
machine_id,
socket_id,
};
let ServerProcessorReplyValue::TcpShutdown = self.perform_request(request).await? else {
return Err(VirtualNetworkError::ResponseMismatch);
};
Ok(())
} }
pub async fn udp_bind( pub async fn udp_bind(
self, self,
machine_id: MachineId, machine_id: MachineId,
local_address: Option<SocketAddr>, opt_local_address: Option<SocketAddr>,
options: VirtualUdpOptions, options: VirtualUdpOptions,
) -> VirtualNetworkResult<SocketId> { ) -> VirtualNetworkResult<(SocketId, SocketAddr)> {
let request = ServerProcessorRequest::UdpBind { let request = ServerProcessorRequest::UdpBind {
machine_id, machine_id,
local_address, local_address: opt_local_address,
options, options,
}; };
let ServerProcessorResponse::UdpBind { socket_id } = self.perform_request(request).await? let ServerProcessorReplyValue::UdpBind {
socket_id,
local_address,
} = self.perform_request(request).await?
else { else {
return Err(VirtualNetworkError::ResponseMismatch); return Err(VirtualNetworkError::ResponseMismatch);
}; };
Ok(socket_id) Ok((socket_id, local_address))
}
pub async fn close(
self,
machine_id: MachineId,
socket_id: SocketId,
) -> VirtualNetworkResult<()> {
let request = ServerProcessorRequest::Close {
machine_id,
socket_id,
};
let ServerProcessorResponse::Close = self.perform_request(request).await? else {
return Err(VirtualNetworkError::ResponseMismatch);
};
Ok(())
} }
pub async fn send( pub async fn send(
@ -350,7 +397,7 @@ impl RouterClient {
socket_id, socket_id,
data, data,
}; };
let ServerProcessorResponse::Send { len } = self.perform_request(request).await? else { let ServerProcessorReplyValue::Send { len } = self.perform_request(request).await? else {
return Err(VirtualNetworkError::ResponseMismatch); return Err(VirtualNetworkError::ResponseMismatch);
}; };
Ok(len as usize) Ok(len as usize)
@ -369,7 +416,7 @@ impl RouterClient {
data, data,
remote_address, remote_address,
}; };
let ServerProcessorResponse::SendTo { len } = self.perform_request(request).await? else { let ServerProcessorReplyValue::SendTo { len } = self.perform_request(request).await? else {
return Err(VirtualNetworkError::ResponseMismatch); return Err(VirtualNetworkError::ResponseMismatch);
}; };
Ok(len as usize) Ok(len as usize)
@ -379,14 +426,14 @@ impl RouterClient {
self, self,
machine_id: MachineId, machine_id: MachineId,
socket_id: u64, socket_id: u64,
len: u32, len: usize,
) -> VirtualNetworkResult<Vec<u8>> { ) -> VirtualNetworkResult<Vec<u8>> {
let request = ServerProcessorRequest::Recv { let request = ServerProcessorRequest::Recv {
machine_id, machine_id,
socket_id, socket_id,
len, len: len as u32,
}; };
let ServerProcessorResponse::Recv { data } = self.perform_request(request).await? else { let ServerProcessorReplyValue::Recv { data } = self.perform_request(request).await? else {
return Err(VirtualNetworkError::ResponseMismatch); return Err(VirtualNetworkError::ResponseMismatch);
}; };
Ok(data) Ok(data)
@ -396,14 +443,14 @@ impl RouterClient {
self, self,
machine_id: MachineId, machine_id: MachineId,
socket_id: u64, socket_id: u64,
len: u32, len: usize,
) -> VirtualNetworkResult<(Vec<u8>, SocketAddr)> { ) -> VirtualNetworkResult<(Vec<u8>, SocketAddr)> {
let request = ServerProcessorRequest::RecvFrom { let request = ServerProcessorRequest::RecvFrom {
machine_id, machine_id,
socket_id, socket_id,
len, len: len as u32,
}; };
let ServerProcessorResponse::RecvFrom { let ServerProcessorReplyValue::RecvFrom {
data, data,
remote_address, remote_address,
} = self.perform_request(request).await? } = self.perform_request(request).await?
@ -417,17 +464,16 @@ impl RouterClient {
// Private implementation // Private implementation
fn new( fn new(
receiver: flume::Receiver<Bytes>,
sender: flume::Sender<Bytes>, sender: flume::Sender<Bytes>,
router_op_waiter: RouterOpWaiter<ServerProcessorReplyResult, ()>,
jh_handler: MustJoinHandle<()>, jh_handler: MustJoinHandle<()>,
stop_source: StopSource, stop_source: StopSource,
) -> RouterClient { ) -> RouterClient {
RouterClient { RouterClient {
unlocked_inner: Arc::new(RouterClientUnlockedInner { unlocked_inner: Arc::new(RouterClientUnlockedInner {
receiver,
sender, sender,
next_message_id: AtomicU64::new(0), next_message_id: AtomicU64::new(0),
router_op_waiter: RouterOpWaiter::new(), router_op_waiter,
}), }),
inner: Arc::new(Mutex::new(RouterClientInner { inner: Arc::new(Mutex::new(RouterClientInner {
jh_handler: Some(jh_handler), jh_handler: Some(jh_handler),
@ -436,24 +482,60 @@ impl RouterClient {
} }
} }
fn report_closed_socket(&self, machine_id: MachineId, socket_id: SocketId) {
let command = ServerProcessorCommand::CloseSocket {
machine_id,
socket_id,
};
let command_vec = match to_stdvec(&command).map_err(VirtualNetworkError::SerializationError)
{
Ok(v) => Bytes::from(v),
Err(e) => {
error!("{}", e);
return;
}
};
if let Err(e) = self
.unlocked_inner
.sender
.send(command_vec)
.map_err(|_| VirtualNetworkError::IoError(io::ErrorKind::BrokenPipe))
{
error!("{}", e);
}
}
pub(super) fn drop_tcp_stream(&self, machine_id: MachineId, socket_id: SocketId) {
self.report_closed_socket(machine_id, socket_id);
}
pub(super) fn drop_tcp_listener(&self, machine_id: MachineId, socket_id: SocketId) {
self.report_closed_socket(machine_id, socket_id);
}
pub(super) fn drop_udp_socket(&self, machine_id: MachineId, socket_id: SocketId) {
self.report_closed_socket(machine_id, socket_id);
}
async fn perform_request( async fn perform_request(
&self, &self,
request: ServerProcessorRequest, request: ServerProcessorRequest,
) -> VirtualNetworkResult<ServerProcessorResponse> { ) -> VirtualNetworkResult<ServerProcessorReplyValue> {
let message_id = self let message_id = self
.unlocked_inner .unlocked_inner
.next_message_id .next_message_id
.fetch_add(1, Ordering::AcqRel); .fetch_add(1, Ordering::AcqRel);
let msg = ServerProcessorRequestMessage { let command = ServerProcessorCommand::Message(ServerProcessorMessage {
message_id, message_id,
request, request,
}; });
let msg_vec = let command_vec =
Bytes::from(to_stdvec(&msg).map_err(VirtualNetworkError::SerializationError)?); Bytes::from(to_stdvec(&command).map_err(VirtualNetworkError::SerializationError)?);
self.unlocked_inner self.unlocked_inner
.sender .sender
.send_async(msg_vec) .send_async(command_vec)
.await .await
.map_err(|_| VirtualNetworkError::IoError(io::ErrorKind::BrokenPipe))?; .map_err(|_| VirtualNetworkError::IoError(io::ErrorKind::BrokenPipe))?;
let handle = self let handle = self
@ -469,16 +551,16 @@ impl RouterClient {
.map_err(|_| VirtualNetworkError::WaitError)?; .map_err(|_| VirtualNetworkError::WaitError)?;
match status { match status {
ServerProcessorResponseStatus::Success(server_processor_response) => { ServerProcessorReplyResult::Value(server_processor_response) => {
Ok(server_processor_response) Ok(server_processor_response)
} }
ServerProcessorResponseStatus::InvalidMachineId => { ServerProcessorReplyResult::InvalidMachineId => {
Err(VirtualNetworkError::InvalidMachineId) Err(VirtualNetworkError::InvalidMachineId)
} }
ServerProcessorResponseStatus::InvalidSocketId => { ServerProcessorReplyResult::InvalidSocketId => {
Err(VirtualNetworkError::InvalidSocketId) Err(VirtualNetworkError::InvalidSocketId)
} }
ServerProcessorResponseStatus::IoError(k) => Err(VirtualNetworkError::IoError(k)), ServerProcessorReplyResult::IoError(k) => Err(VirtualNetworkError::IoError(k)),
} }
} }
@ -486,7 +568,7 @@ impl RouterClient {
reader: R, reader: R,
writer: W, writer: W,
receiver: flume::Receiver<Bytes>, receiver: flume::Receiver<Bytes>,
sender: flume::Sender<Bytes>, router_op_waiter: RouterOpWaiter<ServerProcessorReplyResult, ()>,
stop_token: StopToken, stop_token: StopToken,
) where ) where
R: AsyncReadExt + Unpin + Send, R: AsyncReadExt + Unpin + Send,
@ -503,10 +585,27 @@ impl RouterClient {
} }
}); });
let framed_reader_fut = system_boxed(async move { let framed_reader_fut = system_boxed(async move {
if let Err(e) = framed_reader let fut = framed_reader.try_for_each(|x| async {
.forward(sender.into_sink().sink_map_err(::std::io::Error::other)) let x = x;
.await let evt = from_bytes::<ServerProcessorEvent>(&x)
{ .map_err(VirtualNetworkError::SerializationError)?;
match evt {
ServerProcessorEvent::Reply(reply) => {
router_op_waiter
.complete_op_waiter(reply.message_id, reply.status)
.map_err(io::Error::other)?;
} // ServerProcessorEvent::DeadSocket {
// machine_id,
// socket_id,
// } => {
// //
// }
}
Ok(())
});
if let Err(e) = fut.await {
error!("{}", e); error!("{}", e);
} }
}); });

View File

@ -25,16 +25,6 @@ where
result_receiver: Option<flume::Receiver<T>>, result_receiver: Option<flume::Receiver<T>>,
} }
impl<T, C> RouterOpWaitHandle<T, C>
where
T: Unpin,
C: Unpin + Clone,
{
pub fn id(&self) -> RouterOpId {
self.op_id
}
}
impl<T, C> Drop for RouterOpWaitHandle<T, C> impl<T, C> Drop for RouterOpWaitHandle<T, C>
where where
T: Unpin, T: Unpin,
@ -123,6 +113,7 @@ where
} }
/// Get operation context /// Get operation context
#[expect(dead_code)]
pub fn get_op_context(&self, op_id: RouterOpId) -> Result<C, RouterOpWaitError<T>> { pub fn get_op_context(&self, op_id: RouterOpId) -> Result<C, RouterOpWaitError<T>> {
let inner = self.inner.lock(); let inner = self.inner.lock();
let Some(waiting_op) = inner.waiting_op_table.get(&op_id) else { let Some(waiting_op) = inner.waiting_op_table.get(&op_id) else {
@ -132,14 +123,12 @@ where
} }
/// Remove wait for op /// Remove wait for op
#[instrument(level = "trace", target = "rpc", skip_all)]
fn cancel_op_waiter(&self, op_id: RouterOpId) { fn cancel_op_waiter(&self, op_id: RouterOpId) {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
inner.waiting_op_table.remove(&op_id); inner.waiting_op_table.remove(&op_id);
} }
/// Complete the waiting op /// Complete the waiting op
#[instrument(level = "trace", target = "rpc", skip_all)]
pub fn complete_op_waiter( pub fn complete_op_waiter(
&self, &self,
op_id: RouterOpId, op_id: RouterOpId,
@ -159,7 +148,6 @@ where
} }
/// Wait for operation to complete /// Wait for operation to complete
#[instrument(level = "trace", target = "rpc", skip_all)]
pub async fn wait_for_op( pub async fn wait_for_op(
&self, &self,
mut handle: RouterOpWaitHandle<T, C>, mut handle: RouterOpWaitHandle<T, C>,

View File

@ -0,0 +1,67 @@
use super::*;
#[derive(Debug)]
pub struct VirtualTcpListener {
pub(super) machine: Machine,
pub(super) socket_id: SocketId,
pub(super) local_address: SocketAddr,
}
impl VirtualTcpListener {
/////////////////////////////////////////////////////////////
// Public Interface
pub async fn bind(
opt_local_address: Option<SocketAddr>,
options: VirtualTcpOptions,
) -> VirtualNetworkResult<Self> {
let machine = default_machine().unwrap();
Self::bind_with_machine(machine, opt_local_address, options).await
}
pub async fn bind_with_machine(
machine: Machine,
opt_local_address: Option<SocketAddr>,
options: VirtualTcpOptions,
) -> VirtualNetworkResult<Self> {
machine
.router_client
.clone()
.tcp_bind(machine.id, opt_local_address, options)
.await
.map(|(socket_id, local_address)| Self::new(machine, socket_id, local_address))
}
pub async fn accept(&self) -> VirtualNetworkResult<(VirtualTcpStream, SocketAddr)> {
self.machine
.router_client
.clone()
.tcp_accept(self.machine.id, self.socket_id)
.await
.map(|v| {
(
VirtualTcpStream::new(self.machine.clone(), v.0, self.local_address, v.1),
v.1,
)
})
}
/////////////////////////////////////////////////////////////
// Private Implementation
fn new(machine: Machine, socket_id: SocketId, local_address: SocketAddr) -> Self {
Self {
machine,
socket_id,
local_address,
}
}
}
impl Drop for VirtualTcpListener {
fn drop(&mut self) {
self.machine
.router_client
.drop_tcp_listener(self.machine.id, self.socket_id);
}
}

View File

@ -0,0 +1,86 @@
use super::*;
use core::pin::Pin;
use core::task::{Context, Poll};
use futures_util::{stream::Stream, FutureExt};
use std::io;
/// A wrapper around [`VirtualTcpListener`] that implements [`Stream`].
///
/// [`VirtualTcpListener`]: struct@crate::VirtualTcpListener
/// [`Stream`]: trait@futures_util::stream::Stream
pub struct VirtualTcpListenerStream {
inner: VirtualTcpListener,
current_accept_fut: Option<SendPinBoxFuture<VirtualNetworkResult<(SocketId, SocketAddr)>>>,
}
impl fmt::Debug for VirtualTcpListenerStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("VirtualTcpListenerStream")
.field("inner", &self.inner)
.field(
"current_accept_fut",
if self.current_accept_fut.is_some() {
&"Some(...)"
} else {
&"None"
},
)
.finish()
}
}
impl VirtualTcpListenerStream {
/// Create a new `VirtualTcpListenerStream`.
pub fn new(listener: VirtualTcpListener) -> Self {
Self {
inner: listener,
current_accept_fut: None,
}
}
/// Get back the inner `VirtualTcpListener`.
pub fn into_inner(self) -> VirtualTcpListener {
self.inner
}
}
impl Stream for VirtualTcpListenerStream {
type Item = io::Result<VirtualTcpStream>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<io::Result<VirtualTcpStream>>> {
if self.current_accept_fut.is_none() {
let machine_id = self.inner.machine.id;
let router_client = self.inner.machine.router_client.clone();
let socket_id = self.inner.socket_id;
self.current_accept_fut =
Some(Box::pin(router_client.tcp_accept(machine_id, socket_id)));
}
let fut = self.current_accept_fut.as_mut().unwrap();
fut.poll_unpin(cx).map(|v| match v {
Ok(v) => Some(Ok(VirtualTcpStream::new(
self.inner.machine.clone(),
v.0,
self.inner.local_address,
v.1,
))),
Err(e) => Some(Err(e.into())),
})
}
}
impl AsRef<VirtualTcpListener> for VirtualTcpListenerStream {
fn as_ref(&self) -> &VirtualTcpListener {
&self.inner
}
}
impl AsMut<VirtualTcpListener> for VirtualTcpListenerStream {
fn as_mut(&mut self) -> &mut VirtualTcpListener {
&mut self.inner
}
}

View File

@ -13,12 +13,52 @@ pub struct VirtualTcpOptions {
pub struct VirtualTcpStream { pub struct VirtualTcpStream {
machine: Machine, machine: Machine,
socket_id: SocketId, socket_id: SocketId,
local_address: SocketAddr,
remote_address: SocketAddr,
current_recv_fut: Option<SendPinBoxFuture<Result<Vec<u8>, VirtualNetworkError>>>, current_recv_fut: Option<SendPinBoxFuture<Result<Vec<u8>, VirtualNetworkError>>>,
current_send_fut: Option<SendPinBoxFuture<Result<usize, VirtualNetworkError>>>, current_send_fut: Option<SendPinBoxFuture<Result<usize, VirtualNetworkError>>>,
current_close_fut: Option<SendPinBoxFuture<Result<(), VirtualNetworkError>>>, current_tcp_shutdown_fut: Option<SendPinBoxFuture<Result<(), VirtualNetworkError>>>,
}
impl fmt::Debug for VirtualTcpStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("VirtualTcpStream")
.field("machine", &self.machine)
.field("socket_id", &self.socket_id)
.field("local_address", &self.local_address)
.field("remote_address", &self.remote_address)
.field(
"current_recv_fut",
if self.current_recv_fut.is_some() {
&"Some(...)"
} else {
&"None"
},
)
.field(
"current_send_fut",
if self.current_send_fut.is_some() {
&"Some(...)"
} else {
&"None"
},
)
.field(
"current_close_fut",
if self.current_tcp_shutdown_fut.is_some() {
&"Some(...)"
} else {
&"None"
},
)
.finish()
}
} }
impl VirtualTcpStream { impl VirtualTcpStream {
//////////////////////////////////////////////////////////////////////////
// Public Interface
pub async fn connect( pub async fn connect(
remote_address: SocketAddr, remote_address: SocketAddr,
local_address: Option<SocketAddr>, local_address: Option<SocketAddr>,
@ -48,13 +88,45 @@ impl VirtualTcpStream {
options, options,
) )
.await .await
.map(|socket_id| Self { .map(|(socket_id, local_address)| {
Self::new(machine, socket_id, local_address, remote_address)
})
}
pub fn local_addr(&self) -> VirtualNetworkResult<SocketAddr> {
Ok(self.local_address)
}
pub fn peer_addr(&self) -> VirtualNetworkResult<SocketAddr> {
Ok(self.remote_address)
}
//////////////////////////////////////////////////////////////////////////
// Private Implementation
pub(super) fn new(
machine: Machine,
socket_id: SocketId,
local_address: SocketAddr,
remote_address: SocketAddr,
) -> Self {
Self {
machine, machine,
socket_id, socket_id,
local_address,
remote_address,
current_recv_fut: None, current_recv_fut: None,
current_send_fut: None, current_send_fut: None,
current_close_fut: None, current_tcp_shutdown_fut: None,
}) }
}
}
impl Drop for VirtualTcpStream {
fn drop(&mut self) {
self.machine
.router_client
.drop_tcp_stream(self.machine.id, self.socket_id);
} }
} }
@ -68,7 +140,7 @@ impl futures_util::AsyncRead for VirtualTcpStream {
self.current_recv_fut = Some(Box::pin(self.machine.router_client.clone().recv( self.current_recv_fut = Some(Box::pin(self.machine.router_client.clone().recv(
self.machine.id, self.machine.id,
self.socket_id, self.socket_id,
buf.len() as u32, buf.len(),
))); )));
} }
let fut = self.current_recv_fut.as_mut().unwrap(); let fut = self.current_recv_fut.as_mut().unwrap();
@ -118,18 +190,18 @@ impl futures_util::AsyncWrite for VirtualTcpStream {
mut self: Pin<&mut Self>, mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>, cx: &mut task::Context<'_>,
) -> task::Poll<std::io::Result<()>> { ) -> task::Poll<std::io::Result<()>> {
if self.current_close_fut.is_none() { if self.current_tcp_shutdown_fut.is_none() {
self.current_close_fut = Some(Box::pin( self.current_tcp_shutdown_fut = Some(Box::pin(
self.machine self.machine
.router_client .router_client
.clone() .clone()
.close(self.machine.id, self.socket_id), .tcp_shutdown(self.machine.id, self.socket_id),
)); ));
} }
let fut = self.current_close_fut.as_mut().unwrap(); let fut = self.current_tcp_shutdown_fut.as_mut().unwrap();
fut.poll_unpin(cx).map(|v| match v { fut.poll_unpin(cx).map(|v| match v {
Ok(v) => { Ok(v) => {
self.current_close_fut = None; self.current_tcp_shutdown_fut = None;
Ok(v) Ok(v)
} }
Err(e) => Err(e.into()), Err(e) => Err(e.into()),

View File

@ -7,50 +7,80 @@ pub struct VirtualUdpOptions {
reuse_address_port: bool, reuse_address_port: bool,
} }
#[derive(Debug)]
pub struct VirtualUdpSocket { pub struct VirtualUdpSocket {
machine: Machine, machine: Machine,
socket_id: SocketId, socket_id: SocketId,
local_address: SocketAddr,
} }
impl VirtualUdpSocket { impl VirtualUdpSocket {
// pub async fn connect( /////////////////////////////////////////////////////////////
// remote_address: SocketAddr, // Public Interface
// local_address: Option<SocketAddr>,
// timeout_ms: u32,
// options: VirtualTcpOptions,
// ) -> VirtualNetworkResult<Self> {
// let machine = default_machine().unwrap();
// Self::connect_with_machine(machine, remote_address, local_address, timeout_ms, options)
// .await
// }
// pub async fn connect_with_machine( pub async fn bind(
// machine: Machine, opt_local_address: Option<SocketAddr>,
// remote_address: SocketAddr, options: VirtualUdpOptions,
// local_address: Option<SocketAddr>, ) -> VirtualNetworkResult<Self> {
// timeout_ms: u32, let machine = default_machine().unwrap();
// options: VirtualTcpOptions, Self::bind_with_machine(machine, opt_local_address, options).await
// ) -> VirtualNetworkResult<Self> {
// machine
// .router_client
// .tcp_connect(
// machine.id,
// remote_address,
// local_address,
// timeout_ms,
// options,
// )
// .await
// .map(|socket_id| Self { machine, socket_id })
// }
} }
// impl futures_util::AsyncRead for VirtualUdpSocket { pub async fn bind_with_machine(
// fn poll_read( machine: Machine,
// self: Pin<&mut Self>, opt_local_address: Option<SocketAddr>,
// cx: &mut task::Context<'_>, options: VirtualUdpOptions,
// buf: &mut [u8], ) -> VirtualNetworkResult<Self> {
// ) -> task::Poll<std::io::Result<usize>> { machine
// todo!() .router_client
// } .clone()
// } .udp_bind(machine.id, opt_local_address, options)
.await
.map(|(socket_id, local_address)| Self::new(machine, socket_id, local_address))
}
pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> VirtualNetworkResult<usize> {
self.machine
.router_client
.clone()
.send_to(self.machine.id, self.socket_id, target, buf.to_vec())
.await
}
pub async fn recv_from(&self, buf: &mut [u8]) -> VirtualNetworkResult<(usize, SocketAddr)> {
let (v, addr) = self
.machine
.router_client
.clone()
.recv_from(self.machine.id, self.socket_id, buf.len())
.await?;
let len = usize::min(buf.len(), v.len());
buf[0..len].copy_from_slice(&v[0..len]);
Ok((len, addr))
}
pub fn local_addr(&self) -> VirtualNetworkResult<SocketAddr> {
Ok(self.local_address)
}
/////////////////////////////////////////////////////////////
// Private Implementation
fn new(machine: Machine, socket_id: SocketId, local_address: SocketAddr) -> Self {
Self {
machine,
socket_id,
local_address,
}
}
}
impl Drop for VirtualUdpSocket {
fn drop(&mut self) {
self.machine
.router_client
.drop_udp_socket(self.machine.id, self.socket_id);
}
}

View File

@ -66,13 +66,6 @@ pub fn is_ipv6_supported() -> bool {
if let Some(supp) = *opt_supp { if let Some(supp) = *opt_supp {
return supp; return supp;
} }
// let supp = match UdpSocket::bind(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 0, 0, 0)) {
// Ok(_) => true,
// Err(e) => !matches!(
// e.kind(),
// std::io::ErrorKind::AddrNotAvailable | std::io::ErrorKind::Unsupported
// ),
// };
// XXX: See issue #92 // XXX: See issue #92
let supp = false; let supp = false;