From 6877ae4d6f895bdf454c5945200eca70ff7f9da7 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 15 Jan 2021 19:37:36 +1100 Subject: [PATCH] Initial draft --- .gitignore | 2 + Cargo.toml | 20 ++ src/lib.rs | 480 +++++++++++++++++++++++++++++++++++++++++++ src/swarm_harness.rs | 150 ++++++++++++++ 4 files changed, 652 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 src/lib.rs create mode 100644 src/swarm_harness.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..96ef6c0b --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 00000000..fb038fb1 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "libp2p-nmessage" +version = "0.1.0" +authors = ["Thomas Eizinger "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +libp2p = { version = "0.34", default-features = false } +log = "0.4" + +[dev-dependencies] +anyhow = "1" +serde_cbor = "0.11" +tokio = { version = "1", features = ["macros", "rt", "time"] } +libp2p = { version = "0.34", default-features = false, features = ["noise", "yamux"] } +rand = "0.8" +serde = { version = "1", features = ["derive"] } +env_logger = "0.8" diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 00000000..4ec82be3 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,480 @@ +use libp2p::swarm::{ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr, KeepAlive, SubstreamProtocol, NegotiatedSubstream, NetworkBehaviour, NetworkBehaviourAction, PollParameters, NotifyHandler}; +use libp2p::futures::task::{Context, Poll}; +use libp2p::{InboundUpgrade, PeerId, OutboundUpgrade}; +use libp2p::core::{UpgradeInfo, Multiaddr}; +use libp2p::futures::{FutureExt}; +use std::future::{Ready, Future}; +use std::convert::Infallible; +use libp2p::futures::future::BoxFuture; +use std::collections::VecDeque; +use libp2p::swarm::protocols_handler::OutboundUpgradeSend; +use libp2p::core::connection::ConnectionId; +use std::iter; +use std::task::Waker; + +#[cfg(test)] +mod swarm_harness; + +pub struct NMessageHandler { + inbound_substream: Option, + outbound_substream: Option, + + inbound_future: Option>>, + inbound_future_fn: Option BoxFuture<'static, Result> + Send + 'static>>, + + // todo: make enum for state + outbound_future: Option>>, + outbound_future_fn: Option BoxFuture<'static, Result> + Send + 'static>>, + + substream_request: Option>, + + info: &'static [u8] +} + +impl NMessageHandler { + pub fn new(info: &'static [u8]) -> Self { + Self { + inbound_substream: None, + inbound_future: None, + outbound_substream: None, + outbound_future: None, + outbound_future_fn: None, + substream_request: None, + info, + inbound_future_fn: None + } + } +} + +pub struct NMessageProtocol { + info: &'static [u8] +} + +impl NMessageProtocol { + fn new(info: &'static [u8]) -> Self { + Self { + info + } + } +} + +impl UpgradeInfo for NMessageProtocol { + type Info = &'static [u8]; + type InfoIter = iter::Once<&'static [u8]>; + + fn protocol_info(&self) -> Self::InfoIter { + iter::once(self.info) + } +} + +impl InboundUpgrade for NMessageProtocol { + type Output = NegotiatedSubstream; + type Error = Infallible; + type Future = Ready>; + + fn upgrade_inbound(self, socket: NegotiatedSubstream, _: Self::Info) -> Self::Future { + std::future::ready(Ok(socket)) + } +} + +impl OutboundUpgrade for NMessageProtocol { + type Output = NegotiatedSubstream; + type Error = Infallible; + type Future = Ready>; + + fn upgrade_outbound(self, socket: NegotiatedSubstream, _: Self::Info) -> Self::Future { + std::future::ready(Ok(socket)) + } +} + +pub enum ProtocolInEvent { + ExecuteInbound(Box BoxFuture<'static, Result> + Send + 'static>), + ExecuteOutbound(Box BoxFuture<'static, Result> + Send + 'static>), +} + +pub enum ProtocolOutEvent { + InboundFinished(I), + OutboundFinished(O), + InboundFailed(E), + OutboundFailed(E), +} + +impl ProtocolsHandler for NMessageHandler where TInboundOut: Send + 'static, TOutboundOut: Send + 'static, TErr: Send + 'static { + type InEvent = ProtocolInEvent; + type OutEvent = ProtocolOutEvent; + type Error = std::io::Error; + type InboundProtocol = NMessageProtocol; + type OutboundProtocol = NMessageProtocol; + type InboundOpenInfo = (); + type OutboundOpenInfo = (); + + fn listen_protocol(&self) -> SubstreamProtocol { + SubstreamProtocol::new(NMessageProtocol::new(self.info), ()) + } + + fn inject_fully_negotiated_inbound(&mut self, protocol: NegotiatedSubstream, _: Self::InboundOpenInfo) { + log::info!("inject_fully_negotiated_inbound"); + + if let Some(future_fn) = self.inbound_future_fn.take() { + self.inbound_future = Some(future_fn(protocol)) + } else { + self.inbound_substream = Some(protocol) + } + } + + fn inject_fully_negotiated_outbound(&mut self, protocol: NegotiatedSubstream, _: Self::OutboundOpenInfo) { + log::info!("inject_fully_negotiated_outbound"); + + if let Some(future_fn) = self.outbound_future_fn.take() { + self.outbound_future = Some(future_fn(protocol)) + } else { + self.outbound_substream = Some(protocol) + } + } + + fn inject_event(&mut self, event: Self::InEvent) { + match event { + ProtocolInEvent::ExecuteInbound(protocol_fn) => { + log::trace!("got execute inbound event"); + + match self.inbound_substream.take() { + Some(substream) => { + log::trace!("got inbound substream, upgrading with custom protocol"); + + self.inbound_future = Some(protocol_fn(substream)) + } + None => { + self.inbound_future_fn = Some(protocol_fn); + } + } + } + ProtocolInEvent::ExecuteOutbound(protocol_fn) => { + log::trace!("got execute outbound event"); + + self.substream_request = Some(SubstreamProtocol::new(NMessageProtocol::new(self.info), ())); + + match self.outbound_substream.take() { + Some(substream) => { + log::trace!("got outbound substream, upgrading with custom protocol"); + + self.outbound_future = Some(protocol_fn(substream)); + } + None => { + self.outbound_future_fn = Some(protocol_fn); + } + } + } + } + } + + fn inject_dial_upgrade_error(&mut self, _: Self::OutboundOpenInfo, _: ProtocolsHandlerUpgrErr< + ::Error + > + ) { + unimplemented!("TODO: handle this") + } + + fn connection_keep_alive(&self) -> KeepAlive { + KeepAlive::Yes + } + + fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Some(protocol) = self.substream_request.take() { + return Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol }) + } + + if let Some(future) = self.inbound_future.as_mut() { + match future.poll_unpin(cx) { + Poll::Ready(Ok(value)) => { + return Poll::Ready(ProtocolsHandlerEvent::Custom(ProtocolOutEvent::InboundFinished(value))) + } + Poll::Ready(Err(e)) => { + return Poll::Ready(ProtocolsHandlerEvent::Custom(ProtocolOutEvent::InboundFailed(e))) + } + Poll::Pending => {} + } + } + + if let Some(future) = self.outbound_future.as_mut() { + match future.poll_unpin(cx) { + Poll::Ready(Ok(value)) => { + return Poll::Ready(ProtocolsHandlerEvent::Custom(ProtocolOutEvent::OutboundFinished(value))) + } + Poll::Ready(Err(e)) => { + return Poll::Ready(ProtocolsHandlerEvent::Custom(ProtocolOutEvent::OutboundFailed(e))) + } + Poll::Pending => {} + } + } + + Poll::Pending + } +} + +pub struct NMessageBehaviour { + protocol_in_events: VecDeque<(PeerId, ProtocolInEvent)>, + protocol_out_events: VecDeque<(PeerId, ProtocolOutEvent)>, + + waker: Option, + + connected_peers: Vec, + + info: &'static [u8] +} + +impl NMessageBehaviour { + /// Constructs a new [`NMessageBehaviour`] with the given protocol info. + /// + /// # Example + /// + /// ``` + /// # use libp2p_nmessage::NMessageBehaviour; + /// + /// let _ = NMessageBehaviour::new(b"/foo/bar/1.0.0"); + /// ``` + pub fn new(info: &'static [u8]) -> Self { + Self { + protocol_in_events: Default::default(), + protocol_out_events: Default::default(), + waker: None, + connected_peers: vec![], + info + } + } +} + +impl NMessageBehaviour { + pub fn do_protocol_listener(&mut self, peer: PeerId, protocol: impl FnOnce(NegotiatedSubstream) -> F + Send + 'static ) where F: Future> + Send + 'static { + self.protocol_in_events.push_back((peer, ProtocolInEvent::ExecuteInbound(Box::new(move |substream| protocol(substream).boxed())))); + + log::info!("pushing ExecuteInbound event"); + + if let Some(waker) = self.waker.take() { + log::trace!("waking task"); + + waker.wake(); + } + } + + pub fn do_protocol_dialer(&mut self, peer: PeerId, protocol: impl FnOnce(NegotiatedSubstream) -> F + Send + 'static ) where F: Future> + Send + 'static { + self.protocol_in_events.push_back((peer, ProtocolInEvent::ExecuteOutbound(Box::new(move |substream| protocol(substream).boxed())))); + + if let Some(waker) = self.waker.take() { + waker.wake(); + } + } +} + +#[derive(Clone)] +pub enum BehaviourOutEvent { + InboundFinished(PeerId, I), + OutboundFinished(PeerId, O), + InboundFailed(PeerId, E), + OutboundFailed(PeerId, E), +} + +impl NetworkBehaviour for NMessageBehaviour where I: Send + 'static, O: Send + 'static, E: Send + 'static { + type ProtocolsHandler = NMessageHandler; + type OutEvent = BehaviourOutEvent; + + fn new_handler(&mut self) -> Self::ProtocolsHandler { + NMessageHandler::new(self.info) + } + + fn addresses_of_peer(&mut self, _: &PeerId) -> Vec { + Vec::new() + } + + fn inject_connected(&mut self, peer: &PeerId) { + self.connected_peers.push(peer.clone()); + } + + fn inject_disconnected(&mut self, peer: &PeerId) { + self.connected_peers.retain(|p| p != peer) + } + + fn inject_event(&mut self, peer: PeerId, _: ConnectionId, event: ProtocolOutEvent) { + self.protocol_out_events.push_back((peer, event)); + } + + fn poll(&mut self, cx: &mut Context<'_>, params: &mut impl PollParameters) -> Poll, Self::OutEvent>> { + log::debug!("peer {}, no. events {}", params.local_peer_id(), self.protocol_in_events.len()); + + if let Some((peer, event)) = self.protocol_in_events.pop_front() { + log::debug!("notifying handler"); + + if !self.connected_peers.contains(&peer) { + log::info!("not connected to peer {}, waiting ...", peer); + self.protocol_in_events.push_back((peer, event)); + } else { + return Poll::Ready(NetworkBehaviourAction::NotifyHandler { peer_id: peer, handler: NotifyHandler::Any, event}) + } + } + + if let Some((peer, event)) = self.protocol_out_events.pop_front() { + return Poll::Ready(NetworkBehaviourAction::GenerateEvent(match event { + ProtocolOutEvent::InboundFinished(event) => BehaviourOutEvent::InboundFinished(peer, event), + ProtocolOutEvent::OutboundFinished(event) => BehaviourOutEvent::OutboundFinished(peer, event), + ProtocolOutEvent::InboundFailed(e) => BehaviourOutEvent::InboundFailed(peer, e), + ProtocolOutEvent::OutboundFailed(e) => BehaviourOutEvent::OutboundFailed(peer, e) + })) + } + + self.waker = Some(cx.waker().clone()); + + Poll::Pending + } +} + +#[cfg(test)] +mod tests { + use super::*; + use libp2p::core::upgrade; + use anyhow::{Context, Error}; + use swarm_harness::new_connected_swarm_pair; + use libp2p::swarm::SwarmEvent; + use libp2p::futures::future::join; + + #[derive(serde::Serialize, serde::Deserialize)] + #[derive(Debug)] + struct Message0 { + foo: u32 + } + #[derive(serde::Serialize, serde::Deserialize)] + #[derive(Debug)] + struct Message1 { + bar: u32 + } + #[derive(serde::Serialize, serde::Deserialize)] + #[derive(Debug)] + struct Message2 { + baz: u32 + } + + #[derive(Debug)] + struct AliceResult { + bar: u32 + } + #[derive(Debug)] + struct BobResult { + foo: u32, + baz: u32 + } + + #[derive(Debug)] + enum MyOutEvent { + Alice(AliceResult), + Bob(BobResult), + Failed(anyhow::Error), + } + + impl From> for MyOutEvent { + fn from(event: BehaviourOutEvent) -> Self { + match event { + BehaviourOutEvent::InboundFinished(_, bob) => MyOutEvent::Bob(bob), + BehaviourOutEvent::OutboundFinished(_, alice) => MyOutEvent::Alice(alice), + BehaviourOutEvent::InboundFailed(_, e) | BehaviourOutEvent::OutboundFailed(_, e) => MyOutEvent::Failed(e) + } + } + } + + #[derive(libp2p::NetworkBehaviour)] + #[behaviour(out_event = "MyOutEvent", event_process = false)] + struct MyBehaviour { + inner: NMessageBehaviour + } + + impl MyBehaviour { + pub fn new() -> Self { + Self { + inner: NMessageBehaviour::new(b"/foo/bar/1.0.0") + } + } + } + + impl MyBehaviour { + fn alice_do_protocol(&mut self, bob: PeerId, foo: u32, baz: u32) { + self.inner.do_protocol_dialer(bob, move |mut substream| async move { + log::trace!("alice starting protocol"); + + upgrade::write_one(&mut substream, serde_cbor::to_vec(&Message0 { + foo + }).context("failed to serialize Message0")?).await?; + + log::trace!("alice sent message0"); + + let bytes = upgrade::read_one(&mut substream, 1024).await?; + let message1 = serde_cbor::from_slice::(&bytes)?; + + log::trace!("alice read message1"); + + upgrade::write_one(&mut substream, serde_cbor::to_vec(&Message2 { + baz + }).context("failed to serialize Message2")?).await?; + + log::trace!("alice sent message2"); + + log::trace!("alice finished"); + + Ok(AliceResult { + bar: message1.bar + }) + }) + } + + fn bob_do_protocol(&mut self, alice: PeerId, bar: u32) { + self.inner.do_protocol_listener(alice, move |mut substream| async move { + log::trace!("bob start protocol"); + + let bytes = upgrade::read_one(&mut substream, 1024).await?; + let message0 = serde_cbor::from_slice::(&bytes)?; + + log::trace!("bob read message0"); + + upgrade::write_one(&mut substream, serde_cbor::to_vec(&Message1 { + bar + }).context("failed to serialize Message1")?).await?; + + log::trace!("bob sent message1"); + + let bytes = upgrade::read_one(&mut substream, 1024).await?; + let message2 = serde_cbor::from_slice::(&bytes)?; + + log::trace!("bob read message2"); + + log::trace!("bob finished"); + + Ok(BobResult { + foo: message0.foo, + baz: message2.baz + }) + }) + } + } + + #[tokio::test] + async fn it_works() { + let _ = env_logger::try_init(); + + let (mut alice, mut bob) = new_connected_swarm_pair(|_, _| MyBehaviour::new()).await; + + log::info!("alice = {}", alice.peer_id); + log::info!("bob = {}", bob.peer_id); + + alice.swarm.alice_do_protocol(bob.peer_id, 10, 42); + bob.swarm.bob_do_protocol(alice.peer_id, 1337); + + let alice_handle = tokio::spawn(async move { alice.swarm.next_event().await }); + let bob_handle = tokio::spawn(async move { bob.swarm.next_event().await }); + + let (alice_event, bob_event) = join(alice_handle, bob_handle).await; + + assert!(matches!(dbg!(alice_event.unwrap()), SwarmEvent::Behaviour(MyOutEvent::Alice(AliceResult { + bar: 1337 + })))); + assert!(matches!(dbg!(bob_event.unwrap()), SwarmEvent::Behaviour(MyOutEvent::Bob(BobResult { + foo: 10, + baz: 42 + })))); + } +} diff --git a/src/swarm_harness.rs b/src/swarm_harness.rs new file mode 100644 index 00000000..ca2edd01 --- /dev/null +++ b/src/swarm_harness.rs @@ -0,0 +1,150 @@ +use libp2p::futures::future; +use libp2p::{ + core::{ + muxing::StreamMuxerBox, transport::memory::MemoryTransport, upgrade::Version, Executor, + }, + identity, + noise::{self, NoiseConfig, X25519Spec}, + swarm::{NetworkBehaviour, SwarmBuilder, SwarmEvent}, + yamux::YamuxConfig, + Multiaddr, PeerId, Swarm, Transport, +}; +use std::{fmt::Debug, future::Future, pin::Pin, time::Duration}; +use tokio::time; + +/// An adaptor struct for libp2p that spawns futures into the current +/// thread-local runtime. +struct GlobalSpawnTokioExecutor; + +impl Executor for GlobalSpawnTokioExecutor { + fn exec(&self, future: Pin + Send>>) { + let _ = tokio::spawn(future); + } +} + +#[allow(missing_debug_implementations)] +pub struct Actor { + pub swarm: Swarm, + pub addr: Multiaddr, + pub peer_id: PeerId, +} + +pub async fn new_connected_swarm_pair(behaviour_fn: F) -> (Actor, Actor) +where + B: NetworkBehaviour, + F: Fn(PeerId, identity::Keypair) -> B + Clone, +::OutEvent: Debug{ + let (swarm, addr, peer_id) = new_swarm(behaviour_fn.clone()); + let mut alice = Actor { + swarm, + addr, + peer_id, + }; + + let (swarm, addr, peer_id) = new_swarm(behaviour_fn); + let mut bob = Actor { + swarm, + addr, + peer_id, + }; + + connect(&mut alice.swarm, &mut bob.swarm).await; + + (alice, bob) +} + +pub fn new_swarm B>(behaviour_fn: F) -> (Swarm, Multiaddr, PeerId) { + let id_keys = identity::Keypair::generate_ed25519(); + let peer_id = PeerId::from(id_keys.public()); + + let dh_keys = noise::Keypair::::new() + .into_authentic(&id_keys) + .expect("failed to create dh_keys"); + let noise = NoiseConfig::xx(dh_keys).into_authenticated(); + + let transport = MemoryTransport::default() + .upgrade(Version::V1) + .authenticate(noise) + .multiplex(YamuxConfig::default()) + .map(|(peer, muxer), _| (peer, StreamMuxerBox::new(muxer))) + .boxed(); + + let mut swarm: Swarm = SwarmBuilder::new( + transport, + behaviour_fn(peer_id.clone(), id_keys), + peer_id.clone(), + ) + .executor(Box::new(GlobalSpawnTokioExecutor)) + .build(); + + let address_port = rand::random::(); + let addr = format!("/memory/{}", address_port) + .parse::() + .unwrap(); + + Swarm::listen_on(&mut swarm, addr.clone()).unwrap(); + + (swarm, addr, peer_id) +} + +pub async fn await_events_or_timeout( + alice_event: impl Future, + bob_event: impl Future, +) -> (A, B) { + time::timeout( + Duration::from_secs(10), + future::join(alice_event, bob_event), + ) + .await + .expect("network behaviours to emit an event within 10 seconds") +} + +/// Connects two swarms with each other. +/// +/// This assumes the transport that is in use can be used by Alice to connect to +/// the listen address that is emitted by Bob. In other words, they have to be +/// on the same network. The memory transport used by the above `new_swarm` +/// function fulfills this. +/// +/// We also assume that the swarms don't emit any behaviour events during the +/// connection phase. Any event emitted is considered a bug from this functions +/// PoV because they would be lost. +pub async fn connect(alice: &mut Swarm, bob: &mut Swarm) + where + B: NetworkBehaviour, +::OutEvent: Debug{ + let mut alice_connected = false; + let mut bob_connected = false; + + while !alice_connected && !bob_connected { + let (alice_event, bob_event) = future::join(alice.next_event(), bob.next_event()).await; + + match alice_event { + SwarmEvent::ConnectionEstablished { .. } => { + alice_connected = true; + } + SwarmEvent::Behaviour(event) => { + panic!( + "alice unexpectedly emitted a behaviour event during connection: {:?}", + event + ); + } + _ => {} + } + match bob_event { + SwarmEvent::ConnectionEstablished { .. } => { + bob_connected = true; + } + SwarmEvent::NewListenAddr(addr) => { + Swarm::dial_addr(alice, addr).unwrap(); + } + SwarmEvent::Behaviour(event) => { + panic!( + "bob unexpectedly emitted a behaviour event during connection: {:?}", + event + ); + } + _ => {} + } + } +}