diff --git a/xmr-btc/Cargo.toml b/xmr-btc/Cargo.toml index c49df15d..dfb9299b 100644 --- a/xmr-btc/Cargo.toml +++ b/xmr-btc/Cargo.toml @@ -15,13 +15,16 @@ ed25519-dalek = "1.0.0-pre.4" # Cannot be 1 because they depend on curve25519-da miniscript = "1" monero = "0.9" rand = "0.7" +serde = { version = "1", features = ["derive"], optional = true } sha2 = "0.9" thiserror = "1" [dev-dependencies] base64 = "0.12" bitcoin-harness = { git = "https://github.com/coblox/bitcoin-harness-rs", rev = "d402b36d3d6406150e3bfb71492ff4a0a7cb290e" } +futures = "0.3" monero-harness = { path = "../monero-harness" } reqwest = { version = "0.10", default-features = false } +serde_json = "1" testcontainers = "0.10" tokio = { version = "0.2", default-features = false, features = ["blocking", "macros", "rt-core", "time", "rt-threaded"] } diff --git a/xmr-btc/src/happy_path.rs b/xmr-btc/src/happy_path.rs index 28ce7620..3760931b 100644 --- a/xmr-btc/src/happy_path.rs +++ b/xmr-btc/src/happy_path.rs @@ -1,9 +1,19 @@ //! This module shows how a BTC/XMR atomic swap proceeds along the happy path. -use crate::{alice, bitcoin, bob, monero}; +use crate::{alice, bitcoin, bob, monero, Message, ReceiveMessage, SendMessage}; +use anyhow::{anyhow, Result}; +use async_trait::async_trait; use bitcoin_harness::Bitcoind; +use futures::{ + channel::{ + mpsc, + mpsc::{Receiver, Sender}, + }, + SinkExt, StreamExt, +}; use monero_harness::Monero; use rand::rngs::OsRng; +use std::convert::TryInto; use testcontainers::clients::Cli; const TEN_XMR: u64 = 10_000_000_000_000; @@ -15,13 +25,64 @@ pub async fn init_bitcoind(tc_client: &Cli) -> Bitcoind<'_> { bitcoind } +/// Create two mock `Transport`s which mimic a peer to peer connection between +/// two parties, allowing them to send and receive `thor::Message`s. +pub fn make_transports() -> (Transport, Transport) { + let (a_sender, b_receiver) = mpsc::channel(5); + let (b_sender, a_receiver) = mpsc::channel(5); + + let a_transport = Transport { + sender: a_sender, + receiver: a_receiver, + }; + + let b_transport = Transport { + sender: b_sender, + receiver: b_receiver, + }; + + (a_transport, b_transport) +} + +#[derive(Debug)] +pub struct Transport { + sender: Sender, + receiver: Receiver, +} + +#[async_trait] +impl SendMessage for Transport { + async fn send_message(&mut self, msg: Message) -> Result<()> { + self.sender + .send(msg) + .await + .map_err(|_| anyhow!("failed to send message")) + } +} + +#[async_trait] +impl ReceiveMessage for Transport { + async fn receive_message(&mut self) -> Result { + let msg = self + .receiver + .next() + .await + .ok_or_else(|| anyhow!("failed to receive message"))?; + + Ok(msg) + } +} + #[tokio::test] async fn happy_path() { let cli = Cli::default(); let monero = Monero::new(&cli); let bitcoind = init_bitcoind(&cli).await; - // must be bigger than our hardcoded fee of 10_000 + // Mocks send/receive message for Alice and Bob. + let (mut a_trans, mut b_trans) = make_transports(); + + // Must be bigger than our hardcoded fee of 10_000 let btc_amount = bitcoin::Amount::from_sat(10_000_000); let xmr_amount = monero::Amount::from_piconero(1_000_000_000_000); @@ -70,19 +131,32 @@ async fn happy_path() { refund_address.clone(), ); - let a_msg = a_state.next_message(&mut OsRng); - let b_msg = b_state.next_message(&mut OsRng); + let a_msg = Message::Alice0(a_state.next_message(&mut OsRng)); + let b_msg = Message::Bob0(b_state.next_message(&mut OsRng)); + // Calls to send/receive must be ordered otherwise we will block + // waiting for the message. + a_trans.send_message(a_msg).await.unwrap(); + let b_recv_msg = b_trans.receive_message().await.unwrap().try_into().unwrap(); + b_trans.send_message(b_msg).await.unwrap(); + let a_recv_msg = a_trans.receive_message().await.unwrap().try_into().unwrap(); - let a_state = a_state.receive(b_msg).unwrap(); - let b_state = b_state.receive(&b_btc_wallet, a_msg).await.unwrap(); + let a_state = a_state.receive(a_recv_msg).unwrap(); + let b_state = b_state.receive(&b_btc_wallet, b_recv_msg).await.unwrap(); - let b_msg = b_state.next_message(); - let a_state = a_state.receive(b_msg); - let a_msg = a_state.next_message(); - let b_state = b_state.receive(a_msg).unwrap(); + let msg = Message::Bob1(b_state.next_message()); + b_trans.send_message(msg).await.unwrap(); + let a_recv_msg = a_trans.receive_message().await.unwrap().try_into().unwrap(); + let a_state = a_state.receive(a_recv_msg); - let b_msg = b_state.next_message(); - let a_state = a_state.receive(b_msg).unwrap(); + let msg = Message::Alice1(a_state.next_message()); + a_trans.send_message(msg).await.unwrap(); + let b_recv_msg = b_trans.receive_message().await.unwrap().try_into().unwrap(); + let b_state = b_state.receive(b_recv_msg).unwrap(); + + let msg = Message::Bob2(b_state.next_message()); + b_trans.send_message(msg).await.unwrap(); + let a_recv_msg = a_trans.receive_message().await.unwrap().try_into().unwrap(); + let a_state = a_state.receive(a_recv_msg).unwrap(); let b_state = b_state.lock_btc(&b_btc_wallet).await.unwrap(); let lock_txid = b_state.tx_lock_id(); @@ -91,15 +165,18 @@ async fn happy_path() { let (a_state, lock_tx_monero_fee) = a_state.lock_xmr(&a_xmr_wallet).await.unwrap(); - let a_msg = a_state.next_message(); - + let msg = Message::Alice2(a_state.next_message()); + a_trans.send_message(msg).await.unwrap(); + let b_recv_msg = b_trans.receive_message().await.unwrap().try_into().unwrap(); let b_state = b_state - .watch_for_lock_xmr(&b_xmr_wallet, a_msg) + .watch_for_lock_xmr(&b_xmr_wallet, b_recv_msg) .await .unwrap(); - let b_msg = b_state.next_message(); - let a_state = a_state.receive(b_msg); + let msg = Message::Bob3(b_state.next_message()); + b_trans.send_message(msg).await.unwrap(); + let a_recv_msg = a_trans.receive_message().await.unwrap().try_into().unwrap(); + let a_state = a_state.receive(a_recv_msg); a_state.redeem_btc(&a_btc_wallet).await.unwrap(); let b_state = b_state.watch_for_redeem_btc(&b_btc_wallet).await.unwrap(); diff --git a/xmr-btc/src/lib.rs b/xmr-btc/src/lib.rs index 3526f1d4..11166d59 100644 --- a/xmr-btc/src/lib.rs +++ b/xmr-btc/src/lib.rs @@ -14,6 +14,10 @@ #![forbid(unsafe_code)] #![allow(non_snake_case)] +use anyhow::Result; +use async_trait::async_trait; +use std::convert::TryFrom; + pub mod alice; pub mod bitcoin; pub mod bob; @@ -22,6 +26,186 @@ pub mod monero; #[cfg(test)] mod happy_path; +#[async_trait] +pub trait SendMessage { + async fn send_message(&mut self, message: Message) -> Result<()>; +} + +#[async_trait] +pub trait ReceiveMessage { + async fn receive_message(&mut self) -> Result; +} + +/// All possible messages that are sent between two parties. +#[derive(Debug)] +pub enum Message { + Alice0(alice::Message0), + Alice1(alice::Message1), + Alice2(alice::Message2), + Bob0(bob::Message0), + Bob1(bob::Message1), + Bob2(bob::Message2), + Bob3(bob::Message3), +} + +#[derive(Debug, thiserror::Error)] +#[error("expected message of type {expected_type}, got {received:?}")] +pub struct UnexpectedMessage { + expected_type: String, + received: Message, +} + +impl UnexpectedMessage { + pub fn new(received: Message) -> Self { + let expected_type = std::any::type_name::(); + + Self { + expected_type: expected_type.to_string(), + received, + } + } +} + +impl From for Message { + fn from(msg: alice::Message0) -> Self { + Message::Alice0(msg) + } +} + +impl TryFrom for alice::Message0 { + type Error = UnexpectedMessage; + + fn try_from(msg: Message) -> Result { + match msg { + Message::Alice0(msg) => Ok(msg), + _ => Err(UnexpectedMessage { + expected_type: "alice::Message0".to_string(), + received: msg, + }), + } + } +} + +impl From for Message { + fn from(msg: alice::Message1) -> Self { + Message::Alice1(msg) + } +} + +impl TryFrom for alice::Message1 { + type Error = UnexpectedMessage; + + fn try_from(msg: Message) -> Result { + match msg { + Message::Alice1(msg) => Ok(msg), + _ => Err(UnexpectedMessage { + expected_type: "alice::Message1".to_string(), + received: msg, + }), + } + } +} + +impl From for Message { + fn from(msg: alice::Message2) -> Self { + Message::Alice2(msg) + } +} + +impl TryFrom for alice::Message2 { + type Error = UnexpectedMessage; + + fn try_from(msg: Message) -> Result { + match msg { + Message::Alice2(msg) => Ok(msg), + _ => Err(UnexpectedMessage { + expected_type: "alice::Message2".to_string(), + received: msg, + }), + } + } +} + +impl From for Message { + fn from(msg: bob::Message0) -> Self { + Message::Bob0(msg) + } +} + +impl TryFrom for bob::Message0 { + type Error = UnexpectedMessage; + + fn try_from(msg: Message) -> Result { + match msg { + Message::Bob0(msg) => Ok(msg), + _ => Err(UnexpectedMessage { + expected_type: "bob::Message0".to_string(), + received: msg, + }), + } + } +} + +impl From for Message { + fn from(msg: bob::Message1) -> Self { + Message::Bob1(msg) + } +} + +impl TryFrom for bob::Message1 { + type Error = UnexpectedMessage; + + fn try_from(msg: Message) -> Result { + match msg { + Message::Bob1(msg) => Ok(msg), + _ => Err(UnexpectedMessage { + expected_type: "bob::Message1".to_string(), + received: msg, + }), + } + } +} + +impl From for Message { + fn from(msg: bob::Message2) -> Self { + Message::Bob2(msg) + } +} + +impl TryFrom for bob::Message2 { + type Error = UnexpectedMessage; + + fn try_from(msg: Message) -> Result { + match msg { + Message::Bob2(msg) => Ok(msg), + _ => Err(UnexpectedMessage { + expected_type: "bob::Message2".to_string(), + received: msg, + }), + } + } +} + +impl From for Message { + fn from(msg: bob::Message3) -> Self { + Message::Bob3(msg) + } +} + +impl TryFrom for bob::Message3 { + type Error = UnexpectedMessage; + + fn try_from(msg: Message) -> Result { + match msg { + Message::Bob3(msg) => Ok(msg), + _ => Err(UnexpectedMessage { + expected_type: "bob::Message3".to_string(), + received: msg, + }), + } + } +} + #[cfg(test)] mod tests { use crate::{