This commit is contained in:
John Smith 2022-07-13 09:51:56 -04:00
parent 007150c818
commit b90d453cef
13 changed files with 340 additions and 374 deletions

View File

@ -256,7 +256,7 @@ impl ConnectionManager {
// Attempt new connection // Attempt new connection
let conn = loop { let conn = loop {
match ProtocolNetworkConnection::connect(local_addr, dial_info.clone()).await { match ProtocolNetworkConnection::connect(local_addr, &dial_info).await {
Ok(v) => break Ok(v), Ok(v) => break Ok(v),
Err(e) => { Err(e) => {
if retry_count == 0 { if retry_count == 0 {

View File

@ -1079,10 +1079,14 @@ impl NetworkManager {
}; };
// Send boot magic to requested peer address // Send boot magic to requested peer address
let data = BOOT_MAGIC.to_vec(); let data = BOOT_MAGIC.to_vec();
let out_data: Vec<u8> = self let out_data: Vec<u8> = match self
.net() .net()
.send_recv_data_unbound_to_dial_info(dial_info, data, timeout_ms) .send_recv_data_unbound_to_dial_info(dial_info, data, timeout_ms)
.await?; .await?
{
TimeoutOr::Timeout => return Ok(Vec::new()),
TimeoutOr::Value(v) => v,
};
let bootstrap_peerinfo: Vec<PeerInfo> = let bootstrap_peerinfo: Vec<PeerInfo> =
deserialize_json(std::str::from_utf8(&out_data).wrap_err("bad utf8 in boot peerinfo")?) deserialize_json(std::str::from_utf8(&out_data).wrap_err("bad utf8 in boot peerinfo")?)

View File

@ -288,26 +288,35 @@ impl Network {
data: Vec<u8>, data: Vec<u8>,
) -> EyreResult<()> { ) -> EyreResult<()> {
let data_len = data.len(); let data_len = data.len();
let res = match dial_info.protocol_type() { match dial_info.protocol_type() {
ProtocolType::UDP => { ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr(); let peer_socket_addr = dial_info.to_socket_addr();
RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data).await let h = RawUdpProtocolHandler::new_unspecified_bound_handler(&peer_socket_addr)
.await
.wrap_err("create socket failure")?;
h.send_message(data, peer_socket_addr)
.await
.wrap_err("send message failure")?;
} }
ProtocolType::TCP => { ProtocolType::TCP => {
let peer_socket_addr = dial_info.to_socket_addr(); let peer_socket_addr = dial_info.to_socket_addr();
RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data).await let pnc = RawTcpProtocolHandler::connect(None, peer_socket_addr)
.await
.wrap_err("connect failure")?;
pnc.send(data).await.wrap_err("send failure")?;
} }
ProtocolType::WS | ProtocolType::WSS => { ProtocolType::WS | ProtocolType::WSS => {
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data).await let pnc = WebsocketProtocolHandler::connect(None, &dial_info)
.await
.wrap_err("connect failure")?;
pnc.send(data).await.wrap_err("send failure")?;
} }
} }
.wrap_err("low level network error"); // Network accounting
if res.is_ok() { self.network_manager()
// Network accounting .stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
self.network_manager()
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64); Ok(())
}
res
} }
// Send data to a dial info, unbound, using a new connection from a random port // Send data to a dial info, unbound, using a new connection from a random port
@ -315,43 +324,94 @@ impl Network {
// This creates a short-lived connection in the case of connection-oriented protocols // This creates a short-lived connection in the case of connection-oriented protocols
// for the purpose of sending this one message. // for the purpose of sending this one message.
// This bypasses the connection table as it is not a 'node to node' connection. // This bypasses the connection table as it is not a 'node to node' connection.
#[instrument(level="trace", err, skip(self, data), fields(data.len = data.len(), ret.len))] #[instrument(level="trace", err, skip(self, data), fields(ret.timeout_or, data.len = data.len()))]
pub async fn send_recv_data_unbound_to_dial_info( pub async fn send_recv_data_unbound_to_dial_info(
&self, &self,
dial_info: DialInfo, dial_info: DialInfo,
data: Vec<u8>, data: Vec<u8>,
timeout_ms: u32, timeout_ms: u32,
) -> EyreResult<Vec<u8>> { ) -> EyreResult<TimeoutOr<Vec<u8>>> {
let data_len = data.len(); let data_len = data.len();
let out = match dial_info.protocol_type() { match dial_info.protocol_type() {
ProtocolType::UDP => { ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr(); let peer_socket_addr = dial_info.to_socket_addr();
RawUdpProtocolHandler::send_recv_unbound_message(peer_socket_addr, data, timeout_ms) let h = RawUdpProtocolHandler::new_unspecified_bound_handler(&peer_socket_addr)
.await? .await
} .wrap_err("create socket failure")?;
ProtocolType::TCP => { h.send_message(data, peer_socket_addr)
let peer_socket_addr = dial_info.to_socket_addr(); .await
RawTcpProtocolHandler::send_recv_unbound_message(peer_socket_addr, data, timeout_ms) .wrap_err("send message failure")?;
.await? self.network_manager()
} .stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
ProtocolType::WS | ProtocolType::WSS => {
WebsocketProtocolHandler::send_recv_unbound_message(
dial_info.clone(),
data,
timeout_ms,
)
.await?
}
};
// Network accounting // receive single response
self.network_manager() let mut out = vec![0u8; MAX_MESSAGE_SIZE];
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64); let timeout_or_ret = timeout(timeout_ms, h.recv_message(&mut out))
self.network_manager() .await
.stats_packet_rcvd(dial_info.to_ip_addr(), out.len() as u64); .into_timeout_or()
.into_result()
.wrap_err("recv_message failure")?;
let (recv_len, recv_addr) = match timeout_or_ret {
TimeoutOr::Value(v) => v,
TimeoutOr::Timeout => {
tracing::Span::current().record("ret.timeout_or", &"Timeout".to_owned());
return Ok(TimeoutOr::Timeout);
}
};
tracing::Span::current().record("ret.len", &out.len()); let recv_socket_addr = recv_addr.remote_address().to_socket_addr();
Ok(out) self.network_manager()
.stats_packet_rcvd(recv_socket_addr.ip(), recv_len as u64);
// if the from address is not the same as the one we sent to, then drop this
if recv_socket_addr != peer_socket_addr {
bail!("wrong address");
}
out.resize(recv_len, 0u8);
Ok(TimeoutOr::Value(out))
}
ProtocolType::TCP | ProtocolType::WS | ProtocolType::WSS => {
let pnc = match dial_info.protocol_type() {
ProtocolType::UDP => unreachable!(),
ProtocolType::TCP => {
let peer_socket_addr = dial_info.to_socket_addr();
RawTcpProtocolHandler::connect(None, peer_socket_addr)
.await
.wrap_err("connect failure")?
}
ProtocolType::WS | ProtocolType::WSS => {
WebsocketProtocolHandler::connect(None, &dial_info)
.await
.wrap_err("connect failure")?
}
};
pnc.send(data).await.wrap_err("send failure")?;
self.network_manager()
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
let out = timeout(timeout_ms, pnc.recv())
.await
.into_timeout_or()
.into_result()
.wrap_err("recv failure")?;
tracing::Span::current().record(
"ret.timeout_or",
&match out {
TimeoutOr::<Vec<u8>>::Value(ref v) => format!("Value(len={})", v.len()),
TimeoutOr::<Vec<u8>>::Timeout => "Timeout".to_owned(),
},
);
if let TimeoutOr::Value(out) = &out {
self.network_manager()
.stats_packet_rcvd(dial_info.to_ip_addr(), out.len() as u64);
}
Ok(out)
}
}
} }
#[instrument(level="trace", err, skip(self, data), fields(data.len = data.len()))] #[instrument(level="trace", err, skip(self, data), fields(data.len = data.len()))]

View File

@ -21,14 +21,14 @@ pub enum ProtocolNetworkConnection {
impl ProtocolNetworkConnection { impl ProtocolNetworkConnection {
pub async fn connect( pub async fn connect(
local_address: Option<SocketAddr>, local_address: Option<SocketAddr>,
dial_info: DialInfo, dial_info: &DialInfo,
) -> io::Result<ProtocolNetworkConnection> { ) -> io::Result<ProtocolNetworkConnection> {
match dial_info.protocol_type() { match dial_info.protocol_type() {
ProtocolType::UDP => { ProtocolType::UDP => {
panic!("Should not connect to UDP dialinfo"); panic!("Should not connect to UDP dialinfo");
} }
ProtocolType::TCP => { ProtocolType::TCP => {
tcp::RawTcpProtocolHandler::connect(local_address, dial_info).await tcp::RawTcpProtocolHandler::connect(local_address, dial_info.to_socket_addr()).await
} }
ProtocolType::WS | ProtocolType::WSS => { ProtocolType::WS | ProtocolType::WSS => {
ws::WebsocketProtocolHandler::connect(local_address, dial_info).await ws::WebsocketProtocolHandler::connect(local_address, dial_info).await
@ -36,53 +36,6 @@ impl ProtocolNetworkConnection {
} }
} }
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> io::Result<()> {
match dial_info.protocol_type() {
ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr();
udp::RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
}
ProtocolType::TCP => {
let peer_socket_addr = dial_info.to_socket_addr();
tcp::RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
}
ProtocolType::WS | ProtocolType::WSS => {
ws::WebsocketProtocolHandler::send_unbound_message(dial_info, data).await
}
}
}
pub async fn send_recv_unbound_message(
dial_info: DialInfo,
data: Vec<u8>,
timeout_ms: u32,
) -> io::Result<Vec<u8>> {
match dial_info.protocol_type() {
ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr();
udp::RawUdpProtocolHandler::send_recv_unbound_message(
peer_socket_addr,
data,
timeout_ms,
)
.await
}
ProtocolType::TCP => {
let peer_socket_addr = dial_info.to_socket_addr();
tcp::RawTcpProtocolHandler::send_recv_unbound_message(
peer_socket_addr,
data,
timeout_ms,
)
.await
}
ProtocolType::WS | ProtocolType::WSS => {
ws::WebsocketProtocolHandler::send_recv_unbound_message(dial_info, data, timeout_ms)
.await
}
}
}
pub fn descriptor(&self) -> ConnectionDescriptor { pub fn descriptor(&self) -> ConnectionDescriptor {
match self { match self {
Self::Dummy(d) => d.descriptor(), Self::Dummy(d) => d.descriptor(),

View File

@ -58,7 +58,7 @@ impl RawTcpNetworkConnection {
Self::send_internal(&mut stream, message).await Self::send_internal(&mut stream, message).await
} }
pub async fn recv_internal(stream: &mut AsyncPeekStream) -> io::Result<Vec<u8>> { async fn recv_internal(stream: &mut AsyncPeekStream) -> io::Result<Vec<u8>> {
let mut header = [0u8; 4]; let mut header = [0u8; 4];
stream.read_exact(&mut header).await?; stream.read_exact(&mut header).await?;
@ -141,21 +141,16 @@ impl RawTcpProtocolHandler {
#[instrument(level = "trace", err)] #[instrument(level = "trace", err)]
pub async fn connect( pub async fn connect(
local_address: Option<SocketAddr>, local_address: Option<SocketAddr>,
dial_info: DialInfo, socket_addr: SocketAddr,
) -> io::Result<ProtocolNetworkConnection> { ) -> io::Result<ProtocolNetworkConnection> {
// Get remote socket address to connect to
let remote_socket_addr = dial_info.to_socket_addr();
// Make a shared socket // Make a shared socket
let socket = match local_address { let socket = match local_address {
Some(a) => new_bound_shared_tcp_socket(a)?, Some(a) => new_bound_shared_tcp_socket(a)?,
None => { None => new_unbound_shared_tcp_socket(socket2::Domain::for_address(socket_addr))?,
new_unbound_shared_tcp_socket(socket2::Domain::for_address(remote_socket_addr))?
}
}; };
// Non-blocking connect to remote address // Non-blocking connect to remote address
let ts = nonblocking_connect(socket, remote_socket_addr).await?; let ts = nonblocking_connect(socket, socket_addr).await?;
// See what local address we ended up with and turn this into a stream // See what local address we ended up with and turn this into a stream
let actual_local_address = ts.local_addr()?; let actual_local_address = ts.local_addr()?;
@ -166,7 +161,10 @@ impl RawTcpProtocolHandler {
// Wrap the stream in a network connection and return it // Wrap the stream in a network connection and return it
let conn = ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new( let conn = ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(
ConnectionDescriptor::new( ConnectionDescriptor::new(
dial_info.to_peer_address(), PeerAddress::new(
SocketAddress::from_socket_addr(socket_addr),
ProtocolType::TCP,
),
SocketAddress::from_socket_addr(actual_local_address), SocketAddress::from_socket_addr(actual_local_address),
), ),
ps, ps,
@ -175,79 +173,74 @@ impl RawTcpProtocolHandler {
Ok(conn) Ok(conn)
} }
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))] // #[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
pub async fn send_unbound_message(socket_addr: SocketAddr, data: Vec<u8>) -> io::Result<()> { // pub async fn send_unbound_message(socket_addr: SocketAddr, data: Vec<u8>) -> io::Result<()> {
if data.len() > MAX_MESSAGE_SIZE { // if data.len() > MAX_MESSAGE_SIZE {
bail_io_error_other!("sending too large unbound TCP message"); // bail_io_error_other!("sending too large unbound TCP message");
} // }
trace!( // // Make a shared socket
"sending unbound message of length {} to {}", // let socket = new_unbound_shared_tcp_socket(socket2::Domain::for_address(socket_addr))?;
data.len(),
socket_addr
);
// Make a shared socket // // Non-blocking connect to remote address
let socket = new_unbound_shared_tcp_socket(socket2::Domain::for_address(socket_addr))?; // let ts = nonblocking_connect(socket, socket_addr).await?;
// Non-blocking connect to remote address // // See what local address we ended up with and turn this into a stream
let ts = nonblocking_connect(socket, socket_addr).await?; // // let actual_local_address = ts
// // .local_addr()
// // .map_err(map_to_string)
// // .map_err(logthru_net!("could not get local address from TCP stream"))?;
// See what local address we ended up with and turn this into a stream // #[cfg(feature = "rt-tokio")]
// let actual_local_address = ts // let ts = ts.compat();
// .local_addr() // let mut ps = AsyncPeekStream::new(ts);
// .map_err(map_to_string)
// .map_err(logthru_net!("could not get local address from TCP stream"))?;
#[cfg(feature = "rt-tokio")] // // Send directly from the raw network connection
let ts = ts.compat(); // // this builds the connection and tears it down immediately after the send
let mut ps = AsyncPeekStream::new(ts); // RawTcpNetworkConnection::send_internal(&mut ps, data).await
// }
// Send directly from the raw network connection // #[instrument(level = "trace", err, skip(data), fields(data.len = data.len(), ret.timeout_or))]
// this builds the connection and tears it down immediately after the send // pub async fn send_recv_unbound_message(
RawTcpNetworkConnection::send_internal(&mut ps, data).await // socket_addr: SocketAddr,
} // data: Vec<u8>,
// timeout_ms: u32,
// ) -> io::Result<TimeoutOr<Vec<u8>>> {
// if data.len() > MAX_MESSAGE_SIZE {
// bail_io_error_other!("sending too large unbound TCP message");
// }
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len(), ret.len))] // // Make a shared socket
pub async fn send_recv_unbound_message( // let socket = new_unbound_shared_tcp_socket(socket2::Domain::for_address(socket_addr))?;
socket_addr: SocketAddr,
data: Vec<u8>,
timeout_ms: u32,
) -> io::Result<Vec<u8>> {
if data.len() > MAX_MESSAGE_SIZE {
bail_io_error_other!("sending too large unbound TCP message");
}
trace!(
"sending unbound message of length {} to {}",
data.len(),
socket_addr
);
// Make a shared socket // // Non-blocking connect to remote address
let socket = new_unbound_shared_tcp_socket(socket2::Domain::for_address(socket_addr))?; // let ts = nonblocking_connect(socket, socket_addr).await?;
// Non-blocking connect to remote address // // See what local address we ended up with and turn this into a stream
let ts = nonblocking_connect(socket, socket_addr).await?; // // let actual_local_address = ts
// // .local_addr()
// // .map_err(map_to_string)
// // .map_err(logthru_net!("could not get local address from TCP stream"))?;
// #[cfg(feature = "rt-tokio")]
// let ts = ts.compat();
// let mut ps = AsyncPeekStream::new(ts);
// See what local address we ended up with and turn this into a stream // // Send directly from the raw network connection
// let actual_local_address = ts // // this builds the connection and tears it down immediately after the send
// .local_addr() // RawTcpNetworkConnection::send_internal(&mut ps, data).await?;
// .map_err(map_to_string) // let out = timeout(timeout_ms, RawTcpNetworkConnection::recv_internal(&mut ps))
// .map_err(logthru_net!("could not get local address from TCP stream"))?; // .await
#[cfg(feature = "rt-tokio")] // .into_timeout_or()
let ts = ts.compat(); // .into_result()?;
let mut ps = AsyncPeekStream::new(ts);
// Send directly from the raw network connection // tracing::Span::current().record(
// this builds the connection and tears it down immediately after the send // "ret.timeout_or",
RawTcpNetworkConnection::send_internal(&mut ps, data).await?; // &match out {
// TimeoutOr::<Vec<u8>>::Value(ref v) => format!("Value(len={})", v.len()),
let out = timeout(timeout_ms, RawTcpNetworkConnection::recv_internal(&mut ps)) // TimeoutOr::<Vec<u8>>::Timeout => "Timeout".to_owned(),
.await // },
.map_err(|e| e.to_io())??; // );
// Ok(out)
tracing::Span::current().record("ret.len", &out.len()); // }
Ok(out)
}
} }
impl ProtocolAcceptHandler for RawTcpProtocolHandler { impl ProtocolAcceptHandler for RawTcpProtocolHandler {

View File

@ -60,68 +60,65 @@ impl RawUdpProtocolHandler {
Ok(()) Ok(())
} }
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))] #[instrument(level = "trace", err)]
pub async fn send_unbound_message(socket_addr: SocketAddr, data: Vec<u8>) -> io::Result<()> { pub async fn new_unspecified_bound_handler(
if data.len() > MAX_MESSAGE_SIZE { socket_addr: &SocketAddr,
bail_io_error_other!("sending too large unbound UDP message"); ) -> io::Result<RawUdpProtocolHandler> {
}
// get local wildcard address for bind // get local wildcard address for bind
let local_socket_addr = match socket_addr { let local_socket_addr = compatible_unspecified_socket_addr(&socket_addr);
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
SocketAddr::V6(_) => {
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
}
};
let socket = UdpSocket::bind(local_socket_addr).await?; let socket = UdpSocket::bind(local_socket_addr).await?;
let len = socket.send_to(&data, socket_addr).await?; Ok(RawUdpProtocolHandler::new(Arc::new(socket)))
if len != data.len() {
bail_io_error_other!("UDP partial unbound send")
}
Ok(())
} }
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len(), ret.len))] // #[instrument(level = "trace", err, skip(data), fields(data.len = data.len(), ret.timeout_or))]
pub async fn send_recv_unbound_message( // pub async fn send_recv_unbound_message(
socket_addr: SocketAddr, // socket_addr: SocketAddr,
data: Vec<u8>, // data: Vec<u8>,
timeout_ms: u32, // timeout_ms: u32,
) -> io::Result<Vec<u8>> { // ) -> io::Result<TimeoutOr<Vec<u8>>> {
if data.len() > MAX_MESSAGE_SIZE { // if data.len() > MAX_MESSAGE_SIZE {
bail_io_error_other!("sending too large unbound UDP message"); // bail_io_error_other!("sending too large unbound UDP message");
} // }
// get local wildcard address for bind // // get local wildcard address for bind
let local_socket_addr = match socket_addr { // let local_socket_addr = match socket_addr {
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0), // SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
SocketAddr::V6(_) => { // SocketAddr::V6(_) => {
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0) // SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
} // }
}; // };
// get unspecified bound socket // // get unspecified bound socket
let socket = UdpSocket::bind(local_socket_addr).await?; // let socket = UdpSocket::bind(local_socket_addr).await?;
let len = socket.send_to(&data, socket_addr).await?; // let len = socket.send_to(&data, socket_addr).await?;
if len != data.len() { // if len != data.len() {
bail_io_error_other!("UDP partial unbound send"); // bail_io_error_other!("UDP partial unbound send");
} // }
// receive single response // // receive single response
let mut out = vec![0u8; MAX_MESSAGE_SIZE]; // let mut out = vec![0u8; MAX_MESSAGE_SIZE];
let (len, from_addr) = timeout(timeout_ms, socket.recv_from(&mut out)) // let timeout_or_ret = timeout(timeout_ms, socket.recv_from(&mut out))
.await // .await
.map_err(|e| e.to_io())??; // .into_timeout_or()
// .into_result()?;
// let (len, from_addr) = match timeout_or_ret {
// TimeoutOr::Value(v) => v,
// TimeoutOr::Timeout => {
// tracing::Span::current().record("ret.timeout_or", &"Timeout".to_owned());
// return Ok(TimeoutOr::Timeout);
// }
// };
// if the from address is not the same as the one we sent to, then drop this // // if the from address is not the same as the one we sent to, then drop this
if from_addr != socket_addr { // if from_addr != socket_addr {
bail_io_error_other!(format!( // bail_io_error_other!(format!(
"Unbound response received from wrong address: addr={}", // "Unbound response received from wrong address: addr={}",
from_addr, // from_addr,
)); // ));
} // }
out.resize(len, 0u8); // out.resize(len, 0u8);
tracing::Span::current().record("ret.len", &len);
Ok(out) // tracing::Span::current().record("ret.timeout_or", &format!("Value(len={})", out.len()));
} // Ok(TimeoutOr::Value(out))
// }
} }

View File

@ -223,12 +223,13 @@ impl WebsocketProtocolHandler {
Ok(Some(conn)) Ok(Some(conn))
} }
async fn connect_internal( #[instrument(level = "trace", err)]
pub async fn connect(
local_address: Option<SocketAddr>, local_address: Option<SocketAddr>,
dial_info: DialInfo, dial_info: &DialInfo,
) -> io::Result<ProtocolNetworkConnection> { ) -> io::Result<ProtocolNetworkConnection> {
// Split dial info up // Split dial info up
let (tls, scheme) = match &dial_info { let (tls, scheme) = match dial_info {
DialInfo::WS(_) => (false, "ws"), DialInfo::WS(_) => (false, "ws"),
DialInfo::WSS(_) => (true, "wss"), DialInfo::WSS(_) => (true, "wss"),
_ => panic!("invalid dialinfo for WS/WSS protocol"), _ => panic!("invalid dialinfo for WS/WSS protocol"),
@ -285,46 +286,6 @@ impl WebsocketProtocolHandler {
)) ))
} }
} }
#[instrument(level = "trace", err)]
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> io::Result<ProtocolNetworkConnection> {
Self::connect_internal(local_address, dial_info).await
}
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> io::Result<()> {
if data.len() > MAX_MESSAGE_SIZE {
bail_io_error_other!("sending too large unbound WS message");
}
let protconn = Self::connect_internal(None, dial_info.clone()).await?;
protconn.send(data).await
}
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len(), ret.len))]
pub async fn send_recv_unbound_message(
dial_info: DialInfo,
data: Vec<u8>,
timeout_ms: u32,
) -> io::Result<Vec<u8>> {
if data.len() > MAX_MESSAGE_SIZE {
bail_io_error_other!("sending too large unbound WS message");
}
let protconn = Self::connect_internal(None, dial_info.clone()).await?;
protconn.send(data).await?;
let out = timeout(timeout_ms, protconn.recv())
.await
.map_err(|e| e.to_io())??;
tracing::Span::current().record("ret.len", &out.len());
Ok(out)
}
} }
impl ProtocolAcceptHandler for WebsocketProtocolHandler { impl ProtocolAcceptHandler for WebsocketProtocolHandler {

View File

@ -59,7 +59,7 @@ impl Network {
) -> EyreResult<()> { ) -> EyreResult<()> {
let data_len = data.len(); let data_len = data.len();
let res = match dial_info.protocol_type() { match dial_info.protocol_type() {
ProtocolType::UDP => { ProtocolType::UDP => {
bail!("no support for UDP protocol") bail!("no support for UDP protocol")
} }
@ -67,17 +67,18 @@ impl Network {
bail!("no support for TCP protocol") bail!("no support for TCP protocol")
} }
ProtocolType::WS | ProtocolType::WSS => { ProtocolType::WS | ProtocolType::WSS => {
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data) let pnc = WebsocketProtocolHandler::connect(None, &dial_info)
.await .await
.wrap_err("failed to send unbound message") .wrap_err("connect failure")?;
pnc.send(data).await.wrap_err("send failure")?;
} }
}; };
if res.is_ok() {
// Network accounting // Network accounting
self.network_manager() self.network_manager()
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64); .stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
}
res Ok(())
} }
// Send data to a dial info, unbound, using a new connection from a random port // Send data to a dial info, unbound, using a new connection from a random port
@ -91,9 +92,9 @@ impl Network {
dial_info: DialInfo, dial_info: DialInfo,
data: Vec<u8>, data: Vec<u8>,
timeout_ms: u32, timeout_ms: u32,
) -> EyreResult<Vec<u8>> { ) -> EyreResult<TimeoutOr<Vec<u8>>> {
let data_len = data.len(); let data_len = data.len();
let out = match dial_info.protocol_type() { match dial_info.protocol_type() {
ProtocolType::UDP => { ProtocolType::UDP => {
bail!("no support for UDP protocol") bail!("no support for UDP protocol")
} }
@ -101,23 +102,47 @@ impl Network {
bail!("no support for TCP protocol") bail!("no support for TCP protocol")
} }
ProtocolType::WS | ProtocolType::WSS => { ProtocolType::WS | ProtocolType::WSS => {
WebsocketProtocolHandler::send_recv_unbound_message( let pnc = match dial_info.protocol_type() {
dial_info.clone(), ProtocolType::UDP => unreachable!(),
data, ProtocolType::TCP => {
timeout_ms, let peer_socket_addr = dial_info.to_socket_addr();
) RawTcpProtocolHandler::connect(None, peer_socket_addr)
.await? .await
.wrap_err("connect failure")?
}
ProtocolType::WS | ProtocolType::WSS => {
WebsocketProtocolHandler::connect(None, &dial_info)
.await
.wrap_err("connect failure")?
}
};
pnc.send(data).await.wrap_err("send failure")?;
self.network_manager()
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
let out = timeout(timeout_ms, pnc.recv())
.await
.into_timeout_or()
.into_result()
.wrap_err("recv failure")?;
tracing::Span::current().record(
"ret.timeout_or",
&match out {
TimeoutOr::<Vec<u8>>::Value(ref v) => format!("Value(len={})", v.len()),
TimeoutOr::<Vec<u8>>::Timeout => "Timeout".to_owned(),
},
);
if let TimeoutOr::Value(out) = &out {
self.network_manager()
.stats_packet_rcvd(dial_info.to_ip_addr(), out.len() as u64);
}
Ok(out)
} }
}; }
// Network accounting
self.network_manager()
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
self.network_manager()
.stats_packet_rcvd(dial_info.to_ip_addr(), out.len() as u64);
tracing::Span::current().record("ret.len", &out.len());
Ok(out)
} }
#[instrument(level="trace", err, skip(self, data), fields(data.len = data.len()))] #[instrument(level="trace", err, skip(self, data), fields(data.len = data.len()))]

View File

@ -16,7 +16,7 @@ pub enum ProtocolNetworkConnection {
impl ProtocolNetworkConnection { impl ProtocolNetworkConnection {
pub async fn connect( pub async fn connect(
local_address: Option<SocketAddr>, local_address: Option<SocketAddr>,
dial_info: DialInfo, dial_info: &DialInfo,
) -> io::Result<ProtocolNetworkConnection> { ) -> io::Result<ProtocolNetworkConnection> {
match dial_info.protocol_type() { match dial_info.protocol_type() {
ProtocolType::UDP => { ProtocolType::UDP => {
@ -31,42 +31,6 @@ impl ProtocolNetworkConnection {
} }
} }
pub async fn send_unbound_message(
dial_info: DialInfo,
data: Vec<u8>,
) -> io::Result<()> {
match dial_info.protocol_type() {
ProtocolType::UDP => {
panic!("UDP dial info is not supported on WASM targets");
}
ProtocolType::TCP => {
panic!("TCP dial info is not supported on WASM targets");
}
ProtocolType::WS | ProtocolType::WSS => {
ws::WebsocketProtocolHandler::send_unbound_message(dial_info, data).await
}
}
}
pub async fn send_recv_unbound_message(
dial_info: DialInfo,
data: Vec<u8>,
timeout_ms: u32,
) -> io::Result<Vec<u8>> {
match dial_info.protocol_type() {
ProtocolType::UDP => {
panic!("UDP dial info is not supported on WASM targets");
}
ProtocolType::TCP => {
panic!("TCP dial info is not supported on WASM targets");
}
ProtocolType::WS | ProtocolType::WSS => {
ws::WebsocketProtocolHandler::send_recv_unbound_message(dial_info, data, timeout_ms)
.await
}
}
}
pub fn descriptor(&self) -> ConnectionDescriptor { pub fn descriptor(&self) -> ConnectionDescriptor {
match self { match self {
Self::Dummy(d) => d.descriptor(), Self::Dummy(d) => d.descriptor(),

View File

@ -85,12 +85,12 @@ impl WebsocketProtocolHandler {
#[instrument(level = "trace", err)] #[instrument(level = "trace", err)]
pub async fn connect( pub async fn connect(
local_address: Option<SocketAddr>, local_address: Option<SocketAddr>,
dial_info: DialInfo, dial_info: &DialInfo,
) -> io::Result<ProtocolNetworkConnection> { ) -> io::Result<ProtocolNetworkConnection> {
assert!(local_address.is_none()); assert!(local_address.is_none());
// Split dial info up // Split dial info up
let (_tls, scheme) = match &dial_info { let (_tls, scheme) = match dial_info {
DialInfo::WS(_) => (false, "ws"), DialInfo::WS(_) => (false, "ws"),
DialInfo::WSS(_) => (true, "wss"), DialInfo::WSS(_) => (true, "wss"),
_ => panic!("invalid dialinfo for WS/WSS protocol"), _ => panic!("invalid dialinfo for WS/WSS protocol"),
@ -105,45 +105,10 @@ impl WebsocketProtocolHandler {
let (wsmeta, wsio) = fut.await.map_err(to_io)?; let (wsmeta, wsio) = fut.await.map_err(to_io)?;
// Make our connection descriptor // Make our connection descriptor
Ok(ProtocolNetworkConnection::Ws( Ok(WebsocketNetworkConnection::new(
WebsocketNetworkConnection::new( ConnectionDescriptor::new_no_local(dial_info.to_peer_address()),
ConnectionDescriptor::new_no_local(dial_info.to_peer_address()), wsmeta,
wsmeta, wsio,
wsio,
),
)) ))
} }
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> io::Result<()> {
if data.len() > MAX_MESSAGE_SIZE {
bail_io_error_other!("sending too large unbound WS message");
}
// Make the real connection
let conn = Self::connect(None, dial_info).await?;
conn.send(data).await
}
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len(), ret.len))]
pub async fn send_recv_unbound_message(
dial_info: DialInfo,
data: Vec<u8>,
timeout_ms: u32,
) -> io::Result<Vec<u8>> {
if data.len() > MAX_MESSAGE_SIZE {
bail_io_error_other!("sending too large unbound WS message");
}
let conn = Self::connect(None, dial_info.clone()).await?;
conn.send(data).await?;
let out = timeout(timeout_ms, conn.recv())
.await
.map_err(|e| e.to_io())??;
tracing::Span::current().record("ret.len", &out.len());
Ok(out)
}
} }

View File

@ -207,13 +207,7 @@ impl RoutingTable {
let mut unord = FuturesUnordered::new(); let mut unord = FuturesUnordered::new();
for bootstrap_di in bootstrap_dialinfos { for bootstrap_di in bootstrap_dialinfos {
let peer_info = match network_manager.boot_request(bootstrap_di).await { let peer_info = network_manager.boot_request(bootstrap_di).await?;
Ok(v) => v,
Err(e) => {
error!("BOOT request failed: {}", e);
continue;
}
};
// Got peer info, let's add it to the routing table // Got peer info, let's add it to the routing table
for pi in peer_info { for pi in peer_info {

View File

@ -32,6 +32,49 @@ cfg_if! {
} }
////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////
// Non-fallible timeout conversions
pub trait TimeoutOrExt<T> {
fn into_timeout_or(self) -> TimeoutOr<T>;
}
impl<T> TimeoutOrExt<T> for Result<T, TimeoutError> {
fn into_timeout_or(self) -> TimeoutOr<T> {
self.ok().map(|v| TimeoutOr::<T>::Value(v)).unwrap_or(TimeoutOr::<T>::Timeout)
}
}
pub trait IoTimeoutOrExt<T> {
fn into_timeout_or(self) -> io::Result<TimeoutOr<T>>;
}
impl<T> IoTimeoutOrExt<T> for io::Result<T> {
fn into_timeout_or(self) -> io::Result<TimeoutOr<T>> {
match self {
Ok(v) => Ok(TimeoutOr::<T>::Value(v)),
Err(e) if e.kind() == io::ErrorKind::TimedOut => Ok(TimeoutOr::<T>::Timeout),
Err(e) => Err(e),
}
}
}
pub trait TimeoutOrResultExt<T, E> {
fn into_result(self) -> Result<TimeoutOr<T>, E>;
}
impl<T,E> TimeoutOrResultExt<T, E> for TimeoutOr<Result<T,E>> {
fn into_result(self) -> Result<TimeoutOr<T>, E> {
match self {
TimeoutOr::<Result::<T,E>>::Timeout => Ok(TimeoutOr::<T>::Timeout),
TimeoutOr::<Result::<T,E>>::Value(Ok(v)) => Ok(TimeoutOr::<T>::Value(v)),
TimeoutOr::<Result::<T,E>>::Value(Err(e)) => Err(e),
}
}
}
//////////////////////////////////////////////////////////////////
// Non-fallible timeout
pub enum TimeoutOr<T> { pub enum TimeoutOr<T> {
Timeout, Timeout,

View File

@ -146,6 +146,13 @@ where
} }
} }
pub fn compatible_unspecified_socket_addr(socket_addr: &SocketAddr) -> SocketAddr {
match socket_addr {
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0),
}
}
pub fn listen_address_to_socket_addrs(listen_address: &str) -> EyreResult<Vec<SocketAddr>> { pub fn listen_address_to_socket_addrs(listen_address: &str) -> EyreResult<Vec<SocketAddr>> {
// If no address is specified, but the port is, use ipv4 and ipv6 unspecified // If no address is specified, but the port is, use ipv4 and ipv6 unspecified
// If the address is specified, only use the specified port and fail otherwise // If the address is specified, only use the specified port and fail otherwise