mirror of
https://gitlab.com/veilid/veilid.git
synced 2025-01-12 15:59:52 -05:00
refactor network into veilid-tools
fix deadlock in debug command 'entries fastest'
This commit is contained in:
parent
ad747d7831
commit
547427271c
4
Cargo.lock
generated
4
Cargo.lock
generated
@ -6299,7 +6299,6 @@ name = "veilid-core"
|
||||
version = "0.4.1"
|
||||
dependencies = [
|
||||
"argon2",
|
||||
"async-io 1.13.0",
|
||||
"async-std",
|
||||
"async-std-resolver",
|
||||
"async-tls",
|
||||
@ -6366,7 +6365,6 @@ dependencies = [
|
||||
"sha2 0.10.8",
|
||||
"shell-words",
|
||||
"simplelog",
|
||||
"socket2 0.5.7",
|
||||
"static_assertions",
|
||||
"stop-token",
|
||||
"sysinfo",
|
||||
@ -6518,6 +6516,7 @@ name = "veilid-tools"
|
||||
version = "0.4.1"
|
||||
dependencies = [
|
||||
"android_logger 0.13.3",
|
||||
"async-io 1.13.0",
|
||||
"async-lock 3.4.0",
|
||||
"async-std",
|
||||
"async_executors",
|
||||
@ -6556,6 +6555,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serial_test 2.0.0",
|
||||
"simplelog",
|
||||
"socket2 0.5.7",
|
||||
"static_assertions",
|
||||
"stop-token",
|
||||
"thiserror",
|
||||
|
@ -223,13 +223,12 @@ impl ClientApiConnection {
|
||||
trace!("ClientApiConnection::handle_tcp_connection");
|
||||
|
||||
// Connect the TCP socket
|
||||
let stream = TcpStream::connect(connect_addr)
|
||||
let stream = connect_async_tcp_stream(None, connect_addr, 10_000)
|
||||
.await
|
||||
.map_err(map_to_string)?
|
||||
.into_timeout_error()
|
||||
.map_err(map_to_string)?;
|
||||
|
||||
// If it succeed, disable nagle algorithm
|
||||
stream.set_nodelay(true).map_err(map_to_string)?;
|
||||
|
||||
// State we connected
|
||||
let comproc = self.inner.lock().comproc.clone();
|
||||
comproc.set_connection_state(ConnectionState::ConnectedTCP(
|
||||
@ -239,16 +238,8 @@ impl ClientApiConnection {
|
||||
|
||||
// Split into reader and writer halves
|
||||
// with line buffering on the reader
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
use futures::AsyncReadExt;
|
||||
let (reader, writer) = stream.split();
|
||||
let reader = BufReader::new(reader);
|
||||
} else {
|
||||
let (reader, writer) = stream.into_split();
|
||||
let reader = BufReader::new(reader);
|
||||
}
|
||||
}
|
||||
let (reader, writer) = split_async_tcp_stream(stream);
|
||||
let reader = BufReader::new(reader);
|
||||
|
||||
self.clone().run_json_api_processor(reader, writer).await
|
||||
}
|
||||
|
@ -56,6 +56,7 @@ veilid_core_ios_tests = ["dep:tracing-oslog"]
|
||||
debug-locks = ["veilid-tools/debug-locks"]
|
||||
unstable-blockstore = []
|
||||
unstable-tunnels = []
|
||||
virtual-network = []
|
||||
|
||||
# GeoIP
|
||||
geolocation = ["maxminddb", "reqwest"]
|
||||
@ -164,7 +165,6 @@ sysinfo = { version = "^0.30.13", default-features = false }
|
||||
tokio = { version = "1.38.1", features = ["full"], optional = true }
|
||||
tokio-util = { version = "0.7.11", features = ["compat"], optional = true }
|
||||
tokio-stream = { version = "0.1.15", features = ["net"], optional = true }
|
||||
async-io = { version = "1.13.0" }
|
||||
futures-util = { version = "0.3.30", default-features = false, features = [
|
||||
"async-await",
|
||||
"sink",
|
||||
@ -184,7 +184,6 @@ webpki = "0.22.4"
|
||||
webpki-roots = "0.25.4"
|
||||
rustls = "0.21.12"
|
||||
rustls-pemfile = "1.0.4"
|
||||
socket2 = { version = "0.5.7", features = ["all"] }
|
||||
|
||||
# Dependencies for WASM builds only
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
|
@ -1,6 +1,5 @@
|
||||
use super::*;
|
||||
use async_tls::TlsAcceptor;
|
||||
use sockets::*;
|
||||
use stop_token::future::FutureExt;
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
@ -135,39 +134,12 @@ impl Network {
|
||||
}
|
||||
};
|
||||
|
||||
#[cfg(all(feature = "rt-async-std", unix))]
|
||||
if let Err(e) = set_tcp_stream_linger(&tcp_stream, Some(core::time::Duration::from_secs(0)))
|
||||
{
|
||||
// async-std does not directly support linger on TcpStream yet
|
||||
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd};
|
||||
if let Err(e) = unsafe {
|
||||
let s = socket2::Socket::from_raw_fd(tcp_stream.as_raw_fd());
|
||||
let res = s.set_linger(Some(core::time::Duration::from_secs(0)));
|
||||
s.into_raw_fd();
|
||||
res
|
||||
} {
|
||||
log_net!(debug "Couldn't set TCP linger: {}", e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#[cfg(all(feature = "rt-async-std", windows))]
|
||||
{
|
||||
// async-std does not directly support linger on TcpStream yet
|
||||
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket};
|
||||
if let Err(e) = unsafe {
|
||||
let s = socket2::Socket::from_raw_socket(tcp_stream.as_raw_socket());
|
||||
let res = s.set_linger(Some(core::time::Duration::from_secs(0)));
|
||||
s.into_raw_socket();
|
||||
res
|
||||
} {
|
||||
log_net!(debug "Couldn't set TCP linger: {}", e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "rt-async-std"))]
|
||||
if let Err(e) = tcp_stream.set_linger(Some(core::time::Duration::from_secs(0))) {
|
||||
log_net!(debug "Couldn't set TCP linger: {}", e);
|
||||
return;
|
||||
}
|
||||
|
||||
if let Err(e) = tcp_stream.set_nodelay(true) {
|
||||
log_net!(debug "Couldn't set TCP nodelay: {}", e);
|
||||
return;
|
||||
@ -257,41 +229,11 @@ impl Network {
|
||||
)
|
||||
};
|
||||
|
||||
// Create a socket and bind it
|
||||
let Some(socket) = new_bound_default_tcp_socket(addr)
|
||||
.wrap_err("failed to create default socket listener")?
|
||||
else {
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
// Drop the socket
|
||||
drop(socket);
|
||||
|
||||
// Create a shared socket and bind it once we have determined the port is free
|
||||
let Some(socket) = new_bound_shared_tcp_socket(addr)
|
||||
.wrap_err("failed to create shared socket listener")?
|
||||
else {
|
||||
let Some(listener) = bind_async_tcp_listener(addr)? else {
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
// Listen on the socket
|
||||
if socket.listen(128).is_err() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
// Make an async tcplistener from the socket2 socket
|
||||
let std_listener: std::net::TcpListener = socket.into();
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
let listener = TcpListener::from(std_listener);
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
std_listener.set_nonblocking(true).expect("failed to set nonblocking");
|
||||
let listener = TcpListener::from_std(std_listener).wrap_err("failed to create tokio tcp listener")?;
|
||||
} else {
|
||||
compile_error!("needs executor implementation");
|
||||
}
|
||||
}
|
||||
|
||||
log_net!(debug "spawn_socket_listener: binding successful to {}", addr);
|
||||
|
||||
// Create protocol handler records
|
||||
@ -311,15 +253,7 @@ impl Network {
|
||||
// moves listener object in and get incoming iterator
|
||||
// when this task exists, the listener will close the socket
|
||||
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
let incoming_stream = listener.incoming();
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
let incoming_stream = tokio_stream::wrappers::TcpListenerStream::new(listener);
|
||||
} else {
|
||||
compile_error!("needs executor implementation");
|
||||
}
|
||||
}
|
||||
let incoming_stream = async_tcp_listener_incoming(listener);
|
||||
|
||||
let _ = incoming_stream
|
||||
.for_each_concurrent(None, |tcp_stream| {
|
||||
|
@ -1,5 +1,4 @@
|
||||
use super::*;
|
||||
use sockets::*;
|
||||
use stop_token::future::FutureExt;
|
||||
|
||||
impl Network {
|
||||
@ -114,23 +113,10 @@ impl Network {
|
||||
async fn create_udp_protocol_handler(&self, addr: SocketAddr) -> EyreResult<bool> {
|
||||
log_net!(debug "create_udp_protocol_handler on {:?}", &addr);
|
||||
|
||||
// Create a reusable socket
|
||||
let Some(socket) = new_bound_default_udp_socket(addr)? else {
|
||||
// Create a single-address-family UDP socket with default options bound to an address
|
||||
let Some(udp_socket) = bind_async_udp_socket(addr)? else {
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
// Make an async UdpSocket from the socket2 socket
|
||||
let std_udp_socket: std::net::UdpSocket = socket.into();
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
let udp_socket = UdpSocket::from(std_udp_socket);
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
std_udp_socket.set_nonblocking(true).expect("failed to set nonblocking");
|
||||
let udp_socket = UdpSocket::from_std(std_udp_socket).wrap_err("failed to make inbound tokio udpsocket")?;
|
||||
} else {
|
||||
compile_error!("needs executor implementation");
|
||||
}
|
||||
}
|
||||
let socket_arc = Arc::new(udp_socket);
|
||||
|
||||
// Create protocol handler
|
||||
|
@ -1,4 +1,3 @@
|
||||
pub mod sockets;
|
||||
pub mod tcp;
|
||||
pub mod udp;
|
||||
pub mod wrtc;
|
||||
|
@ -1,191 +0,0 @@
|
||||
use crate::*;
|
||||
use async_io::Async;
|
||||
use std::io;
|
||||
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
pub use async_std::net::{TcpStream, TcpListener, UdpSocket};
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
pub use tokio::net::{TcpStream, TcpListener, UdpSocket};
|
||||
pub use tokio_util::compat::*;
|
||||
} else {
|
||||
compile_error!("needs executor implementation");
|
||||
}
|
||||
}
|
||||
|
||||
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
|
||||
|
||||
// cfg_if! {
|
||||
// if #[cfg(windows)] {
|
||||
// use winapi::shared::ws2def::{ SOL_SOCKET, SO_EXCLUSIVEADDRUSE};
|
||||
// use winapi::um::winsock2::{SOCKET_ERROR, setsockopt};
|
||||
// use winapi::ctypes::c_int;
|
||||
// use std::os::windows::io::AsRawSocket;
|
||||
|
||||
// fn set_exclusiveaddruse(socket: &Socket) -> io::Result<()> {
|
||||
// unsafe {
|
||||
// let optval:c_int = 1;
|
||||
// if setsockopt(socket.as_raw_socket().try_into().unwrap(), SOL_SOCKET, SO_EXCLUSIVEADDRUSE, (&optval as *const c_int).cast(),
|
||||
// std::mem::size_of::<c_int>() as c_int) == SOCKET_ERROR {
|
||||
// return Err(io::Error::last_os_error());
|
||||
// }
|
||||
// Ok(())
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
#[instrument(level = "trace", ret)]
|
||||
pub fn new_shared_udp_socket(domain: Domain) -> io::Result<Socket> {
|
||||
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
|
||||
if domain == Domain::IPV6 {
|
||||
socket.set_only_v6(true)?;
|
||||
}
|
||||
socket.set_reuse_address(true)?;
|
||||
|
||||
cfg_if! {
|
||||
if #[cfg(unix)] {
|
||||
socket.set_reuse_port(true)?;
|
||||
}
|
||||
}
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", ret)]
|
||||
pub fn new_default_udp_socket(domain: Domain) -> io::Result<Socket> {
|
||||
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
|
||||
if domain == Domain::IPV6 {
|
||||
socket.set_only_v6(true)?;
|
||||
}
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", ret)]
|
||||
pub fn new_bound_default_udp_socket(local_address: SocketAddr) -> io::Result<Option<Socket>> {
|
||||
let domain = Domain::for_address(local_address);
|
||||
let socket = new_default_udp_socket(domain)?;
|
||||
let socket2_addr = SockAddr::from(local_address);
|
||||
|
||||
if socket.bind(&socket2_addr).is_err() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
log_net!("created bound default udp socket on {:?}", &local_address);
|
||||
|
||||
Ok(Some(socket))
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", ret)]
|
||||
pub fn new_default_tcp_socket(domain: Domain) -> io::Result<Socket> {
|
||||
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
|
||||
if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) {
|
||||
log_net!(error "Couldn't set TCP linger: {}", e);
|
||||
}
|
||||
if let Err(e) = socket.set_nodelay(true) {
|
||||
log_net!(error "Couldn't set TCP nodelay: {}", e);
|
||||
}
|
||||
if domain == Domain::IPV6 {
|
||||
socket.set_only_v6(true)?;
|
||||
}
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", ret)]
|
||||
pub fn new_shared_tcp_socket(domain: Domain) -> io::Result<Socket> {
|
||||
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
|
||||
if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) {
|
||||
log_net!(error "Couldn't set TCP linger: {}", e);
|
||||
}
|
||||
if let Err(e) = socket.set_nodelay(true) {
|
||||
log_net!(error "Couldn't set TCP nodelay: {}", e);
|
||||
}
|
||||
if domain == Domain::IPV6 {
|
||||
socket.set_only_v6(true)?;
|
||||
}
|
||||
socket.set_reuse_address(true)?;
|
||||
cfg_if! {
|
||||
if #[cfg(unix)] {
|
||||
socket.set_reuse_port(true)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
#[instrument(level = "trace", ret)]
|
||||
pub fn new_bound_default_tcp_socket(local_address: SocketAddr) -> io::Result<Option<Socket>> {
|
||||
let domain = Domain::for_address(local_address);
|
||||
let socket = new_default_tcp_socket(domain)?;
|
||||
let socket2_addr = SockAddr::from(local_address);
|
||||
if socket.bind(&socket2_addr).is_err() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
log_net!("created bound default tcp socket on {:?}", &local_address);
|
||||
|
||||
Ok(Some(socket))
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", ret)]
|
||||
pub fn new_bound_shared_tcp_socket(local_address: SocketAddr) -> io::Result<Option<Socket>> {
|
||||
let domain = Domain::for_address(local_address);
|
||||
let socket = new_shared_tcp_socket(domain)?;
|
||||
let socket2_addr = SockAddr::from(local_address);
|
||||
if socket.bind(&socket2_addr).is_err() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
log_net!("created bound shared tcp socket on {:?}", &local_address);
|
||||
|
||||
Ok(Some(socket))
|
||||
}
|
||||
|
||||
// Non-blocking connect is tricky when you want to start with a prepared socket
|
||||
// Errors should not be logged as they are valid conditions for this function
|
||||
#[instrument(level = "trace", ret)]
|
||||
pub async fn nonblocking_connect(
|
||||
socket: Socket,
|
||||
addr: SocketAddr,
|
||||
timeout_ms: u32,
|
||||
) -> io::Result<TimeoutOr<TcpStream>> {
|
||||
// Set for non blocking connect
|
||||
socket.set_nonblocking(true)?;
|
||||
|
||||
// Make socket2 SockAddr
|
||||
let socket2_addr = socket2::SockAddr::from(addr);
|
||||
|
||||
// Connect to the remote address
|
||||
match socket.connect(&socket2_addr) {
|
||||
Ok(()) => Ok(()),
|
||||
#[cfg(unix)]
|
||||
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
|
||||
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(()),
|
||||
Err(e) => Err(e),
|
||||
}?;
|
||||
let async_stream = Async::new(std::net::TcpStream::from(socket))?;
|
||||
|
||||
// The stream becomes writable when connected
|
||||
timeout_or_try!(
|
||||
timeout(timeout_ms, async_stream.writable().in_current_span())
|
||||
.await
|
||||
.into_timeout_or()
|
||||
.into_result()?
|
||||
);
|
||||
|
||||
// Check low level error
|
||||
let async_stream = match async_stream.get_ref().take_error()? {
|
||||
None => Ok(async_stream),
|
||||
Some(err) => Err(err),
|
||||
}?;
|
||||
|
||||
// Convert back to inner and then return async version
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
Ok(TimeoutOr::value(TcpStream::from(async_stream.into_inner()?)))
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
Ok(TimeoutOr::value(TcpStream::from_std(async_stream.into_inner()?)?))
|
||||
} else {
|
||||
compile_error!("needs executor implementation");
|
||||
}
|
||||
}
|
||||
}
|
@ -1,6 +1,5 @@
|
||||
use super::*;
|
||||
use futures_util::{AsyncReadExt, AsyncWriteExt};
|
||||
use sockets::*;
|
||||
|
||||
pub struct RawTcpNetworkConnection {
|
||||
flow: Flow,
|
||||
@ -157,32 +156,28 @@ impl RawTcpProtocolHandler {
|
||||
#[instrument(level = "trace", target = "protocol", err)]
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
socket_addr: SocketAddr,
|
||||
remote_address: SocketAddr,
|
||||
timeout_ms: u32,
|
||||
) -> io::Result<NetworkResult<ProtocolNetworkConnection>> {
|
||||
// Make a shared socket
|
||||
let socket = match local_address {
|
||||
Some(a) => {
|
||||
new_bound_shared_tcp_socket(a)?.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?
|
||||
}
|
||||
None => new_default_tcp_socket(socket2::Domain::for_address(socket_addr))?,
|
||||
};
|
||||
|
||||
// Non-blocking connect to remote address
|
||||
let ts = network_result_try!(nonblocking_connect(socket, socket_addr, timeout_ms)
|
||||
.await
|
||||
.folded()?);
|
||||
let tcp_stream = network_result_try!(connect_async_tcp_stream(
|
||||
local_address,
|
||||
remote_address,
|
||||
timeout_ms
|
||||
)
|
||||
.await
|
||||
.folded()?);
|
||||
|
||||
// See what local address we ended up with and turn this into a stream
|
||||
let actual_local_address = ts.local_addr()?;
|
||||
let actual_local_address = tcp_stream.local_addr()?;
|
||||
#[cfg(feature = "rt-tokio")]
|
||||
let ts = ts.compat();
|
||||
let ps = AsyncPeekStream::new(ts);
|
||||
let tcp_stream = tcp_stream.compat();
|
||||
let ps = AsyncPeekStream::new(tcp_stream);
|
||||
|
||||
// Wrap the stream in a network connection and return it
|
||||
let flow = Flow::new(
|
||||
PeerAddress::new(
|
||||
SocketAddress::from_socket_addr(socket_addr),
|
||||
SocketAddress::from_socket_addr(remote_address),
|
||||
ProtocolType::TCP,
|
||||
),
|
||||
SocketAddress::from_socket_addr(actual_local_address),
|
||||
|
@ -1,5 +1,4 @@
|
||||
use super::*;
|
||||
use sockets::*;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RawUdpProtocolHandler {
|
||||
@ -141,7 +140,8 @@ impl RawUdpProtocolHandler {
|
||||
) -> io::Result<RawUdpProtocolHandler> {
|
||||
// get local wildcard address for bind
|
||||
let local_socket_addr = compatible_unspecified_socket_addr(socket_addr);
|
||||
let socket = UdpSocket::bind(local_socket_addr).await?;
|
||||
let socket = bind_async_udp_socket(local_socket_addr)?
|
||||
.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?;
|
||||
Ok(RawUdpProtocolHandler::new(Arc::new(socket), None))
|
||||
}
|
||||
}
|
||||
|
@ -10,7 +10,6 @@ use async_tungstenite::tungstenite::protocol::{frame::coding::CloseCode, CloseFr
|
||||
use async_tungstenite::tungstenite::Error;
|
||||
use async_tungstenite::{accept_hdr_async, client_async, WebSocketStream};
|
||||
use futures_util::{AsyncRead, AsyncWrite, SinkExt};
|
||||
use sockets::*;
|
||||
|
||||
// Maximum number of websocket request headers to permit
|
||||
const MAX_WS_HEADERS: usize = 24;
|
||||
@ -316,21 +315,16 @@ impl WebsocketProtocolHandler {
|
||||
let domain = split_url.host.clone();
|
||||
|
||||
// Resolve remote address
|
||||
let remote_socket_addr = dial_info.to_socket_addr();
|
||||
|
||||
// Make a shared socket
|
||||
let socket = match local_address {
|
||||
Some(a) => {
|
||||
new_bound_shared_tcp_socket(a)?.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?
|
||||
}
|
||||
None => new_default_tcp_socket(socket2::Domain::for_address(remote_socket_addr))?,
|
||||
};
|
||||
let remote_address = dial_info.to_socket_addr();
|
||||
|
||||
// Non-blocking connect to remote address
|
||||
let tcp_stream =
|
||||
network_result_try!(nonblocking_connect(socket, remote_socket_addr, timeout_ms)
|
||||
.await
|
||||
.folded()?);
|
||||
let tcp_stream = network_result_try!(connect_async_tcp_stream(
|
||||
local_address,
|
||||
remote_address,
|
||||
timeout_ms
|
||||
)
|
||||
.await
|
||||
.folded()?);
|
||||
|
||||
// See what local address we ended up with
|
||||
let actual_local_addr = tcp_stream.local_addr()?;
|
||||
|
@ -333,9 +333,10 @@ impl RoutingTable {
|
||||
relaying_count += 1;
|
||||
}
|
||||
|
||||
let best_node_id = node.best_node_id();
|
||||
|
||||
out += " ";
|
||||
out += &node
|
||||
.operate(|_rti, e| Self::format_entry(cur_ts, node.best_node_id(), e, &relay_tag));
|
||||
out += &node.operate(|_rti, e| Self::format_entry(cur_ts, best_node_id, e, &relay_tag));
|
||||
out += "\n";
|
||||
}
|
||||
|
||||
|
@ -165,17 +165,12 @@ impl ClientApi {
|
||||
}
|
||||
|
||||
async fn handle_tcp_incoming(self, bind_addr: SocketAddr) -> std::io::Result<()> {
|
||||
let listener = TcpListener::bind(bind_addr).await?;
|
||||
let listener = bind_async_tcp_listener(bind_addr)?
|
||||
.ok_or(std::io::Error::from(std::io::ErrorKind::AddrInUse))?;
|
||||
debug!(target: "client_api", "TCPClient API listening on: {:?}", bind_addr);
|
||||
|
||||
// Process the incoming accept stream
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
let mut incoming_stream = listener.incoming();
|
||||
} else {
|
||||
let mut incoming_stream = tokio_stream::wrappers::TcpListenerStream::new(listener);
|
||||
}
|
||||
}
|
||||
let mut incoming_stream = async_tcp_listener_incoming(listener);
|
||||
|
||||
// Make wait group for all incoming connections
|
||||
let awg = AsyncWaitGroup::new();
|
||||
|
@ -7,8 +7,6 @@ pub use tracing::*;
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
// pub use async_std::task::JoinHandle;
|
||||
pub use async_std::net::TcpListener;
|
||||
pub use async_std::net::TcpStream;
|
||||
pub use async_std::io::BufReader;
|
||||
//pub use async_std::future::TimeoutError;
|
||||
//pub fn spawn_detached<F: Future<Output = T> + Send + 'static, T: Send + 'static>(f: F) -> JoinHandle<T> {
|
||||
@ -27,8 +25,6 @@ cfg_if! {
|
||||
}
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
//pub use tokio::task::JoinHandle;
|
||||
pub use tokio::net::TcpListener;
|
||||
pub use tokio::net::TcpStream;
|
||||
pub use tokio::io::BufReader;
|
||||
//pub use tokio_util::compat::*;
|
||||
//pub use tokio::time::error::Elapsed as TimeoutError;
|
||||
|
@ -76,20 +76,21 @@ flume = { version = "0.11.0", features = ["async"] }
|
||||
# Dependencies for native builds only
|
||||
# Linux, Windows, Mac, iOS, Android
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||
async-io = { version = "1.13.0" }
|
||||
async-std = { version = "1.12.0", features = ["unstable"], optional = true }
|
||||
tokio = { version = "1.38.1", features = ["full"], optional = true }
|
||||
tokio-util = { version = "0.7.11", features = ["compat"], optional = true }
|
||||
tokio-stream = { version = "0.1.15", features = ["net"], optional = true }
|
||||
chrono = "0.4.38"
|
||||
futures-util = { version = "0.3.30", default-features = false, features = [
|
||||
"async-await",
|
||||
"sink",
|
||||
"std",
|
||||
"io",
|
||||
] }
|
||||
chrono = "0.4.38"
|
||||
|
||||
libc = "0.2.155"
|
||||
nix = { version = "0.27.1", features = ["user"] }
|
||||
socket2 = { version = "0.5.7", features = ["all"] }
|
||||
tokio = { version = "1.38.1", features = ["full"], optional = true }
|
||||
tokio-util = { version = "0.7.11", features = ["compat"], optional = true }
|
||||
tokio-stream = { version = "0.1.15", features = ["net"], optional = true }
|
||||
|
||||
# Dependencies for WASM builds only
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
|
@ -1,108 +0,0 @@
|
||||
use super::*;
|
||||
|
||||
cfg_if! {
|
||||
if #[cfg(target_arch = "wasm32")] {
|
||||
|
||||
} else {
|
||||
use std::net::{TcpListener, UdpSocket};
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(ThisError, Debug, Clone, PartialEq, Eq)]
|
||||
pub enum BumpPortError {
|
||||
#[error("Unsupported architecture")]
|
||||
Unsupported,
|
||||
#[error("Failure: {0}")]
|
||||
Failed(String),
|
||||
}
|
||||
|
||||
pub enum BumpPortType {
|
||||
UDP,
|
||||
TCP,
|
||||
}
|
||||
|
||||
pub fn tcp_port_available(addr: &SocketAddr) -> bool {
|
||||
cfg_if! {
|
||||
if #[cfg(target_arch = "wasm32")] {
|
||||
true
|
||||
} else {
|
||||
match TcpListener::bind(addr) {
|
||||
Ok(_) => true,
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn udp_port_available(addr: &SocketAddr) -> bool {
|
||||
cfg_if! {
|
||||
if #[cfg(target_arch = "wasm32")] {
|
||||
true
|
||||
} else {
|
||||
match UdpSocket::bind(addr) {
|
||||
Ok(_) => true,
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bump_port(addr: &mut SocketAddr, bpt: BumpPortType) -> Result<bool, BumpPortError> {
|
||||
cfg_if! {
|
||||
if #[cfg(target_arch = "wasm32")] {
|
||||
Err(BumpPortError::Unsupported)
|
||||
}
|
||||
else
|
||||
{
|
||||
let mut bumped = false;
|
||||
let mut port = addr.port();
|
||||
let mut addr_bump = addr.clone();
|
||||
loop {
|
||||
|
||||
if match bpt {
|
||||
BumpPortType::TCP => tcp_port_available(&addr_bump),
|
||||
BumpPortType::UDP => udp_port_available(&addr_bump),
|
||||
} {
|
||||
*addr = addr_bump;
|
||||
return Ok(bumped);
|
||||
}
|
||||
if port == u16::MAX {
|
||||
break;
|
||||
}
|
||||
port += 1;
|
||||
addr_bump.set_port(port);
|
||||
bumped = true;
|
||||
}
|
||||
|
||||
Err(BumpPortError::Failure("no ports remaining".to_owned()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bump_port_string(addr: &mut String, bpt: BumpPortType) -> Result<bool, BumpPortError> {
|
||||
cfg_if! {
|
||||
if #[cfg(target_arch = "wasm32")] {
|
||||
return Err(BumpPortError::Unsupported);
|
||||
}
|
||||
else
|
||||
{
|
||||
let savec: Vec<SocketAddr> = addr
|
||||
.to_socket_addrs()
|
||||
.map_err(|x| BumpPortError::Failure(format!("failed to resolve socket address: {}", x)))?
|
||||
.collect();
|
||||
|
||||
if savec.len() == 0 {
|
||||
return Err(BumpPortError::Failure("No socket addresses resolved".to_owned()));
|
||||
}
|
||||
let mut sa = savec.first().unwrap().clone();
|
||||
|
||||
if !bump_port(&mut sa, bpt)? {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
*addr = sa.to_string();
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
}
|
@ -24,7 +24,6 @@
|
||||
#![allow(clippy::comparison_chain, clippy::upper_case_acronyms)]
|
||||
#![deny(unused_must_use)]
|
||||
|
||||
// pub mod bump_port;
|
||||
pub mod assembly_buffer;
|
||||
pub mod async_peek_stream;
|
||||
pub mod async_tag_lock;
|
||||
@ -49,6 +48,8 @@ pub mod network_result;
|
||||
pub mod random;
|
||||
pub mod single_shot_eventual;
|
||||
pub mod sleep;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub mod socket_tools;
|
||||
pub mod spawn;
|
||||
pub mod split_url;
|
||||
pub mod startup_lock;
|
||||
@ -184,7 +185,6 @@ cfg_if! {
|
||||
}
|
||||
}
|
||||
|
||||
// pub use bump_port::*;
|
||||
#[doc(inline)]
|
||||
pub use assembly_buffer::*;
|
||||
#[doc(inline)]
|
||||
@ -233,6 +233,9 @@ pub use single_shot_eventual::*;
|
||||
#[doc(inline)]
|
||||
pub use sleep::*;
|
||||
#[doc(inline)]
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub use socket_tools::*;
|
||||
#[doc(inline)]
|
||||
pub use spawn::*;
|
||||
#[doc(inline)]
|
||||
pub use split_url::*;
|
||||
|
284
veilid-tools/src/socket_tools.rs
Normal file
284
veilid-tools/src/socket_tools.rs
Normal file
@ -0,0 +1,284 @@
|
||||
use super::*;
|
||||
use async_io::Async;
|
||||
use std::io;
|
||||
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
pub use async_std::net::{TcpStream, TcpListener, UdpSocket};
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
pub use tokio::net::{TcpStream, TcpListener, UdpSocket};
|
||||
pub use tokio_util::compat::*;
|
||||
} else {
|
||||
compile_error!("needs executor implementation");
|
||||
}
|
||||
}
|
||||
|
||||
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
pub fn bind_async_udp_socket(local_address: SocketAddr) -> io::Result<Option<UdpSocket>> {
|
||||
let Some(socket) = new_bound_default_socket2_udp(local_address)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
// Make an async UdpSocket from the socket2 socket
|
||||
let std_udp_socket: std::net::UdpSocket = socket.into();
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
let udp_socket = UdpSocket::from(std_udp_socket);
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
std_udp_socket.set_nonblocking(true)?;
|
||||
let udp_socket = UdpSocket::from_std(std_udp_socket)?;
|
||||
} else {
|
||||
compile_error!("needs executor implementation");
|
||||
}
|
||||
}
|
||||
Ok(Some(udp_socket))
|
||||
}
|
||||
|
||||
pub fn bind_async_tcp_listener(local_address: SocketAddr) -> io::Result<Option<TcpListener>> {
|
||||
// Create a default non-shared socket and bind it
|
||||
let Some(socket) = new_bound_default_socket2_tcp(local_address)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
// Drop the socket so we can make another shared socket in its place
|
||||
drop(socket);
|
||||
|
||||
// Create a shared socket and bind it now we have determined the port is free
|
||||
let Some(socket) = new_bound_shared_socket2_tcp(local_address)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
// Listen on the socket
|
||||
if socket.listen(128).is_err() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Make an async tcplistener from the socket2 socket
|
||||
let std_listener: std::net::TcpListener = socket.into();
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
let listener = TcpListener::from(std_listener);
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
std_listener.set_nonblocking(true)?;
|
||||
let listener = TcpListener::from_std(std_listener)?;
|
||||
} else {
|
||||
compile_error!("needs executor implementation");
|
||||
}
|
||||
}
|
||||
Ok(Some(listener))
|
||||
}
|
||||
|
||||
pub async fn connect_async_tcp_stream(
|
||||
local_address: Option<SocketAddr>,
|
||||
remote_address: SocketAddr,
|
||||
timeout_ms: u32,
|
||||
) -> io::Result<TimeoutOr<TcpStream>> {
|
||||
let socket = match local_address {
|
||||
Some(a) => {
|
||||
new_bound_shared_socket2_tcp(a)?.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?
|
||||
}
|
||||
None => new_default_socket2_tcp(domain_for_address(remote_address))?,
|
||||
};
|
||||
|
||||
// Non-blocking connect to remote address
|
||||
nonblocking_connect(socket, remote_address, timeout_ms).await
|
||||
}
|
||||
|
||||
pub fn set_tcp_stream_linger(
|
||||
tcp_stream: &TcpStream,
|
||||
linger: Option<core::time::Duration>,
|
||||
) -> io::Result<()> {
|
||||
#[cfg(all(feature = "rt-async-std", unix))]
|
||||
{
|
||||
// async-std does not directly support linger on TcpStream yet
|
||||
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd};
|
||||
unsafe {
|
||||
let s = socket2::Socket::from_raw_fd(tcp_stream.as_raw_fd());
|
||||
let res = s.set_linger(linger);
|
||||
s.into_raw_fd();
|
||||
res
|
||||
}
|
||||
}
|
||||
#[cfg(all(feature = "rt-async-std", windows))]
|
||||
{
|
||||
// async-std does not directly support linger on TcpStream yet
|
||||
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket};
|
||||
unsafe {
|
||||
let s = socket2::Socket::from_raw_socket(tcp_stream.as_raw_socket());
|
||||
let res = s.set_linger(linger);
|
||||
s.into_raw_socket();
|
||||
res
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "rt-async-std"))]
|
||||
tcp_stream.set_linger(linger)
|
||||
}
|
||||
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
pub type IncomingStream = Incoming;
|
||||
pub type ReadHalf = futures_util::io::ReadHalf<TcpStream>,
|
||||
pub type WriteHalf = futures_util::io::WriteHalf<TcpStream>,
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
pub type IncomingStream = tokio_stream::wrappers::TcpListenerStream;
|
||||
pub type ReadHalf = tokio::net::tcp::OwnedReadHalf;
|
||||
pub type WriteHalf = tokio::net::tcp::OwnedWriteHalf;
|
||||
} else {
|
||||
compile_error!("needs executor implementation");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn async_tcp_listener_incoming(tcp_listener: TcpListener) -> IncomingStream {
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
tcp_listener.incoming()
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
tokio_stream::wrappers::TcpListenerStream::new(tcp_listener)
|
||||
} else {
|
||||
compile_error!("needs executor implementation");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn split_async_tcp_stream(tcp_stream: TcpStream) -> (ReadHalf, WriteHalf) {
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
use futures_util::AsyncReadExt;
|
||||
tcp_stream.split()
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
tcp_stream.into_split()
|
||||
} else {
|
||||
compile_error!("needs executor implementation");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
fn new_default_udp_socket(domain: core::ffi::c_int) -> io::Result<Socket> {
|
||||
let domain = Domain::from(domain);
|
||||
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
|
||||
if domain == Domain::IPV6 {
|
||||
socket.set_only_v6(true)?;
|
||||
}
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
fn new_bound_default_socket2_udp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
|
||||
let domain = domain_for_address(local_address);
|
||||
let socket = new_default_udp_socket(domain)?;
|
||||
let socket2_addr = SockAddr::from(local_address);
|
||||
|
||||
if socket.bind(&socket2_addr).is_err() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(socket))
|
||||
}
|
||||
|
||||
pub fn new_default_socket2_tcp(domain: core::ffi::c_int) -> io::Result<Socket> {
|
||||
let domain = Domain::from(domain);
|
||||
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
|
||||
socket.set_linger(Some(core::time::Duration::from_secs(0)))?;
|
||||
socket.set_nodelay(true)?;
|
||||
if domain == Domain::IPV6 {
|
||||
socket.set_only_v6(true)?;
|
||||
}
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
fn new_shared_socket2_tcp(domain: core::ffi::c_int) -> io::Result<Socket> {
|
||||
let domain = Domain::from(domain);
|
||||
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
|
||||
socket.set_linger(Some(core::time::Duration::from_secs(0)))?;
|
||||
socket.set_nodelay(true)?;
|
||||
if domain == Domain::IPV6 {
|
||||
socket.set_only_v6(true)?;
|
||||
}
|
||||
socket.set_reuse_address(true)?;
|
||||
cfg_if! {
|
||||
if #[cfg(unix)] {
|
||||
socket.set_reuse_port(true)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
fn new_bound_default_socket2_tcp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
|
||||
let domain = domain_for_address(local_address);
|
||||
let socket = new_default_socket2_tcp(domain)?;
|
||||
let socket2_addr = SockAddr::from(local_address);
|
||||
if socket.bind(&socket2_addr).is_err() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(socket))
|
||||
}
|
||||
|
||||
fn new_bound_shared_socket2_tcp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
|
||||
// Create the reuseaddr/reuseport socket now that we've asserted the port is free
|
||||
let domain = domain_for_address(local_address);
|
||||
let socket = new_shared_socket2_tcp(domain)?;
|
||||
let socket2_addr = SockAddr::from(local_address);
|
||||
if socket.bind(&socket2_addr).is_err() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(socket))
|
||||
}
|
||||
|
||||
// Non-blocking connect is tricky when you want to start with a prepared socket
|
||||
// Errors should not be logged as they are valid conditions for this function
|
||||
async fn nonblocking_connect(
|
||||
socket: Socket,
|
||||
addr: SocketAddr,
|
||||
timeout_ms: u32,
|
||||
) -> io::Result<TimeoutOr<TcpStream>> {
|
||||
// Set for non blocking connect
|
||||
socket.set_nonblocking(true)?;
|
||||
|
||||
// Make socket2 SockAddr
|
||||
let socket2_addr = socket2::SockAddr::from(addr);
|
||||
|
||||
// Connect to the remote address
|
||||
match socket.connect(&socket2_addr) {
|
||||
Ok(()) => Ok(()),
|
||||
#[cfg(unix)]
|
||||
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
|
||||
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(()),
|
||||
Err(e) => Err(e),
|
||||
}?;
|
||||
let async_stream = Async::new(std::net::TcpStream::from(socket))?;
|
||||
|
||||
// The stream becomes writable when connected
|
||||
timeout_or_try!(timeout(timeout_ms, async_stream.writable())
|
||||
.await
|
||||
.into_timeout_or()
|
||||
.into_result()?);
|
||||
|
||||
// Check low level error
|
||||
let async_stream = match async_stream.get_ref().take_error()? {
|
||||
None => Ok(async_stream),
|
||||
Some(err) => Err(err),
|
||||
}?;
|
||||
|
||||
// Convert back to inner and then return async version
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
Ok(TimeoutOr::value(TcpStream::from(async_stream.into_inner()?)))
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
Ok(TimeoutOr::value(TcpStream::from_std(async_stream.into_inner()?)?))
|
||||
} else {
|
||||
compile_error!("needs executor implementation");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn domain_for_address(address: SocketAddr) -> core::ffi::c_int {
|
||||
socket2::Domain::for_address(address).into()
|
||||
}
|
@ -8,7 +8,6 @@ cfg_if! {
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::time::sleep;
|
||||
use tokio_util::compat::*;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,7 +2,7 @@ use super::*;
|
||||
|
||||
pub type MachineId = u64;
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Machine {
|
||||
pub router_client: RouterClient,
|
||||
pub id: MachineId,
|
||||
|
@ -25,6 +25,7 @@
|
||||
//!
|
||||
//! [VirtualTcpStream]
|
||||
//! [VirtualUdpSocket]
|
||||
//! [VirtualTcpListener]
|
||||
//! [VirtualTcpListenerStream]
|
||||
//! [VirtualWsMeta]
|
||||
//! [VirtualWsStream]
|
||||
@ -44,6 +45,8 @@ mod router_op_table;
|
||||
mod router_server;
|
||||
mod serde_io_error;
|
||||
mod virtual_network_error;
|
||||
mod virtual_tcp_listener;
|
||||
mod virtual_tcp_listener_stream;
|
||||
mod virtual_tcp_stream;
|
||||
mod virtual_udp_socket;
|
||||
|
||||
@ -53,5 +56,7 @@ pub use machine::*;
|
||||
pub use router_client::*;
|
||||
pub use router_server::*;
|
||||
pub use virtual_network_error::*;
|
||||
pub use virtual_tcp_listener::*;
|
||||
pub use virtual_tcp_listener_stream::*;
|
||||
pub use virtual_tcp_stream::*;
|
||||
pub use virtual_udp_socket::*;
|
||||
|
@ -2,8 +2,7 @@ use super::*;
|
||||
use core::sync::atomic::AtomicU64;
|
||||
use futures_codec::{Bytes, BytesCodec, FramedRead, FramedWrite};
|
||||
use futures_util::{
|
||||
io::BufReader, stream::FuturesUnordered, AsyncReadExt, AsyncWriteExt, FutureExt, SinkExt,
|
||||
StreamExt, TryStreamExt,
|
||||
stream::FuturesUnordered, AsyncReadExt, AsyncWriteExt, StreamExt, TryStreamExt,
|
||||
};
|
||||
use postcard::{from_bytes, to_stdvec};
|
||||
use router_op_table::*;
|
||||
@ -16,11 +15,29 @@ struct RouterClientInner {
|
||||
stop_source: Option<StopSource>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for RouterClientInner {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("RouterClientInner")
|
||||
.field("jh_handler", &self.jh_handler)
|
||||
.field("stop_source", &self.stop_source)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
struct RouterClientUnlockedInner {
|
||||
receiver: flume::Receiver<Bytes>,
|
||||
sender: flume::Sender<Bytes>,
|
||||
next_message_id: AtomicU64,
|
||||
router_op_waiter: RouterOpWaiter<ServerProcessorResponseStatus, ()>,
|
||||
router_op_waiter: RouterOpWaiter<ServerProcessorReplyResult, ()>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for RouterClientUnlockedInner {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("RouterClientUnlockedInner")
|
||||
.field("sender", &self.sender)
|
||||
.field("next_message_id", &self.next_message_id)
|
||||
.field("router_op_waiter", &self.router_op_waiter)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub type MessageId = u64;
|
||||
@ -49,9 +66,9 @@ enum ServerProcessorRequest {
|
||||
},
|
||||
TcpAccept {
|
||||
machine_id: MachineId,
|
||||
socket_id: SocketId,
|
||||
listen_socket_id: SocketId,
|
||||
},
|
||||
Close {
|
||||
TcpShutdown {
|
||||
machine_id: MachineId,
|
||||
socket_id: SocketId,
|
||||
},
|
||||
@ -84,13 +101,22 @@ enum ServerProcessorRequest {
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
struct ServerProcessorRequestMessage {
|
||||
struct ServerProcessorMessage {
|
||||
message_id: MessageId,
|
||||
request: ServerProcessorRequest,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
enum ServerProcessorResponse {
|
||||
enum ServerProcessorCommand {
|
||||
Message(ServerProcessorMessage),
|
||||
CloseSocket {
|
||||
machine_id: MachineId,
|
||||
socket_id: SocketId,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
enum ServerProcessorReplyValue {
|
||||
AllocateMachine {
|
||||
machine_id: MachineId,
|
||||
},
|
||||
@ -100,17 +126,21 @@ enum ServerProcessorResponse {
|
||||
},
|
||||
TcpConnect {
|
||||
socket_id: SocketId,
|
||||
local_address: SocketAddr,
|
||||
},
|
||||
TcpBind {
|
||||
socket_id: SocketId,
|
||||
local_address: SocketAddr,
|
||||
},
|
||||
TcpAccept {
|
||||
child_socket_id: SocketId,
|
||||
socket_id: SocketId,
|
||||
address: SocketAddr,
|
||||
},
|
||||
TcpShutdown,
|
||||
UdpBind {
|
||||
socket_id: SocketId,
|
||||
local_address: SocketAddr,
|
||||
},
|
||||
Close,
|
||||
Send {
|
||||
len: u32,
|
||||
},
|
||||
@ -127,20 +157,29 @@ enum ServerProcessorResponse {
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
enum ServerProcessorResponseStatus {
|
||||
Success(ServerProcessorResponse),
|
||||
enum ServerProcessorReplyResult {
|
||||
Value(ServerProcessorReplyValue),
|
||||
InvalidMachineId,
|
||||
InvalidSocketId,
|
||||
IoError(#[serde(with = "serde_io_error::SerdeIoErrorKindDef")] io::ErrorKind),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
struct ServerProcessorResponseMessage {
|
||||
struct ServerProcessorReply {
|
||||
message_id: MessageId,
|
||||
status: ServerProcessorResponseStatus,
|
||||
status: ServerProcessorReplyResult,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
enum ServerProcessorEvent {
|
||||
Reply(ServerProcessorReply),
|
||||
// DeadSocket {
|
||||
// machine_id: MachineId,
|
||||
// socket_id: SocketId,
|
||||
// },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RouterClient {
|
||||
unlocked_inner: Arc<RouterClientUnlockedInner>,
|
||||
inner: Arc<Mutex<RouterClientInner>>,
|
||||
@ -151,7 +190,7 @@ impl RouterClient {
|
||||
// Public interface
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub async fn router_connect_tcp<H: ToSocketAddrs>(host: H) -> ::std::io::Result<RouterClient> {
|
||||
pub async fn router_connect_tcp<H: ToSocketAddrs>(host: H) -> io::Result<RouterClient> {
|
||||
let addrs = host.to_socket_addrs()?.collect::<Vec<_>>();
|
||||
|
||||
// Connect to RouterServer
|
||||
@ -173,37 +212,37 @@ impl RouterClient {
|
||||
}
|
||||
}
|
||||
|
||||
let ts_buf_reader = BufReader::new(ts_reader);
|
||||
|
||||
// Create channels
|
||||
let (client_sender, server_receiver) = flume::unbounded::<Bytes>();
|
||||
let (server_sender, client_receiver) = flume::unbounded::<Bytes>();
|
||||
|
||||
// Create stopper
|
||||
let stop_source = StopSource::new();
|
||||
|
||||
// Create router operation waiter
|
||||
let router_op_waiter = RouterOpWaiter::new();
|
||||
|
||||
// Spawn a client connection handler
|
||||
let jh_handler = spawn(
|
||||
"RouterClient server processor",
|
||||
Self::run_server_processor(
|
||||
ts_buf_reader,
|
||||
ts_reader,
|
||||
ts_writer,
|
||||
server_receiver,
|
||||
server_sender,
|
||||
router_op_waiter.clone(),
|
||||
stop_source.token(),
|
||||
),
|
||||
);
|
||||
|
||||
Ok(Self::new(
|
||||
client_receiver,
|
||||
client_sender,
|
||||
router_op_waiter,
|
||||
jh_handler,
|
||||
stop_source,
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub async fn router_connect_ws<H: AsRef<str>>(host: H) -> ::std::io::Result<RouterClient> {
|
||||
pub async fn router_connect_ws<H: AsRef<str>>(host: H) -> io::Result<RouterClient> {
|
||||
let host = host.as_ref();
|
||||
|
||||
Ok(RouterClient {})
|
||||
@ -219,7 +258,7 @@ impl RouterClient {
|
||||
|
||||
pub async fn allocate_machine(self) -> VirtualNetworkResult<MachineId> {
|
||||
let request = ServerProcessorRequest::AllocateMachine;
|
||||
let ServerProcessorResponse::AllocateMachine { machine_id } =
|
||||
let ServerProcessorReplyValue::AllocateMachine { machine_id } =
|
||||
self.perform_request(request).await?
|
||||
else {
|
||||
return Err(VirtualNetworkError::ResponseMismatch);
|
||||
@ -229,7 +268,7 @@ impl RouterClient {
|
||||
|
||||
pub async fn release_machine(self, machine_id: MachineId) -> VirtualNetworkResult<()> {
|
||||
let request = ServerProcessorRequest::ReleaseMachine { machine_id };
|
||||
let ServerProcessorResponse::ReleaseMachine = self.perform_request(request).await? else {
|
||||
let ServerProcessorReplyValue::ReleaseMachine = self.perform_request(request).await? else {
|
||||
return Err(VirtualNetworkError::ResponseMismatch);
|
||||
};
|
||||
Ok(())
|
||||
@ -240,7 +279,7 @@ impl RouterClient {
|
||||
machine_id: MachineId,
|
||||
) -> VirtualNetworkResult<BTreeMap<String, NetworkInterface>> {
|
||||
let request = ServerProcessorRequest::GetInterfaces { machine_id };
|
||||
let ServerProcessorResponse::GetInterfaces { interfaces } =
|
||||
let ServerProcessorReplyValue::GetInterfaces { interfaces } =
|
||||
self.perform_request(request).await?
|
||||
else {
|
||||
return Err(VirtualNetworkError::ResponseMismatch);
|
||||
@ -252,91 +291,99 @@ impl RouterClient {
|
||||
self,
|
||||
machine_id: MachineId,
|
||||
remote_address: SocketAddr,
|
||||
local_address: Option<SocketAddr>,
|
||||
opt_local_address: Option<SocketAddr>,
|
||||
timeout_ms: u32,
|
||||
options: VirtualTcpOptions,
|
||||
) -> VirtualNetworkResult<SocketId> {
|
||||
) -> VirtualNetworkResult<(SocketId, SocketAddr)> {
|
||||
let request = ServerProcessorRequest::TcpConnect {
|
||||
machine_id,
|
||||
local_address,
|
||||
local_address: opt_local_address,
|
||||
remote_address,
|
||||
timeout_ms,
|
||||
options,
|
||||
};
|
||||
let ServerProcessorResponse::TcpConnect { socket_id } =
|
||||
self.perform_request(request).await?
|
||||
let ServerProcessorReplyValue::TcpConnect {
|
||||
socket_id,
|
||||
local_address,
|
||||
} = self.perform_request(request).await?
|
||||
else {
|
||||
return Err(VirtualNetworkError::ResponseMismatch);
|
||||
};
|
||||
Ok(socket_id)
|
||||
Ok((socket_id, local_address))
|
||||
}
|
||||
|
||||
pub async fn tcp_bind(
|
||||
self,
|
||||
machine_id: MachineId,
|
||||
local_address: Option<SocketAddr>,
|
||||
opt_local_address: Option<SocketAddr>,
|
||||
options: VirtualTcpOptions,
|
||||
) -> VirtualNetworkResult<SocketId> {
|
||||
) -> VirtualNetworkResult<(SocketId, SocketAddr)> {
|
||||
let request = ServerProcessorRequest::TcpBind {
|
||||
machine_id,
|
||||
local_address,
|
||||
local_address: opt_local_address,
|
||||
options,
|
||||
};
|
||||
let ServerProcessorResponse::TcpBind { socket_id } = self.perform_request(request).await?
|
||||
let ServerProcessorReplyValue::TcpBind {
|
||||
socket_id,
|
||||
local_address,
|
||||
} = self.perform_request(request).await?
|
||||
else {
|
||||
return Err(VirtualNetworkError::ResponseMismatch);
|
||||
};
|
||||
Ok(socket_id)
|
||||
Ok((socket_id, local_address))
|
||||
}
|
||||
|
||||
pub async fn tcp_accept(
|
||||
self,
|
||||
machine_id: MachineId,
|
||||
socket_id: SocketId,
|
||||
) -> VirtualNetworkResult<SocketId> {
|
||||
listen_socket_id: SocketId,
|
||||
) -> VirtualNetworkResult<(SocketId, SocketAddr)> {
|
||||
let request = ServerProcessorRequest::TcpAccept {
|
||||
machine_id,
|
||||
socket_id,
|
||||
listen_socket_id,
|
||||
};
|
||||
let ServerProcessorResponse::TcpAccept { child_socket_id } =
|
||||
let ServerProcessorReplyValue::TcpAccept { socket_id, address } =
|
||||
self.perform_request(request).await?
|
||||
else {
|
||||
return Err(VirtualNetworkError::ResponseMismatch);
|
||||
};
|
||||
Ok(child_socket_id)
|
||||
Ok((socket_id, address))
|
||||
}
|
||||
|
||||
pub async fn tcp_shutdown(
|
||||
self,
|
||||
machine_id: MachineId,
|
||||
socket_id: SocketId,
|
||||
) -> VirtualNetworkResult<()> {
|
||||
let request = ServerProcessorRequest::TcpShutdown {
|
||||
machine_id,
|
||||
socket_id,
|
||||
};
|
||||
let ServerProcessorReplyValue::TcpShutdown = self.perform_request(request).await? else {
|
||||
return Err(VirtualNetworkError::ResponseMismatch);
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn udp_bind(
|
||||
self,
|
||||
machine_id: MachineId,
|
||||
local_address: Option<SocketAddr>,
|
||||
opt_local_address: Option<SocketAddr>,
|
||||
options: VirtualUdpOptions,
|
||||
) -> VirtualNetworkResult<SocketId> {
|
||||
) -> VirtualNetworkResult<(SocketId, SocketAddr)> {
|
||||
let request = ServerProcessorRequest::UdpBind {
|
||||
machine_id,
|
||||
local_address,
|
||||
local_address: opt_local_address,
|
||||
options,
|
||||
};
|
||||
let ServerProcessorResponse::UdpBind { socket_id } = self.perform_request(request).await?
|
||||
let ServerProcessorReplyValue::UdpBind {
|
||||
socket_id,
|
||||
local_address,
|
||||
} = self.perform_request(request).await?
|
||||
else {
|
||||
return Err(VirtualNetworkError::ResponseMismatch);
|
||||
};
|
||||
Ok(socket_id)
|
||||
}
|
||||
|
||||
pub async fn close(
|
||||
self,
|
||||
machine_id: MachineId,
|
||||
socket_id: SocketId,
|
||||
) -> VirtualNetworkResult<()> {
|
||||
let request = ServerProcessorRequest::Close {
|
||||
machine_id,
|
||||
socket_id,
|
||||
};
|
||||
let ServerProcessorResponse::Close = self.perform_request(request).await? else {
|
||||
return Err(VirtualNetworkError::ResponseMismatch);
|
||||
};
|
||||
Ok(())
|
||||
Ok((socket_id, local_address))
|
||||
}
|
||||
|
||||
pub async fn send(
|
||||
@ -350,7 +397,7 @@ impl RouterClient {
|
||||
socket_id,
|
||||
data,
|
||||
};
|
||||
let ServerProcessorResponse::Send { len } = self.perform_request(request).await? else {
|
||||
let ServerProcessorReplyValue::Send { len } = self.perform_request(request).await? else {
|
||||
return Err(VirtualNetworkError::ResponseMismatch);
|
||||
};
|
||||
Ok(len as usize)
|
||||
@ -369,7 +416,7 @@ impl RouterClient {
|
||||
data,
|
||||
remote_address,
|
||||
};
|
||||
let ServerProcessorResponse::SendTo { len } = self.perform_request(request).await? else {
|
||||
let ServerProcessorReplyValue::SendTo { len } = self.perform_request(request).await? else {
|
||||
return Err(VirtualNetworkError::ResponseMismatch);
|
||||
};
|
||||
Ok(len as usize)
|
||||
@ -379,14 +426,14 @@ impl RouterClient {
|
||||
self,
|
||||
machine_id: MachineId,
|
||||
socket_id: u64,
|
||||
len: u32,
|
||||
len: usize,
|
||||
) -> VirtualNetworkResult<Vec<u8>> {
|
||||
let request = ServerProcessorRequest::Recv {
|
||||
machine_id,
|
||||
socket_id,
|
||||
len,
|
||||
len: len as u32,
|
||||
};
|
||||
let ServerProcessorResponse::Recv { data } = self.perform_request(request).await? else {
|
||||
let ServerProcessorReplyValue::Recv { data } = self.perform_request(request).await? else {
|
||||
return Err(VirtualNetworkError::ResponseMismatch);
|
||||
};
|
||||
Ok(data)
|
||||
@ -396,14 +443,14 @@ impl RouterClient {
|
||||
self,
|
||||
machine_id: MachineId,
|
||||
socket_id: u64,
|
||||
len: u32,
|
||||
len: usize,
|
||||
) -> VirtualNetworkResult<(Vec<u8>, SocketAddr)> {
|
||||
let request = ServerProcessorRequest::RecvFrom {
|
||||
machine_id,
|
||||
socket_id,
|
||||
len,
|
||||
len: len as u32,
|
||||
};
|
||||
let ServerProcessorResponse::RecvFrom {
|
||||
let ServerProcessorReplyValue::RecvFrom {
|
||||
data,
|
||||
remote_address,
|
||||
} = self.perform_request(request).await?
|
||||
@ -417,17 +464,16 @@ impl RouterClient {
|
||||
// Private implementation
|
||||
|
||||
fn new(
|
||||
receiver: flume::Receiver<Bytes>,
|
||||
sender: flume::Sender<Bytes>,
|
||||
router_op_waiter: RouterOpWaiter<ServerProcessorReplyResult, ()>,
|
||||
jh_handler: MustJoinHandle<()>,
|
||||
stop_source: StopSource,
|
||||
) -> RouterClient {
|
||||
RouterClient {
|
||||
unlocked_inner: Arc::new(RouterClientUnlockedInner {
|
||||
receiver,
|
||||
sender,
|
||||
next_message_id: AtomicU64::new(0),
|
||||
router_op_waiter: RouterOpWaiter::new(),
|
||||
router_op_waiter,
|
||||
}),
|
||||
inner: Arc::new(Mutex::new(RouterClientInner {
|
||||
jh_handler: Some(jh_handler),
|
||||
@ -436,24 +482,60 @@ impl RouterClient {
|
||||
}
|
||||
}
|
||||
|
||||
fn report_closed_socket(&self, machine_id: MachineId, socket_id: SocketId) {
|
||||
let command = ServerProcessorCommand::CloseSocket {
|
||||
machine_id,
|
||||
socket_id,
|
||||
};
|
||||
let command_vec = match to_stdvec(&command).map_err(VirtualNetworkError::SerializationError)
|
||||
{
|
||||
Ok(v) => Bytes::from(v),
|
||||
Err(e) => {
|
||||
error!("{}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = self
|
||||
.unlocked_inner
|
||||
.sender
|
||||
.send(command_vec)
|
||||
.map_err(|_| VirtualNetworkError::IoError(io::ErrorKind::BrokenPipe))
|
||||
{
|
||||
error!("{}", e);
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn drop_tcp_stream(&self, machine_id: MachineId, socket_id: SocketId) {
|
||||
self.report_closed_socket(machine_id, socket_id);
|
||||
}
|
||||
|
||||
pub(super) fn drop_tcp_listener(&self, machine_id: MachineId, socket_id: SocketId) {
|
||||
self.report_closed_socket(machine_id, socket_id);
|
||||
}
|
||||
|
||||
pub(super) fn drop_udp_socket(&self, machine_id: MachineId, socket_id: SocketId) {
|
||||
self.report_closed_socket(machine_id, socket_id);
|
||||
}
|
||||
|
||||
async fn perform_request(
|
||||
&self,
|
||||
request: ServerProcessorRequest,
|
||||
) -> VirtualNetworkResult<ServerProcessorResponse> {
|
||||
) -> VirtualNetworkResult<ServerProcessorReplyValue> {
|
||||
let message_id = self
|
||||
.unlocked_inner
|
||||
.next_message_id
|
||||
.fetch_add(1, Ordering::AcqRel);
|
||||
let msg = ServerProcessorRequestMessage {
|
||||
let command = ServerProcessorCommand::Message(ServerProcessorMessage {
|
||||
message_id,
|
||||
request,
|
||||
};
|
||||
let msg_vec =
|
||||
Bytes::from(to_stdvec(&msg).map_err(VirtualNetworkError::SerializationError)?);
|
||||
});
|
||||
let command_vec =
|
||||
Bytes::from(to_stdvec(&command).map_err(VirtualNetworkError::SerializationError)?);
|
||||
|
||||
self.unlocked_inner
|
||||
.sender
|
||||
.send_async(msg_vec)
|
||||
.send_async(command_vec)
|
||||
.await
|
||||
.map_err(|_| VirtualNetworkError::IoError(io::ErrorKind::BrokenPipe))?;
|
||||
let handle = self
|
||||
@ -469,16 +551,16 @@ impl RouterClient {
|
||||
.map_err(|_| VirtualNetworkError::WaitError)?;
|
||||
|
||||
match status {
|
||||
ServerProcessorResponseStatus::Success(server_processor_response) => {
|
||||
ServerProcessorReplyResult::Value(server_processor_response) => {
|
||||
Ok(server_processor_response)
|
||||
}
|
||||
ServerProcessorResponseStatus::InvalidMachineId => {
|
||||
ServerProcessorReplyResult::InvalidMachineId => {
|
||||
Err(VirtualNetworkError::InvalidMachineId)
|
||||
}
|
||||
ServerProcessorResponseStatus::InvalidSocketId => {
|
||||
ServerProcessorReplyResult::InvalidSocketId => {
|
||||
Err(VirtualNetworkError::InvalidSocketId)
|
||||
}
|
||||
ServerProcessorResponseStatus::IoError(k) => Err(VirtualNetworkError::IoError(k)),
|
||||
ServerProcessorReplyResult::IoError(k) => Err(VirtualNetworkError::IoError(k)),
|
||||
}
|
||||
}
|
||||
|
||||
@ -486,7 +568,7 @@ impl RouterClient {
|
||||
reader: R,
|
||||
writer: W,
|
||||
receiver: flume::Receiver<Bytes>,
|
||||
sender: flume::Sender<Bytes>,
|
||||
router_op_waiter: RouterOpWaiter<ServerProcessorReplyResult, ()>,
|
||||
stop_token: StopToken,
|
||||
) where
|
||||
R: AsyncReadExt + Unpin + Send,
|
||||
@ -503,10 +585,27 @@ impl RouterClient {
|
||||
}
|
||||
});
|
||||
let framed_reader_fut = system_boxed(async move {
|
||||
if let Err(e) = framed_reader
|
||||
.forward(sender.into_sink().sink_map_err(::std::io::Error::other))
|
||||
.await
|
||||
{
|
||||
let fut = framed_reader.try_for_each(|x| async {
|
||||
let x = x;
|
||||
let evt = from_bytes::<ServerProcessorEvent>(&x)
|
||||
.map_err(VirtualNetworkError::SerializationError)?;
|
||||
|
||||
match evt {
|
||||
ServerProcessorEvent::Reply(reply) => {
|
||||
router_op_waiter
|
||||
.complete_op_waiter(reply.message_id, reply.status)
|
||||
.map_err(io::Error::other)?;
|
||||
} // ServerProcessorEvent::DeadSocket {
|
||||
// machine_id,
|
||||
// socket_id,
|
||||
// } => {
|
||||
// //
|
||||
// }
|
||||
}
|
||||
|
||||
Ok(())
|
||||
});
|
||||
if let Err(e) = fut.await {
|
||||
error!("{}", e);
|
||||
}
|
||||
});
|
||||
|
@ -25,16 +25,6 @@ where
|
||||
result_receiver: Option<flume::Receiver<T>>,
|
||||
}
|
||||
|
||||
impl<T, C> RouterOpWaitHandle<T, C>
|
||||
where
|
||||
T: Unpin,
|
||||
C: Unpin + Clone,
|
||||
{
|
||||
pub fn id(&self) -> RouterOpId {
|
||||
self.op_id
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, C> Drop for RouterOpWaitHandle<T, C>
|
||||
where
|
||||
T: Unpin,
|
||||
@ -123,6 +113,7 @@ where
|
||||
}
|
||||
|
||||
/// Get operation context
|
||||
#[expect(dead_code)]
|
||||
pub fn get_op_context(&self, op_id: RouterOpId) -> Result<C, RouterOpWaitError<T>> {
|
||||
let inner = self.inner.lock();
|
||||
let Some(waiting_op) = inner.waiting_op_table.get(&op_id) else {
|
||||
@ -132,14 +123,12 @@ where
|
||||
}
|
||||
|
||||
/// Remove wait for op
|
||||
#[instrument(level = "trace", target = "rpc", skip_all)]
|
||||
fn cancel_op_waiter(&self, op_id: RouterOpId) {
|
||||
let mut inner = self.inner.lock();
|
||||
inner.waiting_op_table.remove(&op_id);
|
||||
}
|
||||
|
||||
/// Complete the waiting op
|
||||
#[instrument(level = "trace", target = "rpc", skip_all)]
|
||||
pub fn complete_op_waiter(
|
||||
&self,
|
||||
op_id: RouterOpId,
|
||||
@ -159,7 +148,6 @@ where
|
||||
}
|
||||
|
||||
/// Wait for operation to complete
|
||||
#[instrument(level = "trace", target = "rpc", skip_all)]
|
||||
pub async fn wait_for_op(
|
||||
&self,
|
||||
mut handle: RouterOpWaitHandle<T, C>,
|
||||
|
67
veilid-tools/src/virtual_network/virtual_tcp_listener.rs
Normal file
67
veilid-tools/src/virtual_network/virtual_tcp_listener.rs
Normal file
@ -0,0 +1,67 @@
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct VirtualTcpListener {
|
||||
pub(super) machine: Machine,
|
||||
pub(super) socket_id: SocketId,
|
||||
pub(super) local_address: SocketAddr,
|
||||
}
|
||||
|
||||
impl VirtualTcpListener {
|
||||
/////////////////////////////////////////////////////////////
|
||||
// Public Interface
|
||||
|
||||
pub async fn bind(
|
||||
opt_local_address: Option<SocketAddr>,
|
||||
options: VirtualTcpOptions,
|
||||
) -> VirtualNetworkResult<Self> {
|
||||
let machine = default_machine().unwrap();
|
||||
Self::bind_with_machine(machine, opt_local_address, options).await
|
||||
}
|
||||
|
||||
pub async fn bind_with_machine(
|
||||
machine: Machine,
|
||||
opt_local_address: Option<SocketAddr>,
|
||||
options: VirtualTcpOptions,
|
||||
) -> VirtualNetworkResult<Self> {
|
||||
machine
|
||||
.router_client
|
||||
.clone()
|
||||
.tcp_bind(machine.id, opt_local_address, options)
|
||||
.await
|
||||
.map(|(socket_id, local_address)| Self::new(machine, socket_id, local_address))
|
||||
}
|
||||
|
||||
pub async fn accept(&self) -> VirtualNetworkResult<(VirtualTcpStream, SocketAddr)> {
|
||||
self.machine
|
||||
.router_client
|
||||
.clone()
|
||||
.tcp_accept(self.machine.id, self.socket_id)
|
||||
.await
|
||||
.map(|v| {
|
||||
(
|
||||
VirtualTcpStream::new(self.machine.clone(), v.0, self.local_address, v.1),
|
||||
v.1,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////
|
||||
// Private Implementation
|
||||
|
||||
fn new(machine: Machine, socket_id: SocketId, local_address: SocketAddr) -> Self {
|
||||
Self {
|
||||
machine,
|
||||
socket_id,
|
||||
local_address,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for VirtualTcpListener {
|
||||
fn drop(&mut self) {
|
||||
self.machine
|
||||
.router_client
|
||||
.drop_tcp_listener(self.machine.id, self.socket_id);
|
||||
}
|
||||
}
|
@ -0,0 +1,86 @@
|
||||
use super::*;
|
||||
|
||||
use core::pin::Pin;
|
||||
use core::task::{Context, Poll};
|
||||
use futures_util::{stream::Stream, FutureExt};
|
||||
use std::io;
|
||||
|
||||
/// A wrapper around [`VirtualTcpListener`] that implements [`Stream`].
|
||||
///
|
||||
/// [`VirtualTcpListener`]: struct@crate::VirtualTcpListener
|
||||
/// [`Stream`]: trait@futures_util::stream::Stream
|
||||
pub struct VirtualTcpListenerStream {
|
||||
inner: VirtualTcpListener,
|
||||
current_accept_fut: Option<SendPinBoxFuture<VirtualNetworkResult<(SocketId, SocketAddr)>>>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for VirtualTcpListenerStream {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("VirtualTcpListenerStream")
|
||||
.field("inner", &self.inner)
|
||||
.field(
|
||||
"current_accept_fut",
|
||||
if self.current_accept_fut.is_some() {
|
||||
&"Some(...)"
|
||||
} else {
|
||||
&"None"
|
||||
},
|
||||
)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl VirtualTcpListenerStream {
|
||||
/// Create a new `VirtualTcpListenerStream`.
|
||||
pub fn new(listener: VirtualTcpListener) -> Self {
|
||||
Self {
|
||||
inner: listener,
|
||||
current_accept_fut: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get back the inner `VirtualTcpListener`.
|
||||
pub fn into_inner(self) -> VirtualTcpListener {
|
||||
self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for VirtualTcpListenerStream {
|
||||
type Item = io::Result<VirtualTcpStream>;
|
||||
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<io::Result<VirtualTcpStream>>> {
|
||||
if self.current_accept_fut.is_none() {
|
||||
let machine_id = self.inner.machine.id;
|
||||
let router_client = self.inner.machine.router_client.clone();
|
||||
let socket_id = self.inner.socket_id;
|
||||
|
||||
self.current_accept_fut =
|
||||
Some(Box::pin(router_client.tcp_accept(machine_id, socket_id)));
|
||||
}
|
||||
let fut = self.current_accept_fut.as_mut().unwrap();
|
||||
fut.poll_unpin(cx).map(|v| match v {
|
||||
Ok(v) => Some(Ok(VirtualTcpStream::new(
|
||||
self.inner.machine.clone(),
|
||||
v.0,
|
||||
self.inner.local_address,
|
||||
v.1,
|
||||
))),
|
||||
Err(e) => Some(Err(e.into())),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<VirtualTcpListener> for VirtualTcpListenerStream {
|
||||
fn as_ref(&self) -> &VirtualTcpListener {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl AsMut<VirtualTcpListener> for VirtualTcpListenerStream {
|
||||
fn as_mut(&mut self) -> &mut VirtualTcpListener {
|
||||
&mut self.inner
|
||||
}
|
||||
}
|
@ -13,12 +13,52 @@ pub struct VirtualTcpOptions {
|
||||
pub struct VirtualTcpStream {
|
||||
machine: Machine,
|
||||
socket_id: SocketId,
|
||||
local_address: SocketAddr,
|
||||
remote_address: SocketAddr,
|
||||
current_recv_fut: Option<SendPinBoxFuture<Result<Vec<u8>, VirtualNetworkError>>>,
|
||||
current_send_fut: Option<SendPinBoxFuture<Result<usize, VirtualNetworkError>>>,
|
||||
current_close_fut: Option<SendPinBoxFuture<Result<(), VirtualNetworkError>>>,
|
||||
current_tcp_shutdown_fut: Option<SendPinBoxFuture<Result<(), VirtualNetworkError>>>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for VirtualTcpStream {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("VirtualTcpStream")
|
||||
.field("machine", &self.machine)
|
||||
.field("socket_id", &self.socket_id)
|
||||
.field("local_address", &self.local_address)
|
||||
.field("remote_address", &self.remote_address)
|
||||
.field(
|
||||
"current_recv_fut",
|
||||
if self.current_recv_fut.is_some() {
|
||||
&"Some(...)"
|
||||
} else {
|
||||
&"None"
|
||||
},
|
||||
)
|
||||
.field(
|
||||
"current_send_fut",
|
||||
if self.current_send_fut.is_some() {
|
||||
&"Some(...)"
|
||||
} else {
|
||||
&"None"
|
||||
},
|
||||
)
|
||||
.field(
|
||||
"current_close_fut",
|
||||
if self.current_tcp_shutdown_fut.is_some() {
|
||||
&"Some(...)"
|
||||
} else {
|
||||
&"None"
|
||||
},
|
||||
)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl VirtualTcpStream {
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// Public Interface
|
||||
|
||||
pub async fn connect(
|
||||
remote_address: SocketAddr,
|
||||
local_address: Option<SocketAddr>,
|
||||
@ -48,14 +88,46 @@ impl VirtualTcpStream {
|
||||
options,
|
||||
)
|
||||
.await
|
||||
.map(|socket_id| Self {
|
||||
machine,
|
||||
socket_id,
|
||||
current_recv_fut: None,
|
||||
current_send_fut: None,
|
||||
current_close_fut: None,
|
||||
.map(|(socket_id, local_address)| {
|
||||
Self::new(machine, socket_id, local_address, remote_address)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn local_addr(&self) -> VirtualNetworkResult<SocketAddr> {
|
||||
Ok(self.local_address)
|
||||
}
|
||||
|
||||
pub fn peer_addr(&self) -> VirtualNetworkResult<SocketAddr> {
|
||||
Ok(self.remote_address)
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// Private Implementation
|
||||
|
||||
pub(super) fn new(
|
||||
machine: Machine,
|
||||
socket_id: SocketId,
|
||||
local_address: SocketAddr,
|
||||
remote_address: SocketAddr,
|
||||
) -> Self {
|
||||
Self {
|
||||
machine,
|
||||
socket_id,
|
||||
local_address,
|
||||
remote_address,
|
||||
current_recv_fut: None,
|
||||
current_send_fut: None,
|
||||
current_tcp_shutdown_fut: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for VirtualTcpStream {
|
||||
fn drop(&mut self) {
|
||||
self.machine
|
||||
.router_client
|
||||
.drop_tcp_stream(self.machine.id, self.socket_id);
|
||||
}
|
||||
}
|
||||
|
||||
impl futures_util::AsyncRead for VirtualTcpStream {
|
||||
@ -68,7 +140,7 @@ impl futures_util::AsyncRead for VirtualTcpStream {
|
||||
self.current_recv_fut = Some(Box::pin(self.machine.router_client.clone().recv(
|
||||
self.machine.id,
|
||||
self.socket_id,
|
||||
buf.len() as u32,
|
||||
buf.len(),
|
||||
)));
|
||||
}
|
||||
let fut = self.current_recv_fut.as_mut().unwrap();
|
||||
@ -118,18 +190,18 @@ impl futures_util::AsyncWrite for VirtualTcpStream {
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
) -> task::Poll<std::io::Result<()>> {
|
||||
if self.current_close_fut.is_none() {
|
||||
self.current_close_fut = Some(Box::pin(
|
||||
if self.current_tcp_shutdown_fut.is_none() {
|
||||
self.current_tcp_shutdown_fut = Some(Box::pin(
|
||||
self.machine
|
||||
.router_client
|
||||
.clone()
|
||||
.close(self.machine.id, self.socket_id),
|
||||
.tcp_shutdown(self.machine.id, self.socket_id),
|
||||
));
|
||||
}
|
||||
let fut = self.current_close_fut.as_mut().unwrap();
|
||||
let fut = self.current_tcp_shutdown_fut.as_mut().unwrap();
|
||||
fut.poll_unpin(cx).map(|v| match v {
|
||||
Ok(v) => {
|
||||
self.current_close_fut = None;
|
||||
self.current_tcp_shutdown_fut = None;
|
||||
Ok(v)
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
|
@ -7,50 +7,80 @@ pub struct VirtualUdpOptions {
|
||||
reuse_address_port: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct VirtualUdpSocket {
|
||||
machine: Machine,
|
||||
socket_id: SocketId,
|
||||
local_address: SocketAddr,
|
||||
}
|
||||
|
||||
impl VirtualUdpSocket {
|
||||
// pub async fn connect(
|
||||
// remote_address: SocketAddr,
|
||||
// local_address: Option<SocketAddr>,
|
||||
// timeout_ms: u32,
|
||||
// options: VirtualTcpOptions,
|
||||
// ) -> VirtualNetworkResult<Self> {
|
||||
// let machine = default_machine().unwrap();
|
||||
// Self::connect_with_machine(machine, remote_address, local_address, timeout_ms, options)
|
||||
// .await
|
||||
// }
|
||||
/////////////////////////////////////////////////////////////
|
||||
// Public Interface
|
||||
|
||||
// pub async fn connect_with_machine(
|
||||
// machine: Machine,
|
||||
// remote_address: SocketAddr,
|
||||
// local_address: Option<SocketAddr>,
|
||||
// timeout_ms: u32,
|
||||
// options: VirtualTcpOptions,
|
||||
// ) -> VirtualNetworkResult<Self> {
|
||||
// machine
|
||||
// .router_client
|
||||
// .tcp_connect(
|
||||
// machine.id,
|
||||
// remote_address,
|
||||
// local_address,
|
||||
// timeout_ms,
|
||||
// options,
|
||||
// )
|
||||
// .await
|
||||
// .map(|socket_id| Self { machine, socket_id })
|
||||
// }
|
||||
pub async fn bind(
|
||||
opt_local_address: Option<SocketAddr>,
|
||||
options: VirtualUdpOptions,
|
||||
) -> VirtualNetworkResult<Self> {
|
||||
let machine = default_machine().unwrap();
|
||||
Self::bind_with_machine(machine, opt_local_address, options).await
|
||||
}
|
||||
|
||||
pub async fn bind_with_machine(
|
||||
machine: Machine,
|
||||
opt_local_address: Option<SocketAddr>,
|
||||
options: VirtualUdpOptions,
|
||||
) -> VirtualNetworkResult<Self> {
|
||||
machine
|
||||
.router_client
|
||||
.clone()
|
||||
.udp_bind(machine.id, opt_local_address, options)
|
||||
.await
|
||||
.map(|(socket_id, local_address)| Self::new(machine, socket_id, local_address))
|
||||
}
|
||||
|
||||
pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> VirtualNetworkResult<usize> {
|
||||
self.machine
|
||||
.router_client
|
||||
.clone()
|
||||
.send_to(self.machine.id, self.socket_id, target, buf.to_vec())
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn recv_from(&self, buf: &mut [u8]) -> VirtualNetworkResult<(usize, SocketAddr)> {
|
||||
let (v, addr) = self
|
||||
.machine
|
||||
.router_client
|
||||
.clone()
|
||||
.recv_from(self.machine.id, self.socket_id, buf.len())
|
||||
.await?;
|
||||
|
||||
let len = usize::min(buf.len(), v.len());
|
||||
buf[0..len].copy_from_slice(&v[0..len]);
|
||||
|
||||
Ok((len, addr))
|
||||
}
|
||||
|
||||
pub fn local_addr(&self) -> VirtualNetworkResult<SocketAddr> {
|
||||
Ok(self.local_address)
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////
|
||||
// Private Implementation
|
||||
|
||||
fn new(machine: Machine, socket_id: SocketId, local_address: SocketAddr) -> Self {
|
||||
Self {
|
||||
machine,
|
||||
socket_id,
|
||||
local_address,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// impl futures_util::AsyncRead for VirtualUdpSocket {
|
||||
// fn poll_read(
|
||||
// self: Pin<&mut Self>,
|
||||
// cx: &mut task::Context<'_>,
|
||||
// buf: &mut [u8],
|
||||
// ) -> task::Poll<std::io::Result<usize>> {
|
||||
// todo!()
|
||||
// }
|
||||
// }
|
||||
impl Drop for VirtualUdpSocket {
|
||||
fn drop(&mut self) {
|
||||
self.machine
|
||||
.router_client
|
||||
.drop_udp_socket(self.machine.id, self.socket_id);
|
||||
}
|
||||
}
|
||||
|
@ -66,13 +66,6 @@ pub fn is_ipv6_supported() -> bool {
|
||||
if let Some(supp) = *opt_supp {
|
||||
return supp;
|
||||
}
|
||||
// let supp = match UdpSocket::bind(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 0, 0, 0)) {
|
||||
// Ok(_) => true,
|
||||
// Err(e) => !matches!(
|
||||
// e.kind(),
|
||||
// std::io::ErrorKind::AddrNotAvailable | std::io::ErrorKind::Unsupported
|
||||
// ),
|
||||
// };
|
||||
|
||||
// XXX: See issue #92
|
||||
let supp = false;
|
||||
|
Loading…
Reference in New Issue
Block a user