diff --git a/swap/src/database.rs b/swap/src/database.rs index a22cd7d3..056084f8 100644 --- a/swap/src/database.rs +++ b/swap/src/database.rs @@ -1,25 +1,17 @@ -pub use alice::Alice; -pub use bob::Bob; - -use async_trait::async_trait; -use anyhow::{anyhow, Context, Result}; -use itertools::Itertools; -use libp2p::{Multiaddr, PeerId}; -use serde::de::DeserializeOwned; -use serde::{Deserialize, Serialize}; -use std::fmt::{Display, Debug}; -use std::path::Path; -use std::str::FromStr; -use uuid::Uuid; -use crate::protocol::alice::AliceState; +use crate::database::alice::Alice; +use crate::database::bob::Bob; +use crate::protocol::State; +use std::fmt::Display; use crate::protocol::bob::BobState; -use crate::protocol::{Database, State}; -use std::collections::HashMap; - +use crate::protocol::alice::AliceState; +use serde::{Deserialize, Serialize}; +mod sled; mod alice; mod bob; +pub use self::sled::SledDatabase; + #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] pub enum Swap { Alice(Alice), @@ -76,351 +68,3 @@ impl From for Swap { Self::Alice(Alice::from(state)) } } - -#[derive(Clone)] -pub struct SledDatabase { - swaps: sled::Tree, - peers: sled::Tree, - addresses: sled::Tree, - monero_addresses: sled::Tree, -} - -#[async_trait] -impl Database for SledDatabase { - async fn insert_peer_id(&self, swap_id: Uuid, peer_id: PeerId) -> Result<()> { - let peer_id_str = peer_id.to_string(); - - let key = serialize(&swap_id)?; - let value = serialize(&peer_id_str).context("Could not serialize peer-id")?; - - self.peers.insert(key, value)?; - - self.peers - .flush_async() - .await - .map(|_| ()) - .context("Could not flush db") - } - - async fn get_peer_id(&self, swap_id: Uuid) -> Result { - let key = serialize(&swap_id)?; - - let encoded = self - .peers - .get(&key)? - .ok_or_else(|| anyhow!("No peer-id found for swap id {} in database", swap_id))?; - - let peer_id: String = deserialize(&encoded).context("Could not deserialize peer-id")?; - Ok(PeerId::from_str(peer_id.as_str())?) - } - - async fn insert_monero_address( - &self, - swap_id: Uuid, - address: monero::Address, - ) -> Result<()> { - let key = swap_id.as_bytes(); - let value = serialize(&address)?; - - self.monero_addresses.insert(key, value)?; - - self.monero_addresses - .flush_async() - .await - .map(|_| ()) - .context("Could not flush db") - } - - async fn get_monero_address(&self, swap_id: Uuid) -> Result { - let encoded = self - .monero_addresses - .get(swap_id.as_bytes())? - .ok_or_else(|| { - anyhow!( - "No Monero address found for swap id {} in database", - swap_id - ) - })?; - - let monero_address = deserialize(&encoded)?; - - Ok(monero_address) - } - - async fn insert_address(&self, peer_id: PeerId, address: Multiaddr) -> Result<()> { - let key = peer_id.to_bytes(); - - let existing_addresses = self.addresses.get(&key)?; - - let new_addresses = { - let existing_addresses = existing_addresses.clone(); - - Some(match existing_addresses { - Some(encoded) => { - let mut addresses = deserialize::>(&encoded)?; - addresses.push(address); - - serialize(&addresses)? - } - None => serialize(&[address])?, - }) - }; - - self.addresses - .compare_and_swap(key, existing_addresses, new_addresses)??; - - self.addresses - .flush_async() - .await - .map(|_| ()) - .context("Could not flush db") - } - - async fn get_addresses(&self, peer_id: PeerId) -> Result> { - let key = peer_id.to_bytes(); - - let addresses = match self.addresses.get(&key)? { - Some(encoded) => deserialize(&encoded).context("Failed to deserialize addresses")?, - None => vec![], - }; - - Ok(addresses) - } - - async fn insert_latest_state(&self, swap_id: Uuid, state: State) -> Result<()> { - let key = serialize(&swap_id)?; - let swap = Swap::from(state); - let new_value = serialize(&swap).context("Could not serialize new state value")?; - - let old_value = self.swaps.get(&key)?; - - self.swaps - .compare_and_swap(key, old_value, Some(new_value)) - .context("Could not write in the DB")? - .context("Stored swap somehow changed, aborting saving")?; - - self.swaps - .flush_async() - .await - .map(|_| ()) - .context("Could not flush db") - } - - async fn get_state(&self, swap_id: Uuid) -> Result { - let key = serialize(&swap_id)?; - - let encoded = self - .swaps - .get(&key)? - .ok_or_else(|| anyhow!("Swap with id {} not found in database", swap_id))?; - - let swap = deserialize::(&encoded).context("Could not deserialize state")?; - - let state = State::from(swap); - - Ok(state) - } - - async fn all(&self) -> Result> { - self.all_iter().collect() - } - - async fn unfinished(&self, unfinished: fn(State) -> bool) -> Result> { - self.all_iter().into_iter() - .filter_ok(|(_swap_id, state)| unfinished(state.clone())) - .collect() - } -} - -impl SledDatabase { - pub async fn open(path: &Path) -> Result { - tracing::debug!("Opening database at {}", path.display()); - - let db = - sled::open(path).with_context(|| format!("Could not open the DB at {:?}", path))?; - - let swaps = db.open_tree("swaps")?; - let peers = db.open_tree("peers")?; - let addresses = db.open_tree("addresses")?; - let monero_addresses = db.open_tree("monero_addresses")?; - - Ok(SledDatabase { - swaps, - peers, - addresses, - monero_addresses, - }) - } - - fn all_iter(&self) -> impl Iterator> { - self.swaps.iter().map(|item| { - let (key, value) = item.context("Failed to retrieve swap from DB")?; - - let swap_id = deserialize::(&key)?; - let swap = deserialize::(&value).context("Failed to deserialize swap")?; - - let state = State::from(swap); - - Ok((swap_id, state)) - }) - } -} - -pub fn serialize(t: &T) -> Result> -where - T: Serialize, -{ - Ok(serde_cbor::to_vec(t)?) -} - -pub fn deserialize(v: &[u8]) -> Result -where - T: DeserializeOwned, -{ - Ok(serde_cbor::from_slice(&v)?) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::protocol::alice::AliceState; - - #[tokio::test] - async fn can_write_and_read_to_multiple_keys() { - let db_dir = tempfile::tempdir().unwrap(); - let db = SledDatabase::open(db_dir.path()).await.unwrap(); - - let state_1 = State::from(AliceState::BtcRedeemed); - let swap_id_1 = Uuid::new_v4(); - db.insert_latest_state(swap_id_1, state_1.clone()) - .await - .expect("Failed to save second state"); - - let state_2 = State::from(AliceState::BtcPunished); - let swap_id_2 = Uuid::new_v4(); - db.insert_latest_state(swap_id_2, state_2.clone()) - .await - .expect("Failed to save first state"); - - let recovered_1 = db - .get_state(swap_id_1) - .await - .expect("Failed to recover first state"); - - let recovered_2 = db - .get_state(swap_id_2) - .await - .expect("Failed to recover second state"); - - assert_eq!(recovered_1, state_1); - assert_eq!(recovered_2, state_2); - } - - #[tokio::test] - async fn can_write_twice_to_one_key() { - let db_dir = tempfile::tempdir().unwrap(); - let db = SledDatabase::open(db_dir.path()).await.unwrap(); - - let state = State::from(AliceState::SafelyAborted); - - let swap_id = Uuid::new_v4(); - db.insert_latest_state(swap_id, state.clone()) - .await - .expect("Failed to save state the first time"); - let recovered = db - .get_state(swap_id) - .await - .expect("Failed to recover state the first time"); - - // We insert and recover twice to ensure database implementation allows the - // caller to write to an existing key - db.insert_latest_state(swap_id, recovered) - .await - .expect("Failed to save state the second time"); - let recovered = db - .get_state(swap_id) - .await - .expect("Failed to recover state the second time"); - - assert_eq!(recovered, state); - } - - #[tokio::test] - async fn can_save_swap_state_and_peer_id_with_same_swap_id() -> Result<()> { - let db_dir = tempfile::tempdir().unwrap(); - let db = SledDatabase::open(db_dir.path()).await.unwrap(); - - let alice_id = Uuid::new_v4(); - let alice_state = State::from(AliceState::BtcPunished); - let peer_id = PeerId::random(); - - db.insert_latest_state(alice_id, alice_state.clone()).await?; - db.insert_peer_id(alice_id, peer_id).await?; - - let loaded_swap = db.get_state(alice_id).await?; - let loaded_peer_id = db.get_peer_id(alice_id).await?; - - assert_eq!(alice_state, loaded_swap); - assert_eq!(peer_id, loaded_peer_id); - - Ok(()) - } - - #[tokio::test] - async fn test_reopen_db() -> Result<()> { - let db_dir = tempfile::tempdir().unwrap(); - let alice_id = Uuid::new_v4(); - let alice_state = State::from(AliceState::BtcPunished); - - let peer_id = PeerId::random(); - - { - let db = SledDatabase::open(db_dir.path()).await.unwrap(); - db.insert_latest_state(alice_id,alice_state.clone()).await?; - db.insert_peer_id(alice_id, peer_id).await?; - } - - let db = SledDatabase::open(db_dir.path()).await.unwrap(); - - let loaded_swap = db.get_state(alice_id).await?; - let loaded_peer_id = db.get_peer_id(alice_id).await?; - - assert_eq!(alice_state, loaded_swap); - assert_eq!(peer_id, loaded_peer_id); - - Ok(()) - } - - #[tokio::test] - async fn save_and_load_addresses() -> Result<()> { - let db_dir = tempfile::tempdir()?; - let peer_id = PeerId::random(); - let home1 = "/ip4/127.0.0.1/tcp/1".parse::()?; - let home2 = "/ip4/127.0.0.1/tcp/2".parse::()?; - - { - let db = SledDatabase::open(db_dir.path()).await?; - db.insert_address(peer_id, home1.clone()).await?; - db.insert_address(peer_id, home2.clone()).await?; - } - - let addresses = SledDatabase::open(db_dir.path()).await?.get_addresses(peer_id).await?; - - assert_eq!(addresses, vec![home1, home2]); - - Ok(()) - } - - #[tokio::test] - async fn save_and_load_monero_address() -> Result<()> { - let db_dir = tempfile::tempdir()?; - let swap_id = Uuid::new_v4(); - - SledDatabase::open(db_dir.path()).await?.insert_monero_address(swap_id, "53gEuGZUhP9JMEBZoGaFNzhwEgiG7hwQdMCqFxiyiTeFPmkbt1mAoNybEUvYBKHcnrSgxnVWgZsTvRBaHBNXPa8tHiCU51a".parse()?).await?; - let loaded_monero_address = SledDatabase::open(db_dir.path()).await?.get_monero_address(swap_id).await?; - - assert_eq!(loaded_monero_address.to_string(), "53gEuGZUhP9JMEBZoGaFNzhwEgiG7hwQdMCqFxiyiTeFPmkbt1mAoNybEUvYBKHcnrSgxnVWgZsTvRBaHBNXPa8tHiCU51a"); - - Ok(()) - } -} diff --git a/swap/src/database/sled.rs b/swap/src/database/sled.rs new file mode 100644 index 00000000..c5f9e712 --- /dev/null +++ b/swap/src/database/sled.rs @@ -0,0 +1,364 @@ +use std::path::Path; +use std::str::FromStr; + +use anyhow::{anyhow, Context, Result}; +use async_trait::async_trait; +use itertools::Itertools; +use libp2p::{Multiaddr, PeerId}; +use serde::Serialize; +use serde::de::DeserializeOwned; +use uuid::Uuid; + +pub use crate::database::alice::Alice; +pub use crate::database::bob::Bob; +use crate::database::Swap; +use crate::protocol::{Database, State}; +use std::collections::HashMap; + +#[derive(Clone)] +pub struct SledDatabase { + swaps: sled::Tree, + peers: sled::Tree, + addresses: sled::Tree, + monero_addresses: sled::Tree, +} + +#[async_trait] +impl Database for SledDatabase { + async fn insert_peer_id(&self, swap_id: Uuid, peer_id: PeerId) -> Result<()> { + let peer_id_str = peer_id.to_string(); + + let key = serialize(&swap_id)?; + let value = serialize(&peer_id_str).context("Could not serialize peer-id")?; + + self.peers.insert(key, value)?; + + self.peers + .flush_async() + .await + .map(|_| ()) + .context("Could not flush db") + } + + async fn get_peer_id(&self, swap_id: Uuid) -> Result { + let key = serialize(&swap_id)?; + + let encoded = self + .peers + .get(&key)? + .ok_or_else(|| anyhow!("No peer-id found for swap id {} in database", swap_id))?; + + let peer_id: String = deserialize(&encoded).context("Could not deserialize peer-id")?; + Ok(PeerId::from_str(peer_id.as_str())?) + } + + async fn insert_monero_address( + &self, + swap_id: Uuid, + address: monero::Address, + ) -> Result<()> { + let key = swap_id.as_bytes(); + let value = serialize(&address)?; + + self.monero_addresses.insert(key, value)?; + + self.monero_addresses + .flush_async() + .await + .map(|_| ()) + .context("Could not flush db") + } + + async fn get_monero_address(&self, swap_id: Uuid) -> Result { + let encoded = self + .monero_addresses + .get(swap_id.as_bytes())? + .ok_or_else(|| { + anyhow!( + "No Monero address found for swap id {} in database", + swap_id + ) + })?; + + let monero_address = deserialize(&encoded)?; + + Ok(monero_address) + } + + async fn insert_address(&self, peer_id: PeerId, address: Multiaddr) -> Result<()> { + let key = peer_id.to_bytes(); + + let existing_addresses = self.addresses.get(&key)?; + + let new_addresses = { + let existing_addresses = existing_addresses.clone(); + + Some(match existing_addresses { + Some(encoded) => { + let mut addresses = deserialize::>(&encoded)?; + addresses.push(address); + + serialize(&addresses)? + } + None => serialize(&[address])?, + }) + }; + + self.addresses + .compare_and_swap(key, existing_addresses, new_addresses)??; + + self.addresses + .flush_async() + .await + .map(|_| ()) + .context("Could not flush db") + } + + async fn get_addresses(&self, peer_id: PeerId) -> Result> { + let key = peer_id.to_bytes(); + + let addresses = match self.addresses.get(&key)? { + Some(encoded) => deserialize(&encoded).context("Failed to deserialize addresses")?, + None => vec![], + }; + + Ok(addresses) + } + + async fn insert_latest_state(&self, swap_id: Uuid, state: State) -> Result<()> { + let key = serialize(&swap_id)?; + let swap = Swap::from(state); + let new_value = serialize(&swap).context("Could not serialize new state value")?; + + let old_value = self.swaps.get(&key)?; + + self.swaps + .compare_and_swap(key, old_value, Some(new_value)) + .context("Could not write in the DB")? + .context("Stored swap somehow changed, aborting saving")?; + + self.swaps + .flush_async() + .await + .map(|_| ()) + .context("Could not flush db") + } + + async fn get_state(&self, swap_id: Uuid) -> Result { + let key = serialize(&swap_id)?; + + let encoded = self + .swaps + .get(&key)? + .ok_or_else(|| anyhow!("Swap with id {} not found in database", swap_id))?; + + let swap = deserialize::(&encoded).context("Could not deserialize state")?; + + let state = State::from(swap); + + Ok(state) + } + + async fn all(&self) -> Result> { + self.all_iter().collect() + } + + async fn unfinished(&self, unfinished: fn(State) -> bool) -> Result> { + self.all_iter().into_iter() + .filter_ok(|(_swap_id, state)| unfinished(state.clone())) + .collect() + } +} + +impl SledDatabase { + pub async fn open(path: &Path) -> Result { + tracing::debug!("Opening database at {}", path.display()); + + let db = + sled::open(path).with_context(|| format!("Could not open the DB at {:?}", path))?; + + let swaps = db.open_tree("swaps")?; + let peers = db.open_tree("peers")?; + let addresses = db.open_tree("addresses")?; + let monero_addresses = db.open_tree("monero_addresses")?; + + Ok(SledDatabase { + swaps, + peers, + addresses, + monero_addresses, + }) + } + + fn all_iter(&self) -> impl Iterator> { + self.swaps.iter().map(|item| { + let (key, value) = item.context("Failed to retrieve swap from DB")?; + + let swap_id = deserialize::(&key)?; + let swap = deserialize::(&value).context("Failed to deserialize swap")?; + + let state = State::from(swap); + + Ok((swap_id, state)) + }) + } +} + +pub fn serialize(t: &T) -> Result> +where + T: Serialize, +{ + Ok(serde_cbor::to_vec(t)?) +} + +pub fn deserialize(v: &[u8]) -> Result +where + T: DeserializeOwned, +{ + Ok(serde_cbor::from_slice(&v)?) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocol::alice::AliceState; + + #[tokio::test] + async fn can_write_and_read_to_multiple_keys() { + let db_dir = tempfile::tempdir().unwrap(); + let db = SledDatabase::open(db_dir.path()).await.unwrap(); + + let state_1 = State::from(AliceState::BtcRedeemed); + let swap_id_1 = Uuid::new_v4(); + db.insert_latest_state(swap_id_1, state_1.clone()) + .await + .expect("Failed to save second state"); + + let state_2 = State::from(AliceState::BtcPunished); + let swap_id_2 = Uuid::new_v4(); + db.insert_latest_state(swap_id_2, state_2.clone()) + .await + .expect("Failed to save first state"); + + let recovered_1 = db + .get_state(swap_id_1) + .await + .expect("Failed to recover first state"); + + let recovered_2 = db + .get_state(swap_id_2) + .await + .expect("Failed to recover second state"); + + assert_eq!(recovered_1, state_1); + assert_eq!(recovered_2, state_2); + } + + #[tokio::test] + async fn can_write_twice_to_one_key() { + let db_dir = tempfile::tempdir().unwrap(); + let db = SledDatabase::open(db_dir.path()).await.unwrap(); + + let state = State::from(AliceState::SafelyAborted); + + let swap_id = Uuid::new_v4(); + db.insert_latest_state(swap_id, state.clone()) + .await + .expect("Failed to save state the first time"); + let recovered = db + .get_state(swap_id) + .await + .expect("Failed to recover state the first time"); + + // We insert and recover twice to ensure database implementation allows the + // caller to write to an existing key + db.insert_latest_state(swap_id, recovered) + .await + .expect("Failed to save state the second time"); + let recovered = db + .get_state(swap_id) + .await + .expect("Failed to recover state the second time"); + + assert_eq!(recovered, state); + } + + #[tokio::test] + async fn can_save_swap_state_and_peer_id_with_same_swap_id() -> Result<()> { + let db_dir = tempfile::tempdir().unwrap(); + let db = SledDatabase::open(db_dir.path()).await.unwrap(); + + let alice_id = Uuid::new_v4(); + let alice_state = State::from(AliceState::BtcPunished); + let peer_id = PeerId::random(); + + db.insert_latest_state(alice_id, alice_state.clone()).await?; + db.insert_peer_id(alice_id, peer_id).await?; + + let loaded_swap = db.get_state(alice_id).await?; + let loaded_peer_id = db.get_peer_id(alice_id).await?; + + assert_eq!(alice_state, loaded_swap); + assert_eq!(peer_id, loaded_peer_id); + + Ok(()) + } + + #[tokio::test] + async fn test_reopen_db() -> Result<()> { + let db_dir = tempfile::tempdir().unwrap(); + let alice_id = Uuid::new_v4(); + let alice_state = State::from(AliceState::BtcPunished); + + let peer_id = PeerId::random(); + + { + let db = SledDatabase::open(db_dir.path()).await.unwrap(); + db.insert_latest_state(alice_id,alice_state.clone()).await?; + db.insert_peer_id(alice_id, peer_id).await?; + } + + let db = SledDatabase::open(db_dir.path()).await.unwrap(); + + let loaded_swap = db.get_state(alice_id).await?; + let loaded_peer_id = db.get_peer_id(alice_id).await?; + + assert_eq!(alice_state, loaded_swap); + assert_eq!(peer_id, loaded_peer_id); + + Ok(()) + } + + #[tokio::test] + async fn save_and_load_addresses() -> Result<()> { + let db_dir = tempfile::tempdir()?; + let peer_id = PeerId::random(); + let home1 = "/ip4/127.0.0.1/tcp/1".parse::()?; + let home2 = "/ip4/127.0.0.1/tcp/2".parse::()?; + + { + let db = SledDatabase::open(db_dir.path()).await?; + db.insert_address(peer_id, home1.clone()).await?; + db.insert_address(peer_id, home2.clone()).await?; + } + + let addresses = SledDatabase::open(db_dir.path()).await?.get_addresses(peer_id).await?; + + assert_eq!(addresses, vec![home1, home2]); + + Ok(()) + } + + #[tokio::test] + async fn save_and_load_monero_address() -> Result<()> { + let db_dir = tempfile::tempdir()?; + let swap_id = Uuid::new_v4(); + + SledDatabase::open(db_dir.path()).await?.insert_monero_address(swap_id, "53gEuGZUhP9JMEBZoGaFNzhwEgiG7hwQdMCqFxiyiTeFPmkbt1mAoNybEUvYBKHcnrSgxnVWgZsTvRBaHBNXPa8tHiCU51a".parse()?).await?; + let loaded_monero_address = SledDatabase::open(db_dir.path()).await?.get_monero_address(swap_id).await?; + + assert_eq!(loaded_monero_address.to_string(), "53gEuGZUhP9JMEBZoGaFNzhwEgiG7hwQdMCqFxiyiTeFPmkbt1mAoNybEUvYBKHcnrSgxnVWgZsTvRBaHBNXPa8tHiCU51a"); + + Ok(()) + } +}