frag work

This commit is contained in:
John Smith 2023-06-22 16:54:01 -04:00
parent 3d3582e688
commit 7dbd1e8b92
6 changed files with 106 additions and 67 deletions

View File

@ -196,7 +196,6 @@ impl AttachmentManager {
if let Err(err) = netman.startup().await {
error!("network startup failed: {}", err);
netman.shutdown().await;
restart = true;
break;
}

View File

@ -8,7 +8,7 @@ const VERSION_1: u8 = 1;
type LengthType = u16;
type SequenceType = u16;
const HEADER_LEN: usize = 8;
const MAX_MESSAGE_LEN: usize = LengthType::MAX as usize;
const MAX_LEN: usize = LengthType::MAX as usize;
// XXX: keep statistics on all drops and why we dropped them
// XXX: move to config
@ -16,14 +16,10 @@ const FRAGMENT_LEN: usize = 1280 - HEADER_LEN;
const MAX_CONCURRENT_HOSTS: usize = 256;
const MAX_ASSEMBLIES_PER_HOST: usize = 256;
const MAX_BUFFER_PER_HOST: usize = 256 * 1024;
const MAX_ASSEMBLY_AGE_US: u64 = 10_000_000;
/////////////////////////////////////////////////////////
pub struct Message {
data: Vec<u8>,
remote_addr: SocketAddr,
}
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
struct PeerKey {
remote_addr: SocketAddr,
@ -31,6 +27,7 @@ struct PeerKey {
#[derive(Clone, Eq, PartialEq)]
struct MessageAssembly {
timestamp: Timestamp,
seq: SequenceType,
data: Vec<u8>,
parts: RangeSetBlaze<LengthType>,
@ -38,15 +35,29 @@ struct MessageAssembly {
#[derive(Clone, Eq, PartialEq)]
struct PeerMessages {
assemblies: Vec<MessageAssembly>,
assemblies: LinkedList<MessageAssembly>,
}
impl PeerMessages {
pub fn new() -> Self {
Self {
assemblies: Vec::new(),
assemblies: LinkedList::new(),
}
}
pub fn insert_fragment(
&mut self,
seq: SequenceType,
off: LengthType,
len: LengthType,
chunk: &[u8],
) -> Option<Vec<u8>> {
// Get the current timestamp
let cur_ts = get_timestamp();
// Get the assembly this belongs to by its sequence number
for a in self.assemblies {}
}
}
/////////////////////////////////////////////////////////
@ -70,19 +81,19 @@ pub struct AssemblyBuffer {
}
impl AssemblyBuffer {
pub fn new_unlocked_inner() -> AssemblyBufferUnlockedInner {
fn new_unlocked_inner() -> AssemblyBufferUnlockedInner {
AssemblyBufferUnlockedInner {
outbound_lock_table: AsyncTagLockTable::new(),
next_seq: AtomicU16::new(0),
}
}
pub fn new_inner() -> AssemblyBufferInner {
fn new_inner() -> AssemblyBufferInner {
AssemblyBufferInner {
peer_message_map: HashMap::new(),
}
}
pub fn new(frag_len: usize) -> Self {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(Self::new_inner())),
unlocked_inner: Arc::new(Self::new_unlocked_inner()),
@ -91,18 +102,15 @@ impl AssemblyBuffer {
/// Receive a packet chunk and add to the message assembly
/// if a message has been completely, return it
pub fn receive_packet(&self, frame: &[u8], remote_addr: SocketAddr) -> Option<Message> {
pub fn insert_frame(&self, frame: &[u8], remote_addr: SocketAddr) -> Option<Vec<u8>> {
// If we receive a zero length frame, send it
if frame.len() == 0 {
return Some(Message {
data: frame.to_vec(),
remote_addr,
});
return Some(frame.to_vec());
}
// If we receive a frame smaller than or equal to the length of the header, drop it
// or if this frame is larger than our max message length, then drop it
if frame.len() <= HEADER_LEN || frame.len() > MAX_MESSAGE_LEN {
if frame.len() <= HEADER_LEN || frame.len() > MAX_LEN {
return None;
}
@ -120,10 +128,7 @@ impl AssemblyBuffer {
// See if we have a whole message and not a fragment
if off == 0 && len as usize == chunk.len() {
return Some(Message {
data: frame.to_vec(),
remote_addr,
});
return Some(frame.to_vec());
}
// Drop fragments with offsets greater than or equal to the message length
@ -139,25 +144,32 @@ impl AssemblyBuffer {
// and drop the packet if we have too many peers
let mut inner = self.inner.lock();
let peer_key = PeerKey { remote_addr };
let peer_messages = match inner.peer_message_map.entry(peer_key) {
std::collections::hash_map::Entry::Occupied(e) => e.get_mut(),
let peer_count = inner.peer_message_map.len();
match inner.peer_message_map.entry(peer_key) {
std::collections::hash_map::Entry::Occupied(mut e) => {
let peer_messages = e.get_mut();
// Insert the fragment and see what comes out
peer_messages.insert_fragment(seq, off, len, chunk)
}
std::collections::hash_map::Entry::Vacant(v) => {
// See if we have room for one more
if inner.peer_message_map.len() == MAX_CONCURRENT_HOSTS {
if peer_count == MAX_CONCURRENT_HOSTS {
return None;
}
// Add the peer
v.insert(PeerMessages::new())
}
};
let peer_messages = v.insert(PeerMessages::new());
None
// Insert the fragment and see what comes out
peer_messages.insert_fragment(seq, off, len, chunk)
}
}
}
/// Add framing to chunk to send to the wire
fn frame_chunk(chunk: &[u8], offset: usize, message_len: usize, seq: SequenceType) -> Vec<u8> {
assert!(chunk.len() > 0);
assert!(message_len <= MAX_MESSAGE_LEN);
assert!(message_len <= MAX_LEN);
assert!(offset + chunk.len() <= message_len);
let off: LengthType = offset as LengthType;
@ -175,7 +187,7 @@ impl AssemblyBuffer {
out[6..HEADER_LEN].copy_from_slice(&len.to_be_bytes()); // total length of message
// Write out body
out[HEADER_LEN..out.len()].copy_from_slice(chunk);
out[HEADER_LEN..].copy_from_slice(chunk);
out
}
}
@ -183,25 +195,30 @@ impl AssemblyBuffer {
/// Split a message into packets and send them serially, ensuring
/// that they are sent consecutively to a particular remote address,
/// never interleaving packets from one message and other to minimize reassembly problems
pub async fn split_message<F>(&self, message: Message, sender: F) -> std::io::Result<()>
pub async fn split_message<S, F>(
&self,
data: Vec<u8>,
remote_addr: SocketAddr,
sender: S,
) -> std::io::Result<NetworkResult<()>>
where
F: Fn(Vec<u8>, SocketAddr) -> SendPinBoxFuture<std::io::Result<()>>,
S: Fn(Vec<u8>, SocketAddr) -> F,
F: Future<Output = std::io::Result<NetworkResult<()>>>,
{
if message.data.len() > MAX_MESSAGE_LEN {
if data.len() > MAX_LEN {
return Err(Error::from(ErrorKind::InvalidData));
}
// Do not frame or split anything zero bytes long, just send it
if message.data.len() == 0 {
sender(message.data, message.remote_addr).await?;
return Ok(());
if data.len() == 0 {
return sender(data, remote_addr).await;
}
// Lock per remote addr
let _tag_lock = self
.unlocked_inner
.outbound_lock_table
.lock_tag(message.remote_addr)
.lock_tag(remote_addr)
.await;
// Get a message seq
@ -209,16 +226,16 @@ impl AssemblyBuffer {
// Chunk it up
let mut offset = 0usize;
let message_len = message.data.len();
for chunk in message.data.chunks(FRAGMENT_LEN) {
let message_len = data.len();
for chunk in data.chunks(FRAGMENT_LEN) {
// Frame chunk
let framed_chunk = Self::frame_chunk(chunk, offset, message_len, seq);
// Send chunk
sender(framed_chunk, message.remote_addr).await?;
network_result_try!(sender(framed_chunk, remote_addr).await?);
// Go to next chunk
offset += chunk.len()
}
Ok(())
Ok(NetworkResult::value(()))
}
}

View File

@ -56,11 +56,11 @@ impl RawTcpNetworkConnection {
stream.flush().await.into_network_result()
}
#[instrument(level="trace", err, skip(self, message), fields(network_result, message.len = message.len()))]
//#[instrument(level="trace", err, skip(self, message), fields(network_result, message.len = message.len()))]
pub async fn send(&self, message: Vec<u8>) -> io::Result<NetworkResult<()>> {
let mut stream = self.stream.clone();
let out = Self::send_internal(&mut stream, message).await?;
tracing::Span::current().record("network_result", &tracing::field::display(&out));
//tracing::Span::current().record("network_result", &tracing::field::display(&out));
Ok(out)
}

View File

@ -18,13 +18,25 @@ impl RawUdpProtocolHandler {
// #[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.descriptor))]
pub async fn recv_message(&self, data: &mut [u8]) -> io::Result<(usize, ConnectionDescriptor)> {
let (size, descriptor) = loop {
let (message_len, descriptor) = loop {
// Get a packet
let (size, remote_addr) = network_result_value_or_log!(self.socket.recv_from(data).await.into_network_result()? => continue);
if size > MAX_MESSAGE_SIZE {
// Insert into assembly buffer
let Some(message) = self.assembly_buffer.insert_frame(&data[0..size], remote_addr) else {
continue;
};
// Check length of reassembled message (same for all protocols)
if message.len() > MAX_MESSAGE_SIZE {
log_net!(debug "{}({}) at {}@{}:{}", "Invalid message".green(), "received too large UDP message", file!(), line!(), column!());
continue;
}
// Copy assemble message out if we got one
data[0..message.len()].copy_from_slice(&message);
// Return a connection descriptor and the amount of data in the message
let peer_addr = PeerAddress::new(
SocketAddress::from_socket_addr(remote_addr),
ProtocolType::UDP,
@ -35,25 +47,46 @@ impl RawUdpProtocolHandler {
SocketAddress::from_socket_addr(local_socket_addr),
);
break (size, descriptor);
break (message.len(), descriptor);
};
// tracing::Span::current().record("ret.len", &size);
// tracing::Span::current().record("ret.len", &message_len);
// tracing::Span::current().record("ret.descriptor", &format!("{:?}", descriptor).as_str());
Ok((size, descriptor))
Ok((message_len, descriptor))
}
#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.descriptor))]
//#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.descriptor))]
pub async fn send_message(
&self,
data: Vec<u8>,
socket_addr: SocketAddr,
remote_addr: SocketAddr,
) -> io::Result<NetworkResult<ConnectionDescriptor>> {
if data.len() > MAX_MESSAGE_SIZE {
bail_io_error_other!("sending too large UDP message");
}
// Fragment and send
let sender = |framed_chunk: Vec<u8>, remote_addr: SocketAddr| async move {
let len = network_result_try!(self
.socket
.send_to(&framed_chunk, remote_addr)
.await
.into_network_result()?);
if len != framed_chunk.len() {
bail_io_error_other!("UDP partial send")
}
Ok(NetworkResult::value(()))
};
network_result_try!(
self.assembly_buffer
.split_message(data, remote_addr, sender)
.await?
);
// Return a connection descriptor for the sent message
let peer_addr = PeerAddress::new(
SocketAddress::from_socket_addr(socket_addr),
SocketAddress::from_socket_addr(remote_addr),
ProtocolType::UDP,
);
let local_socket_addr = self.socket.local_addr()?;
@ -63,17 +96,7 @@ impl RawUdpProtocolHandler {
SocketAddress::from_socket_addr(local_socket_addr),
);
let len = network_result_try!(self
.socket
.send_to(&data, socket_addr)
.await
.into_network_result()?);
if len != data.len() {
bail_io_error_other!("UDP partial send")
}
tracing::Span::current().record("ret.len", &len);
tracing::Span::current().record("ret.descriptor", &format!("{:?}", descriptor).as_str());
// tracing::Span::current().record("ret.descriptor", &format!("{:?}", descriptor).as_str());
Ok(NetworkResult::value(descriptor))
}

View File

@ -72,7 +72,7 @@ where
// .map_err(to_io_error_other)
// }
#[instrument(level = "trace", err, skip(self, message), fields(network_result, message.len = message.len()))]
//#[instrument(level = "trace", err, skip(self, message), fields(network_result, message.len = message.len()))]
pub async fn send(&self, message: Vec<u8>) -> io::Result<NetworkResult<()>> {
if message.len() > MAX_MESSAGE_SIZE {
bail_io_error_other!("received too large WS message");
@ -89,7 +89,7 @@ where
Ok(v) => NetworkResult::value(v),
Err(e) => err_to_network_result(e),
};
tracing::Span::current().record("network_result", &tracing::field::display(&out));
//tracing::Span::current().record("network_result", &tracing::field::display(&out));
Ok(out)
}

View File

@ -64,7 +64,7 @@ impl WebsocketNetworkConnection {
// self.inner.ws_meta.close().await.map_err(to_io).map(drop)
// }
#[instrument(level = "trace", err, skip(self, message), fields(network_result, message.len = message.len()))]
//#[instrument(level = "trace", err, skip(self, message), fields(network_result, message.len = message.len()))]
pub async fn send(&self, message: Vec<u8>) -> io::Result<NetworkResult<()>> {
if message.len() > MAX_MESSAGE_SIZE {
bail_io_error_other!("sending too large WS message");
@ -79,7 +79,7 @@ impl WebsocketNetworkConnection {
.map_err(to_io)
.into_network_result()?;
tracing::Span::current().record("network_result", &tracing::field::display(&out));
//tracing::Span::current().record("network_result", &tracing::field::display(&out));
Ok(out)
}