Move protocol states in an enum

This commit is contained in:
Franck Royer 2021-01-19 10:25:29 +11:00
parent 19f11697fb
commit 4b71739dc9
No known key found for this signature in database
GPG Key ID: A82ED75A8DFC50A4

View File

@ -13,7 +13,7 @@ use libp2p::{InboundUpgrade, OutboundUpgrade, PeerId};
use std::collections::{HashMap, VecDeque};
use std::convert::Infallible;
use std::future::{Future, Ready};
use std::iter;
use std::{iter, mem};
#[cfg(test)]
mod swarm_harness;
@ -23,17 +23,30 @@ type InboundProtocolFn<I, E> = Box<dyn FnOnce(InboundSubstream) -> Protocol<I, E
type OutboundProtocolFn<O, E> =
Box<dyn FnOnce(OutboundSubstream) -> Protocol<O, E> + Send + 'static>;
enum InboundProtocolState<T, E> {
None,
PendingSubstream(InboundProtocolFn<T, E>),
PendingProtocolFn(InboundSubstream),
ReadyToPoll(Protocol<T, E>),
Done,
Poisoned,
}
enum OutboundProtocolState<T, E> {
None,
PendingSubstream(OutboundProtocolFn<T, E>),
PendingProtocolFn(OutboundSubstream),
ReadyToPoll(Protocol<T, E>),
Done,
Poisoned,
}
pub struct NMessageHandler<TInboundOut, TOutboundOut, TErr> {
inbound_substream: Option<InboundSubstream>,
outbound_substream: Option<OutboundSubstream>,
inbound_future: Option<Protocol<TInboundOut, TErr>>,
inbound_future_fn: Option<InboundProtocolFn<TInboundOut, TErr>>,
// todo: make enum for state
outbound_future: Option<Protocol<TOutboundOut, TErr>>,
outbound_future_fn: Option<OutboundProtocolFn<TOutboundOut, TErr>>,
inbound_state: InboundProtocolState<TInboundOut, TErr>,
outbound_state: OutboundProtocolState<TOutboundOut, TErr>,
// TODO: See if it can be included in OutboundProtocolState.
// Or it can be inferred from OutboundProtocolState current variant.
substream_request: Option<SubstreamProtocol<NMessageProtocol, ()>>,
info: &'static [u8],
@ -42,14 +55,10 @@ pub struct NMessageHandler<TInboundOut, TOutboundOut, TErr> {
impl<TInboundOut, TOutboundOut, TErr> NMessageHandler<TInboundOut, TOutboundOut, TErr> {
pub fn new(info: &'static [u8]) -> Self {
Self {
inbound_substream: None,
inbound_future: None,
outbound_substream: None,
outbound_future: None,
outbound_future_fn: None,
inbound_state: InboundProtocolState::None,
outbound_state: OutboundProtocolState::None,
substream_request: None,
info,
inbound_future_fn: None,
}
}
}
@ -102,6 +111,7 @@ pub enum ProtocolInEvent<I, O, E> {
ExecuteOutbound(OutboundProtocolFn<O, E>),
}
// TODO: Remove Finished/Failed and just wrap a Result
pub enum ProtocolOutEvent<I, O, E> {
InboundFinished(I),
OutboundFinished(O),
@ -130,46 +140,82 @@ where
fn inject_fully_negotiated_inbound(
&mut self,
protocol: InboundSubstream,
substream: InboundSubstream,
_: Self::InboundOpenInfo,
) {
if let Some(future_fn) = self.inbound_future_fn.take() {
self.inbound_future = Some(future_fn(protocol))
} else {
self.inbound_substream = Some(protocol)
match mem::replace(&mut self.inbound_state, InboundProtocolState::Poisoned) {
InboundProtocolState::None => {
self.inbound_state = InboundProtocolState::PendingProtocolFn(substream);
}
InboundProtocolState::PendingSubstream(protocol_fn) => {
self.inbound_state = InboundProtocolState::ReadyToPoll(protocol_fn(substream));
}
InboundProtocolState::PendingProtocolFn(_)
| InboundProtocolState::ReadyToPoll(_)
| InboundProtocolState::Done
| InboundProtocolState::Poisoned => {
panic!("Failed to inject inbound substream due to unexpected state.");
}
}
}
fn inject_fully_negotiated_outbound(
&mut self,
protocol: OutboundSubstream,
substream: OutboundSubstream,
_: Self::OutboundOpenInfo,
) {
if let Some(future_fn) = self.outbound_future_fn.take() {
self.outbound_future = Some(future_fn(protocol))
} else {
self.outbound_substream = Some(protocol)
match mem::replace(&mut self.outbound_state, OutboundProtocolState::Poisoned) {
OutboundProtocolState::None => {
self.outbound_state = OutboundProtocolState::PendingProtocolFn(substream);
}
OutboundProtocolState::PendingSubstream(protocol_fn) => {
self.outbound_state = OutboundProtocolState::ReadyToPoll(protocol_fn(substream));
}
OutboundProtocolState::PendingProtocolFn(_)
| OutboundProtocolState::ReadyToPoll(_)
| OutboundProtocolState::Done
| OutboundProtocolState::Poisoned => {
panic!("Failed to inject outbound substream due to unexpected state.");
}
}
}
fn inject_event(&mut self, event: Self::InEvent) {
match event {
ProtocolInEvent::ExecuteInbound(protocol_fn) => match self.inbound_substream.take() {
Some(substream) => self.inbound_future = Some(protocol_fn(substream)),
None => {
self.inbound_future_fn = Some(protocol_fn);
ProtocolInEvent::ExecuteInbound(protocol_fn) => {
match mem::replace(&mut self.inbound_state, InboundProtocolState::Poisoned) {
InboundProtocolState::None => {
self.inbound_state = InboundProtocolState::PendingSubstream(protocol_fn);
}
InboundProtocolState::PendingProtocolFn(substream) => {
self.inbound_state =
InboundProtocolState::ReadyToPoll(protocol_fn(substream));
}
InboundProtocolState::PendingSubstream(_)
| InboundProtocolState::ReadyToPoll(_)
| InboundProtocolState::Done
| InboundProtocolState::Poisoned => {
panic!("Failed to inject inbound protocol fn due to unexpected state.");
}
}
},
}
ProtocolInEvent::ExecuteOutbound(protocol_fn) => {
self.substream_request =
Some(SubstreamProtocol::new(NMessageProtocol::new(self.info), ()));
match self.outbound_substream.take() {
Some(substream) => {
self.outbound_future = Some(protocol_fn(substream));
match mem::replace(&mut self.outbound_state, OutboundProtocolState::Poisoned) {
OutboundProtocolState::None => {
self.outbound_state = OutboundProtocolState::PendingSubstream(protocol_fn);
}
None => {
self.outbound_future_fn = Some(protocol_fn);
OutboundProtocolState::PendingProtocolFn(substream) => {
self.outbound_state =
OutboundProtocolState::ReadyToPoll(protocol_fn(substream));
}
OutboundProtocolState::PendingSubstream(_)
| OutboundProtocolState::ReadyToPoll(_)
| OutboundProtocolState::Done
| OutboundProtocolState::Poisoned => {
panic!("Failed to inject outbound protocol fn due to unexpected state.");
}
}
}
@ -203,43 +249,59 @@ where
return Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol });
}
if let Some(mut future) = self.inbound_future.take() {
match future.poll_unpin(cx) {
match mem::replace(&mut self.inbound_state, InboundProtocolState::Poisoned) {
InboundProtocolState::ReadyToPoll(mut protocol) => match protocol.poll_unpin(cx) {
Poll::Ready(Ok(value)) => {
self.inbound_state = InboundProtocolState::Done;
return Poll::Ready(ProtocolsHandlerEvent::Custom(
ProtocolOutEvent::InboundFinished(value),
))
));
}
Poll::Ready(Err(e)) => {
self.inbound_state = InboundProtocolState::Done;
return Poll::Ready(ProtocolsHandlerEvent::Custom(
ProtocolOutEvent::InboundFailed(e),
))
));
}
Poll::Pending => {
self.inbound_future = Some(future);
self.inbound_state = InboundProtocolState::ReadyToPoll(protocol);
return Poll::Pending;
}
},
InboundProtocolState::Poisoned => {
unreachable!("Inbound protocol is poisoned (transient state)")
}
}
other => {
self.inbound_state = other;
}
};
if let Some(mut future) = self.outbound_future.take() {
match future.poll_unpin(cx) {
match mem::replace(&mut self.outbound_state, OutboundProtocolState::Poisoned) {
OutboundProtocolState::ReadyToPoll(mut protocol) => match protocol.poll_unpin(cx) {
Poll::Ready(Ok(value)) => {
self.outbound_state = OutboundProtocolState::Done;
return Poll::Ready(ProtocolsHandlerEvent::Custom(
ProtocolOutEvent::OutboundFinished(value),
))
));
}
Poll::Ready(Err(e)) => {
self.outbound_state = OutboundProtocolState::Done;
return Poll::Ready(ProtocolsHandlerEvent::Custom(
ProtocolOutEvent::OutboundFailed(e),
))
));
}
Poll::Pending => {
self.outbound_future = Some(future);
self.outbound_state = OutboundProtocolState::ReadyToPoll(protocol);
return Poll::Pending;
}
},
OutboundProtocolState::Poisoned => {
unreachable!("Outbound protocol is poisoned (transient state)")
}
}
other => {
self.outbound_state = other;
}
};
Poll::Pending
}