From 4b71739dc9a3a16c908542158d433750e653719a Mon Sep 17 00:00:00 2001 From: Franck Royer Date: Tue, 19 Jan 2021 10:25:29 +1100 Subject: [PATCH] Move protocol states in an enum --- src/lib.rs | 158 +++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 110 insertions(+), 48 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d38d9e09..dff194ba 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,7 @@ use libp2p::{InboundUpgrade, OutboundUpgrade, PeerId}; use std::collections::{HashMap, VecDeque}; use std::convert::Infallible; use std::future::{Future, Ready}; -use std::iter; +use std::{iter, mem}; #[cfg(test)] mod swarm_harness; @@ -23,17 +23,30 @@ type InboundProtocolFn = Box Protocol = Box Protocol + Send + 'static>; +enum InboundProtocolState { + None, + PendingSubstream(InboundProtocolFn), + PendingProtocolFn(InboundSubstream), + ReadyToPoll(Protocol), + Done, + Poisoned, +} + +enum OutboundProtocolState { + None, + PendingSubstream(OutboundProtocolFn), + PendingProtocolFn(OutboundSubstream), + ReadyToPoll(Protocol), + Done, + Poisoned, +} + pub struct NMessageHandler { - inbound_substream: Option, - outbound_substream: Option, - - inbound_future: Option>, - inbound_future_fn: Option>, - - // todo: make enum for state - outbound_future: Option>, - outbound_future_fn: Option>, + inbound_state: InboundProtocolState, + outbound_state: OutboundProtocolState, + // TODO: See if it can be included in OutboundProtocolState. + // Or it can be inferred from OutboundProtocolState current variant. substream_request: Option>, info: &'static [u8], @@ -42,14 +55,10 @@ pub struct NMessageHandler { impl NMessageHandler { pub fn new(info: &'static [u8]) -> Self { Self { - inbound_substream: None, - inbound_future: None, - outbound_substream: None, - outbound_future: None, - outbound_future_fn: None, + inbound_state: InboundProtocolState::None, + outbound_state: OutboundProtocolState::None, substream_request: None, info, - inbound_future_fn: None, } } } @@ -102,6 +111,7 @@ pub enum ProtocolInEvent { ExecuteOutbound(OutboundProtocolFn), } +// TODO: Remove Finished/Failed and just wrap a Result pub enum ProtocolOutEvent { InboundFinished(I), OutboundFinished(O), @@ -130,46 +140,82 @@ where fn inject_fully_negotiated_inbound( &mut self, - protocol: InboundSubstream, + substream: InboundSubstream, _: Self::InboundOpenInfo, ) { - if let Some(future_fn) = self.inbound_future_fn.take() { - self.inbound_future = Some(future_fn(protocol)) - } else { - self.inbound_substream = Some(protocol) + match mem::replace(&mut self.inbound_state, InboundProtocolState::Poisoned) { + InboundProtocolState::None => { + self.inbound_state = InboundProtocolState::PendingProtocolFn(substream); + } + InboundProtocolState::PendingSubstream(protocol_fn) => { + self.inbound_state = InboundProtocolState::ReadyToPoll(protocol_fn(substream)); + } + InboundProtocolState::PendingProtocolFn(_) + | InboundProtocolState::ReadyToPoll(_) + | InboundProtocolState::Done + | InboundProtocolState::Poisoned => { + panic!("Failed to inject inbound substream due to unexpected state."); + } } } fn inject_fully_negotiated_outbound( &mut self, - protocol: OutboundSubstream, + substream: OutboundSubstream, _: Self::OutboundOpenInfo, ) { - if let Some(future_fn) = self.outbound_future_fn.take() { - self.outbound_future = Some(future_fn(protocol)) - } else { - self.outbound_substream = Some(protocol) + match mem::replace(&mut self.outbound_state, OutboundProtocolState::Poisoned) { + OutboundProtocolState::None => { + self.outbound_state = OutboundProtocolState::PendingProtocolFn(substream); + } + OutboundProtocolState::PendingSubstream(protocol_fn) => { + self.outbound_state = OutboundProtocolState::ReadyToPoll(protocol_fn(substream)); + } + OutboundProtocolState::PendingProtocolFn(_) + | OutboundProtocolState::ReadyToPoll(_) + | OutboundProtocolState::Done + | OutboundProtocolState::Poisoned => { + panic!("Failed to inject outbound substream due to unexpected state."); + } } } fn inject_event(&mut self, event: Self::InEvent) { match event { - ProtocolInEvent::ExecuteInbound(protocol_fn) => match self.inbound_substream.take() { - Some(substream) => self.inbound_future = Some(protocol_fn(substream)), - None => { - self.inbound_future_fn = Some(protocol_fn); + ProtocolInEvent::ExecuteInbound(protocol_fn) => { + match mem::replace(&mut self.inbound_state, InboundProtocolState::Poisoned) { + InboundProtocolState::None => { + self.inbound_state = InboundProtocolState::PendingSubstream(protocol_fn); + } + InboundProtocolState::PendingProtocolFn(substream) => { + self.inbound_state = + InboundProtocolState::ReadyToPoll(protocol_fn(substream)); + } + InboundProtocolState::PendingSubstream(_) + | InboundProtocolState::ReadyToPoll(_) + | InboundProtocolState::Done + | InboundProtocolState::Poisoned => { + panic!("Failed to inject inbound protocol fn due to unexpected state."); + } } - }, + } ProtocolInEvent::ExecuteOutbound(protocol_fn) => { self.substream_request = Some(SubstreamProtocol::new(NMessageProtocol::new(self.info), ())); - match self.outbound_substream.take() { - Some(substream) => { - self.outbound_future = Some(protocol_fn(substream)); + match mem::replace(&mut self.outbound_state, OutboundProtocolState::Poisoned) { + OutboundProtocolState::None => { + self.outbound_state = OutboundProtocolState::PendingSubstream(protocol_fn); } - None => { - self.outbound_future_fn = Some(protocol_fn); + OutboundProtocolState::PendingProtocolFn(substream) => { + self.outbound_state = + OutboundProtocolState::ReadyToPoll(protocol_fn(substream)); + } + OutboundProtocolState::PendingSubstream(_) + | OutboundProtocolState::ReadyToPoll(_) + | OutboundProtocolState::Done + | OutboundProtocolState::Poisoned => { + panic!("Failed to inject outbound protocol fn due to unexpected state."); } } } @@ -203,43 +249,59 @@ where return Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol }); } - if let Some(mut future) = self.inbound_future.take() { - match future.poll_unpin(cx) { + match mem::replace(&mut self.inbound_state, InboundProtocolState::Poisoned) { + InboundProtocolState::ReadyToPoll(mut protocol) => match protocol.poll_unpin(cx) { Poll::Ready(Ok(value)) => { + self.inbound_state = InboundProtocolState::Done; return Poll::Ready(ProtocolsHandlerEvent::Custom( ProtocolOutEvent::InboundFinished(value), - )) + )); } Poll::Ready(Err(e)) => { + self.inbound_state = InboundProtocolState::Done; return Poll::Ready(ProtocolsHandlerEvent::Custom( ProtocolOutEvent::InboundFailed(e), - )) + )); } Poll::Pending => { - self.inbound_future = Some(future); + self.inbound_state = InboundProtocolState::ReadyToPoll(protocol); return Poll::Pending; } + }, + InboundProtocolState::Poisoned => { + unreachable!("Inbound protocol is poisoned (transient state)") } - } + other => { + self.inbound_state = other; + } + }; - if let Some(mut future) = self.outbound_future.take() { - match future.poll_unpin(cx) { + match mem::replace(&mut self.outbound_state, OutboundProtocolState::Poisoned) { + OutboundProtocolState::ReadyToPoll(mut protocol) => match protocol.poll_unpin(cx) { Poll::Ready(Ok(value)) => { + self.outbound_state = OutboundProtocolState::Done; return Poll::Ready(ProtocolsHandlerEvent::Custom( ProtocolOutEvent::OutboundFinished(value), - )) + )); } Poll::Ready(Err(e)) => { + self.outbound_state = OutboundProtocolState::Done; return Poll::Ready(ProtocolsHandlerEvent::Custom( ProtocolOutEvent::OutboundFailed(e), - )) + )); } Poll::Pending => { - self.outbound_future = Some(future); + self.outbound_state = OutboundProtocolState::ReadyToPoll(protocol); return Poll::Pending; } + }, + OutboundProtocolState::Poisoned => { + unreachable!("Outbound protocol is poisoned (transient state)") } - } + other => { + self.outbound_state = other; + } + }; Poll::Pending }