refactor network into veilid-tools

fix deadlock in debug command 'entries fastest'
This commit is contained in:
Christien Rioux 2024-11-06 15:44:43 -05:00
parent ad747d7831
commit 547427271c
27 changed files with 837 additions and 619 deletions

4
Cargo.lock generated
View File

@ -6299,7 +6299,6 @@ name = "veilid-core"
version = "0.4.1"
dependencies = [
"argon2",
"async-io 1.13.0",
"async-std",
"async-std-resolver",
"async-tls",
@ -6366,7 +6365,6 @@ dependencies = [
"sha2 0.10.8",
"shell-words",
"simplelog",
"socket2 0.5.7",
"static_assertions",
"stop-token",
"sysinfo",
@ -6518,6 +6516,7 @@ name = "veilid-tools"
version = "0.4.1"
dependencies = [
"android_logger 0.13.3",
"async-io 1.13.0",
"async-lock 3.4.0",
"async-std",
"async_executors",
@ -6556,6 +6555,7 @@ dependencies = [
"serde",
"serial_test 2.0.0",
"simplelog",
"socket2 0.5.7",
"static_assertions",
"stop-token",
"thiserror",

View File

@ -223,13 +223,12 @@ impl ClientApiConnection {
trace!("ClientApiConnection::handle_tcp_connection");
// Connect the TCP socket
let stream = TcpStream::connect(connect_addr)
let stream = connect_async_tcp_stream(None, connect_addr, 10_000)
.await
.map_err(map_to_string)?
.into_timeout_error()
.map_err(map_to_string)?;
// If it succeed, disable nagle algorithm
stream.set_nodelay(true).map_err(map_to_string)?;
// State we connected
let comproc = self.inner.lock().comproc.clone();
comproc.set_connection_state(ConnectionState::ConnectedTCP(
@ -239,16 +238,8 @@ impl ClientApiConnection {
// Split into reader and writer halves
// with line buffering on the reader
cfg_if! {
if #[cfg(feature="rt-async-std")] {
use futures::AsyncReadExt;
let (reader, writer) = stream.split();
let reader = BufReader::new(reader);
} else {
let (reader, writer) = stream.into_split();
let reader = BufReader::new(reader);
}
}
let (reader, writer) = split_async_tcp_stream(stream);
let reader = BufReader::new(reader);
self.clone().run_json_api_processor(reader, writer).await
}

View File

@ -56,6 +56,7 @@ veilid_core_ios_tests = ["dep:tracing-oslog"]
debug-locks = ["veilid-tools/debug-locks"]
unstable-blockstore = []
unstable-tunnels = []
virtual-network = []
# GeoIP
geolocation = ["maxminddb", "reqwest"]
@ -164,7 +165,6 @@ sysinfo = { version = "^0.30.13", default-features = false }
tokio = { version = "1.38.1", features = ["full"], optional = true }
tokio-util = { version = "0.7.11", features = ["compat"], optional = true }
tokio-stream = { version = "0.1.15", features = ["net"], optional = true }
async-io = { version = "1.13.0" }
futures-util = { version = "0.3.30", default-features = false, features = [
"async-await",
"sink",
@ -184,7 +184,6 @@ webpki = "0.22.4"
webpki-roots = "0.25.4"
rustls = "0.21.12"
rustls-pemfile = "1.0.4"
socket2 = { version = "0.5.7", features = ["all"] }
# Dependencies for WASM builds only
[target.'cfg(target_arch = "wasm32")'.dependencies]

View File

@ -1,6 +1,5 @@
use super::*;
use async_tls::TlsAcceptor;
use sockets::*;
use stop_token::future::FutureExt;
/////////////////////////////////////////////////////////////////
@ -135,39 +134,12 @@ impl Network {
}
};
#[cfg(all(feature = "rt-async-std", unix))]
if let Err(e) = set_tcp_stream_linger(&tcp_stream, Some(core::time::Duration::from_secs(0)))
{
// async-std does not directly support linger on TcpStream yet
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd};
if let Err(e) = unsafe {
let s = socket2::Socket::from_raw_fd(tcp_stream.as_raw_fd());
let res = s.set_linger(Some(core::time::Duration::from_secs(0)));
s.into_raw_fd();
res
} {
log_net!(debug "Couldn't set TCP linger: {}", e);
return;
}
}
#[cfg(all(feature = "rt-async-std", windows))]
{
// async-std does not directly support linger on TcpStream yet
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket};
if let Err(e) = unsafe {
let s = socket2::Socket::from_raw_socket(tcp_stream.as_raw_socket());
let res = s.set_linger(Some(core::time::Duration::from_secs(0)));
s.into_raw_socket();
res
} {
log_net!(debug "Couldn't set TCP linger: {}", e);
return;
}
}
#[cfg(not(feature = "rt-async-std"))]
if let Err(e) = tcp_stream.set_linger(Some(core::time::Duration::from_secs(0))) {
log_net!(debug "Couldn't set TCP linger: {}", e);
return;
}
if let Err(e) = tcp_stream.set_nodelay(true) {
log_net!(debug "Couldn't set TCP nodelay: {}", e);
return;
@ -257,41 +229,11 @@ impl Network {
)
};
// Create a socket and bind it
let Some(socket) = new_bound_default_tcp_socket(addr)
.wrap_err("failed to create default socket listener")?
else {
return Ok(false);
};
// Drop the socket
drop(socket);
// Create a shared socket and bind it once we have determined the port is free
let Some(socket) = new_bound_shared_tcp_socket(addr)
.wrap_err("failed to create shared socket listener")?
else {
let Some(listener) = bind_async_tcp_listener(addr)? else {
return Ok(false);
};
// Listen on the socket
if socket.listen(128).is_err() {
return Ok(false);
}
// Make an async tcplistener from the socket2 socket
let std_listener: std::net::TcpListener = socket.into();
cfg_if! {
if #[cfg(feature="rt-async-std")] {
let listener = TcpListener::from(std_listener);
} else if #[cfg(feature="rt-tokio")] {
std_listener.set_nonblocking(true).expect("failed to set nonblocking");
let listener = TcpListener::from_std(std_listener).wrap_err("failed to create tokio tcp listener")?;
} else {
compile_error!("needs executor implementation");
}
}
log_net!(debug "spawn_socket_listener: binding successful to {}", addr);
// Create protocol handler records
@ -311,15 +253,7 @@ impl Network {
// moves listener object in and get incoming iterator
// when this task exists, the listener will close the socket
cfg_if! {
if #[cfg(feature="rt-async-std")] {
let incoming_stream = listener.incoming();
} else if #[cfg(feature="rt-tokio")] {
let incoming_stream = tokio_stream::wrappers::TcpListenerStream::new(listener);
} else {
compile_error!("needs executor implementation");
}
}
let incoming_stream = async_tcp_listener_incoming(listener);
let _ = incoming_stream
.for_each_concurrent(None, |tcp_stream| {

View File

@ -1,5 +1,4 @@
use super::*;
use sockets::*;
use stop_token::future::FutureExt;
impl Network {
@ -114,23 +113,10 @@ impl Network {
async fn create_udp_protocol_handler(&self, addr: SocketAddr) -> EyreResult<bool> {
log_net!(debug "create_udp_protocol_handler on {:?}", &addr);
// Create a reusable socket
let Some(socket) = new_bound_default_udp_socket(addr)? else {
// Create a single-address-family UDP socket with default options bound to an address
let Some(udp_socket) = bind_async_udp_socket(addr)? else {
return Ok(false);
};
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
cfg_if! {
if #[cfg(feature="rt-async-std")] {
let udp_socket = UdpSocket::from(std_udp_socket);
} else if #[cfg(feature="rt-tokio")] {
std_udp_socket.set_nonblocking(true).expect("failed to set nonblocking");
let udp_socket = UdpSocket::from_std(std_udp_socket).wrap_err("failed to make inbound tokio udpsocket")?;
} else {
compile_error!("needs executor implementation");
}
}
let socket_arc = Arc::new(udp_socket);
// Create protocol handler

View File

@ -1,4 +1,3 @@
pub mod sockets;
pub mod tcp;
pub mod udp;
pub mod wrtc;

View File

@ -1,191 +0,0 @@
use crate::*;
use async_io::Async;
use std::io;
cfg_if! {
if #[cfg(feature="rt-async-std")] {
pub use async_std::net::{TcpStream, TcpListener, UdpSocket};
} else if #[cfg(feature="rt-tokio")] {
pub use tokio::net::{TcpStream, TcpListener, UdpSocket};
pub use tokio_util::compat::*;
} else {
compile_error!("needs executor implementation");
}
}
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
// cfg_if! {
// if #[cfg(windows)] {
// use winapi::shared::ws2def::{ SOL_SOCKET, SO_EXCLUSIVEADDRUSE};
// use winapi::um::winsock2::{SOCKET_ERROR, setsockopt};
// use winapi::ctypes::c_int;
// use std::os::windows::io::AsRawSocket;
// fn set_exclusiveaddruse(socket: &Socket) -> io::Result<()> {
// unsafe {
// let optval:c_int = 1;
// if setsockopt(socket.as_raw_socket().try_into().unwrap(), SOL_SOCKET, SO_EXCLUSIVEADDRUSE, (&optval as *const c_int).cast(),
// std::mem::size_of::<c_int>() as c_int) == SOCKET_ERROR {
// return Err(io::Error::last_os_error());
// }
// Ok(())
// }
// }
// }
// }
#[instrument(level = "trace", ret)]
pub fn new_shared_udp_socket(domain: Domain) -> io::Result<Socket> {
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
socket.set_reuse_address(true)?;
cfg_if! {
if #[cfg(unix)] {
socket.set_reuse_port(true)?;
}
}
Ok(socket)
}
#[instrument(level = "trace", ret)]
pub fn new_default_udp_socket(domain: Domain) -> io::Result<Socket> {
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
Ok(socket)
}
#[instrument(level = "trace", ret)]
pub fn new_bound_default_udp_socket(local_address: SocketAddr) -> io::Result<Option<Socket>> {
let domain = Domain::for_address(local_address);
let socket = new_default_udp_socket(domain)?;
let socket2_addr = SockAddr::from(local_address);
if socket.bind(&socket2_addr).is_err() {
return Ok(None);
}
log_net!("created bound default udp socket on {:?}", &local_address);
Ok(Some(socket))
}
#[instrument(level = "trace", ret)]
pub fn new_default_tcp_socket(domain: Domain) -> io::Result<Socket> {
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) {
log_net!(error "Couldn't set TCP linger: {}", e);
}
if let Err(e) = socket.set_nodelay(true) {
log_net!(error "Couldn't set TCP nodelay: {}", e);
}
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
Ok(socket)
}
#[instrument(level = "trace", ret)]
pub fn new_shared_tcp_socket(domain: Domain) -> io::Result<Socket> {
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) {
log_net!(error "Couldn't set TCP linger: {}", e);
}
if let Err(e) = socket.set_nodelay(true) {
log_net!(error "Couldn't set TCP nodelay: {}", e);
}
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
socket.set_reuse_address(true)?;
cfg_if! {
if #[cfg(unix)] {
socket.set_reuse_port(true)?;
}
}
Ok(socket)
}
#[instrument(level = "trace", ret)]
pub fn new_bound_default_tcp_socket(local_address: SocketAddr) -> io::Result<Option<Socket>> {
let domain = Domain::for_address(local_address);
let socket = new_default_tcp_socket(domain)?;
let socket2_addr = SockAddr::from(local_address);
if socket.bind(&socket2_addr).is_err() {
return Ok(None);
}
log_net!("created bound default tcp socket on {:?}", &local_address);
Ok(Some(socket))
}
#[instrument(level = "trace", ret)]
pub fn new_bound_shared_tcp_socket(local_address: SocketAddr) -> io::Result<Option<Socket>> {
let domain = Domain::for_address(local_address);
let socket = new_shared_tcp_socket(domain)?;
let socket2_addr = SockAddr::from(local_address);
if socket.bind(&socket2_addr).is_err() {
return Ok(None);
}
log_net!("created bound shared tcp socket on {:?}", &local_address);
Ok(Some(socket))
}
// Non-blocking connect is tricky when you want to start with a prepared socket
// Errors should not be logged as they are valid conditions for this function
#[instrument(level = "trace", ret)]
pub async fn nonblocking_connect(
socket: Socket,
addr: SocketAddr,
timeout_ms: u32,
) -> io::Result<TimeoutOr<TcpStream>> {
// Set for non blocking connect
socket.set_nonblocking(true)?;
// Make socket2 SockAddr
let socket2_addr = socket2::SockAddr::from(addr);
// Connect to the remote address
match socket.connect(&socket2_addr) {
Ok(()) => Ok(()),
#[cfg(unix)]
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(()),
Err(e) => Err(e),
}?;
let async_stream = Async::new(std::net::TcpStream::from(socket))?;
// The stream becomes writable when connected
timeout_or_try!(
timeout(timeout_ms, async_stream.writable().in_current_span())
.await
.into_timeout_or()
.into_result()?
);
// Check low level error
let async_stream = match async_stream.get_ref().take_error()? {
None => Ok(async_stream),
Some(err) => Err(err),
}?;
// Convert back to inner and then return async version
cfg_if! {
if #[cfg(feature="rt-async-std")] {
Ok(TimeoutOr::value(TcpStream::from(async_stream.into_inner()?)))
} else if #[cfg(feature="rt-tokio")] {
Ok(TimeoutOr::value(TcpStream::from_std(async_stream.into_inner()?)?))
} else {
compile_error!("needs executor implementation");
}
}
}

View File

@ -1,6 +1,5 @@
use super::*;
use futures_util::{AsyncReadExt, AsyncWriteExt};
use sockets::*;
pub struct RawTcpNetworkConnection {
flow: Flow,
@ -157,32 +156,28 @@ impl RawTcpProtocolHandler {
#[instrument(level = "trace", target = "protocol", err)]
pub async fn connect(
local_address: Option<SocketAddr>,
socket_addr: SocketAddr,
remote_address: SocketAddr,
timeout_ms: u32,
) -> io::Result<NetworkResult<ProtocolNetworkConnection>> {
// Make a shared socket
let socket = match local_address {
Some(a) => {
new_bound_shared_tcp_socket(a)?.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?
}
None => new_default_tcp_socket(socket2::Domain::for_address(socket_addr))?,
};
// Non-blocking connect to remote address
let ts = network_result_try!(nonblocking_connect(socket, socket_addr, timeout_ms)
.await
.folded()?);
let tcp_stream = network_result_try!(connect_async_tcp_stream(
local_address,
remote_address,
timeout_ms
)
.await
.folded()?);
// See what local address we ended up with and turn this into a stream
let actual_local_address = ts.local_addr()?;
let actual_local_address = tcp_stream.local_addr()?;
#[cfg(feature = "rt-tokio")]
let ts = ts.compat();
let ps = AsyncPeekStream::new(ts);
let tcp_stream = tcp_stream.compat();
let ps = AsyncPeekStream::new(tcp_stream);
// Wrap the stream in a network connection and return it
let flow = Flow::new(
PeerAddress::new(
SocketAddress::from_socket_addr(socket_addr),
SocketAddress::from_socket_addr(remote_address),
ProtocolType::TCP,
),
SocketAddress::from_socket_addr(actual_local_address),

View File

@ -1,5 +1,4 @@
use super::*;
use sockets::*;
#[derive(Clone)]
pub struct RawUdpProtocolHandler {
@ -141,7 +140,8 @@ impl RawUdpProtocolHandler {
) -> io::Result<RawUdpProtocolHandler> {
// get local wildcard address for bind
let local_socket_addr = compatible_unspecified_socket_addr(socket_addr);
let socket = UdpSocket::bind(local_socket_addr).await?;
let socket = bind_async_udp_socket(local_socket_addr)?
.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?;
Ok(RawUdpProtocolHandler::new(Arc::new(socket), None))
}
}

View File

@ -10,7 +10,6 @@ use async_tungstenite::tungstenite::protocol::{frame::coding::CloseCode, CloseFr
use async_tungstenite::tungstenite::Error;
use async_tungstenite::{accept_hdr_async, client_async, WebSocketStream};
use futures_util::{AsyncRead, AsyncWrite, SinkExt};
use sockets::*;
// Maximum number of websocket request headers to permit
const MAX_WS_HEADERS: usize = 24;
@ -316,21 +315,16 @@ impl WebsocketProtocolHandler {
let domain = split_url.host.clone();
// Resolve remote address
let remote_socket_addr = dial_info.to_socket_addr();
// Make a shared socket
let socket = match local_address {
Some(a) => {
new_bound_shared_tcp_socket(a)?.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?
}
None => new_default_tcp_socket(socket2::Domain::for_address(remote_socket_addr))?,
};
let remote_address = dial_info.to_socket_addr();
// Non-blocking connect to remote address
let tcp_stream =
network_result_try!(nonblocking_connect(socket, remote_socket_addr, timeout_ms)
.await
.folded()?);
let tcp_stream = network_result_try!(connect_async_tcp_stream(
local_address,
remote_address,
timeout_ms
)
.await
.folded()?);
// See what local address we ended up with
let actual_local_addr = tcp_stream.local_addr()?;

View File

@ -333,9 +333,10 @@ impl RoutingTable {
relaying_count += 1;
}
let best_node_id = node.best_node_id();
out += " ";
out += &node
.operate(|_rti, e| Self::format_entry(cur_ts, node.best_node_id(), e, &relay_tag));
out += &node.operate(|_rti, e| Self::format_entry(cur_ts, best_node_id, e, &relay_tag));
out += "\n";
}

View File

@ -165,17 +165,12 @@ impl ClientApi {
}
async fn handle_tcp_incoming(self, bind_addr: SocketAddr) -> std::io::Result<()> {
let listener = TcpListener::bind(bind_addr).await?;
let listener = bind_async_tcp_listener(bind_addr)?
.ok_or(std::io::Error::from(std::io::ErrorKind::AddrInUse))?;
debug!(target: "client_api", "TCPClient API listening on: {:?}", bind_addr);
// Process the incoming accept stream
cfg_if! {
if #[cfg(feature="rt-async-std")] {
let mut incoming_stream = listener.incoming();
} else {
let mut incoming_stream = tokio_stream::wrappers::TcpListenerStream::new(listener);
}
}
let mut incoming_stream = async_tcp_listener_incoming(listener);
// Make wait group for all incoming connections
let awg = AsyncWaitGroup::new();

View File

@ -7,8 +7,6 @@ pub use tracing::*;
cfg_if! {
if #[cfg(feature="rt-async-std")] {
// pub use async_std::task::JoinHandle;
pub use async_std::net::TcpListener;
pub use async_std::net::TcpStream;
pub use async_std::io::BufReader;
//pub use async_std::future::TimeoutError;
//pub fn spawn_detached<F: Future<Output = T> + Send + 'static, T: Send + 'static>(f: F) -> JoinHandle<T> {
@ -27,8 +25,6 @@ cfg_if! {
}
} else if #[cfg(feature="rt-tokio")] {
//pub use tokio::task::JoinHandle;
pub use tokio::net::TcpListener;
pub use tokio::net::TcpStream;
pub use tokio::io::BufReader;
//pub use tokio_util::compat::*;
//pub use tokio::time::error::Elapsed as TimeoutError;

View File

@ -76,20 +76,21 @@ flume = { version = "0.11.0", features = ["async"] }
# Dependencies for native builds only
# Linux, Windows, Mac, iOS, Android
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
async-io = { version = "1.13.0" }
async-std = { version = "1.12.0", features = ["unstable"], optional = true }
tokio = { version = "1.38.1", features = ["full"], optional = true }
tokio-util = { version = "0.7.11", features = ["compat"], optional = true }
tokio-stream = { version = "0.1.15", features = ["net"], optional = true }
chrono = "0.4.38"
futures-util = { version = "0.3.30", default-features = false, features = [
"async-await",
"sink",
"std",
"io",
] }
chrono = "0.4.38"
libc = "0.2.155"
nix = { version = "0.27.1", features = ["user"] }
socket2 = { version = "0.5.7", features = ["all"] }
tokio = { version = "1.38.1", features = ["full"], optional = true }
tokio-util = { version = "0.7.11", features = ["compat"], optional = true }
tokio-stream = { version = "0.1.15", features = ["net"], optional = true }
# Dependencies for WASM builds only
[target.'cfg(target_arch = "wasm32")'.dependencies]

View File

@ -1,108 +0,0 @@
use super::*;
cfg_if! {
if #[cfg(target_arch = "wasm32")] {
} else {
use std::net::{TcpListener, UdpSocket};
}
}
#[derive(ThisError, Debug, Clone, PartialEq, Eq)]
pub enum BumpPortError {
#[error("Unsupported architecture")]
Unsupported,
#[error("Failure: {0}")]
Failed(String),
}
pub enum BumpPortType {
UDP,
TCP,
}
pub fn tcp_port_available(addr: &SocketAddr) -> bool {
cfg_if! {
if #[cfg(target_arch = "wasm32")] {
true
} else {
match TcpListener::bind(addr) {
Ok(_) => true,
Err(_) => false,
}
}
}
}
pub fn udp_port_available(addr: &SocketAddr) -> bool {
cfg_if! {
if #[cfg(target_arch = "wasm32")] {
true
} else {
match UdpSocket::bind(addr) {
Ok(_) => true,
Err(_) => false,
}
}
}
}
pub fn bump_port(addr: &mut SocketAddr, bpt: BumpPortType) -> Result<bool, BumpPortError> {
cfg_if! {
if #[cfg(target_arch = "wasm32")] {
Err(BumpPortError::Unsupported)
}
else
{
let mut bumped = false;
let mut port = addr.port();
let mut addr_bump = addr.clone();
loop {
if match bpt {
BumpPortType::TCP => tcp_port_available(&addr_bump),
BumpPortType::UDP => udp_port_available(&addr_bump),
} {
*addr = addr_bump;
return Ok(bumped);
}
if port == u16::MAX {
break;
}
port += 1;
addr_bump.set_port(port);
bumped = true;
}
Err(BumpPortError::Failure("no ports remaining".to_owned()))
}
}
}
pub fn bump_port_string(addr: &mut String, bpt: BumpPortType) -> Result<bool, BumpPortError> {
cfg_if! {
if #[cfg(target_arch = "wasm32")] {
return Err(BumpPortError::Unsupported);
}
else
{
let savec: Vec<SocketAddr> = addr
.to_socket_addrs()
.map_err(|x| BumpPortError::Failure(format!("failed to resolve socket address: {}", x)))?
.collect();
if savec.len() == 0 {
return Err(BumpPortError::Failure("No socket addresses resolved".to_owned()));
}
let mut sa = savec.first().unwrap().clone();
if !bump_port(&mut sa, bpt)? {
return Ok(false);
}
*addr = sa.to_string();
Ok(true)
}
}
}

View File

@ -24,7 +24,6 @@
#![allow(clippy::comparison_chain, clippy::upper_case_acronyms)]
#![deny(unused_must_use)]
// pub mod bump_port;
pub mod assembly_buffer;
pub mod async_peek_stream;
pub mod async_tag_lock;
@ -49,6 +48,8 @@ pub mod network_result;
pub mod random;
pub mod single_shot_eventual;
pub mod sleep;
#[cfg(not(target_arch = "wasm32"))]
pub mod socket_tools;
pub mod spawn;
pub mod split_url;
pub mod startup_lock;
@ -184,7 +185,6 @@ cfg_if! {
}
}
// pub use bump_port::*;
#[doc(inline)]
pub use assembly_buffer::*;
#[doc(inline)]
@ -233,6 +233,9 @@ pub use single_shot_eventual::*;
#[doc(inline)]
pub use sleep::*;
#[doc(inline)]
#[cfg(not(target_arch = "wasm32"))]
pub use socket_tools::*;
#[doc(inline)]
pub use spawn::*;
#[doc(inline)]
pub use split_url::*;

View File

@ -0,0 +1,284 @@
use super::*;
use async_io::Async;
use std::io;
cfg_if! {
if #[cfg(feature="rt-async-std")] {
pub use async_std::net::{TcpStream, TcpListener, UdpSocket};
} else if #[cfg(feature="rt-tokio")] {
pub use tokio::net::{TcpStream, TcpListener, UdpSocket};
pub use tokio_util::compat::*;
} else {
compile_error!("needs executor implementation");
}
}
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
//////////////////////////////////////////////////////////////////////////////////////////
pub fn bind_async_udp_socket(local_address: SocketAddr) -> io::Result<Option<UdpSocket>> {
let Some(socket) = new_bound_default_socket2_udp(local_address)? else {
return Ok(None);
};
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
cfg_if! {
if #[cfg(feature="rt-async-std")] {
let udp_socket = UdpSocket::from(std_udp_socket);
} else if #[cfg(feature="rt-tokio")] {
std_udp_socket.set_nonblocking(true)?;
let udp_socket = UdpSocket::from_std(std_udp_socket)?;
} else {
compile_error!("needs executor implementation");
}
}
Ok(Some(udp_socket))
}
pub fn bind_async_tcp_listener(local_address: SocketAddr) -> io::Result<Option<TcpListener>> {
// Create a default non-shared socket and bind it
let Some(socket) = new_bound_default_socket2_tcp(local_address)? else {
return Ok(None);
};
// Drop the socket so we can make another shared socket in its place
drop(socket);
// Create a shared socket and bind it now we have determined the port is free
let Some(socket) = new_bound_shared_socket2_tcp(local_address)? else {
return Ok(None);
};
// Listen on the socket
if socket.listen(128).is_err() {
return Ok(None);
}
// Make an async tcplistener from the socket2 socket
let std_listener: std::net::TcpListener = socket.into();
cfg_if! {
if #[cfg(feature="rt-async-std")] {
let listener = TcpListener::from(std_listener);
} else if #[cfg(feature="rt-tokio")] {
std_listener.set_nonblocking(true)?;
let listener = TcpListener::from_std(std_listener)?;
} else {
compile_error!("needs executor implementation");
}
}
Ok(Some(listener))
}
pub async fn connect_async_tcp_stream(
local_address: Option<SocketAddr>,
remote_address: SocketAddr,
timeout_ms: u32,
) -> io::Result<TimeoutOr<TcpStream>> {
let socket = match local_address {
Some(a) => {
new_bound_shared_socket2_tcp(a)?.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?
}
None => new_default_socket2_tcp(domain_for_address(remote_address))?,
};
// Non-blocking connect to remote address
nonblocking_connect(socket, remote_address, timeout_ms).await
}
pub fn set_tcp_stream_linger(
tcp_stream: &TcpStream,
linger: Option<core::time::Duration>,
) -> io::Result<()> {
#[cfg(all(feature = "rt-async-std", unix))]
{
// async-std does not directly support linger on TcpStream yet
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd};
unsafe {
let s = socket2::Socket::from_raw_fd(tcp_stream.as_raw_fd());
let res = s.set_linger(linger);
s.into_raw_fd();
res
}
}
#[cfg(all(feature = "rt-async-std", windows))]
{
// async-std does not directly support linger on TcpStream yet
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket};
unsafe {
let s = socket2::Socket::from_raw_socket(tcp_stream.as_raw_socket());
let res = s.set_linger(linger);
s.into_raw_socket();
res
}
}
#[cfg(not(feature = "rt-async-std"))]
tcp_stream.set_linger(linger)
}
cfg_if! {
if #[cfg(feature="rt-async-std")] {
pub type IncomingStream = Incoming;
pub type ReadHalf = futures_util::io::ReadHalf<TcpStream>,
pub type WriteHalf = futures_util::io::WriteHalf<TcpStream>,
} else if #[cfg(feature="rt-tokio")] {
pub type IncomingStream = tokio_stream::wrappers::TcpListenerStream;
pub type ReadHalf = tokio::net::tcp::OwnedReadHalf;
pub type WriteHalf = tokio::net::tcp::OwnedWriteHalf;
} else {
compile_error!("needs executor implementation");
}
}
pub fn async_tcp_listener_incoming(tcp_listener: TcpListener) -> IncomingStream {
cfg_if! {
if #[cfg(feature="rt-async-std")] {
tcp_listener.incoming()
} else if #[cfg(feature="rt-tokio")] {
tokio_stream::wrappers::TcpListenerStream::new(tcp_listener)
} else {
compile_error!("needs executor implementation");
}
}
}
pub fn split_async_tcp_stream(tcp_stream: TcpStream) -> (ReadHalf, WriteHalf) {
cfg_if! {
if #[cfg(feature="rt-async-std")] {
use futures_util::AsyncReadExt;
tcp_stream.split()
} else if #[cfg(feature="rt-tokio")] {
tcp_stream.into_split()
} else {
compile_error!("needs executor implementation");
}
}
}
//////////////////////////////////////////////////////////////////////////////////////////
fn new_default_udp_socket(domain: core::ffi::c_int) -> io::Result<Socket> {
let domain = Domain::from(domain);
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
Ok(socket)
}
fn new_bound_default_socket2_udp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
let domain = domain_for_address(local_address);
let socket = new_default_udp_socket(domain)?;
let socket2_addr = SockAddr::from(local_address);
if socket.bind(&socket2_addr).is_err() {
return Ok(None);
}
Ok(Some(socket))
}
pub fn new_default_socket2_tcp(domain: core::ffi::c_int) -> io::Result<Socket> {
let domain = Domain::from(domain);
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
socket.set_linger(Some(core::time::Duration::from_secs(0)))?;
socket.set_nodelay(true)?;
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
Ok(socket)
}
fn new_shared_socket2_tcp(domain: core::ffi::c_int) -> io::Result<Socket> {
let domain = Domain::from(domain);
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
socket.set_linger(Some(core::time::Duration::from_secs(0)))?;
socket.set_nodelay(true)?;
if domain == Domain::IPV6 {
socket.set_only_v6(true)?;
}
socket.set_reuse_address(true)?;
cfg_if! {
if #[cfg(unix)] {
socket.set_reuse_port(true)?;
}
}
Ok(socket)
}
fn new_bound_default_socket2_tcp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
let domain = domain_for_address(local_address);
let socket = new_default_socket2_tcp(domain)?;
let socket2_addr = SockAddr::from(local_address);
if socket.bind(&socket2_addr).is_err() {
return Ok(None);
}
Ok(Some(socket))
}
fn new_bound_shared_socket2_tcp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
// Create the reuseaddr/reuseport socket now that we've asserted the port is free
let domain = domain_for_address(local_address);
let socket = new_shared_socket2_tcp(domain)?;
let socket2_addr = SockAddr::from(local_address);
if socket.bind(&socket2_addr).is_err() {
return Ok(None);
}
Ok(Some(socket))
}
// Non-blocking connect is tricky when you want to start with a prepared socket
// Errors should not be logged as they are valid conditions for this function
async fn nonblocking_connect(
socket: Socket,
addr: SocketAddr,
timeout_ms: u32,
) -> io::Result<TimeoutOr<TcpStream>> {
// Set for non blocking connect
socket.set_nonblocking(true)?;
// Make socket2 SockAddr
let socket2_addr = socket2::SockAddr::from(addr);
// Connect to the remote address
match socket.connect(&socket2_addr) {
Ok(()) => Ok(()),
#[cfg(unix)]
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(()),
Err(e) => Err(e),
}?;
let async_stream = Async::new(std::net::TcpStream::from(socket))?;
// The stream becomes writable when connected
timeout_or_try!(timeout(timeout_ms, async_stream.writable())
.await
.into_timeout_or()
.into_result()?);
// Check low level error
let async_stream = match async_stream.get_ref().take_error()? {
None => Ok(async_stream),
Some(err) => Err(err),
}?;
// Convert back to inner and then return async version
cfg_if! {
if #[cfg(feature="rt-async-std")] {
Ok(TimeoutOr::value(TcpStream::from(async_stream.into_inner()?)))
} else if #[cfg(feature="rt-tokio")] {
Ok(TimeoutOr::value(TcpStream::from_std(async_stream.into_inner()?)?))
} else {
compile_error!("needs executor implementation");
}
}
}
pub fn domain_for_address(address: SocketAddr) -> core::ffi::c_int {
socket2::Domain::for_address(address).into()
}

View File

@ -8,7 +8,6 @@ cfg_if! {
} else if #[cfg(feature="rt-tokio")] {
use tokio::net::{TcpListener, TcpStream};
use tokio::time::sleep;
use tokio_util::compat::*;
}
}

View File

@ -2,7 +2,7 @@ use super::*;
pub type MachineId = u64;
#[derive(Clone)]
#[derive(Debug, Clone)]
pub struct Machine {
pub router_client: RouterClient,
pub id: MachineId,

View File

@ -25,6 +25,7 @@
//!
//! [VirtualTcpStream]
//! [VirtualUdpSocket]
//! [VirtualTcpListener]
//! [VirtualTcpListenerStream]
//! [VirtualWsMeta]
//! [VirtualWsStream]
@ -44,6 +45,8 @@ mod router_op_table;
mod router_server;
mod serde_io_error;
mod virtual_network_error;
mod virtual_tcp_listener;
mod virtual_tcp_listener_stream;
mod virtual_tcp_stream;
mod virtual_udp_socket;
@ -53,5 +56,7 @@ pub use machine::*;
pub use router_client::*;
pub use router_server::*;
pub use virtual_network_error::*;
pub use virtual_tcp_listener::*;
pub use virtual_tcp_listener_stream::*;
pub use virtual_tcp_stream::*;
pub use virtual_udp_socket::*;

View File

@ -2,8 +2,7 @@ use super::*;
use core::sync::atomic::AtomicU64;
use futures_codec::{Bytes, BytesCodec, FramedRead, FramedWrite};
use futures_util::{
io::BufReader, stream::FuturesUnordered, AsyncReadExt, AsyncWriteExt, FutureExt, SinkExt,
StreamExt, TryStreamExt,
stream::FuturesUnordered, AsyncReadExt, AsyncWriteExt, StreamExt, TryStreamExt,
};
use postcard::{from_bytes, to_stdvec};
use router_op_table::*;
@ -16,11 +15,29 @@ struct RouterClientInner {
stop_source: Option<StopSource>,
}
impl fmt::Debug for RouterClientInner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RouterClientInner")
.field("jh_handler", &self.jh_handler)
.field("stop_source", &self.stop_source)
.finish()
}
}
struct RouterClientUnlockedInner {
receiver: flume::Receiver<Bytes>,
sender: flume::Sender<Bytes>,
next_message_id: AtomicU64,
router_op_waiter: RouterOpWaiter<ServerProcessorResponseStatus, ()>,
router_op_waiter: RouterOpWaiter<ServerProcessorReplyResult, ()>,
}
impl fmt::Debug for RouterClientUnlockedInner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RouterClientUnlockedInner")
.field("sender", &self.sender)
.field("next_message_id", &self.next_message_id)
.field("router_op_waiter", &self.router_op_waiter)
.finish()
}
}
pub type MessageId = u64;
@ -49,9 +66,9 @@ enum ServerProcessorRequest {
},
TcpAccept {
machine_id: MachineId,
socket_id: SocketId,
listen_socket_id: SocketId,
},
Close {
TcpShutdown {
machine_id: MachineId,
socket_id: SocketId,
},
@ -84,13 +101,22 @@ enum ServerProcessorRequest {
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
struct ServerProcessorRequestMessage {
struct ServerProcessorMessage {
message_id: MessageId,
request: ServerProcessorRequest,
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
enum ServerProcessorResponse {
enum ServerProcessorCommand {
Message(ServerProcessorMessage),
CloseSocket {
machine_id: MachineId,
socket_id: SocketId,
},
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
enum ServerProcessorReplyValue {
AllocateMachine {
machine_id: MachineId,
},
@ -100,17 +126,21 @@ enum ServerProcessorResponse {
},
TcpConnect {
socket_id: SocketId,
local_address: SocketAddr,
},
TcpBind {
socket_id: SocketId,
local_address: SocketAddr,
},
TcpAccept {
child_socket_id: SocketId,
socket_id: SocketId,
address: SocketAddr,
},
TcpShutdown,
UdpBind {
socket_id: SocketId,
local_address: SocketAddr,
},
Close,
Send {
len: u32,
},
@ -127,20 +157,29 @@ enum ServerProcessorResponse {
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
enum ServerProcessorResponseStatus {
Success(ServerProcessorResponse),
enum ServerProcessorReplyResult {
Value(ServerProcessorReplyValue),
InvalidMachineId,
InvalidSocketId,
IoError(#[serde(with = "serde_io_error::SerdeIoErrorKindDef")] io::ErrorKind),
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
struct ServerProcessorResponseMessage {
struct ServerProcessorReply {
message_id: MessageId,
status: ServerProcessorResponseStatus,
status: ServerProcessorReplyResult,
}
#[derive(Clone)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
enum ServerProcessorEvent {
Reply(ServerProcessorReply),
// DeadSocket {
// machine_id: MachineId,
// socket_id: SocketId,
// },
}
#[derive(Debug, Clone)]
pub struct RouterClient {
unlocked_inner: Arc<RouterClientUnlockedInner>,
inner: Arc<Mutex<RouterClientInner>>,
@ -151,7 +190,7 @@ impl RouterClient {
// Public interface
#[cfg(not(target_arch = "wasm32"))]
pub async fn router_connect_tcp<H: ToSocketAddrs>(host: H) -> ::std::io::Result<RouterClient> {
pub async fn router_connect_tcp<H: ToSocketAddrs>(host: H) -> io::Result<RouterClient> {
let addrs = host.to_socket_addrs()?.collect::<Vec<_>>();
// Connect to RouterServer
@ -173,37 +212,37 @@ impl RouterClient {
}
}
let ts_buf_reader = BufReader::new(ts_reader);
// Create channels
let (client_sender, server_receiver) = flume::unbounded::<Bytes>();
let (server_sender, client_receiver) = flume::unbounded::<Bytes>();
// Create stopper
let stop_source = StopSource::new();
// Create router operation waiter
let router_op_waiter = RouterOpWaiter::new();
// Spawn a client connection handler
let jh_handler = spawn(
"RouterClient server processor",
Self::run_server_processor(
ts_buf_reader,
ts_reader,
ts_writer,
server_receiver,
server_sender,
router_op_waiter.clone(),
stop_source.token(),
),
);
Ok(Self::new(
client_receiver,
client_sender,
router_op_waiter,
jh_handler,
stop_source,
))
}
#[cfg(target_arch = "wasm32")]
pub async fn router_connect_ws<H: AsRef<str>>(host: H) -> ::std::io::Result<RouterClient> {
pub async fn router_connect_ws<H: AsRef<str>>(host: H) -> io::Result<RouterClient> {
let host = host.as_ref();
Ok(RouterClient {})
@ -219,7 +258,7 @@ impl RouterClient {
pub async fn allocate_machine(self) -> VirtualNetworkResult<MachineId> {
let request = ServerProcessorRequest::AllocateMachine;
let ServerProcessorResponse::AllocateMachine { machine_id } =
let ServerProcessorReplyValue::AllocateMachine { machine_id } =
self.perform_request(request).await?
else {
return Err(VirtualNetworkError::ResponseMismatch);
@ -229,7 +268,7 @@ impl RouterClient {
pub async fn release_machine(self, machine_id: MachineId) -> VirtualNetworkResult<()> {
let request = ServerProcessorRequest::ReleaseMachine { machine_id };
let ServerProcessorResponse::ReleaseMachine = self.perform_request(request).await? else {
let ServerProcessorReplyValue::ReleaseMachine = self.perform_request(request).await? else {
return Err(VirtualNetworkError::ResponseMismatch);
};
Ok(())
@ -240,7 +279,7 @@ impl RouterClient {
machine_id: MachineId,
) -> VirtualNetworkResult<BTreeMap<String, NetworkInterface>> {
let request = ServerProcessorRequest::GetInterfaces { machine_id };
let ServerProcessorResponse::GetInterfaces { interfaces } =
let ServerProcessorReplyValue::GetInterfaces { interfaces } =
self.perform_request(request).await?
else {
return Err(VirtualNetworkError::ResponseMismatch);
@ -252,91 +291,99 @@ impl RouterClient {
self,
machine_id: MachineId,
remote_address: SocketAddr,
local_address: Option<SocketAddr>,
opt_local_address: Option<SocketAddr>,
timeout_ms: u32,
options: VirtualTcpOptions,
) -> VirtualNetworkResult<SocketId> {
) -> VirtualNetworkResult<(SocketId, SocketAddr)> {
let request = ServerProcessorRequest::TcpConnect {
machine_id,
local_address,
local_address: opt_local_address,
remote_address,
timeout_ms,
options,
};
let ServerProcessorResponse::TcpConnect { socket_id } =
self.perform_request(request).await?
let ServerProcessorReplyValue::TcpConnect {
socket_id,
local_address,
} = self.perform_request(request).await?
else {
return Err(VirtualNetworkError::ResponseMismatch);
};
Ok(socket_id)
Ok((socket_id, local_address))
}
pub async fn tcp_bind(
self,
machine_id: MachineId,
local_address: Option<SocketAddr>,
opt_local_address: Option<SocketAddr>,
options: VirtualTcpOptions,
) -> VirtualNetworkResult<SocketId> {
) -> VirtualNetworkResult<(SocketId, SocketAddr)> {
let request = ServerProcessorRequest::TcpBind {
machine_id,
local_address,
local_address: opt_local_address,
options,
};
let ServerProcessorResponse::TcpBind { socket_id } = self.perform_request(request).await?
let ServerProcessorReplyValue::TcpBind {
socket_id,
local_address,
} = self.perform_request(request).await?
else {
return Err(VirtualNetworkError::ResponseMismatch);
};
Ok(socket_id)
Ok((socket_id, local_address))
}
pub async fn tcp_accept(
self,
machine_id: MachineId,
socket_id: SocketId,
) -> VirtualNetworkResult<SocketId> {
listen_socket_id: SocketId,
) -> VirtualNetworkResult<(SocketId, SocketAddr)> {
let request = ServerProcessorRequest::TcpAccept {
machine_id,
socket_id,
listen_socket_id,
};
let ServerProcessorResponse::TcpAccept { child_socket_id } =
let ServerProcessorReplyValue::TcpAccept { socket_id, address } =
self.perform_request(request).await?
else {
return Err(VirtualNetworkError::ResponseMismatch);
};
Ok(child_socket_id)
Ok((socket_id, address))
}
pub async fn tcp_shutdown(
self,
machine_id: MachineId,
socket_id: SocketId,
) -> VirtualNetworkResult<()> {
let request = ServerProcessorRequest::TcpShutdown {
machine_id,
socket_id,
};
let ServerProcessorReplyValue::TcpShutdown = self.perform_request(request).await? else {
return Err(VirtualNetworkError::ResponseMismatch);
};
Ok(())
}
pub async fn udp_bind(
self,
machine_id: MachineId,
local_address: Option<SocketAddr>,
opt_local_address: Option<SocketAddr>,
options: VirtualUdpOptions,
) -> VirtualNetworkResult<SocketId> {
) -> VirtualNetworkResult<(SocketId, SocketAddr)> {
let request = ServerProcessorRequest::UdpBind {
machine_id,
local_address,
local_address: opt_local_address,
options,
};
let ServerProcessorResponse::UdpBind { socket_id } = self.perform_request(request).await?
let ServerProcessorReplyValue::UdpBind {
socket_id,
local_address,
} = self.perform_request(request).await?
else {
return Err(VirtualNetworkError::ResponseMismatch);
};
Ok(socket_id)
}
pub async fn close(
self,
machine_id: MachineId,
socket_id: SocketId,
) -> VirtualNetworkResult<()> {
let request = ServerProcessorRequest::Close {
machine_id,
socket_id,
};
let ServerProcessorResponse::Close = self.perform_request(request).await? else {
return Err(VirtualNetworkError::ResponseMismatch);
};
Ok(())
Ok((socket_id, local_address))
}
pub async fn send(
@ -350,7 +397,7 @@ impl RouterClient {
socket_id,
data,
};
let ServerProcessorResponse::Send { len } = self.perform_request(request).await? else {
let ServerProcessorReplyValue::Send { len } = self.perform_request(request).await? else {
return Err(VirtualNetworkError::ResponseMismatch);
};
Ok(len as usize)
@ -369,7 +416,7 @@ impl RouterClient {
data,
remote_address,
};
let ServerProcessorResponse::SendTo { len } = self.perform_request(request).await? else {
let ServerProcessorReplyValue::SendTo { len } = self.perform_request(request).await? else {
return Err(VirtualNetworkError::ResponseMismatch);
};
Ok(len as usize)
@ -379,14 +426,14 @@ impl RouterClient {
self,
machine_id: MachineId,
socket_id: u64,
len: u32,
len: usize,
) -> VirtualNetworkResult<Vec<u8>> {
let request = ServerProcessorRequest::Recv {
machine_id,
socket_id,
len,
len: len as u32,
};
let ServerProcessorResponse::Recv { data } = self.perform_request(request).await? else {
let ServerProcessorReplyValue::Recv { data } = self.perform_request(request).await? else {
return Err(VirtualNetworkError::ResponseMismatch);
};
Ok(data)
@ -396,14 +443,14 @@ impl RouterClient {
self,
machine_id: MachineId,
socket_id: u64,
len: u32,
len: usize,
) -> VirtualNetworkResult<(Vec<u8>, SocketAddr)> {
let request = ServerProcessorRequest::RecvFrom {
machine_id,
socket_id,
len,
len: len as u32,
};
let ServerProcessorResponse::RecvFrom {
let ServerProcessorReplyValue::RecvFrom {
data,
remote_address,
} = self.perform_request(request).await?
@ -417,17 +464,16 @@ impl RouterClient {
// Private implementation
fn new(
receiver: flume::Receiver<Bytes>,
sender: flume::Sender<Bytes>,
router_op_waiter: RouterOpWaiter<ServerProcessorReplyResult, ()>,
jh_handler: MustJoinHandle<()>,
stop_source: StopSource,
) -> RouterClient {
RouterClient {
unlocked_inner: Arc::new(RouterClientUnlockedInner {
receiver,
sender,
next_message_id: AtomicU64::new(0),
router_op_waiter: RouterOpWaiter::new(),
router_op_waiter,
}),
inner: Arc::new(Mutex::new(RouterClientInner {
jh_handler: Some(jh_handler),
@ -436,24 +482,60 @@ impl RouterClient {
}
}
fn report_closed_socket(&self, machine_id: MachineId, socket_id: SocketId) {
let command = ServerProcessorCommand::CloseSocket {
machine_id,
socket_id,
};
let command_vec = match to_stdvec(&command).map_err(VirtualNetworkError::SerializationError)
{
Ok(v) => Bytes::from(v),
Err(e) => {
error!("{}", e);
return;
}
};
if let Err(e) = self
.unlocked_inner
.sender
.send(command_vec)
.map_err(|_| VirtualNetworkError::IoError(io::ErrorKind::BrokenPipe))
{
error!("{}", e);
}
}
pub(super) fn drop_tcp_stream(&self, machine_id: MachineId, socket_id: SocketId) {
self.report_closed_socket(machine_id, socket_id);
}
pub(super) fn drop_tcp_listener(&self, machine_id: MachineId, socket_id: SocketId) {
self.report_closed_socket(machine_id, socket_id);
}
pub(super) fn drop_udp_socket(&self, machine_id: MachineId, socket_id: SocketId) {
self.report_closed_socket(machine_id, socket_id);
}
async fn perform_request(
&self,
request: ServerProcessorRequest,
) -> VirtualNetworkResult<ServerProcessorResponse> {
) -> VirtualNetworkResult<ServerProcessorReplyValue> {
let message_id = self
.unlocked_inner
.next_message_id
.fetch_add(1, Ordering::AcqRel);
let msg = ServerProcessorRequestMessage {
let command = ServerProcessorCommand::Message(ServerProcessorMessage {
message_id,
request,
};
let msg_vec =
Bytes::from(to_stdvec(&msg).map_err(VirtualNetworkError::SerializationError)?);
});
let command_vec =
Bytes::from(to_stdvec(&command).map_err(VirtualNetworkError::SerializationError)?);
self.unlocked_inner
.sender
.send_async(msg_vec)
.send_async(command_vec)
.await
.map_err(|_| VirtualNetworkError::IoError(io::ErrorKind::BrokenPipe))?;
let handle = self
@ -469,16 +551,16 @@ impl RouterClient {
.map_err(|_| VirtualNetworkError::WaitError)?;
match status {
ServerProcessorResponseStatus::Success(server_processor_response) => {
ServerProcessorReplyResult::Value(server_processor_response) => {
Ok(server_processor_response)
}
ServerProcessorResponseStatus::InvalidMachineId => {
ServerProcessorReplyResult::InvalidMachineId => {
Err(VirtualNetworkError::InvalidMachineId)
}
ServerProcessorResponseStatus::InvalidSocketId => {
ServerProcessorReplyResult::InvalidSocketId => {
Err(VirtualNetworkError::InvalidSocketId)
}
ServerProcessorResponseStatus::IoError(k) => Err(VirtualNetworkError::IoError(k)),
ServerProcessorReplyResult::IoError(k) => Err(VirtualNetworkError::IoError(k)),
}
}
@ -486,7 +568,7 @@ impl RouterClient {
reader: R,
writer: W,
receiver: flume::Receiver<Bytes>,
sender: flume::Sender<Bytes>,
router_op_waiter: RouterOpWaiter<ServerProcessorReplyResult, ()>,
stop_token: StopToken,
) where
R: AsyncReadExt + Unpin + Send,
@ -503,10 +585,27 @@ impl RouterClient {
}
});
let framed_reader_fut = system_boxed(async move {
if let Err(e) = framed_reader
.forward(sender.into_sink().sink_map_err(::std::io::Error::other))
.await
{
let fut = framed_reader.try_for_each(|x| async {
let x = x;
let evt = from_bytes::<ServerProcessorEvent>(&x)
.map_err(VirtualNetworkError::SerializationError)?;
match evt {
ServerProcessorEvent::Reply(reply) => {
router_op_waiter
.complete_op_waiter(reply.message_id, reply.status)
.map_err(io::Error::other)?;
} // ServerProcessorEvent::DeadSocket {
// machine_id,
// socket_id,
// } => {
// //
// }
}
Ok(())
});
if let Err(e) = fut.await {
error!("{}", e);
}
});

View File

@ -25,16 +25,6 @@ where
result_receiver: Option<flume::Receiver<T>>,
}
impl<T, C> RouterOpWaitHandle<T, C>
where
T: Unpin,
C: Unpin + Clone,
{
pub fn id(&self) -> RouterOpId {
self.op_id
}
}
impl<T, C> Drop for RouterOpWaitHandle<T, C>
where
T: Unpin,
@ -123,6 +113,7 @@ where
}
/// Get operation context
#[expect(dead_code)]
pub fn get_op_context(&self, op_id: RouterOpId) -> Result<C, RouterOpWaitError<T>> {
let inner = self.inner.lock();
let Some(waiting_op) = inner.waiting_op_table.get(&op_id) else {
@ -132,14 +123,12 @@ where
}
/// Remove wait for op
#[instrument(level = "trace", target = "rpc", skip_all)]
fn cancel_op_waiter(&self, op_id: RouterOpId) {
let mut inner = self.inner.lock();
inner.waiting_op_table.remove(&op_id);
}
/// Complete the waiting op
#[instrument(level = "trace", target = "rpc", skip_all)]
pub fn complete_op_waiter(
&self,
op_id: RouterOpId,
@ -159,7 +148,6 @@ where
}
/// Wait for operation to complete
#[instrument(level = "trace", target = "rpc", skip_all)]
pub async fn wait_for_op(
&self,
mut handle: RouterOpWaitHandle<T, C>,

View File

@ -0,0 +1,67 @@
use super::*;
#[derive(Debug)]
pub struct VirtualTcpListener {
pub(super) machine: Machine,
pub(super) socket_id: SocketId,
pub(super) local_address: SocketAddr,
}
impl VirtualTcpListener {
/////////////////////////////////////////////////////////////
// Public Interface
pub async fn bind(
opt_local_address: Option<SocketAddr>,
options: VirtualTcpOptions,
) -> VirtualNetworkResult<Self> {
let machine = default_machine().unwrap();
Self::bind_with_machine(machine, opt_local_address, options).await
}
pub async fn bind_with_machine(
machine: Machine,
opt_local_address: Option<SocketAddr>,
options: VirtualTcpOptions,
) -> VirtualNetworkResult<Self> {
machine
.router_client
.clone()
.tcp_bind(machine.id, opt_local_address, options)
.await
.map(|(socket_id, local_address)| Self::new(machine, socket_id, local_address))
}
pub async fn accept(&self) -> VirtualNetworkResult<(VirtualTcpStream, SocketAddr)> {
self.machine
.router_client
.clone()
.tcp_accept(self.machine.id, self.socket_id)
.await
.map(|v| {
(
VirtualTcpStream::new(self.machine.clone(), v.0, self.local_address, v.1),
v.1,
)
})
}
/////////////////////////////////////////////////////////////
// Private Implementation
fn new(machine: Machine, socket_id: SocketId, local_address: SocketAddr) -> Self {
Self {
machine,
socket_id,
local_address,
}
}
}
impl Drop for VirtualTcpListener {
fn drop(&mut self) {
self.machine
.router_client
.drop_tcp_listener(self.machine.id, self.socket_id);
}
}

View File

@ -0,0 +1,86 @@
use super::*;
use core::pin::Pin;
use core::task::{Context, Poll};
use futures_util::{stream::Stream, FutureExt};
use std::io;
/// A wrapper around [`VirtualTcpListener`] that implements [`Stream`].
///
/// [`VirtualTcpListener`]: struct@crate::VirtualTcpListener
/// [`Stream`]: trait@futures_util::stream::Stream
pub struct VirtualTcpListenerStream {
inner: VirtualTcpListener,
current_accept_fut: Option<SendPinBoxFuture<VirtualNetworkResult<(SocketId, SocketAddr)>>>,
}
impl fmt::Debug for VirtualTcpListenerStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("VirtualTcpListenerStream")
.field("inner", &self.inner)
.field(
"current_accept_fut",
if self.current_accept_fut.is_some() {
&"Some(...)"
} else {
&"None"
},
)
.finish()
}
}
impl VirtualTcpListenerStream {
/// Create a new `VirtualTcpListenerStream`.
pub fn new(listener: VirtualTcpListener) -> Self {
Self {
inner: listener,
current_accept_fut: None,
}
}
/// Get back the inner `VirtualTcpListener`.
pub fn into_inner(self) -> VirtualTcpListener {
self.inner
}
}
impl Stream for VirtualTcpListenerStream {
type Item = io::Result<VirtualTcpStream>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<io::Result<VirtualTcpStream>>> {
if self.current_accept_fut.is_none() {
let machine_id = self.inner.machine.id;
let router_client = self.inner.machine.router_client.clone();
let socket_id = self.inner.socket_id;
self.current_accept_fut =
Some(Box::pin(router_client.tcp_accept(machine_id, socket_id)));
}
let fut = self.current_accept_fut.as_mut().unwrap();
fut.poll_unpin(cx).map(|v| match v {
Ok(v) => Some(Ok(VirtualTcpStream::new(
self.inner.machine.clone(),
v.0,
self.inner.local_address,
v.1,
))),
Err(e) => Some(Err(e.into())),
})
}
}
impl AsRef<VirtualTcpListener> for VirtualTcpListenerStream {
fn as_ref(&self) -> &VirtualTcpListener {
&self.inner
}
}
impl AsMut<VirtualTcpListener> for VirtualTcpListenerStream {
fn as_mut(&mut self) -> &mut VirtualTcpListener {
&mut self.inner
}
}

View File

@ -13,12 +13,52 @@ pub struct VirtualTcpOptions {
pub struct VirtualTcpStream {
machine: Machine,
socket_id: SocketId,
local_address: SocketAddr,
remote_address: SocketAddr,
current_recv_fut: Option<SendPinBoxFuture<Result<Vec<u8>, VirtualNetworkError>>>,
current_send_fut: Option<SendPinBoxFuture<Result<usize, VirtualNetworkError>>>,
current_close_fut: Option<SendPinBoxFuture<Result<(), VirtualNetworkError>>>,
current_tcp_shutdown_fut: Option<SendPinBoxFuture<Result<(), VirtualNetworkError>>>,
}
impl fmt::Debug for VirtualTcpStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("VirtualTcpStream")
.field("machine", &self.machine)
.field("socket_id", &self.socket_id)
.field("local_address", &self.local_address)
.field("remote_address", &self.remote_address)
.field(
"current_recv_fut",
if self.current_recv_fut.is_some() {
&"Some(...)"
} else {
&"None"
},
)
.field(
"current_send_fut",
if self.current_send_fut.is_some() {
&"Some(...)"
} else {
&"None"
},
)
.field(
"current_close_fut",
if self.current_tcp_shutdown_fut.is_some() {
&"Some(...)"
} else {
&"None"
},
)
.finish()
}
}
impl VirtualTcpStream {
//////////////////////////////////////////////////////////////////////////
// Public Interface
pub async fn connect(
remote_address: SocketAddr,
local_address: Option<SocketAddr>,
@ -48,14 +88,46 @@ impl VirtualTcpStream {
options,
)
.await
.map(|socket_id| Self {
machine,
socket_id,
current_recv_fut: None,
current_send_fut: None,
current_close_fut: None,
.map(|(socket_id, local_address)| {
Self::new(machine, socket_id, local_address, remote_address)
})
}
pub fn local_addr(&self) -> VirtualNetworkResult<SocketAddr> {
Ok(self.local_address)
}
pub fn peer_addr(&self) -> VirtualNetworkResult<SocketAddr> {
Ok(self.remote_address)
}
//////////////////////////////////////////////////////////////////////////
// Private Implementation
pub(super) fn new(
machine: Machine,
socket_id: SocketId,
local_address: SocketAddr,
remote_address: SocketAddr,
) -> Self {
Self {
machine,
socket_id,
local_address,
remote_address,
current_recv_fut: None,
current_send_fut: None,
current_tcp_shutdown_fut: None,
}
}
}
impl Drop for VirtualTcpStream {
fn drop(&mut self) {
self.machine
.router_client
.drop_tcp_stream(self.machine.id, self.socket_id);
}
}
impl futures_util::AsyncRead for VirtualTcpStream {
@ -68,7 +140,7 @@ impl futures_util::AsyncRead for VirtualTcpStream {
self.current_recv_fut = Some(Box::pin(self.machine.router_client.clone().recv(
self.machine.id,
self.socket_id,
buf.len() as u32,
buf.len(),
)));
}
let fut = self.current_recv_fut.as_mut().unwrap();
@ -118,18 +190,18 @@ impl futures_util::AsyncWrite for VirtualTcpStream {
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<std::io::Result<()>> {
if self.current_close_fut.is_none() {
self.current_close_fut = Some(Box::pin(
if self.current_tcp_shutdown_fut.is_none() {
self.current_tcp_shutdown_fut = Some(Box::pin(
self.machine
.router_client
.clone()
.close(self.machine.id, self.socket_id),
.tcp_shutdown(self.machine.id, self.socket_id),
));
}
let fut = self.current_close_fut.as_mut().unwrap();
let fut = self.current_tcp_shutdown_fut.as_mut().unwrap();
fut.poll_unpin(cx).map(|v| match v {
Ok(v) => {
self.current_close_fut = None;
self.current_tcp_shutdown_fut = None;
Ok(v)
}
Err(e) => Err(e.into()),

View File

@ -7,50 +7,80 @@ pub struct VirtualUdpOptions {
reuse_address_port: bool,
}
#[derive(Debug)]
pub struct VirtualUdpSocket {
machine: Machine,
socket_id: SocketId,
local_address: SocketAddr,
}
impl VirtualUdpSocket {
// pub async fn connect(
// remote_address: SocketAddr,
// local_address: Option<SocketAddr>,
// timeout_ms: u32,
// options: VirtualTcpOptions,
// ) -> VirtualNetworkResult<Self> {
// let machine = default_machine().unwrap();
// Self::connect_with_machine(machine, remote_address, local_address, timeout_ms, options)
// .await
// }
/////////////////////////////////////////////////////////////
// Public Interface
// pub async fn connect_with_machine(
// machine: Machine,
// remote_address: SocketAddr,
// local_address: Option<SocketAddr>,
// timeout_ms: u32,
// options: VirtualTcpOptions,
// ) -> VirtualNetworkResult<Self> {
// machine
// .router_client
// .tcp_connect(
// machine.id,
// remote_address,
// local_address,
// timeout_ms,
// options,
// )
// .await
// .map(|socket_id| Self { machine, socket_id })
// }
pub async fn bind(
opt_local_address: Option<SocketAddr>,
options: VirtualUdpOptions,
) -> VirtualNetworkResult<Self> {
let machine = default_machine().unwrap();
Self::bind_with_machine(machine, opt_local_address, options).await
}
pub async fn bind_with_machine(
machine: Machine,
opt_local_address: Option<SocketAddr>,
options: VirtualUdpOptions,
) -> VirtualNetworkResult<Self> {
machine
.router_client
.clone()
.udp_bind(machine.id, opt_local_address, options)
.await
.map(|(socket_id, local_address)| Self::new(machine, socket_id, local_address))
}
pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> VirtualNetworkResult<usize> {
self.machine
.router_client
.clone()
.send_to(self.machine.id, self.socket_id, target, buf.to_vec())
.await
}
pub async fn recv_from(&self, buf: &mut [u8]) -> VirtualNetworkResult<(usize, SocketAddr)> {
let (v, addr) = self
.machine
.router_client
.clone()
.recv_from(self.machine.id, self.socket_id, buf.len())
.await?;
let len = usize::min(buf.len(), v.len());
buf[0..len].copy_from_slice(&v[0..len]);
Ok((len, addr))
}
pub fn local_addr(&self) -> VirtualNetworkResult<SocketAddr> {
Ok(self.local_address)
}
/////////////////////////////////////////////////////////////
// Private Implementation
fn new(machine: Machine, socket_id: SocketId, local_address: SocketAddr) -> Self {
Self {
machine,
socket_id,
local_address,
}
}
}
// impl futures_util::AsyncRead for VirtualUdpSocket {
// fn poll_read(
// self: Pin<&mut Self>,
// cx: &mut task::Context<'_>,
// buf: &mut [u8],
// ) -> task::Poll<std::io::Result<usize>> {
// todo!()
// }
// }
impl Drop for VirtualUdpSocket {
fn drop(&mut self) {
self.machine
.router_client
.drop_udp_socket(self.machine.id, self.socket_id);
}
}

View File

@ -66,13 +66,6 @@ pub fn is_ipv6_supported() -> bool {
if let Some(supp) = *opt_supp {
return supp;
}
// let supp = match UdpSocket::bind(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 0, 0, 0)) {
// Ok(_) => true,
// Err(e) => !matches!(
// e.kind(),
// std::io::ErrorKind::AddrNotAvailable | std::io::ErrorKind::Unsupported
// ),
// };
// XXX: See issue #92
let supp = false;