From 3be223da28bc82b8feb08b3a62dfc3d961c13e1c Mon Sep 17 00:00:00 2001 From: rishflab Date: Tue, 28 Sep 2021 10:15:31 +1000 Subject: [PATCH] Create Database trait Use domain types in database API to prevent leaking of database types. This trait will allow us to smoothly introduce the sqlite database. --- swap/src/asb/event_loop.rs | 17 +- swap/src/asb/recovery/cancel.rs | 10 +- swap/src/asb/recovery/punish.rs | 10 +- swap/src/asb/recovery/redeem.rs | 16 +- swap/src/asb/recovery/refund.rs | 12 +- swap/src/asb/recovery/safely_abort.rs | 12 +- swap/src/bin/asb.rs | 25 +- swap/src/bin/swap.rs | 36 +-- swap/src/cli/cancel.rs | 10 +- swap/src/cli/refund.rs | 11 +- swap/src/database.rs | 269 +++++++----------- swap/src/database/alice.rs | 42 +-- swap/src/database/bob.rs | 7 +- swap/src/protocol.rs | 86 ++++++ swap/src/protocol/alice.rs | 4 +- swap/src/protocol/alice/state.rs | 2 +- swap/src/protocol/alice/swap.rs | 8 +- swap/src/protocol/bob.rs | 13 +- swap/src/protocol/bob/state.rs | 2 +- swap/src/protocol/bob/swap.rs | 6 +- ...ncurrent_bobs_after_xmr_lock_proof_sent.rs | 3 +- swap/tests/harness/mod.rs | 10 +- 22 files changed, 324 insertions(+), 287 deletions(-) diff --git a/swap/src/asb/event_loop.rs b/swap/src/asb/event_loop.rs index 1cc268bb..916a89d5 100644 --- a/swap/src/asb/event_loop.rs +++ b/swap/src/asb/event_loop.rs @@ -1,5 +1,5 @@ use crate::asb::{Behaviour, OutEvent, Rate}; -use crate::database::Database; +use crate::protocol::{Database}; use crate::network::quote::BidQuote; use crate::network::swap_setup::alice::WalletSnapshot; use crate::network::transfer_proof; @@ -14,7 +14,7 @@ use libp2p::swarm::SwarmEvent; use libp2p::{PeerId, Swarm}; use rust_decimal::Decimal; use std::collections::HashMap; -use std::convert::Infallible; +use std::convert::{Infallible, TryInto}; use std::fmt::Debug; use std::sync::Arc; use tokio::sync::mpsc; @@ -39,7 +39,7 @@ where env_config: env::Config, bitcoin_wallet: Arc, monero_wallet: Arc, - db: Arc, + db: Arc, latest_rate: LR, min_buy: bitcoin::Amount, max_buy: bitcoin::Amount, @@ -71,7 +71,7 @@ where env_config: env::Config, bitcoin_wallet: Arc, monero_wallet: Arc, - db: Arc, + db: Arc, latest_rate: LR, min_buy: bitcoin::Amount, max_buy: bitcoin::Amount, @@ -108,7 +108,8 @@ where self.inflight_encrypted_signatures .push(future::pending().boxed()); - let unfinished_swaps = match self.db.unfinished_alice() { + + let unfinished_swaps = match self.db.unfinished(|state | !state.swap_finished()).await { Ok(unfinished_swaps) => unfinished_swaps, Err(_) => { tracing::error!("Failed to load unfinished swaps"); @@ -117,7 +118,7 @@ where }; for (swap_id, state) in unfinished_swaps { - let peer_id = match self.db.get_peer_id(swap_id) { + let peer_id = match self.db.get_peer_id(swap_id).await { Ok(peer_id) => peer_id, Err(_) => { tracing::warn!(%swap_id, "Resuming swap skipped because no peer-id found for swap in database"); @@ -133,7 +134,7 @@ where monero_wallet: self.monero_wallet.clone(), env_config: self.env_config, db: self.db.clone(), - state: state.into(), + state: state.try_into().expect("Alice state loaded from db"), swap_id, }; @@ -197,7 +198,7 @@ where } SwarmEvent::Behaviour(OutEvent::EncryptedSignatureReceived{ msg, channel, peer }) => { let swap_id = msg.swap_id; - let swap_peer = self.db.get_peer_id(swap_id); + let swap_peer = self.db.get_peer_id(swap_id).await; // Ensure that an incoming encrypted signature is sent by the peer-id associated with the swap let swap_peer = match swap_peer { diff --git a/swap/src/asb/recovery/cancel.rs b/swap/src/asb/recovery/cancel.rs index 32af014f..f2463f54 100644 --- a/swap/src/asb/recovery/cancel.rs +++ b/swap/src/asb/recovery/cancel.rs @@ -1,16 +1,17 @@ use crate::bitcoin::{parse_rpc_error_code, RpcErrorCode, Txid, Wallet}; -use crate::database::{Database, Swap}; +use crate::protocol::Database; use crate::protocol::alice::AliceState; use anyhow::{bail, Result}; use std::sync::Arc; use uuid::Uuid; +use std::convert::TryInto; pub async fn cancel( swap_id: Uuid, bitcoin_wallet: Arc, - db: Arc, + db: Arc, ) -> Result<(Txid, AliceState)> { - let state = db.get_state(swap_id)?.try_into_alice()?.into(); + let state = db.get_state(swap_id).await?.try_into()?; let (monero_wallet_restore_blockheight, transfer_proof, state3) = match state { @@ -58,8 +59,7 @@ pub async fn cancel( transfer_proof, state3, }; - let db_state = (&state).into(); - db.insert_latest_state(swap_id, Swap::Alice(db_state)) + db.insert_latest_state(swap_id, state.clone().into()) .await?; Ok((txid, state)) diff --git a/swap/src/asb/recovery/punish.rs b/swap/src/asb/recovery/punish.rs index ddd27c8f..d1f817fb 100644 --- a/swap/src/asb/recovery/punish.rs +++ b/swap/src/asb/recovery/punish.rs @@ -1,9 +1,10 @@ use crate::bitcoin::{self, Txid}; -use crate::database::{Database, Swap}; +use crate::protocol::Database; use crate::protocol::alice::AliceState; use anyhow::{bail, Result}; use std::sync::Arc; use uuid::Uuid; +use std::convert::TryInto; #[derive(Debug, thiserror::Error)] pub enum Error { @@ -14,9 +15,9 @@ pub enum Error { pub async fn punish( swap_id: Uuid, bitcoin_wallet: Arc, - db: Arc, + db: Arc, ) -> Result<(Txid, AliceState)> { - let state = db.get_state(swap_id)?.try_into_alice()?.into(); + let state = db.get_state(swap_id).await?.try_into()?; let state3 = match state { // Punish potentially possible (no knowledge of cancel transaction) @@ -46,8 +47,7 @@ pub async fn punish( let txid = state3.punish_btc(&bitcoin_wallet).await?; let state = AliceState::BtcPunished; - let db_state = (&state).into(); - db.insert_latest_state(swap_id, Swap::Alice(db_state)) + db.insert_latest_state(swap_id, state.clone().into()) .await?; Ok((txid, state)) diff --git a/swap/src/asb/recovery/redeem.rs b/swap/src/asb/recovery/redeem.rs index dd8daa54..f1ee5453 100644 --- a/swap/src/asb/recovery/redeem.rs +++ b/swap/src/asb/recovery/redeem.rs @@ -1,9 +1,10 @@ use crate::bitcoin::{Txid, Wallet}; -use crate::database::{Database, Swap}; +use crate::protocol::Database; use crate::protocol::alice::AliceState; use anyhow::{bail, Result}; use std::sync::Arc; use uuid::Uuid; +use std::convert::TryInto; pub enum Finality { Await, @@ -23,10 +24,10 @@ impl Finality { pub async fn redeem( swap_id: Uuid, bitcoin_wallet: Arc, - db: Arc, + db: Arc, finality: Finality, ) -> Result<(Txid, AliceState)> { - let state = db.get_state(swap_id)?.try_into_alice()?.into(); + let state = db.get_state(swap_id).await?.try_into()?; match state { AliceState::EncSigLearned { @@ -42,8 +43,7 @@ pub async fn redeem( subscription.wait_until_seen().await?; let state = AliceState::BtcRedeemTransactionPublished { state3 }; - let db_state = (&state).into(); - db.insert_latest_state(swap_id, Swap::Alice(db_state)) + db.insert_latest_state(swap_id, state.into()) .await?; if let Finality::Await = finality { @@ -51,8 +51,7 @@ pub async fn redeem( } let state = AliceState::BtcRedeemed; - let db_state = (&state).into(); - db.insert_latest_state(swap_id, Swap::Alice(db_state)) + db.insert_latest_state(swap_id, state.clone().into()) .await?; Ok((txid, state)) @@ -64,8 +63,7 @@ pub async fn redeem( } let state = AliceState::BtcRedeemed; - let db_state = (&state).into(); - db.insert_latest_state(swap_id, Swap::Alice(db_state)) + db.insert_latest_state(swap_id, state.clone().into()) .await?; let txid = state3.tx_redeem().txid(); diff --git a/swap/src/asb/recovery/refund.rs b/swap/src/asb/recovery/refund.rs index 1e91b49a..d940934a 100644 --- a/swap/src/asb/recovery/refund.rs +++ b/swap/src/asb/recovery/refund.rs @@ -1,11 +1,12 @@ use crate::bitcoin::{self}; -use crate::database::{Database, Swap}; +use crate::protocol::Database; use crate::monero; use crate::protocol::alice::AliceState; use anyhow::{bail, Result}; use libp2p::PeerId; use std::sync::Arc; use uuid::Uuid; +use std::convert::TryInto; #[derive(Debug, thiserror::Error)] pub enum Error { @@ -26,9 +27,9 @@ pub async fn refund( swap_id: Uuid, bitcoin_wallet: Arc, monero_wallet: Arc, - db: Arc, + db: Arc, ) -> Result { - let state = db.get_state(swap_id)?.try_into_alice()?.into(); + let state = db.get_state(swap_id).await?.try_into()?; let (monero_wallet_restore_blockheight, transfer_proof, state3) = match state { // In case no XMR has been locked, move to Safely Aborted @@ -66,7 +67,7 @@ pub async fn refund( tracing::debug!(%swap_id, "Bitcoin refund transaction found, extracting key to refund Monero"); state3.extract_monero_private_key(published_refund_tx)? } else { - let bob_peer_id = db.get_peer_id(swap_id)?; + let bob_peer_id = db.get_peer_id(swap_id).await?; bail!(Error::RefundTransactionNotPublishedYet(bob_peer_id),); }; @@ -81,8 +82,7 @@ pub async fn refund( .await?; let state = AliceState::XmrRefunded; - let db_state = (&state).into(); - db.insert_latest_state(swap_id, Swap::Alice(db_state)) + db.insert_latest_state(swap_id, state.clone().into()) .await?; Ok(state) diff --git a/swap/src/asb/recovery/safely_abort.rs b/swap/src/asb/recovery/safely_abort.rs index 8105f068..227c6a0b 100644 --- a/swap/src/asb/recovery/safely_abort.rs +++ b/swap/src/asb/recovery/safely_abort.rs @@ -1,11 +1,12 @@ -use crate::database::{Database, Swap}; +use crate::protocol::Database; use crate::protocol::alice::AliceState; use anyhow::{bail, Result}; -use std::sync::Arc; use uuid::Uuid; +use std::convert::TryInto; +use std::sync::Arc; -pub async fn safely_abort(swap_id: Uuid, db: Arc) -> Result { - let state = db.get_state(swap_id)?.try_into_alice()?.into(); +pub async fn safely_abort(swap_id: Uuid, db: Arc) -> Result { + let state = db.get_state(swap_id).await?.try_into()?; match state { AliceState::Started { .. } @@ -13,8 +14,7 @@ pub async fn safely_abort(swap_id: Uuid, db: Arc) -> Result { let state = AliceState::SafelyAborted; - let db_state = (&state).into(); - db.insert_latest_state(swap_id, Swap::Alice(db_state)) + db.insert_latest_state(swap_id, state.clone().into()) .await?; Ok(state) diff --git a/swap/src/bin/asb.rs b/swap/src/bin/asb.rs index ae3ead70..0e13784f 100644 --- a/swap/src/bin/asb.rs +++ b/swap/src/bin/asb.rs @@ -28,15 +28,17 @@ use swap::asb::config::{ initial_setup, query_user_for_initial_config, read_config, Config, ConfigNotInitialized, }; use swap::asb::{cancel, punish, redeem, refund, safely_abort, EventLoop, Finality, KrakenRate}; -use swap::database::Database; +use swap::database::SledDatabase; use swap::monero::Amount; use swap::network::rendezvous::XmrBtcNamespace; use swap::network::swarm; -use swap::protocol::alice::run; +use swap::protocol::alice::{run, AliceState}; use swap::seed::Seed; use swap::tor::AuthenticatedClient; use swap::{asb, bitcoin, kraken, monero, tor}; use tracing_subscriber::filter::LevelFilter; +use swap::protocol::Database; +use std::convert::TryInto; const DEFAULT_WALLET_NAME: &str = "asb-wallet"; @@ -92,8 +94,8 @@ async fn main() -> Result<()> { let db_path = config.data.dir.join("database"); - let db = Database::open(config.data.dir.join(db_path).as_path()) - .context("Could not open database")?; + let db = Arc::new(SledDatabase::open(config.data.dir.join(db_path).as_path()).await + .context("Could not open database")?); let seed = Seed::from_file_or_generate(&config.data.dir).expect("Could not retrieve/initialize seed"); @@ -177,7 +179,7 @@ async fn main() -> Result<()> { env_config, Arc::new(bitcoin_wallet), Arc::new(monero_wallet), - Arc::new(db), + db, kraken_rate.clone(), config.maker.min_buy_btc, config.maker.max_buy_btc, @@ -208,7 +210,8 @@ async fn main() -> Result<()> { table.set_header(vec!["SWAP ID", "STATE"]); - for (swap_id, state) in db.all_alice()? { + for (swap_id, state) in db.all().await? { + let state: AliceState = state.try_into()?; table.add_row(vec![swap_id.to_string(), state.to_string()]); } @@ -248,7 +251,7 @@ async fn main() -> Result<()> { Command::Cancel { swap_id } => { let bitcoin_wallet = init_bitcoin_wallet(&config, &seed, env_config).await?; - let (txid, _) = cancel(swap_id, Arc::new(bitcoin_wallet), Arc::new(db)).await?; + let (txid, _) = cancel(swap_id, Arc::new(bitcoin_wallet), db).await?; tracing::info!("Cancel transaction successfully published with id {}", txid); } @@ -260,7 +263,7 @@ async fn main() -> Result<()> { swap_id, Arc::new(bitcoin_wallet), Arc::new(monero_wallet), - Arc::new(db), + db, ) .await?; @@ -269,12 +272,12 @@ async fn main() -> Result<()> { Command::Punish { swap_id } => { let bitcoin_wallet = init_bitcoin_wallet(&config, &seed, env_config).await?; - let (txid, _) = punish(swap_id, Arc::new(bitcoin_wallet), Arc::new(db)).await?; + let (txid, _) = punish(swap_id, Arc::new(bitcoin_wallet), db).await?; tracing::info!("Punish transaction successfully published with id {}", txid); } Command::SafelyAbort { swap_id } => { - safely_abort(swap_id, Arc::new(db)).await?; + safely_abort(swap_id, db).await?; tracing::info!("Swap safely aborted"); } @@ -287,7 +290,7 @@ async fn main() -> Result<()> { let (txid, _) = redeem( swap_id, Arc::new(bitcoin_wallet), - Arc::new(db), + db, Finality::from_bool(do_not_await_finality), ) .await?; diff --git a/swap/src/bin/swap.rs b/swap/src/bin/swap.rs index cd7c1e6a..afbf4f17 100644 --- a/swap/src/bin/swap.rs +++ b/swap/src/bin/swap.rs @@ -25,17 +25,18 @@ use std::time::Duration; use swap::bitcoin::TxLock; use swap::cli::command::{parse_args_and_apply_defaults, Arguments, Command, ParseResult}; use swap::cli::{list_sellers, EventLoop, SellerStatus}; -use swap::database::Database; +use swap::database::SledDatabase; use swap::env::Config; use swap::libp2p_ext::MultiAddrExt; use swap::network::quote::BidQuote; use swap::network::swarm; -use swap::protocol::bob; -use swap::protocol::bob::Swap; +use swap::protocol::{bob, Database}; +use swap::protocol::bob::{Swap, BobState}; use swap::seed::Seed; use swap::{bitcoin, cli, monero}; use url::Url; use uuid::Uuid; +use std::convert::TryInto; #[tokio::main] async fn main() -> Result<()> { @@ -66,8 +67,8 @@ async fn main() -> Result<()> { let swap_id = Uuid::new_v4(); cli::tracing::init(debug, json, data_dir.join("logs"), Some(swap_id))?; - let db = Database::open(data_dir.join("database").as_path()) - .context("Failed to open database")?; + let db = Arc::new(SledDatabase::open(data_dir.join("database").as_path()).await + .context("Failed to open database")?); let seed = Seed::from_file_or_generate(data_dir.as_path()) .context("Failed to read in seed file")?; @@ -139,14 +140,15 @@ async fn main() -> Result<()> { } } Command::History => { - let db = Database::open(data_dir.join("database").as_path()) + let db = SledDatabase::open(data_dir.join("database").as_path()).await .context("Failed to open database")?; let mut table = Table::new(); table.set_header(vec!["SWAP ID", "STATE"]); - for (swap_id, state) in db.all_bob()? { + for (swap_id, state) in db.all().await? { + let state: BobState = state.try_into()?; table.add_row(vec![swap_id.to_string(), state.to_string()]); } @@ -215,8 +217,8 @@ async fn main() -> Result<()> { tor_socks5_port, } => { cli::tracing::init(debug, json, data_dir.join("logs"), Some(swap_id))?; - let db = Database::open(data_dir.join("database").as_path()) - .context("Failed to open database")?; + let db = Arc::new(SledDatabase::open(data_dir.join("database").as_path()).await + .context("Failed to open database")?); let seed = Seed::from_file_or_generate(data_dir.as_path()) .context("Failed to read in seed file")?; @@ -232,8 +234,8 @@ async fn main() -> Result<()> { init_monero_wallet(data_dir, monero_daemon_address, env_config).await?; let bitcoin_wallet = Arc::new(bitcoin_wallet); - let seller_peer_id = db.get_peer_id(swap_id)?; - let seller_addresses = db.get_addresses(seller_peer_id)?; + let seller_peer_id = db.get_peer_id(swap_id).await?; + let seller_addresses = db.get_addresses(seller_peer_id).await?; let behaviour = cli::Behaviour::new(seller_peer_id, env_config, bitcoin_wallet.clone()); let mut swarm = @@ -251,7 +253,7 @@ async fn main() -> Result<()> { EventLoop::new(swap_id, swarm, seller_peer_id, env_config)?; let handle = tokio::spawn(event_loop.run()); - let monero_receive_address = db.get_monero_address(swap_id)?; + let monero_receive_address = db.get_monero_address(swap_id).await?; let swap = Swap::from_db( db, swap_id, @@ -260,7 +262,7 @@ async fn main() -> Result<()> { env_config, event_loop_handle, monero_receive_address, - )?; + ).await?; tokio::select! { event_loop_result = handle => { @@ -277,8 +279,8 @@ async fn main() -> Result<()> { bitcoin_target_block, } => { cli::tracing::init(debug, json, data_dir.join("logs"), Some(swap_id))?; - let db = Database::open(data_dir.join("database").as_path()) - .context("Failed to open database")?; + let db = Arc::new(SledDatabase::open(data_dir.join("database").as_path()).await + .context("Failed to open database")?); let seed = Seed::from_file_or_generate(data_dir.as_path()) .context("Failed to read in seed file")?; @@ -300,8 +302,8 @@ async fn main() -> Result<()> { bitcoin_target_block, } => { cli::tracing::init(debug, json, data_dir.join("logs"), Some(swap_id))?; - let db = Database::open(data_dir.join("database").as_path()) - .context("Failed to open database")?; + let db = Arc::new(SledDatabase::open(data_dir.join("database").as_path()).await + .context("Failed to open database")?); let seed = Seed::from_file_or_generate(data_dir.as_path()) .context("Failed to read in seed file")?; diff --git a/swap/src/cli/cancel.rs b/swap/src/cli/cancel.rs index 5828b078..d3d794bb 100644 --- a/swap/src/cli/cancel.rs +++ b/swap/src/cli/cancel.rs @@ -1,16 +1,17 @@ use crate::bitcoin::{parse_rpc_error_code, RpcErrorCode, Txid, Wallet}; -use crate::database::{Database, Swap}; +use crate::protocol::Database; use crate::protocol::bob::BobState; use anyhow::{bail, Result}; use std::sync::Arc; use uuid::Uuid; +use std::convert::TryInto; pub async fn cancel( swap_id: Uuid, bitcoin_wallet: Arc, - db: Database, + db: Arc, ) -> Result<(Txid, BobState)> { - let state = db.get_state(swap_id)?.try_into_bob()?.into(); + let state = db.get_state(swap_id).await?.try_into()?; let state6 = match state { BobState::BtcLocked(state3) => state3.cancel(), @@ -48,8 +49,7 @@ pub async fn cancel( }; let state = BobState::BtcCancelled(state6); - let db_state = state.clone().into(); - db.insert_latest_state(swap_id, Swap::Bob(db_state)).await?; + db.insert_latest_state(swap_id, state.clone().into()).await?; Ok((txid, state)) } diff --git a/swap/src/cli/refund.rs b/swap/src/cli/refund.rs index f94b9430..439d53c1 100644 --- a/swap/src/cli/refund.rs +++ b/swap/src/cli/refund.rs @@ -1,12 +1,13 @@ use crate::bitcoin::Wallet; -use crate::database::{Database, Swap}; +use crate::protocol::Database; use crate::protocol::bob::BobState; use anyhow::{bail, Result}; use std::sync::Arc; use uuid::Uuid; +use std::convert::TryInto; -pub async fn refund(swap_id: Uuid, bitcoin_wallet: Arc, db: Database) -> Result { - let state = db.get_state(swap_id)?.try_into_bob()?.into(); +pub async fn refund(swap_id: Uuid, bitcoin_wallet: Arc, db: Arc) -> Result { + let state = db.get_state(swap_id).await?.try_into()?; let state6 = match state { BobState::BtcLocked(state3) => state3.cancel(), @@ -31,9 +32,7 @@ pub async fn refund(swap_id: Uuid, bitcoin_wallet: Arc, db: Database) -> state6.publish_refund_btc(bitcoin_wallet.as_ref()).await?; let state = BobState::BtcRefunded(state6); - let db_state = state.clone().into(); - - db.insert_latest_state(swap_id, Swap::Bob(db_state)).await?; + db.insert_latest_state(swap_id, state.clone().into()).await?; Ok(state) } diff --git a/swap/src/database.rs b/swap/src/database.rs index 339a1e35..a22cd7d3 100644 --- a/swap/src/database.rs +++ b/swap/src/database.rs @@ -1,15 +1,21 @@ pub use alice::Alice; pub use bob::Bob; -use anyhow::{anyhow, bail, Context, Result}; +use async_trait::async_trait; +use anyhow::{anyhow, Context, Result}; use itertools::Itertools; use libp2p::{Multiaddr, PeerId}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; -use std::fmt::Display; +use std::fmt::{Display, Debug}; use std::path::Path; use std::str::FromStr; use uuid::Uuid; +use crate::protocol::alice::AliceState; +use crate::protocol::bob::BobState; +use crate::protocol::{Database, State}; +use std::collections::HashMap; + mod alice; mod bob; @@ -20,6 +26,15 @@ pub enum Swap { Bob(Bob), } +impl From for Swap { + fn from(state: State) -> Self { + match state { + State::Alice(state) => Swap::Alice(state.into()), + State::Bob(state) => Swap::Bob(state.into()), + } + } +} + impl From for Swap { fn from(from: Alice) -> Self { Swap::Alice(from) @@ -41,58 +56,38 @@ impl Display for Swap { } } -#[derive(thiserror::Error, Debug, Clone, Copy, PartialEq)] -#[error("Not in the role of Alice")] -struct NotAlice; - -#[derive(thiserror::Error, Debug, Clone, Copy, PartialEq)] -#[error("Not in the role of Bob")] -struct NotBob; - -impl Swap { - pub fn try_into_alice(self) -> Result { - match self { - Swap::Alice(alice) => Ok(alice), - Swap::Bob(_) => bail!(NotAlice), - } - } - - pub fn try_into_bob(self) -> Result { - match self { - Swap::Bob(bob) => Ok(bob), - Swap::Alice(_) => bail!(NotBob), +impl From for State { + fn from(value: Swap) -> Self { + match value { + Swap::Alice(alice) => State::Alice(alice.into()), + Swap::Bob(bob) => State::Bob(bob.into()), } } } -pub struct Database { +impl From for Swap { + fn from(state: BobState) -> Self { + Self::Bob(Bob::from(state)) + } +} + +impl From for Swap { + fn from(state: AliceState) -> Self { + Self::Alice(Alice::from(state)) + } +} + +#[derive(Clone)] +pub struct SledDatabase { swaps: sled::Tree, peers: sled::Tree, addresses: sled::Tree, monero_addresses: sled::Tree, } -impl Database { - pub fn open(path: &Path) -> Result { - tracing::debug!("Opening database at {}", path.display()); - - let db = - sled::open(path).with_context(|| format!("Could not open the DB at {:?}", path))?; - - let swaps = db.open_tree("swaps")?; - let peers = db.open_tree("peers")?; - let addresses = db.open_tree("addresses")?; - let monero_addresses = db.open_tree("monero_addresses")?; - - Ok(Database { - swaps, - peers, - addresses, - monero_addresses, - }) - } - - pub async fn insert_peer_id(&self, swap_id: Uuid, peer_id: PeerId) -> Result<()> { +#[async_trait] +impl Database for SledDatabase { + async fn insert_peer_id(&self, swap_id: Uuid, peer_id: PeerId) -> Result<()> { let peer_id_str = peer_id.to_string(); let key = serialize(&swap_id)?; @@ -107,7 +102,7 @@ impl Database { .context("Could not flush db") } - pub fn get_peer_id(&self, swap_id: Uuid) -> Result { + async fn get_peer_id(&self, swap_id: Uuid) -> Result { let key = serialize(&swap_id)?; let encoded = self @@ -119,7 +114,7 @@ impl Database { Ok(PeerId::from_str(peer_id.as_str())?) } - pub async fn insert_monero_address( + async fn insert_monero_address( &self, swap_id: Uuid, address: monero::Address, @@ -136,7 +131,7 @@ impl Database { .context("Could not flush db") } - pub fn get_monero_address(&self, swap_id: Uuid) -> Result { + async fn get_monero_address(&self, swap_id: Uuid) -> Result { let encoded = self .monero_addresses .get(swap_id.as_bytes())? @@ -152,7 +147,7 @@ impl Database { Ok(monero_address) } - pub async fn insert_address(&self, peer_id: PeerId, address: Multiaddr) -> Result<()> { + async fn insert_address(&self, peer_id: PeerId, address: Multiaddr) -> Result<()> { let key = peer_id.to_bytes(); let existing_addresses = self.addresses.get(&key)?; @@ -181,7 +176,7 @@ impl Database { .context("Could not flush db") } - pub fn get_addresses(&self, peer_id: PeerId) -> Result> { + async fn get_addresses(&self, peer_id: PeerId) -> Result> { let key = peer_id.to_bytes(); let addresses = match self.addresses.get(&key)? { @@ -192,9 +187,10 @@ impl Database { Ok(addresses) } - pub async fn insert_latest_state(&self, swap_id: Uuid, state: Swap) -> Result<()> { + async fn insert_latest_state(&self, swap_id: Uuid, state: State) -> Result<()> { let key = serialize(&swap_id)?; - let new_value = serialize(&state).context("Could not serialize new state value")?; + let swap = Swap::from(state); + let new_value = serialize(&swap).context("Could not serialize new state value")?; let old_value = self.swaps.get(&key)?; @@ -210,7 +206,7 @@ impl Database { .context("Could not flush db") } - pub fn get_state(&self, swap_id: Uuid) -> Result { + async fn get_state(&self, swap_id: Uuid) -> Result { let key = serialize(&swap_id)?; let encoded = self @@ -218,47 +214,55 @@ impl Database { .get(&key)? .ok_or_else(|| anyhow!("Swap with id {} not found in database", swap_id))?; - let state = deserialize(&encoded).context("Could not deserialize state")?; + let swap = deserialize::(&encoded).context("Could not deserialize state")?; + + let state = State::from(swap); + Ok(state) } - pub fn all_alice(&self) -> Result> { - self.all_alice_iter().collect() + async fn all(&self) -> Result> { + self.all_iter().collect() } - fn all_alice_iter(&self) -> impl Iterator> { - self.all_swaps_iter().map(|item| { - let (swap_id, swap) = item?; - Ok((swap_id, swap.try_into_alice()?)) + async fn unfinished(&self, unfinished: fn(State) -> bool) -> Result> { + self.all_iter().into_iter() + .filter_ok(|(_swap_id, state)| unfinished(state.clone())) + .collect() + } +} + +impl SledDatabase { + pub async fn open(path: &Path) -> Result { + tracing::debug!("Opening database at {}", path.display()); + + let db = + sled::open(path).with_context(|| format!("Could not open the DB at {:?}", path))?; + + let swaps = db.open_tree("swaps")?; + let peers = db.open_tree("peers")?; + let addresses = db.open_tree("addresses")?; + let monero_addresses = db.open_tree("monero_addresses")?; + + Ok(SledDatabase { + swaps, + peers, + addresses, + monero_addresses, }) } - pub fn all_bob(&self) -> Result> { - self.all_bob_iter().collect() - } - - fn all_bob_iter(&self) -> impl Iterator> { - self.all_swaps_iter().map(|item| { - let (swap_id, swap) = item?; - Ok((swap_id, swap.try_into_bob()?)) - }) - } - - fn all_swaps_iter(&self) -> impl Iterator> { + fn all_iter(&self) -> impl Iterator> { self.swaps.iter().map(|item| { let (key, value) = item.context("Failed to retrieve swap from DB")?; let swap_id = deserialize::(&key)?; let swap = deserialize::(&value).context("Failed to deserialize swap")?; - Ok((swap_id, swap)) - }) - } + let state = State::from(swap); - pub fn unfinished_alice(&self) -> Result> { - self.all_alice_iter() - .filter_ok(|(_swap_id, alice)| !matches!(alice, Alice::Done(_))) - .collect() + Ok((swap_id, state)) + }) } } @@ -279,21 +283,20 @@ where #[cfg(test)] mod tests { use super::*; - use crate::database::alice::{Alice, AliceEndState}; - use crate::database::bob::{Bob, BobEndState}; + use crate::protocol::alice::AliceState; #[tokio::test] async fn can_write_and_read_to_multiple_keys() { let db_dir = tempfile::tempdir().unwrap(); - let db = Database::open(db_dir.path()).unwrap(); + let db = SledDatabase::open(db_dir.path()).await.unwrap(); - let state_1 = Swap::Alice(Alice::Done(AliceEndState::BtcRedeemed)); + let state_1 = State::from(AliceState::BtcRedeemed); let swap_id_1 = Uuid::new_v4(); db.insert_latest_state(swap_id_1, state_1.clone()) .await .expect("Failed to save second state"); - let state_2 = Swap::Bob(Bob::Done(BobEndState::SafelyAborted)); + let state_2 = State::from(AliceState::BtcPunished); let swap_id_2 = Uuid::new_v4(); db.insert_latest_state(swap_id_2, state_2.clone()) .await @@ -301,10 +304,12 @@ mod tests { let recovered_1 = db .get_state(swap_id_1) + .await .expect("Failed to recover first state"); let recovered_2 = db .get_state(swap_id_2) + .await .expect("Failed to recover second state"); assert_eq!(recovered_1, state_1); @@ -314,9 +319,9 @@ mod tests { #[tokio::test] async fn can_write_twice_to_one_key() { let db_dir = tempfile::tempdir().unwrap(); - let db = Database::open(db_dir.path()).unwrap(); + let db = SledDatabase::open(db_dir.path()).await.unwrap(); - let state = Swap::Alice(Alice::Done(AliceEndState::SafelyAborted)); + let state = State::from(AliceState::SafelyAborted); let swap_id = Uuid::new_v4(); db.insert_latest_state(swap_id, state.clone()) @@ -324,6 +329,7 @@ mod tests { .expect("Failed to save state the first time"); let recovered = db .get_state(swap_id) + .await .expect("Failed to recover state the first time"); // We insert and recover twice to ensure database implementation allows the @@ -333,84 +339,28 @@ mod tests { .expect("Failed to save state the second time"); let recovered = db .get_state(swap_id) + .await .expect("Failed to recover state the second time"); assert_eq!(recovered, state); } - #[tokio::test] - async fn all_swaps_as_alice() { - let db_dir = tempfile::tempdir().unwrap(); - let db = Database::open(db_dir.path()).unwrap(); - - let alice_state = Alice::Done(AliceEndState::BtcPunished); - let alice_swap = Swap::Alice(alice_state.clone()); - let alice_swap_id = Uuid::new_v4(); - db.insert_latest_state(alice_swap_id, alice_swap) - .await - .expect("Failed to save alice state 1"); - - let alice_swaps = db.all_alice().unwrap(); - assert_eq!(alice_swaps.len(), 1); - assert!(alice_swaps.contains(&(alice_swap_id, alice_state))); - - let bob_state = Bob::Done(BobEndState::SafelyAborted); - let bob_swap = Swap::Bob(bob_state); - let bob_swap_id = Uuid::new_v4(); - db.insert_latest_state(bob_swap_id, bob_swap) - .await - .expect("Failed to save bob state 1"); - - let err = db.all_alice().unwrap_err(); - - assert_eq!(err.downcast_ref::().unwrap(), &NotAlice); - } - - #[tokio::test] - async fn all_swaps_as_bob() { - let db_dir = tempfile::tempdir().unwrap(); - let db = Database::open(db_dir.path()).unwrap(); - - let bob_state = Bob::Done(BobEndState::SafelyAborted); - let bob_swap = Swap::Bob(bob_state.clone()); - let bob_swap_id = Uuid::new_v4(); - db.insert_latest_state(bob_swap_id, bob_swap) - .await - .expect("Failed to save bob state 1"); - - let bob_swaps = db.all_bob().unwrap(); - assert_eq!(bob_swaps.len(), 1); - assert!(bob_swaps.contains(&(bob_swap_id, bob_state))); - - let alice_state = Alice::Done(AliceEndState::BtcPunished); - let alice_swap = Swap::Alice(alice_state); - let alice_swap_id = Uuid::new_v4(); - db.insert_latest_state(alice_swap_id, alice_swap) - .await - .expect("Failed to save alice state 1"); - - let err = db.all_bob().unwrap_err(); - - assert_eq!(err.downcast_ref::().unwrap(), &NotBob); - } - #[tokio::test] async fn can_save_swap_state_and_peer_id_with_same_swap_id() -> Result<()> { let db_dir = tempfile::tempdir().unwrap(); - let db = Database::open(db_dir.path()).unwrap(); + let db = SledDatabase::open(db_dir.path()).await.unwrap(); let alice_id = Uuid::new_v4(); - let alice_state = Alice::Done(AliceEndState::BtcPunished); - let alice_swap = Swap::Alice(alice_state); + let alice_state = State::from(AliceState::BtcPunished); let peer_id = PeerId::random(); - db.insert_latest_state(alice_id, alice_swap.clone()).await?; + db.insert_latest_state(alice_id, alice_state.clone()).await?; db.insert_peer_id(alice_id, peer_id).await?; - let loaded_swap = db.get_state(alice_id)?; - let loaded_peer_id = db.get_peer_id(alice_id)?; + let loaded_swap = db.get_state(alice_id).await?; + let loaded_peer_id = db.get_peer_id(alice_id).await?; - assert_eq!(alice_swap, loaded_swap); + assert_eq!(alice_state, loaded_swap); assert_eq!(peer_id, loaded_peer_id); Ok(()) @@ -420,23 +370,22 @@ mod tests { async fn test_reopen_db() -> Result<()> { let db_dir = tempfile::tempdir().unwrap(); let alice_id = Uuid::new_v4(); - let alice_state = Alice::Done(AliceEndState::BtcPunished); - let alice_swap = Swap::Alice(alice_state); + let alice_state = State::from(AliceState::BtcPunished); let peer_id = PeerId::random(); { - let db = Database::open(db_dir.path()).unwrap(); - db.insert_latest_state(alice_id, alice_swap.clone()).await?; + let db = SledDatabase::open(db_dir.path()).await.unwrap(); + db.insert_latest_state(alice_id,alice_state.clone()).await?; db.insert_peer_id(alice_id, peer_id).await?; } - let db = Database::open(db_dir.path()).unwrap(); + let db = SledDatabase::open(db_dir.path()).await.unwrap(); - let loaded_swap = db.get_state(alice_id)?; - let loaded_peer_id = db.get_peer_id(alice_id)?; + let loaded_swap = db.get_state(alice_id).await?; + let loaded_peer_id = db.get_peer_id(alice_id).await?; - assert_eq!(alice_swap, loaded_swap); + assert_eq!(alice_state, loaded_swap); assert_eq!(peer_id, loaded_peer_id); Ok(()) @@ -450,12 +399,12 @@ mod tests { let home2 = "/ip4/127.0.0.1/tcp/2".parse::()?; { - let db = Database::open(db_dir.path())?; + let db = SledDatabase::open(db_dir.path()).await?; db.insert_address(peer_id, home1.clone()).await?; db.insert_address(peer_id, home2.clone()).await?; } - let addresses = Database::open(db_dir.path())?.get_addresses(peer_id)?; + let addresses = SledDatabase::open(db_dir.path()).await?.get_addresses(peer_id).await?; assert_eq!(addresses, vec![home1, home2]); @@ -467,8 +416,8 @@ mod tests { let db_dir = tempfile::tempdir()?; let swap_id = Uuid::new_v4(); - Database::open(db_dir.path())?.insert_monero_address(swap_id, "53gEuGZUhP9JMEBZoGaFNzhwEgiG7hwQdMCqFxiyiTeFPmkbt1mAoNybEUvYBKHcnrSgxnVWgZsTvRBaHBNXPa8tHiCU51a".parse()?).await?; - let loaded_monero_address = Database::open(db_dir.path())?.get_monero_address(swap_id)?; + SledDatabase::open(db_dir.path()).await?.insert_monero_address(swap_id, "53gEuGZUhP9JMEBZoGaFNzhwEgiG7hwQdMCqFxiyiTeFPmkbt1mAoNybEUvYBKHcnrSgxnVWgZsTvRBaHBNXPa8tHiCU51a".parse()?).await?; + let loaded_monero_address = SledDatabase::open(db_dir.path()).await?.get_monero_address(swap_id).await?; assert_eq!(loaded_monero_address.to_string(), "53gEuGZUhP9JMEBZoGaFNzhwEgiG7hwQdMCqFxiyiTeFPmkbt1mAoNybEUvYBKHcnrSgxnVWgZsTvRBaHBNXPa8tHiCU51a"); diff --git a/swap/src/database/alice.rs b/swap/src/database/alice.rs index 3358238f..26a0b7a8 100644 --- a/swap/src/database/alice.rs +++ b/swap/src/database/alice.rs @@ -1,7 +1,7 @@ use crate::bitcoin::EncryptedSignature; use crate::monero; use crate::monero::{monero_private_key, TransferProof}; -use crate::protocol::alice; +use crate::protocol::{alice}; use crate::protocol::alice::AliceState; use monero_rpc::wallet::BlockHeight; use serde::{Deserialize, Serialize}; @@ -78,8 +78,8 @@ pub enum AliceEndState { BtcPunished, } -impl From<&AliceState> for Alice { - fn from(alice_state: &AliceState) -> Self { +impl From for Alice { + fn from(alice_state: AliceState) -> Self { match alice_state { AliceState::Started { state3 } => Alice::Started { state3: state3.as_ref().clone(), @@ -95,8 +95,8 @@ impl From<&AliceState> for Alice { transfer_proof, state3, } => Alice::XmrLockTransactionSent { - monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight, - transfer_proof: transfer_proof.clone(), + monero_wallet_restore_blockheight, + transfer_proof, state3: state3.as_ref().clone(), }, AliceState::XmrLocked { @@ -104,8 +104,8 @@ impl From<&AliceState> for Alice { transfer_proof, state3, } => Alice::XmrLocked { - monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight, - transfer_proof: transfer_proof.clone(), + monero_wallet_restore_blockheight, + transfer_proof, state3: state3.as_ref().clone(), }, AliceState::XmrLockTransferProofSent { @@ -113,8 +113,8 @@ impl From<&AliceState> for Alice { transfer_proof, state3, } => Alice::XmrLockTransferProofSent { - monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight, - transfer_proof: transfer_proof.clone(), + monero_wallet_restore_blockheight, + transfer_proof, state3: state3.as_ref().clone(), }, AliceState::EncSigLearned { @@ -123,10 +123,10 @@ impl From<&AliceState> for Alice { state3, encrypted_signature, } => Alice::EncSigLearned { - monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight, - transfer_proof: transfer_proof.clone(), + monero_wallet_restore_blockheight, + transfer_proof, state3: state3.as_ref().clone(), - encrypted_signature: *encrypted_signature.clone(), + encrypted_signature: encrypted_signature.as_ref().clone(), }, AliceState::BtcRedeemTransactionPublished { state3 } => { Alice::BtcRedeemTransactionPublished { @@ -139,8 +139,8 @@ impl From<&AliceState> for Alice { transfer_proof, state3, } => Alice::BtcCancelled { - monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight, - transfer_proof: transfer_proof.clone(), + monero_wallet_restore_blockheight, + transfer_proof, state3: state3.as_ref().clone(), }, AliceState::BtcRefunded { @@ -149,9 +149,9 @@ impl From<&AliceState> for Alice { spend_key, state3, } => Alice::BtcRefunded { - monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight, - transfer_proof: transfer_proof.clone(), - spend_key: *spend_key, + monero_wallet_restore_blockheight, + transfer_proof, + spend_key, state3: state3.as_ref().clone(), }, AliceState::BtcPunishable { @@ -159,8 +159,8 @@ impl From<&AliceState> for Alice { transfer_proof, state3, } => Alice::BtcPunishable { - monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight, - transfer_proof: transfer_proof.clone(), + monero_wallet_restore_blockheight, + transfer_proof, state3: state3.as_ref().clone(), }, AliceState::XmrRefunded => Alice::Done(AliceEndState::XmrRefunded), @@ -169,8 +169,8 @@ impl From<&AliceState> for Alice { transfer_proof, state3, } => Alice::CancelTimelockExpired { - monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight, - transfer_proof: transfer_proof.clone(), + monero_wallet_restore_blockheight, + transfer_proof, state3: state3.as_ref().clone(), }, AliceState::BtcPunished => Alice::Done(AliceEndState::BtcPunished), diff --git a/swap/src/database/bob.rs b/swap/src/database/bob.rs index 82c6b848..608b76c8 100644 --- a/swap/src/database/bob.rs +++ b/swap/src/database/bob.rs @@ -38,11 +38,12 @@ pub enum Bob { Done(BobEndState), } +#[allow(clippy::large_enum_variant)] #[derive(Clone, strum::Display, Debug, Deserialize, Serialize, PartialEq)] pub enum BobEndState { SafelyAborted, XmrRedeemed { tx_lock_id: bitcoin::Txid }, - BtcRefunded(Box), + BtcRefunded(bob::State6), BtcPunished { tx_lock_id: bitcoin::Txid }, } @@ -72,7 +73,7 @@ impl From for Bob { BobState::BtcRedeemed(state5) => Bob::BtcRedeemed(state5), BobState::CancelTimelockExpired(state6) => Bob::CancelTimelockExpired(state6), BobState::BtcCancelled(state6) => Bob::BtcCancelled(state6), - BobState::BtcRefunded(state6) => Bob::Done(BobEndState::BtcRefunded(Box::new(state6))), + BobState::BtcRefunded(state6) => Bob::Done(BobEndState::BtcRefunded(state6)), BobState::XmrRedeemed { tx_lock_id } => { Bob::Done(BobEndState::XmrRedeemed { tx_lock_id }) } @@ -113,7 +114,7 @@ impl From for BobState { Bob::Done(end_state) => match end_state { BobEndState::SafelyAborted => BobState::SafelyAborted, BobEndState::XmrRedeemed { tx_lock_id } => BobState::XmrRedeemed { tx_lock_id }, - BobEndState::BtcRefunded(state6) => BobState::BtcRefunded(*state6), + BobEndState::BtcRefunded(state6) => BobState::BtcRefunded(state6), BobEndState::BtcPunished { tx_lock_id } => BobState::BtcPunished { tx_lock_id }, }, } diff --git a/swap/src/protocol.rs b/swap/src/protocol.rs index 59b100c0..129b4959 100644 --- a/swap/src/protocol.rs +++ b/swap/src/protocol.rs @@ -1,3 +1,5 @@ +use anyhow::Result; +use async_trait::async_trait; use crate::{bitcoin, monero}; use conquer_once::Lazy; use ecdsa_fun::fun::marker::Mark; @@ -6,6 +8,13 @@ use sha2::Sha256; use sigma_fun::ext::dl_secp256k1_ed25519_eq::{CrossCurveDLEQ, CrossCurveDLEQProof}; use sigma_fun::HashTranscript; use uuid::Uuid; +use crate::protocol::alice::AliceState; +use crate::protocol::bob::BobState; +use libp2p::{PeerId, Multiaddr}; +use std::convert::TryInto; +use crate::protocol::bob::swap::is_complete as bob_is_complete; +use crate::protocol::alice::swap::is_complete as alice_is_complete; +use std::collections::HashMap; pub mod alice; pub mod bob; @@ -65,3 +74,80 @@ pub struct Message4 { tx_punish_sig: bitcoin::Signature, tx_cancel_sig: bitcoin::Signature, } + +#[allow(clippy::large_enum_variant)] +#[derive(Clone, Debug, PartialEq)] +pub enum State { + Alice(AliceState), + Bob(BobState), +} + +impl State { + pub fn swap_finished(&self) -> bool { + match self { + State::Alice(state) => { + alice_is_complete(state) + }, + State::Bob(state) => { + bob_is_complete(state) + }, + } + + } +} + +impl From for State { + fn from(alice: AliceState) -> Self { + Self::Alice(alice) + } +} + +impl From for State { + fn from(bob: BobState) -> Self { + Self::Bob(bob) + } +} + +#[derive(thiserror::Error, Debug, Clone, Copy, PartialEq)] +#[error("Not in the role of Alice")] +pub struct NotAlice; + +#[derive(thiserror::Error, Debug, Clone, Copy, PartialEq)] +#[error("Not in the role of Bob")] +pub struct NotBob; + +impl TryInto for State { + type Error = NotBob; + + fn try_into(self) -> std::result::Result { + match self { + State::Alice(_) => Err(NotBob), + State::Bob(state) => Ok(state), + } + } +} + +impl TryInto for State { + type Error = NotAlice; + + fn try_into(self) -> std::result::Result { + match self { + State::Alice(state) => Ok(state), + State::Bob(_) => Err(NotAlice), + } + } +} + +#[async_trait] +pub trait Database { + async fn insert_peer_id(&self, swap_id: Uuid, peer_id: PeerId) -> Result<()>; + async fn get_peer_id(&self, swap_id: Uuid) -> Result; + async fn insert_monero_address(&self, swap_id: Uuid, address: monero::Address) -> Result<()>; + async fn get_monero_address(&self, swap_id: Uuid) -> Result; + async fn insert_address(&self, peer_id: PeerId, address: Multiaddr) -> Result<()>; + async fn get_addresses(&self, peer_id: PeerId) -> Result>; + async fn insert_latest_state(&self, swap_id: Uuid, state: State) -> Result<()>; + async fn get_state(&self, swap_id: Uuid) -> Result; + async fn all(&self) -> Result>; + async fn unfinished(&self, unfinished: fn(State) -> bool) -> Result>; +} diff --git a/swap/src/protocol/alice.rs b/swap/src/protocol/alice.rs index f8e80ca7..f742fde8 100644 --- a/swap/src/protocol/alice.rs +++ b/swap/src/protocol/alice.rs @@ -1,6 +1,5 @@ //! Run an XMR/BTC swap in the role of Alice. //! Alice holds XMR and wishes receive BTC. -use crate::database::Database; use crate::env::Config; use crate::{asb, bitcoin, monero}; use std::sync::Arc; @@ -8,6 +7,7 @@ use uuid::Uuid; pub use self::state::*; pub use self::swap::{run, run_until}; +use crate::protocol::Database; pub mod state; pub mod swap; @@ -19,5 +19,5 @@ pub struct Swap { pub monero_wallet: Arc, pub env_config: Config, pub swap_id: Uuid, - pub db: Arc, + pub db: Arc, } diff --git a/swap/src/protocol/alice/state.rs b/swap/src/protocol/alice/state.rs index 18c2a18d..9dcf3f46 100644 --- a/swap/src/protocol/alice/state.rs +++ b/swap/src/protocol/alice/state.rs @@ -16,7 +16,7 @@ use sigma_fun::ext::dl_secp256k1_ed25519_eq::CrossCurveDLEQProof; use std::fmt; use uuid::Uuid; -#[derive(Debug)] +#[derive(Debug, Clone, PartialEq)] pub enum AliceState { Started { state3: Box, diff --git a/swap/src/protocol/alice/swap.rs b/swap/src/protocol/alice/swap.rs index 76324667..737d19f7 100644 --- a/swap/src/protocol/alice/swap.rs +++ b/swap/src/protocol/alice/swap.rs @@ -4,7 +4,7 @@ use crate::asb::{EventLoopHandle, LatestRate}; use crate::bitcoin::ExpiredTimelocks; use crate::env::Config; use crate::protocol::alice::{AliceState, Swap}; -use crate::{bitcoin, database, monero}; +use crate::{bitcoin, monero}; use anyhow::{bail, Context, Result}; use tokio::select; use tokio::time::timeout; @@ -14,6 +14,7 @@ pub async fn run(swap: Swap, rate_service: LR) -> Result where LR: LatestRate + Clone, { + run_until(swap, |_| false, rate_service).await } @@ -40,9 +41,8 @@ where ) .await?; - let db_state = (¤t_state).into(); swap.db - .insert_latest_state(swap.swap_id, database::Swap::Alice(db_state)) + .insert_latest_state(swap.swap_id, current_state.clone().into()) .await?; } @@ -398,7 +398,7 @@ where }) } -fn is_complete(state: &AliceState) -> bool { +pub(crate) fn is_complete(state: &AliceState) -> bool { matches!( state, AliceState::XmrRefunded diff --git a/swap/src/protocol/bob.rs b/swap/src/protocol/bob.rs index dc8f7aec..0ef3e241 100644 --- a/swap/src/protocol/bob.rs +++ b/swap/src/protocol/bob.rs @@ -3,11 +3,12 @@ use std::sync::Arc; use anyhow::Result; use uuid::Uuid; -use crate::database::Database; +use crate::protocol::Database; use crate::{bitcoin, cli, env, monero}; pub use self::state::*; pub use self::swap::{run, run_until}; +use std::convert::TryInto; pub mod state; pub mod swap; @@ -15,7 +16,7 @@ pub mod swap; pub struct Swap { pub state: BobState, pub event_loop_handle: cli::EventLoopHandle, - pub db: Database, + pub db: Arc, pub bitcoin_wallet: Arc, pub monero_wallet: Arc, pub env_config: env::Config, @@ -26,7 +27,7 @@ pub struct Swap { impl Swap { #[allow(clippy::too_many_arguments)] pub fn new( - db: Database, + db: Arc, id: Uuid, bitcoin_wallet: Arc, monero_wallet: Arc, @@ -52,8 +53,8 @@ impl Swap { } #[allow(clippy::too_many_arguments)] - pub fn from_db( - db: Database, + pub async fn from_db( + db: Arc, id: Uuid, bitcoin_wallet: Arc, monero_wallet: Arc, @@ -61,7 +62,7 @@ impl Swap { event_loop_handle: cli::EventLoopHandle, monero_receive_address: monero::Address, ) -> Result { - let state = db.get_state(id)?.try_into_bob()?.into(); + let state = db.get_state(id).await?.try_into()?; Ok(Self { state, diff --git a/swap/src/protocol/bob/state.rs b/swap/src/protocol/bob/state.rs index 8e0bac9a..effd05ee 100644 --- a/swap/src/protocol/bob/state.rs +++ b/swap/src/protocol/bob/state.rs @@ -21,7 +21,7 @@ use sigma_fun::ext::dl_secp256k1_ed25519_eq::CrossCurveDLEQProof; use std::fmt; use uuid::Uuid; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum BobState { Started { btc_amount: bitcoin::Amount, diff --git a/swap/src/protocol/bob/swap.rs b/swap/src/protocol/bob/swap.rs index e1e9f849..c0bccc69 100644 --- a/swap/src/protocol/bob/swap.rs +++ b/swap/src/protocol/bob/swap.rs @@ -1,6 +1,5 @@ use crate::bitcoin::{ExpiredTimelocks, TxCancel, TxRefund}; use crate::cli::EventLoopHandle; -use crate::database::Swap; use crate::network::swap_setup::bob::NewSwap; use crate::protocol::bob; use crate::protocol::bob::state::*; @@ -33,7 +32,7 @@ pub async fn run_until( while !is_target_state(¤t_state) { current_state = next_state( swap.id, - current_state, + current_state.clone(), &mut swap.event_loop_handle, swap.bitcoin_wallet.as_ref(), swap.monero_wallet.as_ref(), @@ -41,9 +40,8 @@ pub async fn run_until( ) .await?; - let db_state = current_state.clone().into(); swap.db - .insert_latest_state(swap.id, Swap::Bob(db_state)) + .insert_latest_state(swap.id, current_state.clone().into()) .await?; } diff --git a/swap/tests/concurrent_bobs_after_xmr_lock_proof_sent.rs b/swap/tests/concurrent_bobs_after_xmr_lock_proof_sent.rs index 45e27819..949d404e 100644 --- a/swap/tests/concurrent_bobs_after_xmr_lock_proof_sent.rs +++ b/swap/tests/concurrent_bobs_after_xmr_lock_proof_sent.rs @@ -2,8 +2,7 @@ pub mod harness; use harness::bob_run_until::is_xmr_locked; use harness::SlowCancelConfig; -use swap::asb::FixedRate; -use swap::protocol::alice::AliceState; +use swap::{asb::FixedRate, protocol::alice::AliceState}; use swap::protocol::bob::BobState; use swap::protocol::{alice, bob}; diff --git a/swap/tests/harness/mod.rs b/swap/tests/harness/mod.rs index ad43bd9e..768e03c0 100644 --- a/swap/tests/harness/mod.rs +++ b/swap/tests/harness/mod.rs @@ -16,7 +16,7 @@ use std::sync::Arc; use std::time::Duration; use swap::asb::FixedRate; use swap::bitcoin::{CancelTimelock, PunishTimelock, TxCancel, TxPunish, TxRedeem, TxRefund}; -use swap::database::Database; +use swap::database::SledDatabase; use swap::env::{Config, GetConfig}; use swap::network::swarm; use swap::protocol::alice::{AliceState, Swap}; @@ -222,7 +222,7 @@ async fn start_alice( bitcoin_wallet: Arc, monero_wallet: Arc, ) -> (AliceApplicationHandle, Receiver) { - let db = Arc::new(Database::open(db_path.as_path()).unwrap()); + let db = Arc::new(SledDatabase::open(db_path.as_path()).await.unwrap()); let min_buy = bitcoin::Amount::from_sat(u64::MIN); let max_buy = bitcoin::Amount::from_sat(u64::MAX); @@ -402,7 +402,7 @@ struct BobParams { impl BobParams { pub async fn new_swap_from_db(&self, swap_id: Uuid) -> Result<(bob::Swap, cli::EventLoop)> { let (event_loop, handle) = self.new_eventloop(swap_id).await?; - let db = Database::open(&self.db_path)?; + let db = Arc::new(SledDatabase::open(&self.db_path).await?); let swap = bob::Swap::from_db( db, @@ -412,7 +412,7 @@ impl BobParams { self.env_config, handle, self.monero_wallet.get_main_address(), - )?; + ).await?; Ok((swap, event_loop)) } @@ -424,7 +424,7 @@ impl BobParams { let swap_id = Uuid::new_v4(); let (event_loop, handle) = self.new_eventloop(swap_id).await?; - let db = Database::open(&self.db_path)?; + let db = Arc::new(SledDatabase::open(&self.db_path).await?); let swap = bob::Swap::new( db,