From a94c320021b05225a6cc429102ed0ea7c8a0b273 Mon Sep 17 00:00:00 2001 From: rishflab Date: Wed, 22 Sep 2021 16:14:49 +1000 Subject: [PATCH] Reorganise modules for multiple database implementations --- swap/src/asb/event_loop.rs | 6 +- swap/src/asb/recovery/cancel.rs | 4 +- swap/src/asb/recovery/punish.rs | 4 +- swap/src/asb/recovery/redeem.rs | 4 +- swap/src/asb/recovery/refund.rs | 4 +- swap/src/asb/recovery/safely_abort.rs | 4 +- swap/src/bin/asb.rs | 4 +- swap/src/bin/swap.rs | 12 +- swap/src/cli/cancel.rs | 4 +- swap/src/cli/refund.rs | 8 +- swap/src/database.rs | 421 +------------------------ swap/src/database/sled.rs | 422 ++++++++++++++++++++++++++ swap/src/protocol/alice.rs | 4 +- swap/src/protocol/bob.rs | 8 +- swap/tests/harness/mod.rs | 8 +- 15 files changed, 464 insertions(+), 453 deletions(-) create mode 100644 swap/src/database/sled.rs diff --git a/swap/src/asb/event_loop.rs b/swap/src/asb/event_loop.rs index 1cc268bb..13b85de8 100644 --- a/swap/src/asb/event_loop.rs +++ b/swap/src/asb/event_loop.rs @@ -1,5 +1,5 @@ use crate::asb::{Behaviour, OutEvent, Rate}; -use crate::database::Database; +use crate::database::SledDatabase; use crate::network::quote::BidQuote; use crate::network::swap_setup::alice::WalletSnapshot; use crate::network::transfer_proof; @@ -39,7 +39,7 @@ where env_config: env::Config, bitcoin_wallet: Arc, monero_wallet: Arc, - db: Arc, + db: Arc, latest_rate: LR, min_buy: bitcoin::Amount, max_buy: bitcoin::Amount, @@ -71,7 +71,7 @@ where env_config: env::Config, bitcoin_wallet: Arc, monero_wallet: Arc, - db: Arc, + db: Arc, latest_rate: LR, min_buy: bitcoin::Amount, max_buy: bitcoin::Amount, diff --git a/swap/src/asb/recovery/cancel.rs b/swap/src/asb/recovery/cancel.rs index 32af014f..e571a0f1 100644 --- a/swap/src/asb/recovery/cancel.rs +++ b/swap/src/asb/recovery/cancel.rs @@ -1,5 +1,5 @@ use crate::bitcoin::{parse_rpc_error_code, RpcErrorCode, Txid, Wallet}; -use crate::database::{Database, Swap}; +use crate::database::{SledDatabase, Swap}; use crate::protocol::alice::AliceState; use anyhow::{bail, Result}; use std::sync::Arc; @@ -8,7 +8,7 @@ use uuid::Uuid; pub async fn cancel( swap_id: Uuid, bitcoin_wallet: Arc, - db: Arc, + db: Arc, ) -> Result<(Txid, AliceState)> { let state = db.get_state(swap_id)?.try_into_alice()?.into(); diff --git a/swap/src/asb/recovery/punish.rs b/swap/src/asb/recovery/punish.rs index ddd27c8f..e797717e 100644 --- a/swap/src/asb/recovery/punish.rs +++ b/swap/src/asb/recovery/punish.rs @@ -1,5 +1,5 @@ use crate::bitcoin::{self, Txid}; -use crate::database::{Database, Swap}; +use crate::database::{SledDatabase, Swap}; use crate::protocol::alice::AliceState; use anyhow::{bail, Result}; use std::sync::Arc; @@ -14,7 +14,7 @@ pub enum Error { pub async fn punish( swap_id: Uuid, bitcoin_wallet: Arc, - db: Arc, + db: Arc, ) -> Result<(Txid, AliceState)> { let state = db.get_state(swap_id)?.try_into_alice()?.into(); diff --git a/swap/src/asb/recovery/redeem.rs b/swap/src/asb/recovery/redeem.rs index dd8daa54..c4729ae0 100644 --- a/swap/src/asb/recovery/redeem.rs +++ b/swap/src/asb/recovery/redeem.rs @@ -1,5 +1,5 @@ use crate::bitcoin::{Txid, Wallet}; -use crate::database::{Database, Swap}; +use crate::database::{SledDatabase, Swap}; use crate::protocol::alice::AliceState; use anyhow::{bail, Result}; use std::sync::Arc; @@ -23,7 +23,7 @@ impl Finality { pub async fn redeem( swap_id: Uuid, bitcoin_wallet: Arc, - db: Arc, + db: Arc, finality: Finality, ) -> Result<(Txid, AliceState)> { let state = db.get_state(swap_id)?.try_into_alice()?.into(); diff --git a/swap/src/asb/recovery/refund.rs b/swap/src/asb/recovery/refund.rs index 1e91b49a..89dc30f6 100644 --- a/swap/src/asb/recovery/refund.rs +++ b/swap/src/asb/recovery/refund.rs @@ -1,5 +1,5 @@ use crate::bitcoin::{self}; -use crate::database::{Database, Swap}; +use crate::database::{SledDatabase, Swap}; use crate::monero; use crate::protocol::alice::AliceState; use anyhow::{bail, Result}; @@ -26,7 +26,7 @@ pub async fn refund( swap_id: Uuid, bitcoin_wallet: Arc, monero_wallet: Arc, - db: Arc, + db: Arc, ) -> Result { let state = db.get_state(swap_id)?.try_into_alice()?.into(); diff --git a/swap/src/asb/recovery/safely_abort.rs b/swap/src/asb/recovery/safely_abort.rs index 8105f068..f8336c8c 100644 --- a/swap/src/asb/recovery/safely_abort.rs +++ b/swap/src/asb/recovery/safely_abort.rs @@ -1,10 +1,10 @@ -use crate::database::{Database, Swap}; +use crate::database::{SledDatabase, Swap}; use crate::protocol::alice::AliceState; use anyhow::{bail, Result}; use std::sync::Arc; use uuid::Uuid; -pub async fn safely_abort(swap_id: Uuid, db: Arc) -> Result { +pub async fn safely_abort(swap_id: Uuid, db: Arc) -> Result { let state = db.get_state(swap_id)?.try_into_alice()?.into(); match state { diff --git a/swap/src/bin/asb.rs b/swap/src/bin/asb.rs index ae3ead70..10801c59 100644 --- a/swap/src/bin/asb.rs +++ b/swap/src/bin/asb.rs @@ -28,7 +28,7 @@ use swap::asb::config::{ initial_setup, query_user_for_initial_config, read_config, Config, ConfigNotInitialized, }; use swap::asb::{cancel, punish, redeem, refund, safely_abort, EventLoop, Finality, KrakenRate}; -use swap::database::Database; +use swap::database::SledDatabase; use swap::monero::Amount; use swap::network::rendezvous::XmrBtcNamespace; use swap::network::swarm; @@ -92,7 +92,7 @@ async fn main() -> Result<()> { let db_path = config.data.dir.join("database"); - let db = Database::open(config.data.dir.join(db_path).as_path()) + let db = SledDatabase::open(config.data.dir.join(db_path).as_path()) .context("Could not open database")?; let seed = diff --git a/swap/src/bin/swap.rs b/swap/src/bin/swap.rs index cd7c1e6a..3da42926 100644 --- a/swap/src/bin/swap.rs +++ b/swap/src/bin/swap.rs @@ -25,7 +25,7 @@ use std::time::Duration; use swap::bitcoin::TxLock; use swap::cli::command::{parse_args_and_apply_defaults, Arguments, Command, ParseResult}; use swap::cli::{list_sellers, EventLoop, SellerStatus}; -use swap::database::Database; +use swap::database::SledDatabase; use swap::env::Config; use swap::libp2p_ext::MultiAddrExt; use swap::network::quote::BidQuote; @@ -66,7 +66,7 @@ async fn main() -> Result<()> { let swap_id = Uuid::new_v4(); cli::tracing::init(debug, json, data_dir.join("logs"), Some(swap_id))?; - let db = Database::open(data_dir.join("database").as_path()) + let db = SledDatabase::open(data_dir.join("database").as_path()) .context("Failed to open database")?; let seed = Seed::from_file_or_generate(data_dir.as_path()) .context("Failed to read in seed file")?; @@ -139,7 +139,7 @@ async fn main() -> Result<()> { } } Command::History => { - let db = Database::open(data_dir.join("database").as_path()) + let db = SledDatabase::open(data_dir.join("database").as_path()) .context("Failed to open database")?; let mut table = Table::new(); @@ -215,7 +215,7 @@ async fn main() -> Result<()> { tor_socks5_port, } => { cli::tracing::init(debug, json, data_dir.join("logs"), Some(swap_id))?; - let db = Database::open(data_dir.join("database").as_path()) + let db = SledDatabase::open(data_dir.join("database").as_path()) .context("Failed to open database")?; let seed = Seed::from_file_or_generate(data_dir.as_path()) .context("Failed to read in seed file")?; @@ -277,7 +277,7 @@ async fn main() -> Result<()> { bitcoin_target_block, } => { cli::tracing::init(debug, json, data_dir.join("logs"), Some(swap_id))?; - let db = Database::open(data_dir.join("database").as_path()) + let db = SledDatabase::open(data_dir.join("database").as_path()) .context("Failed to open database")?; let seed = Seed::from_file_or_generate(data_dir.as_path()) .context("Failed to read in seed file")?; @@ -300,7 +300,7 @@ async fn main() -> Result<()> { bitcoin_target_block, } => { cli::tracing::init(debug, json, data_dir.join("logs"), Some(swap_id))?; - let db = Database::open(data_dir.join("database").as_path()) + let db = SledDatabase::open(data_dir.join("database").as_path()) .context("Failed to open database")?; let seed = Seed::from_file_or_generate(data_dir.as_path()) .context("Failed to read in seed file")?; diff --git a/swap/src/cli/cancel.rs b/swap/src/cli/cancel.rs index 5828b078..42381eb0 100644 --- a/swap/src/cli/cancel.rs +++ b/swap/src/cli/cancel.rs @@ -1,5 +1,5 @@ use crate::bitcoin::{parse_rpc_error_code, RpcErrorCode, Txid, Wallet}; -use crate::database::{Database, Swap}; +use crate::database::{SledDatabase, Swap}; use crate::protocol::bob::BobState; use anyhow::{bail, Result}; use std::sync::Arc; @@ -8,7 +8,7 @@ use uuid::Uuid; pub async fn cancel( swap_id: Uuid, bitcoin_wallet: Arc, - db: Database, + db: SledDatabase, ) -> Result<(Txid, BobState)> { let state = db.get_state(swap_id)?.try_into_bob()?.into(); diff --git a/swap/src/cli/refund.rs b/swap/src/cli/refund.rs index f94b9430..7b2244f0 100644 --- a/swap/src/cli/refund.rs +++ b/swap/src/cli/refund.rs @@ -1,11 +1,15 @@ use crate::bitcoin::Wallet; -use crate::database::{Database, Swap}; +use crate::database::{SledDatabase, Swap}; use crate::protocol::bob::BobState; use anyhow::{bail, Result}; use std::sync::Arc; use uuid::Uuid; -pub async fn refund(swap_id: Uuid, bitcoin_wallet: Arc, db: Database) -> Result { +pub async fn refund( + swap_id: Uuid, + bitcoin_wallet: Arc, + db: SledDatabase, +) -> Result { let state = db.get_state(swap_id)?.try_into_bob()?.into(); let state6 = match state { diff --git a/swap/src/database.rs b/swap/src/database.rs index 339a1e35..1b59d4c8 100644 --- a/swap/src/database.rs +++ b/swap/src/database.rs @@ -1,18 +1,14 @@ +pub use self::sled::SledDatabase; pub use alice::Alice; pub use bob::Bob; -use anyhow::{anyhow, bail, Context, Result}; -use itertools::Itertools; -use libp2p::{Multiaddr, PeerId}; -use serde::de::DeserializeOwned; +use anyhow::{bail, Result}; use serde::{Deserialize, Serialize}; use std::fmt::Display; -use std::path::Path; -use std::str::FromStr; -use uuid::Uuid; mod alice; mod bob; +mod sled; #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] pub enum Swap { @@ -64,414 +60,3 @@ impl Swap { } } } - -pub struct Database { - swaps: sled::Tree, - peers: sled::Tree, - addresses: sled::Tree, - monero_addresses: sled::Tree, -} - -impl Database { - pub fn open(path: &Path) -> Result { - tracing::debug!("Opening database at {}", path.display()); - - let db = - sled::open(path).with_context(|| format!("Could not open the DB at {:?}", path))?; - - let swaps = db.open_tree("swaps")?; - let peers = db.open_tree("peers")?; - let addresses = db.open_tree("addresses")?; - let monero_addresses = db.open_tree("monero_addresses")?; - - Ok(Database { - swaps, - peers, - addresses, - monero_addresses, - }) - } - - pub async fn insert_peer_id(&self, swap_id: Uuid, peer_id: PeerId) -> Result<()> { - let peer_id_str = peer_id.to_string(); - - let key = serialize(&swap_id)?; - let value = serialize(&peer_id_str).context("Could not serialize peer-id")?; - - self.peers.insert(key, value)?; - - self.peers - .flush_async() - .await - .map(|_| ()) - .context("Could not flush db") - } - - pub fn get_peer_id(&self, swap_id: Uuid) -> Result { - let key = serialize(&swap_id)?; - - let encoded = self - .peers - .get(&key)? - .ok_or_else(|| anyhow!("No peer-id found for swap id {} in database", swap_id))?; - - let peer_id: String = deserialize(&encoded).context("Could not deserialize peer-id")?; - Ok(PeerId::from_str(peer_id.as_str())?) - } - - pub async fn insert_monero_address( - &self, - swap_id: Uuid, - address: monero::Address, - ) -> Result<()> { - let key = swap_id.as_bytes(); - let value = serialize(&address)?; - - self.monero_addresses.insert(key, value)?; - - self.monero_addresses - .flush_async() - .await - .map(|_| ()) - .context("Could not flush db") - } - - pub fn get_monero_address(&self, swap_id: Uuid) -> Result { - let encoded = self - .monero_addresses - .get(swap_id.as_bytes())? - .ok_or_else(|| { - anyhow!( - "No Monero address found for swap id {} in database", - swap_id - ) - })?; - - let monero_address = deserialize(&encoded)?; - - Ok(monero_address) - } - - pub async fn insert_address(&self, peer_id: PeerId, address: Multiaddr) -> Result<()> { - let key = peer_id.to_bytes(); - - let existing_addresses = self.addresses.get(&key)?; - - let new_addresses = { - let existing_addresses = existing_addresses.clone(); - - Some(match existing_addresses { - Some(encoded) => { - let mut addresses = deserialize::>(&encoded)?; - addresses.push(address); - - serialize(&addresses)? - } - None => serialize(&[address])?, - }) - }; - - self.addresses - .compare_and_swap(key, existing_addresses, new_addresses)??; - - self.addresses - .flush_async() - .await - .map(|_| ()) - .context("Could not flush db") - } - - pub fn get_addresses(&self, peer_id: PeerId) -> Result> { - let key = peer_id.to_bytes(); - - let addresses = match self.addresses.get(&key)? { - Some(encoded) => deserialize(&encoded).context("Failed to deserialize addresses")?, - None => vec![], - }; - - Ok(addresses) - } - - pub async fn insert_latest_state(&self, swap_id: Uuid, state: Swap) -> Result<()> { - let key = serialize(&swap_id)?; - let new_value = serialize(&state).context("Could not serialize new state value")?; - - let old_value = self.swaps.get(&key)?; - - self.swaps - .compare_and_swap(key, old_value, Some(new_value)) - .context("Could not write in the DB")? - .context("Stored swap somehow changed, aborting saving")?; - - self.swaps - .flush_async() - .await - .map(|_| ()) - .context("Could not flush db") - } - - pub fn get_state(&self, swap_id: Uuid) -> Result { - let key = serialize(&swap_id)?; - - let encoded = self - .swaps - .get(&key)? - .ok_or_else(|| anyhow!("Swap with id {} not found in database", swap_id))?; - - let state = deserialize(&encoded).context("Could not deserialize state")?; - Ok(state) - } - - pub fn all_alice(&self) -> Result> { - self.all_alice_iter().collect() - } - - fn all_alice_iter(&self) -> impl Iterator> { - self.all_swaps_iter().map(|item| { - let (swap_id, swap) = item?; - Ok((swap_id, swap.try_into_alice()?)) - }) - } - - pub fn all_bob(&self) -> Result> { - self.all_bob_iter().collect() - } - - fn all_bob_iter(&self) -> impl Iterator> { - self.all_swaps_iter().map(|item| { - let (swap_id, swap) = item?; - Ok((swap_id, swap.try_into_bob()?)) - }) - } - - fn all_swaps_iter(&self) -> impl Iterator> { - self.swaps.iter().map(|item| { - let (key, value) = item.context("Failed to retrieve swap from DB")?; - - let swap_id = deserialize::(&key)?; - let swap = deserialize::(&value).context("Failed to deserialize swap")?; - - Ok((swap_id, swap)) - }) - } - - pub fn unfinished_alice(&self) -> Result> { - self.all_alice_iter() - .filter_ok(|(_swap_id, alice)| !matches!(alice, Alice::Done(_))) - .collect() - } -} - -pub fn serialize(t: &T) -> Result> -where - T: Serialize, -{ - Ok(serde_cbor::to_vec(t)?) -} - -pub fn deserialize(v: &[u8]) -> Result -where - T: DeserializeOwned, -{ - Ok(serde_cbor::from_slice(&v)?) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::database::alice::{Alice, AliceEndState}; - use crate::database::bob::{Bob, BobEndState}; - - #[tokio::test] - async fn can_write_and_read_to_multiple_keys() { - let db_dir = tempfile::tempdir().unwrap(); - let db = Database::open(db_dir.path()).unwrap(); - - let state_1 = Swap::Alice(Alice::Done(AliceEndState::BtcRedeemed)); - let swap_id_1 = Uuid::new_v4(); - db.insert_latest_state(swap_id_1, state_1.clone()) - .await - .expect("Failed to save second state"); - - let state_2 = Swap::Bob(Bob::Done(BobEndState::SafelyAborted)); - let swap_id_2 = Uuid::new_v4(); - db.insert_latest_state(swap_id_2, state_2.clone()) - .await - .expect("Failed to save first state"); - - let recovered_1 = db - .get_state(swap_id_1) - .expect("Failed to recover first state"); - - let recovered_2 = db - .get_state(swap_id_2) - .expect("Failed to recover second state"); - - assert_eq!(recovered_1, state_1); - assert_eq!(recovered_2, state_2); - } - - #[tokio::test] - async fn can_write_twice_to_one_key() { - let db_dir = tempfile::tempdir().unwrap(); - let db = Database::open(db_dir.path()).unwrap(); - - let state = Swap::Alice(Alice::Done(AliceEndState::SafelyAborted)); - - let swap_id = Uuid::new_v4(); - db.insert_latest_state(swap_id, state.clone()) - .await - .expect("Failed to save state the first time"); - let recovered = db - .get_state(swap_id) - .expect("Failed to recover state the first time"); - - // We insert and recover twice to ensure database implementation allows the - // caller to write to an existing key - db.insert_latest_state(swap_id, recovered) - .await - .expect("Failed to save state the second time"); - let recovered = db - .get_state(swap_id) - .expect("Failed to recover state the second time"); - - assert_eq!(recovered, state); - } - - #[tokio::test] - async fn all_swaps_as_alice() { - let db_dir = tempfile::tempdir().unwrap(); - let db = Database::open(db_dir.path()).unwrap(); - - let alice_state = Alice::Done(AliceEndState::BtcPunished); - let alice_swap = Swap::Alice(alice_state.clone()); - let alice_swap_id = Uuid::new_v4(); - db.insert_latest_state(alice_swap_id, alice_swap) - .await - .expect("Failed to save alice state 1"); - - let alice_swaps = db.all_alice().unwrap(); - assert_eq!(alice_swaps.len(), 1); - assert!(alice_swaps.contains(&(alice_swap_id, alice_state))); - - let bob_state = Bob::Done(BobEndState::SafelyAborted); - let bob_swap = Swap::Bob(bob_state); - let bob_swap_id = Uuid::new_v4(); - db.insert_latest_state(bob_swap_id, bob_swap) - .await - .expect("Failed to save bob state 1"); - - let err = db.all_alice().unwrap_err(); - - assert_eq!(err.downcast_ref::().unwrap(), &NotAlice); - } - - #[tokio::test] - async fn all_swaps_as_bob() { - let db_dir = tempfile::tempdir().unwrap(); - let db = Database::open(db_dir.path()).unwrap(); - - let bob_state = Bob::Done(BobEndState::SafelyAborted); - let bob_swap = Swap::Bob(bob_state.clone()); - let bob_swap_id = Uuid::new_v4(); - db.insert_latest_state(bob_swap_id, bob_swap) - .await - .expect("Failed to save bob state 1"); - - let bob_swaps = db.all_bob().unwrap(); - assert_eq!(bob_swaps.len(), 1); - assert!(bob_swaps.contains(&(bob_swap_id, bob_state))); - - let alice_state = Alice::Done(AliceEndState::BtcPunished); - let alice_swap = Swap::Alice(alice_state); - let alice_swap_id = Uuid::new_v4(); - db.insert_latest_state(alice_swap_id, alice_swap) - .await - .expect("Failed to save alice state 1"); - - let err = db.all_bob().unwrap_err(); - - assert_eq!(err.downcast_ref::().unwrap(), &NotBob); - } - - #[tokio::test] - async fn can_save_swap_state_and_peer_id_with_same_swap_id() -> Result<()> { - let db_dir = tempfile::tempdir().unwrap(); - let db = Database::open(db_dir.path()).unwrap(); - - let alice_id = Uuid::new_v4(); - let alice_state = Alice::Done(AliceEndState::BtcPunished); - let alice_swap = Swap::Alice(alice_state); - let peer_id = PeerId::random(); - - db.insert_latest_state(alice_id, alice_swap.clone()).await?; - db.insert_peer_id(alice_id, peer_id).await?; - - let loaded_swap = db.get_state(alice_id)?; - let loaded_peer_id = db.get_peer_id(alice_id)?; - - assert_eq!(alice_swap, loaded_swap); - assert_eq!(peer_id, loaded_peer_id); - - Ok(()) - } - - #[tokio::test] - async fn test_reopen_db() -> Result<()> { - let db_dir = tempfile::tempdir().unwrap(); - let alice_id = Uuid::new_v4(); - let alice_state = Alice::Done(AliceEndState::BtcPunished); - let alice_swap = Swap::Alice(alice_state); - - let peer_id = PeerId::random(); - - { - let db = Database::open(db_dir.path()).unwrap(); - db.insert_latest_state(alice_id, alice_swap.clone()).await?; - db.insert_peer_id(alice_id, peer_id).await?; - } - - let db = Database::open(db_dir.path()).unwrap(); - - let loaded_swap = db.get_state(alice_id)?; - let loaded_peer_id = db.get_peer_id(alice_id)?; - - assert_eq!(alice_swap, loaded_swap); - assert_eq!(peer_id, loaded_peer_id); - - Ok(()) - } - - #[tokio::test] - async fn save_and_load_addresses() -> Result<()> { - let db_dir = tempfile::tempdir()?; - let peer_id = PeerId::random(); - let home1 = "/ip4/127.0.0.1/tcp/1".parse::()?; - let home2 = "/ip4/127.0.0.1/tcp/2".parse::()?; - - { - let db = Database::open(db_dir.path())?; - db.insert_address(peer_id, home1.clone()).await?; - db.insert_address(peer_id, home2.clone()).await?; - } - - let addresses = Database::open(db_dir.path())?.get_addresses(peer_id)?; - - assert_eq!(addresses, vec![home1, home2]); - - Ok(()) - } - - #[tokio::test] - async fn save_and_load_monero_address() -> Result<()> { - let db_dir = tempfile::tempdir()?; - let swap_id = Uuid::new_v4(); - - Database::open(db_dir.path())?.insert_monero_address(swap_id, "53gEuGZUhP9JMEBZoGaFNzhwEgiG7hwQdMCqFxiyiTeFPmkbt1mAoNybEUvYBKHcnrSgxnVWgZsTvRBaHBNXPa8tHiCU51a".parse()?).await?; - let loaded_monero_address = Database::open(db_dir.path())?.get_monero_address(swap_id)?; - - assert_eq!(loaded_monero_address.to_string(), "53gEuGZUhP9JMEBZoGaFNzhwEgiG7hwQdMCqFxiyiTeFPmkbt1mAoNybEUvYBKHcnrSgxnVWgZsTvRBaHBNXPa8tHiCU51a"); - - Ok(()) - } -} diff --git a/swap/src/database/sled.rs b/swap/src/database/sled.rs new file mode 100644 index 00000000..512b5039 --- /dev/null +++ b/swap/src/database/sled.rs @@ -0,0 +1,422 @@ +use crate::database::{Alice, Bob, Swap}; +use anyhow::{anyhow, Context, Result}; +use itertools::Itertools; +use libp2p::{Multiaddr, PeerId}; +use serde::de::DeserializeOwned; +use serde::Serialize; +use std::path::Path; +use std::str::FromStr; +use uuid::Uuid; + +pub struct SledDatabase { + swaps: sled::Tree, + peers: sled::Tree, + addresses: sled::Tree, + monero_addresses: sled::Tree, +} + +impl SledDatabase { + pub fn open(path: &Path) -> Result { + tracing::debug!("Opening database at {}", path.display()); + + let db = + sled::open(path).with_context(|| format!("Could not open the DB at {:?}", path))?; + + let swaps = db.open_tree("swaps")?; + let peers = db.open_tree("peers")?; + let addresses = db.open_tree("addresses")?; + let monero_addresses = db.open_tree("monero_addresses")?; + + Ok(SledDatabase { + swaps, + peers, + addresses, + monero_addresses, + }) + } + + pub async fn insert_peer_id(&self, swap_id: Uuid, peer_id: PeerId) -> Result<()> { + let peer_id_str = peer_id.to_string(); + + let key = serialize(&swap_id)?; + let value = serialize(&peer_id_str).context("Could not serialize peer-id")?; + + self.peers.insert(key, value)?; + + self.peers + .flush_async() + .await + .map(|_| ()) + .context("Could not flush db") + } + + pub fn get_peer_id(&self, swap_id: Uuid) -> Result { + let key = serialize(&swap_id)?; + + let encoded = self + .peers + .get(&key)? + .ok_or_else(|| anyhow!("No peer-id found for swap id {} in database", swap_id))?; + + let peer_id: String = deserialize(&encoded).context("Could not deserialize peer-id")?; + Ok(PeerId::from_str(peer_id.as_str())?) + } + + pub async fn insert_monero_address( + &self, + swap_id: Uuid, + address: monero::Address, + ) -> Result<()> { + let key = swap_id.as_bytes(); + let value = serialize(&address)?; + + self.monero_addresses.insert(key, value)?; + + self.monero_addresses + .flush_async() + .await + .map(|_| ()) + .context("Could not flush db") + } + + pub fn get_monero_address(&self, swap_id: Uuid) -> Result { + let encoded = self + .monero_addresses + .get(swap_id.as_bytes())? + .ok_or_else(|| { + anyhow!( + "No Monero address found for swap id {} in database", + swap_id + ) + })?; + + let monero_address = deserialize(&encoded)?; + + Ok(monero_address) + } + + pub async fn insert_address(&self, peer_id: PeerId, address: Multiaddr) -> Result<()> { + let key = peer_id.to_bytes(); + + let existing_addresses = self.addresses.get(&key)?; + + let new_addresses = { + let existing_addresses = existing_addresses.clone(); + + Some(match existing_addresses { + Some(encoded) => { + let mut addresses = deserialize::>(&encoded)?; + addresses.push(address); + + serialize(&addresses)? + } + None => serialize(&[address])?, + }) + }; + + self.addresses + .compare_and_swap(key, existing_addresses, new_addresses)??; + + self.addresses + .flush_async() + .await + .map(|_| ()) + .context("Could not flush db") + } + + pub fn get_addresses(&self, peer_id: PeerId) -> Result> { + let key = peer_id.to_bytes(); + + let addresses = match self.addresses.get(&key)? { + Some(encoded) => deserialize(&encoded).context("Failed to deserialize addresses")?, + None => vec![], + }; + + Ok(addresses) + } + + pub async fn insert_latest_state(&self, swap_id: Uuid, state: Swap) -> Result<()> { + let key = serialize(&swap_id)?; + let new_value = serialize(&state).context("Could not serialize new state value")?; + + let old_value = self.swaps.get(&key)?; + + self.swaps + .compare_and_swap(key, old_value, Some(new_value)) + .context("Could not write in the DB")? + .context("Stored swap somehow changed, aborting saving")?; + + self.swaps + .flush_async() + .await + .map(|_| ()) + .context("Could not flush db") + } + + pub fn get_state(&self, swap_id: Uuid) -> Result { + let key = serialize(&swap_id)?; + + let encoded = self + .swaps + .get(&key)? + .ok_or_else(|| anyhow!("Swap with id {} not found in database", swap_id))?; + + let state = deserialize(&encoded).context("Could not deserialize state")?; + Ok(state) + } + + pub fn all_alice(&self) -> Result> { + self.all_alice_iter().collect() + } + + fn all_alice_iter(&self) -> impl Iterator> { + self.all_swaps_iter().map(|item| { + let (swap_id, swap) = item?; + Ok((swap_id, swap.try_into_alice()?)) + }) + } + + pub fn all_bob(&self) -> Result> { + self.all_bob_iter().collect() + } + + fn all_bob_iter(&self) -> impl Iterator> { + self.all_swaps_iter().map(|item| { + let (swap_id, swap) = item?; + Ok((swap_id, swap.try_into_bob()?)) + }) + } + + fn all_swaps_iter(&self) -> impl Iterator> { + self.swaps.iter().map(|item| { + let (key, value) = item.context("Failed to retrieve swap from DB")?; + + let swap_id = deserialize::(&key)?; + let swap = deserialize::(&value).context("Failed to deserialize swap")?; + + Ok((swap_id, swap)) + }) + } + + pub fn unfinished_alice(&self) -> Result> { + self.all_alice_iter() + .filter_ok(|(_swap_id, alice)| !matches!(alice, Alice::Done(_))) + .collect() + } +} + +pub fn serialize(t: &T) -> Result> +where + T: Serialize, +{ + Ok(serde_cbor::to_vec(t)?) +} + +pub fn deserialize(v: &[u8]) -> Result +where + T: DeserializeOwned, +{ + Ok(serde_cbor::from_slice(&v)?) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::database::alice::{Alice, AliceEndState}; + use crate::database::bob::{Bob, BobEndState}; + use crate::database::{NotAlice, NotBob}; + + #[tokio::test] + async fn can_write_and_read_to_multiple_keys() { + let db_dir = tempfile::tempdir().unwrap(); + let db = SledDatabase::open(db_dir.path()).unwrap(); + + let state_1 = Swap::Alice(Alice::Done(AliceEndState::BtcRedeemed)); + let swap_id_1 = Uuid::new_v4(); + db.insert_latest_state(swap_id_1, state_1.clone()) + .await + .expect("Failed to save second state"); + + let state_2 = Swap::Bob(Bob::Done(BobEndState::SafelyAborted)); + let swap_id_2 = Uuid::new_v4(); + db.insert_latest_state(swap_id_2, state_2.clone()) + .await + .expect("Failed to save first state"); + + let recovered_1 = db + .get_state(swap_id_1) + .expect("Failed to recover first state"); + + let recovered_2 = db + .get_state(swap_id_2) + .expect("Failed to recover second state"); + + assert_eq!(recovered_1, state_1); + assert_eq!(recovered_2, state_2); + } + + #[tokio::test] + async fn can_write_twice_to_one_key() { + let db_dir = tempfile::tempdir().unwrap(); + let db = SledDatabase::open(db_dir.path()).unwrap(); + + let state = Swap::Alice(Alice::Done(AliceEndState::SafelyAborted)); + + let swap_id = Uuid::new_v4(); + db.insert_latest_state(swap_id, state.clone()) + .await + .expect("Failed to save state the first time"); + let recovered = db + .get_state(swap_id) + .expect("Failed to recover state the first time"); + + // We insert and recover twice to ensure database implementation allows the + // caller to write to an existing key + db.insert_latest_state(swap_id, recovered) + .await + .expect("Failed to save state the second time"); + let recovered = db + .get_state(swap_id) + .expect("Failed to recover state the second time"); + + assert_eq!(recovered, state); + } + + #[tokio::test] + async fn all_swaps_as_alice() { + let db_dir = tempfile::tempdir().unwrap(); + let db = SledDatabase::open(db_dir.path()).unwrap(); + + let alice_state = Alice::Done(AliceEndState::BtcPunished); + let alice_swap = Swap::Alice(alice_state.clone()); + let alice_swap_id = Uuid::new_v4(); + db.insert_latest_state(alice_swap_id, alice_swap) + .await + .expect("Failed to save alice state 1"); + + let alice_swaps = db.all_alice().unwrap(); + assert_eq!(alice_swaps.len(), 1); + assert!(alice_swaps.contains(&(alice_swap_id, alice_state))); + + let bob_state = Bob::Done(BobEndState::SafelyAborted); + let bob_swap = Swap::Bob(bob_state); + let bob_swap_id = Uuid::new_v4(); + db.insert_latest_state(bob_swap_id, bob_swap) + .await + .expect("Failed to save bob state 1"); + + let err = db.all_alice().unwrap_err(); + + assert_eq!(err.downcast_ref::().unwrap(), &NotAlice); + } + + #[tokio::test] + async fn all_swaps_as_bob() { + let db_dir = tempfile::tempdir().unwrap(); + let db = SledDatabase::open(db_dir.path()).unwrap(); + + let bob_state = Bob::Done(BobEndState::SafelyAborted); + let bob_swap = Swap::Bob(bob_state.clone()); + let bob_swap_id = Uuid::new_v4(); + db.insert_latest_state(bob_swap_id, bob_swap) + .await + .expect("Failed to save bob state 1"); + + let bob_swaps = db.all_bob().unwrap(); + assert_eq!(bob_swaps.len(), 1); + assert!(bob_swaps.contains(&(bob_swap_id, bob_state))); + + let alice_state = Alice::Done(AliceEndState::BtcPunished); + let alice_swap = Swap::Alice(alice_state); + let alice_swap_id = Uuid::new_v4(); + db.insert_latest_state(alice_swap_id, alice_swap) + .await + .expect("Failed to save alice state 1"); + + let err = db.all_bob().unwrap_err(); + + assert_eq!(err.downcast_ref::().unwrap(), &NotBob); + } + + #[tokio::test] + async fn can_save_swap_state_and_peer_id_with_same_swap_id() -> Result<()> { + let db_dir = tempfile::tempdir().unwrap(); + let db = SledDatabase::open(db_dir.path()).unwrap(); + + let alice_id = Uuid::new_v4(); + let alice_state = Alice::Done(AliceEndState::BtcPunished); + let alice_swap = Swap::Alice(alice_state); + let peer_id = PeerId::random(); + + db.insert_latest_state(alice_id, alice_swap.clone()).await?; + db.insert_peer_id(alice_id, peer_id).await?; + + let loaded_swap = db.get_state(alice_id)?; + let loaded_peer_id = db.get_peer_id(alice_id)?; + + assert_eq!(alice_swap, loaded_swap); + assert_eq!(peer_id, loaded_peer_id); + + Ok(()) + } + + #[tokio::test] + async fn test_reopen_db() -> Result<()> { + let db_dir = tempfile::tempdir().unwrap(); + let alice_id = Uuid::new_v4(); + let alice_state = Alice::Done(AliceEndState::BtcPunished); + let alice_swap = Swap::Alice(alice_state); + + let peer_id = PeerId::random(); + + { + let db = SledDatabase::open(db_dir.path()).unwrap(); + db.insert_latest_state(alice_id, alice_swap.clone()).await?; + db.insert_peer_id(alice_id, peer_id).await?; + } + + let db = SledDatabase::open(db_dir.path()).unwrap(); + + let loaded_swap = db.get_state(alice_id)?; + let loaded_peer_id = db.get_peer_id(alice_id)?; + + assert_eq!(alice_swap, loaded_swap); + assert_eq!(peer_id, loaded_peer_id); + + Ok(()) + } + + #[tokio::test] + async fn save_and_load_addresses() -> Result<()> { + let db_dir = tempfile::tempdir()?; + let peer_id = PeerId::random(); + let home1 = "/ip4/127.0.0.1/tcp/1".parse::()?; + let home2 = "/ip4/127.0.0.1/tcp/2".parse::()?; + + { + let db = SledDatabase::open(db_dir.path())?; + db.insert_address(peer_id, home1.clone()).await?; + db.insert_address(peer_id, home2.clone()).await?; + } + + let addresses = SledDatabase::open(db_dir.path())?.get_addresses(peer_id)?; + + assert_eq!(addresses, vec![home1, home2]); + + Ok(()) + } + + #[tokio::test] + async fn save_and_load_monero_address() -> Result<()> { + let db_dir = tempfile::tempdir()?; + let swap_id = Uuid::new_v4(); + + SledDatabase::open(db_dir.path())?.insert_monero_address(swap_id, "53gEuGZUhP9JMEBZoGaFNzhwEgiG7hwQdMCqFxiyiTeFPmkbt1mAoNybEUvYBKHcnrSgxnVWgZsTvRBaHBNXPa8tHiCU51a".parse()?).await?; + let loaded_monero_address = + SledDatabase::open(db_dir.path())?.get_monero_address(swap_id)?; + + assert_eq!(loaded_monero_address.to_string(), "53gEuGZUhP9JMEBZoGaFNzhwEgiG7hwQdMCqFxiyiTeFPmkbt1mAoNybEUvYBKHcnrSgxnVWgZsTvRBaHBNXPa8tHiCU51a"); + + Ok(()) + } +} diff --git a/swap/src/protocol/alice.rs b/swap/src/protocol/alice.rs index f8e80ca7..b0d610b9 100644 --- a/swap/src/protocol/alice.rs +++ b/swap/src/protocol/alice.rs @@ -1,6 +1,6 @@ //! Run an XMR/BTC swap in the role of Alice. //! Alice holds XMR and wishes receive BTC. -use crate::database::Database; +use crate::database::SledDatabase; use crate::env::Config; use crate::{asb, bitcoin, monero}; use std::sync::Arc; @@ -19,5 +19,5 @@ pub struct Swap { pub monero_wallet: Arc, pub env_config: Config, pub swap_id: Uuid, - pub db: Arc, + pub db: Arc, } diff --git a/swap/src/protocol/bob.rs b/swap/src/protocol/bob.rs index dc8f7aec..bbc35e7f 100644 --- a/swap/src/protocol/bob.rs +++ b/swap/src/protocol/bob.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use anyhow::Result; use uuid::Uuid; -use crate::database::Database; +use crate::database::SledDatabase; use crate::{bitcoin, cli, env, monero}; pub use self::state::*; @@ -15,7 +15,7 @@ pub mod swap; pub struct Swap { pub state: BobState, pub event_loop_handle: cli::EventLoopHandle, - pub db: Database, + pub db: SledDatabase, pub bitcoin_wallet: Arc, pub monero_wallet: Arc, pub env_config: env::Config, @@ -26,7 +26,7 @@ pub struct Swap { impl Swap { #[allow(clippy::too_many_arguments)] pub fn new( - db: Database, + db: SledDatabase, id: Uuid, bitcoin_wallet: Arc, monero_wallet: Arc, @@ -53,7 +53,7 @@ impl Swap { #[allow(clippy::too_many_arguments)] pub fn from_db( - db: Database, + db: SledDatabase, id: Uuid, bitcoin_wallet: Arc, monero_wallet: Arc, diff --git a/swap/tests/harness/mod.rs b/swap/tests/harness/mod.rs index ad43bd9e..6e56e22f 100644 --- a/swap/tests/harness/mod.rs +++ b/swap/tests/harness/mod.rs @@ -16,7 +16,7 @@ use std::sync::Arc; use std::time::Duration; use swap::asb::FixedRate; use swap::bitcoin::{CancelTimelock, PunishTimelock, TxCancel, TxPunish, TxRedeem, TxRefund}; -use swap::database::Database; +use swap::database::SledDatabase; use swap::env::{Config, GetConfig}; use swap::network::swarm; use swap::protocol::alice::{AliceState, Swap}; @@ -222,7 +222,7 @@ async fn start_alice( bitcoin_wallet: Arc, monero_wallet: Arc, ) -> (AliceApplicationHandle, Receiver) { - let db = Arc::new(Database::open(db_path.as_path()).unwrap()); + let db = Arc::new(SledDatabase::open(db_path.as_path()).unwrap()); let min_buy = bitcoin::Amount::from_sat(u64::MIN); let max_buy = bitcoin::Amount::from_sat(u64::MAX); @@ -402,7 +402,7 @@ struct BobParams { impl BobParams { pub async fn new_swap_from_db(&self, swap_id: Uuid) -> Result<(bob::Swap, cli::EventLoop)> { let (event_loop, handle) = self.new_eventloop(swap_id).await?; - let db = Database::open(&self.db_path)?; + let db = SledDatabase::open(&self.db_path)?; let swap = bob::Swap::from_db( db, @@ -424,7 +424,7 @@ impl BobParams { let swap_id = Uuid::new_v4(); let (event_loop, handle) = self.new_eventloop(swap_id).await?; - let db = Database::open(&self.db_path)?; + let db = SledDatabase::open(&self.db_path)?; let swap = bob::Swap::new( db,