frag work

This commit is contained in:
John Smith 2023-06-22 16:54:01 -04:00
parent 3d3582e688
commit 7dbd1e8b92
6 changed files with 106 additions and 67 deletions

View File

@ -196,7 +196,6 @@ impl AttachmentManager {
if let Err(err) = netman.startup().await { if let Err(err) = netman.startup().await {
error!("network startup failed: {}", err); error!("network startup failed: {}", err);
netman.shutdown().await; netman.shutdown().await;
restart = true;
break; break;
} }

View File

@ -8,7 +8,7 @@ const VERSION_1: u8 = 1;
type LengthType = u16; type LengthType = u16;
type SequenceType = u16; type SequenceType = u16;
const HEADER_LEN: usize = 8; const HEADER_LEN: usize = 8;
const MAX_MESSAGE_LEN: usize = LengthType::MAX as usize; const MAX_LEN: usize = LengthType::MAX as usize;
// XXX: keep statistics on all drops and why we dropped them // XXX: keep statistics on all drops and why we dropped them
// XXX: move to config // XXX: move to config
@ -16,14 +16,10 @@ const FRAGMENT_LEN: usize = 1280 - HEADER_LEN;
const MAX_CONCURRENT_HOSTS: usize = 256; const MAX_CONCURRENT_HOSTS: usize = 256;
const MAX_ASSEMBLIES_PER_HOST: usize = 256; const MAX_ASSEMBLIES_PER_HOST: usize = 256;
const MAX_BUFFER_PER_HOST: usize = 256 * 1024; const MAX_BUFFER_PER_HOST: usize = 256 * 1024;
const MAX_ASSEMBLY_AGE_US: u64 = 10_000_000;
///////////////////////////////////////////////////////// /////////////////////////////////////////////////////////
pub struct Message {
data: Vec<u8>,
remote_addr: SocketAddr,
}
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] #[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
struct PeerKey { struct PeerKey {
remote_addr: SocketAddr, remote_addr: SocketAddr,
@ -31,6 +27,7 @@ struct PeerKey {
#[derive(Clone, Eq, PartialEq)] #[derive(Clone, Eq, PartialEq)]
struct MessageAssembly { struct MessageAssembly {
timestamp: Timestamp,
seq: SequenceType, seq: SequenceType,
data: Vec<u8>, data: Vec<u8>,
parts: RangeSetBlaze<LengthType>, parts: RangeSetBlaze<LengthType>,
@ -38,15 +35,29 @@ struct MessageAssembly {
#[derive(Clone, Eq, PartialEq)] #[derive(Clone, Eq, PartialEq)]
struct PeerMessages { struct PeerMessages {
assemblies: Vec<MessageAssembly>, assemblies: LinkedList<MessageAssembly>,
} }
impl PeerMessages { impl PeerMessages {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
assemblies: Vec::new(), assemblies: LinkedList::new(),
} }
} }
pub fn insert_fragment(
&mut self,
seq: SequenceType,
off: LengthType,
len: LengthType,
chunk: &[u8],
) -> Option<Vec<u8>> {
// Get the current timestamp
let cur_ts = get_timestamp();
// Get the assembly this belongs to by its sequence number
for a in self.assemblies {}
}
} }
///////////////////////////////////////////////////////// /////////////////////////////////////////////////////////
@ -70,19 +81,19 @@ pub struct AssemblyBuffer {
} }
impl AssemblyBuffer { impl AssemblyBuffer {
pub fn new_unlocked_inner() -> AssemblyBufferUnlockedInner { fn new_unlocked_inner() -> AssemblyBufferUnlockedInner {
AssemblyBufferUnlockedInner { AssemblyBufferUnlockedInner {
outbound_lock_table: AsyncTagLockTable::new(), outbound_lock_table: AsyncTagLockTable::new(),
next_seq: AtomicU16::new(0), next_seq: AtomicU16::new(0),
} }
} }
pub fn new_inner() -> AssemblyBufferInner { fn new_inner() -> AssemblyBufferInner {
AssemblyBufferInner { AssemblyBufferInner {
peer_message_map: HashMap::new(), peer_message_map: HashMap::new(),
} }
} }
pub fn new(frag_len: usize) -> Self { pub fn new() -> Self {
Self { Self {
inner: Arc::new(Mutex::new(Self::new_inner())), inner: Arc::new(Mutex::new(Self::new_inner())),
unlocked_inner: Arc::new(Self::new_unlocked_inner()), unlocked_inner: Arc::new(Self::new_unlocked_inner()),
@ -91,18 +102,15 @@ impl AssemblyBuffer {
/// Receive a packet chunk and add to the message assembly /// Receive a packet chunk and add to the message assembly
/// if a message has been completely, return it /// if a message has been completely, return it
pub fn receive_packet(&self, frame: &[u8], remote_addr: SocketAddr) -> Option<Message> { pub fn insert_frame(&self, frame: &[u8], remote_addr: SocketAddr) -> Option<Vec<u8>> {
// If we receive a zero length frame, send it // If we receive a zero length frame, send it
if frame.len() == 0 { if frame.len() == 0 {
return Some(Message { return Some(frame.to_vec());
data: frame.to_vec(),
remote_addr,
});
} }
// If we receive a frame smaller than or equal to the length of the header, drop it // If we receive a frame smaller than or equal to the length of the header, drop it
// or if this frame is larger than our max message length, then drop it // or if this frame is larger than our max message length, then drop it
if frame.len() <= HEADER_LEN || frame.len() > MAX_MESSAGE_LEN { if frame.len() <= HEADER_LEN || frame.len() > MAX_LEN {
return None; return None;
} }
@ -120,10 +128,7 @@ impl AssemblyBuffer {
// See if we have a whole message and not a fragment // See if we have a whole message and not a fragment
if off == 0 && len as usize == chunk.len() { if off == 0 && len as usize == chunk.len() {
return Some(Message { return Some(frame.to_vec());
data: frame.to_vec(),
remote_addr,
});
} }
// Drop fragments with offsets greater than or equal to the message length // Drop fragments with offsets greater than or equal to the message length
@ -139,25 +144,32 @@ impl AssemblyBuffer {
// and drop the packet if we have too many peers // and drop the packet if we have too many peers
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
let peer_key = PeerKey { remote_addr }; let peer_key = PeerKey { remote_addr };
let peer_messages = match inner.peer_message_map.entry(peer_key) { let peer_count = inner.peer_message_map.len();
std::collections::hash_map::Entry::Occupied(e) => e.get_mut(), match inner.peer_message_map.entry(peer_key) {
std::collections::hash_map::Entry::Occupied(mut e) => {
let peer_messages = e.get_mut();
// Insert the fragment and see what comes out
peer_messages.insert_fragment(seq, off, len, chunk)
}
std::collections::hash_map::Entry::Vacant(v) => { std::collections::hash_map::Entry::Vacant(v) => {
// See if we have room for one more // See if we have room for one more
if inner.peer_message_map.len() == MAX_CONCURRENT_HOSTS { if peer_count == MAX_CONCURRENT_HOSTS {
return None; return None;
} }
// Add the peer // Add the peer
v.insert(PeerMessages::new()) let peer_messages = v.insert(PeerMessages::new());
}
};
None // Insert the fragment and see what comes out
peer_messages.insert_fragment(seq, off, len, chunk)
}
}
} }
/// Add framing to chunk to send to the wire /// Add framing to chunk to send to the wire
fn frame_chunk(chunk: &[u8], offset: usize, message_len: usize, seq: SequenceType) -> Vec<u8> { fn frame_chunk(chunk: &[u8], offset: usize, message_len: usize, seq: SequenceType) -> Vec<u8> {
assert!(chunk.len() > 0); assert!(chunk.len() > 0);
assert!(message_len <= MAX_MESSAGE_LEN); assert!(message_len <= MAX_LEN);
assert!(offset + chunk.len() <= message_len); assert!(offset + chunk.len() <= message_len);
let off: LengthType = offset as LengthType; let off: LengthType = offset as LengthType;
@ -175,7 +187,7 @@ impl AssemblyBuffer {
out[6..HEADER_LEN].copy_from_slice(&len.to_be_bytes()); // total length of message out[6..HEADER_LEN].copy_from_slice(&len.to_be_bytes()); // total length of message
// Write out body // Write out body
out[HEADER_LEN..out.len()].copy_from_slice(chunk); out[HEADER_LEN..].copy_from_slice(chunk);
out out
} }
} }
@ -183,25 +195,30 @@ impl AssemblyBuffer {
/// Split a message into packets and send them serially, ensuring /// Split a message into packets and send them serially, ensuring
/// that they are sent consecutively to a particular remote address, /// that they are sent consecutively to a particular remote address,
/// never interleaving packets from one message and other to minimize reassembly problems /// never interleaving packets from one message and other to minimize reassembly problems
pub async fn split_message<F>(&self, message: Message, sender: F) -> std::io::Result<()> pub async fn split_message<S, F>(
&self,
data: Vec<u8>,
remote_addr: SocketAddr,
sender: S,
) -> std::io::Result<NetworkResult<()>>
where where
F: Fn(Vec<u8>, SocketAddr) -> SendPinBoxFuture<std::io::Result<()>>, S: Fn(Vec<u8>, SocketAddr) -> F,
F: Future<Output = std::io::Result<NetworkResult<()>>>,
{ {
if message.data.len() > MAX_MESSAGE_LEN { if data.len() > MAX_LEN {
return Err(Error::from(ErrorKind::InvalidData)); return Err(Error::from(ErrorKind::InvalidData));
} }
// Do not frame or split anything zero bytes long, just send it // Do not frame or split anything zero bytes long, just send it
if message.data.len() == 0 { if data.len() == 0 {
sender(message.data, message.remote_addr).await?; return sender(data, remote_addr).await;
return Ok(());
} }
// Lock per remote addr // Lock per remote addr
let _tag_lock = self let _tag_lock = self
.unlocked_inner .unlocked_inner
.outbound_lock_table .outbound_lock_table
.lock_tag(message.remote_addr) .lock_tag(remote_addr)
.await; .await;
// Get a message seq // Get a message seq
@ -209,16 +226,16 @@ impl AssemblyBuffer {
// Chunk it up // Chunk it up
let mut offset = 0usize; let mut offset = 0usize;
let message_len = message.data.len(); let message_len = data.len();
for chunk in message.data.chunks(FRAGMENT_LEN) { for chunk in data.chunks(FRAGMENT_LEN) {
// Frame chunk // Frame chunk
let framed_chunk = Self::frame_chunk(chunk, offset, message_len, seq); let framed_chunk = Self::frame_chunk(chunk, offset, message_len, seq);
// Send chunk // Send chunk
sender(framed_chunk, message.remote_addr).await?; network_result_try!(sender(framed_chunk, remote_addr).await?);
// Go to next chunk // Go to next chunk
offset += chunk.len() offset += chunk.len()
} }
Ok(()) Ok(NetworkResult::value(()))
} }
} }

View File

@ -56,11 +56,11 @@ impl RawTcpNetworkConnection {
stream.flush().await.into_network_result() stream.flush().await.into_network_result()
} }
#[instrument(level="trace", err, skip(self, message), fields(network_result, message.len = message.len()))] //#[instrument(level="trace", err, skip(self, message), fields(network_result, message.len = message.len()))]
pub async fn send(&self, message: Vec<u8>) -> io::Result<NetworkResult<()>> { pub async fn send(&self, message: Vec<u8>) -> io::Result<NetworkResult<()>> {
let mut stream = self.stream.clone(); let mut stream = self.stream.clone();
let out = Self::send_internal(&mut stream, message).await?; let out = Self::send_internal(&mut stream, message).await?;
tracing::Span::current().record("network_result", &tracing::field::display(&out)); //tracing::Span::current().record("network_result", &tracing::field::display(&out));
Ok(out) Ok(out)
} }

View File

@ -18,13 +18,25 @@ impl RawUdpProtocolHandler {
// #[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.descriptor))] // #[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.descriptor))]
pub async fn recv_message(&self, data: &mut [u8]) -> io::Result<(usize, ConnectionDescriptor)> { pub async fn recv_message(&self, data: &mut [u8]) -> io::Result<(usize, ConnectionDescriptor)> {
let (size, descriptor) = loop { let (message_len, descriptor) = loop {
// Get a packet
let (size, remote_addr) = network_result_value_or_log!(self.socket.recv_from(data).await.into_network_result()? => continue); let (size, remote_addr) = network_result_value_or_log!(self.socket.recv_from(data).await.into_network_result()? => continue);
if size > MAX_MESSAGE_SIZE {
// Insert into assembly buffer
let Some(message) = self.assembly_buffer.insert_frame(&data[0..size], remote_addr) else {
continue;
};
// Check length of reassembled message (same for all protocols)
if message.len() > MAX_MESSAGE_SIZE {
log_net!(debug "{}({}) at {}@{}:{}", "Invalid message".green(), "received too large UDP message", file!(), line!(), column!()); log_net!(debug "{}({}) at {}@{}:{}", "Invalid message".green(), "received too large UDP message", file!(), line!(), column!());
continue; continue;
} }
// Copy assemble message out if we got one
data[0..message.len()].copy_from_slice(&message);
// Return a connection descriptor and the amount of data in the message
let peer_addr = PeerAddress::new( let peer_addr = PeerAddress::new(
SocketAddress::from_socket_addr(remote_addr), SocketAddress::from_socket_addr(remote_addr),
ProtocolType::UDP, ProtocolType::UDP,
@ -35,25 +47,46 @@ impl RawUdpProtocolHandler {
SocketAddress::from_socket_addr(local_socket_addr), SocketAddress::from_socket_addr(local_socket_addr),
); );
break (size, descriptor); break (message.len(), descriptor);
}; };
// tracing::Span::current().record("ret.len", &size); // tracing::Span::current().record("ret.len", &message_len);
// tracing::Span::current().record("ret.descriptor", &format!("{:?}", descriptor).as_str()); // tracing::Span::current().record("ret.descriptor", &format!("{:?}", descriptor).as_str());
Ok((size, descriptor)) Ok((message_len, descriptor))
} }
#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.descriptor))] //#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.descriptor))]
pub async fn send_message( pub async fn send_message(
&self, &self,
data: Vec<u8>, data: Vec<u8>,
socket_addr: SocketAddr, remote_addr: SocketAddr,
) -> io::Result<NetworkResult<ConnectionDescriptor>> { ) -> io::Result<NetworkResult<ConnectionDescriptor>> {
if data.len() > MAX_MESSAGE_SIZE { if data.len() > MAX_MESSAGE_SIZE {
bail_io_error_other!("sending too large UDP message"); bail_io_error_other!("sending too large UDP message");
} }
// Fragment and send
let sender = |framed_chunk: Vec<u8>, remote_addr: SocketAddr| async move {
let len = network_result_try!(self
.socket
.send_to(&framed_chunk, remote_addr)
.await
.into_network_result()?);
if len != framed_chunk.len() {
bail_io_error_other!("UDP partial send")
}
Ok(NetworkResult::value(()))
};
network_result_try!(
self.assembly_buffer
.split_message(data, remote_addr, sender)
.await?
);
// Return a connection descriptor for the sent message
let peer_addr = PeerAddress::new( let peer_addr = PeerAddress::new(
SocketAddress::from_socket_addr(socket_addr), SocketAddress::from_socket_addr(remote_addr),
ProtocolType::UDP, ProtocolType::UDP,
); );
let local_socket_addr = self.socket.local_addr()?; let local_socket_addr = self.socket.local_addr()?;
@ -63,17 +96,7 @@ impl RawUdpProtocolHandler {
SocketAddress::from_socket_addr(local_socket_addr), SocketAddress::from_socket_addr(local_socket_addr),
); );
let len = network_result_try!(self // tracing::Span::current().record("ret.descriptor", &format!("{:?}", descriptor).as_str());
.socket
.send_to(&data, socket_addr)
.await
.into_network_result()?);
if len != data.len() {
bail_io_error_other!("UDP partial send")
}
tracing::Span::current().record("ret.len", &len);
tracing::Span::current().record("ret.descriptor", &format!("{:?}", descriptor).as_str());
Ok(NetworkResult::value(descriptor)) Ok(NetworkResult::value(descriptor))
} }

View File

@ -72,7 +72,7 @@ where
// .map_err(to_io_error_other) // .map_err(to_io_error_other)
// } // }
#[instrument(level = "trace", err, skip(self, message), fields(network_result, message.len = message.len()))] //#[instrument(level = "trace", err, skip(self, message), fields(network_result, message.len = message.len()))]
pub async fn send(&self, message: Vec<u8>) -> io::Result<NetworkResult<()>> { pub async fn send(&self, message: Vec<u8>) -> io::Result<NetworkResult<()>> {
if message.len() > MAX_MESSAGE_SIZE { if message.len() > MAX_MESSAGE_SIZE {
bail_io_error_other!("received too large WS message"); bail_io_error_other!("received too large WS message");
@ -89,7 +89,7 @@ where
Ok(v) => NetworkResult::value(v), Ok(v) => NetworkResult::value(v),
Err(e) => err_to_network_result(e), Err(e) => err_to_network_result(e),
}; };
tracing::Span::current().record("network_result", &tracing::field::display(&out)); //tracing::Span::current().record("network_result", &tracing::field::display(&out));
Ok(out) Ok(out)
} }

View File

@ -64,7 +64,7 @@ impl WebsocketNetworkConnection {
// self.inner.ws_meta.close().await.map_err(to_io).map(drop) // self.inner.ws_meta.close().await.map_err(to_io).map(drop)
// } // }
#[instrument(level = "trace", err, skip(self, message), fields(network_result, message.len = message.len()))] //#[instrument(level = "trace", err, skip(self, message), fields(network_result, message.len = message.len()))]
pub async fn send(&self, message: Vec<u8>) -> io::Result<NetworkResult<()>> { pub async fn send(&self, message: Vec<u8>) -> io::Result<NetworkResult<()>> {
if message.len() > MAX_MESSAGE_SIZE { if message.len() > MAX_MESSAGE_SIZE {
bail_io_error_other!("sending too large WS message"); bail_io_error_other!("sending too large WS message");
@ -79,7 +79,7 @@ impl WebsocketNetworkConnection {
.map_err(to_io) .map_err(to_io)
.into_network_result()?; .into_network_result()?;
tracing::Span::current().record("network_result", &tracing::field::display(&out)); //tracing::Span::current().record("network_result", &tracing::field::display(&out));
Ok(out) Ok(out)
} }