Create Database trait

Use domain types in database API to prevent leaking of database types.
This trait will allow us to smoothly introduce the sqlite database.
This commit is contained in:
rishflab 2021-09-28 10:15:31 +10:00
parent a94c320021
commit da9d09aa5e
21 changed files with 351 additions and 253 deletions

View file

@ -78,8 +78,8 @@ pub enum AliceEndState {
BtcPunished,
}
impl From<&AliceState> for Alice {
fn from(alice_state: &AliceState) -> Self {
impl From<AliceState> for Alice {
fn from(alice_state: AliceState) -> Self {
match alice_state {
AliceState::Started { state3 } => Alice::Started {
state3: state3.as_ref().clone(),
@ -95,8 +95,8 @@ impl From<&AliceState> for Alice {
transfer_proof,
state3,
} => Alice::XmrLockTransactionSent {
monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight,
transfer_proof: transfer_proof.clone(),
monero_wallet_restore_blockheight,
transfer_proof,
state3: state3.as_ref().clone(),
},
AliceState::XmrLocked {
@ -104,8 +104,8 @@ impl From<&AliceState> for Alice {
transfer_proof,
state3,
} => Alice::XmrLocked {
monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight,
transfer_proof: transfer_proof.clone(),
monero_wallet_restore_blockheight,
transfer_proof,
state3: state3.as_ref().clone(),
},
AliceState::XmrLockTransferProofSent {
@ -113,8 +113,8 @@ impl From<&AliceState> for Alice {
transfer_proof,
state3,
} => Alice::XmrLockTransferProofSent {
monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight,
transfer_proof: transfer_proof.clone(),
monero_wallet_restore_blockheight,
transfer_proof,
state3: state3.as_ref().clone(),
},
AliceState::EncSigLearned {
@ -123,10 +123,10 @@ impl From<&AliceState> for Alice {
state3,
encrypted_signature,
} => Alice::EncSigLearned {
monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight,
transfer_proof: transfer_proof.clone(),
monero_wallet_restore_blockheight,
transfer_proof,
state3: state3.as_ref().clone(),
encrypted_signature: *encrypted_signature.clone(),
encrypted_signature: encrypted_signature.as_ref().clone(),
},
AliceState::BtcRedeemTransactionPublished { state3 } => {
Alice::BtcRedeemTransactionPublished {
@ -139,8 +139,8 @@ impl From<&AliceState> for Alice {
transfer_proof,
state3,
} => Alice::BtcCancelled {
monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight,
transfer_proof: transfer_proof.clone(),
monero_wallet_restore_blockheight,
transfer_proof,
state3: state3.as_ref().clone(),
},
AliceState::BtcRefunded {
@ -149,9 +149,9 @@ impl From<&AliceState> for Alice {
spend_key,
state3,
} => Alice::BtcRefunded {
monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight,
transfer_proof: transfer_proof.clone(),
spend_key: *spend_key,
monero_wallet_restore_blockheight,
transfer_proof,
spend_key,
state3: state3.as_ref().clone(),
},
AliceState::BtcPunishable {
@ -159,8 +159,8 @@ impl From<&AliceState> for Alice {
transfer_proof,
state3,
} => Alice::BtcPunishable {
monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight,
transfer_proof: transfer_proof.clone(),
monero_wallet_restore_blockheight,
transfer_proof,
state3: state3.as_ref().clone(),
},
AliceState::XmrRefunded => Alice::Done(AliceEndState::XmrRefunded),
@ -169,8 +169,8 @@ impl From<&AliceState> for Alice {
transfer_proof,
state3,
} => Alice::CancelTimelockExpired {
monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight,
transfer_proof: transfer_proof.clone(),
monero_wallet_restore_blockheight,
transfer_proof,
state3: state3.as_ref().clone(),
},
AliceState::BtcPunished => Alice::Done(AliceEndState::BtcPunished),

View file

@ -1,6 +1,7 @@
use crate::database::{Alice, Bob, Swap};
use crate::database::Swap;
use crate::protocol::{Database, State};
use anyhow::{anyhow, Context, Result};
use itertools::Itertools;
use async_trait::async_trait;
use libp2p::{Multiaddr, PeerId};
use serde::de::DeserializeOwned;
use serde::Serialize;
@ -8,6 +9,9 @@ use std::path::Path;
use std::str::FromStr;
use uuid::Uuid;
pub use crate::database::alice::Alice;
pub use crate::database::bob::Bob;
pub struct SledDatabase {
swaps: sled::Tree,
peers: sled::Tree,
@ -15,27 +19,9 @@ pub struct SledDatabase {
monero_addresses: sled::Tree,
}
impl SledDatabase {
pub fn open(path: &Path) -> Result<Self> {
tracing::debug!("Opening database at {}", path.display());
let db =
sled::open(path).with_context(|| format!("Could not open the DB at {:?}", path))?;
let swaps = db.open_tree("swaps")?;
let peers = db.open_tree("peers")?;
let addresses = db.open_tree("addresses")?;
let monero_addresses = db.open_tree("monero_addresses")?;
Ok(SledDatabase {
swaps,
peers,
addresses,
monero_addresses,
})
}
pub async fn insert_peer_id(&self, swap_id: Uuid, peer_id: PeerId) -> Result<()> {
#[async_trait]
impl Database for SledDatabase {
async fn insert_peer_id(&self, swap_id: Uuid, peer_id: PeerId) -> Result<()> {
let peer_id_str = peer_id.to_string();
let key = serialize(&swap_id)?;
@ -50,7 +36,7 @@ impl SledDatabase {
.context("Could not flush db")
}
pub fn get_peer_id(&self, swap_id: Uuid) -> Result<PeerId> {
async fn get_peer_id(&self, swap_id: Uuid) -> Result<PeerId> {
let key = serialize(&swap_id)?;
let encoded = self
@ -62,11 +48,7 @@ impl SledDatabase {
Ok(PeerId::from_str(peer_id.as_str())?)
}
pub async fn insert_monero_address(
&self,
swap_id: Uuid,
address: monero::Address,
) -> Result<()> {
async fn insert_monero_address(&self, swap_id: Uuid, address: monero::Address) -> Result<()> {
let key = swap_id.as_bytes();
let value = serialize(&address)?;
@ -79,7 +61,7 @@ impl SledDatabase {
.context("Could not flush db")
}
pub fn get_monero_address(&self, swap_id: Uuid) -> Result<monero::Address> {
async fn get_monero_address(&self, swap_id: Uuid) -> Result<monero::Address> {
let encoded = self
.monero_addresses
.get(swap_id.as_bytes())?
@ -95,7 +77,7 @@ impl SledDatabase {
Ok(monero_address)
}
pub async fn insert_address(&self, peer_id: PeerId, address: Multiaddr) -> Result<()> {
async fn insert_address(&self, peer_id: PeerId, address: Multiaddr) -> Result<()> {
let key = peer_id.to_bytes();
let existing_addresses = self.addresses.get(&key)?;
@ -124,7 +106,7 @@ impl SledDatabase {
.context("Could not flush db")
}
pub fn get_addresses(&self, peer_id: PeerId) -> Result<Vec<Multiaddr>> {
async fn get_addresses(&self, peer_id: PeerId) -> Result<Vec<Multiaddr>> {
let key = peer_id.to_bytes();
let addresses = match self.addresses.get(&key)? {
@ -135,9 +117,10 @@ impl SledDatabase {
Ok(addresses)
}
pub async fn insert_latest_state(&self, swap_id: Uuid, state: Swap) -> Result<()> {
async fn insert_latest_state(&self, swap_id: Uuid, state: State) -> Result<()> {
let key = serialize(&swap_id)?;
let new_value = serialize(&state).context("Could not serialize new state value")?;
let swap = Swap::from(state);
let new_value = serialize(&swap).context("Could not serialize new state value")?;
let old_value = self.swaps.get(&key)?;
@ -153,7 +136,7 @@ impl SledDatabase {
.context("Could not flush db")
}
pub fn get_state(&self, swap_id: Uuid) -> Result<Swap> {
async fn get_state(&self, swap_id: Uuid) -> Result<State> {
let key = serialize(&swap_id)?;
let encoded = self
@ -161,47 +144,91 @@ impl SledDatabase {
.get(&key)?
.ok_or_else(|| anyhow!("Swap with id {} not found in database", swap_id))?;
let state = deserialize(&encoded).context("Could not deserialize state")?;
let swap = deserialize::<Swap>(&encoded).context("Could not deserialize state")?;
let state = State::from(swap);
Ok(state)
}
pub fn all_alice(&self) -> Result<Vec<(Uuid, Alice)>> {
self.all_alice_iter().collect()
async fn all(&self) -> Result<Vec<(Uuid, State)>> {
self.all_iter().collect()
}
}
fn all_alice_iter(&self) -> impl Iterator<Item = Result<(Uuid, Alice)>> {
self.all_swaps_iter().map(|item| {
let (swap_id, swap) = item?;
Ok((swap_id, swap.try_into_alice()?))
impl SledDatabase {
pub async fn open(path: &Path) -> Result<Self> {
tracing::debug!("Opening database at {}", path.display());
let db =
sled::open(path).with_context(|| format!("Could not open the DB at {:?}", path))?;
let swaps = db.open_tree("swaps")?;
let peers = db.open_tree("peers")?;
let addresses = db.open_tree("addresses")?;
let monero_addresses = db.open_tree("monero_addresses")?;
Ok(SledDatabase {
swaps,
peers,
addresses,
monero_addresses,
})
}
pub fn all_bob(&self) -> Result<Vec<(Uuid, Bob)>> {
self.all_bob_iter().collect()
}
pub fn get_all_peers(&self) -> impl Iterator<Item = Result<(Uuid, PeerId)>> {
self.peers.iter().map(|item| {
let (key, value) = item.context("Failed to retrieve peer id from DB")?;
fn all_bob_iter(&self) -> impl Iterator<Item = Result<(Uuid, Bob)>> {
self.all_swaps_iter().map(|item| {
let (swap_id, swap) = item?;
Ok((swap_id, swap.try_into_bob()?))
let swap_id = deserialize::<Uuid>(&key)?;
let peer_id_bytes =
deserialize::<Vec<u8>>(&value).context("Failed to deserialize swap")?;
let peer_id = PeerId::from_bytes(&peer_id_bytes)?;
Ok((swap_id, peer_id))
})
}
fn all_swaps_iter(&self) -> impl Iterator<Item = Result<(Uuid, Swap)>> {
pub fn get_all_addresses(&self) -> impl Iterator<Item = Result<(PeerId, Vec<Multiaddr>)>> {
self.addresses.iter().map(|item| {
let (key, value) = item.context("Failed to retrieve peer address from DB")?;
let peer_id_bytes = deserialize::<Vec<u8>>(&key)?;
let addr =
deserialize::<Vec<Multiaddr>>(&value).context("Failed to deserialize swap")?;
let peer_id = PeerId::from_bytes(&peer_id_bytes)?;
Ok((peer_id, addr))
})
}
pub fn get_all_monero_addresses(
&self,
) -> impl Iterator<Item = Result<(Uuid, monero::Address)>> {
self.monero_addresses.iter().map(|item| {
let (key, value) = item.context("Failed to retrieve monero address from DB")?;
let swap_id = deserialize::<Uuid>(&key)?;
let addr =
deserialize::<monero::Address>(&value).context("Failed to deserialize swap")?;
Ok((swap_id, addr))
})
}
fn all_iter(&self) -> impl Iterator<Item = Result<(Uuid, State)>> {
self.swaps.iter().map(|item| {
let (key, value) = item.context("Failed to retrieve swap from DB")?;
let swap_id = deserialize::<Uuid>(&key)?;
let swap = deserialize::<Swap>(&value).context("Failed to deserialize swap")?;
Ok((swap_id, swap))
})
}
let state = State::from(swap);
pub fn unfinished_alice(&self) -> Result<Vec<(Uuid, Alice)>> {
self.all_alice_iter()
.filter_ok(|(_swap_id, alice)| !matches!(alice, Alice::Done(_)))
.collect()
Ok((swap_id, state))
})
}
}
@ -222,22 +249,20 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::database::alice::{Alice, AliceEndState};
use crate::database::bob::{Bob, BobEndState};
use crate::database::{NotAlice, NotBob};
use crate::protocol::alice::AliceState;
#[tokio::test]
async fn can_write_and_read_to_multiple_keys() {
let db_dir = tempfile::tempdir().unwrap();
let db = SledDatabase::open(db_dir.path()).unwrap();
let db = SledDatabase::open(db_dir.path()).await.unwrap();
let state_1 = Swap::Alice(Alice::Done(AliceEndState::BtcRedeemed));
let state_1 = State::from(AliceState::BtcRedeemed);
let swap_id_1 = Uuid::new_v4();
db.insert_latest_state(swap_id_1, state_1.clone())
.await
.expect("Failed to save second state");
let state_2 = Swap::Bob(Bob::Done(BobEndState::SafelyAborted));
let state_2 = State::from(AliceState::BtcPunished);
let swap_id_2 = Uuid::new_v4();
db.insert_latest_state(swap_id_2, state_2.clone())
.await
@ -245,10 +270,12 @@ mod tests {
let recovered_1 = db
.get_state(swap_id_1)
.await
.expect("Failed to recover first state");
let recovered_2 = db
.get_state(swap_id_2)
.await
.expect("Failed to recover second state");
assert_eq!(recovered_1, state_1);
@ -258,9 +285,9 @@ mod tests {
#[tokio::test]
async fn can_write_twice_to_one_key() {
let db_dir = tempfile::tempdir().unwrap();
let db = SledDatabase::open(db_dir.path()).unwrap();
let db = SledDatabase::open(db_dir.path()).await.unwrap();
let state = Swap::Alice(Alice::Done(AliceEndState::SafelyAborted));
let state = State::from(AliceState::SafelyAborted);
let swap_id = Uuid::new_v4();
db.insert_latest_state(swap_id, state.clone())
@ -268,6 +295,7 @@ mod tests {
.expect("Failed to save state the first time");
let recovered = db
.get_state(swap_id)
.await
.expect("Failed to recover state the first time");
// We insert and recover twice to ensure database implementation allows the
@ -277,84 +305,29 @@ mod tests {
.expect("Failed to save state the second time");
let recovered = db
.get_state(swap_id)
.await
.expect("Failed to recover state the second time");
assert_eq!(recovered, state);
}
#[tokio::test]
async fn all_swaps_as_alice() {
let db_dir = tempfile::tempdir().unwrap();
let db = SledDatabase::open(db_dir.path()).unwrap();
let alice_state = Alice::Done(AliceEndState::BtcPunished);
let alice_swap = Swap::Alice(alice_state.clone());
let alice_swap_id = Uuid::new_v4();
db.insert_latest_state(alice_swap_id, alice_swap)
.await
.expect("Failed to save alice state 1");
let alice_swaps = db.all_alice().unwrap();
assert_eq!(alice_swaps.len(), 1);
assert!(alice_swaps.contains(&(alice_swap_id, alice_state)));
let bob_state = Bob::Done(BobEndState::SafelyAborted);
let bob_swap = Swap::Bob(bob_state);
let bob_swap_id = Uuid::new_v4();
db.insert_latest_state(bob_swap_id, bob_swap)
.await
.expect("Failed to save bob state 1");
let err = db.all_alice().unwrap_err();
assert_eq!(err.downcast_ref::<NotAlice>().unwrap(), &NotAlice);
}
#[tokio::test]
async fn all_swaps_as_bob() {
let db_dir = tempfile::tempdir().unwrap();
let db = SledDatabase::open(db_dir.path()).unwrap();
let bob_state = Bob::Done(BobEndState::SafelyAborted);
let bob_swap = Swap::Bob(bob_state.clone());
let bob_swap_id = Uuid::new_v4();
db.insert_latest_state(bob_swap_id, bob_swap)
.await
.expect("Failed to save bob state 1");
let bob_swaps = db.all_bob().unwrap();
assert_eq!(bob_swaps.len(), 1);
assert!(bob_swaps.contains(&(bob_swap_id, bob_state)));
let alice_state = Alice::Done(AliceEndState::BtcPunished);
let alice_swap = Swap::Alice(alice_state);
let alice_swap_id = Uuid::new_v4();
db.insert_latest_state(alice_swap_id, alice_swap)
.await
.expect("Failed to save alice state 1");
let err = db.all_bob().unwrap_err();
assert_eq!(err.downcast_ref::<NotBob>().unwrap(), &NotBob);
}
#[tokio::test]
async fn can_save_swap_state_and_peer_id_with_same_swap_id() -> Result<()> {
let db_dir = tempfile::tempdir().unwrap();
let db = SledDatabase::open(db_dir.path()).unwrap();
let db = SledDatabase::open(db_dir.path()).await.unwrap();
let alice_id = Uuid::new_v4();
let alice_state = Alice::Done(AliceEndState::BtcPunished);
let alice_swap = Swap::Alice(alice_state);
let alice_state = State::from(AliceState::BtcPunished);
let peer_id = PeerId::random();
db.insert_latest_state(alice_id, alice_swap.clone()).await?;
db.insert_latest_state(alice_id, alice_state.clone())
.await?;
db.insert_peer_id(alice_id, peer_id).await?;
let loaded_swap = db.get_state(alice_id)?;
let loaded_peer_id = db.get_peer_id(alice_id)?;
let loaded_swap = db.get_state(alice_id).await?;
let loaded_peer_id = db.get_peer_id(alice_id).await?;
assert_eq!(alice_swap, loaded_swap);
assert_eq!(alice_state, loaded_swap);
assert_eq!(peer_id, loaded_peer_id);
Ok(())
@ -364,23 +337,23 @@ mod tests {
async fn test_reopen_db() -> Result<()> {
let db_dir = tempfile::tempdir().unwrap();
let alice_id = Uuid::new_v4();
let alice_state = Alice::Done(AliceEndState::BtcPunished);
let alice_swap = Swap::Alice(alice_state);
let alice_state = State::from(AliceState::BtcPunished);
let peer_id = PeerId::random();
{
let db = SledDatabase::open(db_dir.path()).unwrap();
db.insert_latest_state(alice_id, alice_swap.clone()).await?;
let db = SledDatabase::open(db_dir.path()).await.unwrap();
db.insert_latest_state(alice_id, alice_state.clone())
.await?;
db.insert_peer_id(alice_id, peer_id).await?;
}
let db = SledDatabase::open(db_dir.path()).unwrap();
let db = SledDatabase::open(db_dir.path()).await.unwrap();
let loaded_swap = db.get_state(alice_id)?;
let loaded_peer_id = db.get_peer_id(alice_id)?;
let loaded_swap = db.get_state(alice_id).await?;
let loaded_peer_id = db.get_peer_id(alice_id).await?;
assert_eq!(alice_swap, loaded_swap);
assert_eq!(alice_state, loaded_swap);
assert_eq!(peer_id, loaded_peer_id);
Ok(())
@ -394,12 +367,15 @@ mod tests {
let home2 = "/ip4/127.0.0.1/tcp/2".parse::<Multiaddr>()?;
{
let db = SledDatabase::open(db_dir.path())?;
let db = SledDatabase::open(db_dir.path()).await?;
db.insert_address(peer_id, home1.clone()).await?;
db.insert_address(peer_id, home2.clone()).await?;
}
let addresses = SledDatabase::open(db_dir.path())?.get_addresses(peer_id)?;
let addresses = SledDatabase::open(db_dir.path())
.await?
.get_addresses(peer_id)
.await?;
assert_eq!(addresses, vec![home1, home2]);
@ -411,9 +387,11 @@ mod tests {
let db_dir = tempfile::tempdir()?;
let swap_id = Uuid::new_v4();
SledDatabase::open(db_dir.path())?.insert_monero_address(swap_id, "53gEuGZUhP9JMEBZoGaFNzhwEgiG7hwQdMCqFxiyiTeFPmkbt1mAoNybEUvYBKHcnrSgxnVWgZsTvRBaHBNXPa8tHiCU51a".parse()?).await?;
let loaded_monero_address =
SledDatabase::open(db_dir.path())?.get_monero_address(swap_id)?;
SledDatabase::open(db_dir.path()).await?.insert_monero_address(swap_id, "53gEuGZUhP9JMEBZoGaFNzhwEgiG7hwQdMCqFxiyiTeFPmkbt1mAoNybEUvYBKHcnrSgxnVWgZsTvRBaHBNXPa8tHiCU51a".parse()?).await?;
let loaded_monero_address = SledDatabase::open(db_dir.path())
.await?
.get_monero_address(swap_id)
.await?;
assert_eq!(loaded_monero_address.to_string(), "53gEuGZUhP9JMEBZoGaFNzhwEgiG7hwQdMCqFxiyiTeFPmkbt1mAoNybEUvYBKHcnrSgxnVWgZsTvRBaHBNXPa8tHiCU51a");