Distinguish loading all swaps for alice or bob on db level

This commit is contained in:
Daniel Karzel 2021-03-26 15:16:19 +11:00 committed by Thomas Eizinger
parent 183e8f02de
commit 1c129d58c4
No known key found for this signature in database
GPG Key ID: 651AC83A6C6C8B96
5 changed files with 112 additions and 33 deletions

14
Cargo.lock generated
View File

@ -1552,6 +1552,15 @@ dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37d572918e350e82412fe766d24b15e6682fb2ed2bbe018280caa810397cb319"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "0.4.7"
@ -2547,7 +2556,7 @@ checksum = "32d3ebd75ac2679c2af3a92246639f9fcc8a442ee420719cc4fe195b98dd5fa3"
dependencies = [
"bytes 1.0.1",
"heck",
"itertools",
"itertools 0.9.0",
"log 0.4.14",
"multimap",
"petgraph",
@ -2564,7 +2573,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "169a15f3008ecb5160cba7d37bcd690a7601b6d30cfb87a117d45e59d52af5d4"
dependencies = [
"anyhow",
"itertools",
"itertools 0.9.0",
"proc-macro2",
"quote",
"syn",
@ -3522,6 +3531,7 @@ dependencies = [
"futures",
"get-port",
"hyper 0.14.5",
"itertools 0.10.0",
"libp2p",
"libp2p-async-await",
"miniscript",

View File

@ -25,6 +25,7 @@ dialoguer = "0.8"
directories-next = "2"
ecdsa_fun = { git = "https://github.com/LLFourn/secp256kfun", features = ["libsecp_compat", "serde"] }
futures = { version = "0.3", default-features = false }
itertools = "0.10"
libp2p = { version = "0.36", default-features = false, features = ["tcp-tokio", "yamux", "mplex", "dns-tokio", "noise", "request-response"] }
libp2p-async-await = { git = "https://github.com/comit-network/rust-libp2p-async-await" }
miniscript = { version = "5", features = ["serde"] }

View File

@ -130,7 +130,7 @@ async fn main() -> Result<()> {
table.add_row(row!["SWAP ID", "STATE"]);
for (swap_id, state) in db.all()? {
for (swap_id, state) in db.all_alice()? {
table.add_row(row![swap_id, state]);
}

View File

@ -158,7 +158,7 @@ async fn main() -> Result<()> {
table.add_row(row!["SWAP ID", "STATE"]);
for (swap_id, state) in db.all()? {
for (swap_id, state) in db.all_bob()? {
table.add_row(row![swap_id, state]);
}

View File

@ -2,6 +2,7 @@ pub use alice::Alice;
pub use bob::Bob;
use anyhow::{anyhow, bail, Context, Result};
use itertools::Itertools;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::fmt::Display;
@ -38,11 +39,26 @@ impl Display for Swap {
}
}
#[derive(thiserror::Error, Debug, Clone, Copy, PartialEq)]
#[error("Not in the role of Alice")]
struct NotAlice;
#[derive(thiserror::Error, Debug, Clone, Copy, PartialEq)]
#[error("Not in the role of Bob")]
struct NotBob;
impl Swap {
pub fn try_into_alice(self) -> Result<Alice> {
match self {
Swap::Alice(alice) => Ok(alice),
Swap::Bob(_) => bail!(NotAlice),
}
}
pub fn try_into_bob(self) -> Result<Bob> {
match self {
Swap::Bob(bob) => Ok(bob),
Swap::Alice(_) => bail!("Swap instance is not Bob"),
Swap::Alice(_) => bail!(NotBob),
}
}
}
@ -90,22 +106,42 @@ impl Database {
Ok(state)
}
pub fn all(&self) -> Result<Vec<(Uuid, Swap)>> {
self.0
.iter()
.map(|item| match item {
Ok((key, value)) => {
let swap_id = deserialize::<Uuid>(&key);
let swap = deserialize::<Swap>(&value).context("Failed to deserialize swap");
pub fn all_alice(&self) -> Result<Vec<(Uuid, Alice)>> {
self.all_alice_iter().collect()
}
match (swap_id, swap) {
(Ok(swap_id), Ok(swap)) => Ok((swap_id, swap)),
(Ok(_), Err(err)) => Err(err),
_ => bail!("Failed to deserialize swap"),
}
}
Err(err) => Err(err).context("Failed to retrieve swap from DB"),
})
fn all_alice_iter(&self) -> impl Iterator<Item = Result<(Uuid, Alice)>> {
self.all_swaps_iter().map(|item| {
let (swap_id, swap) = item?;
Ok((swap_id, swap.try_into_alice()?))
})
}
pub fn all_bob(&self) -> Result<Vec<(Uuid, Bob)>> {
self.all_bob_iter().collect()
}
fn all_bob_iter(&self) -> impl Iterator<Item = Result<(Uuid, Bob)>> {
self.all_swaps_iter().map(|item| {
let (swap_id, swap) = item?;
Ok((swap_id, swap.try_into_bob()?))
})
}
fn all_swaps_iter(&self) -> impl Iterator<Item = Result<(Uuid, Swap)>> {
self.0.iter().map(|item| {
let (key, value) = item.context("Failed to retrieve swap from DB")?;
let swap_id = deserialize::<Uuid>(&key)?;
let swap = deserialize::<Swap>(&value).context("Failed to deserialize swap")?;
Ok((swap_id, swap))
})
}
pub fn unfinished_alice(&self) -> Result<Vec<(Uuid, Alice)>> {
self.all_alice_iter()
.filter_ok(|(_swap_id, alice)| !matches!(alice, Alice::Done(_)))
.collect()
}
}
@ -187,26 +223,58 @@ mod tests {
}
#[tokio::test]
async fn can_fetch_all_keys() {
async fn all_swaps_as_alice() {
let db_dir = tempfile::tempdir().unwrap();
let db = Database::open(db_dir.path()).unwrap();
let state_1 = Swap::Alice(Alice::Done(AliceEndState::BtcPunished));
let swap_id_1 = Uuid::new_v4();
db.insert_latest_state(swap_id_1, state_1.clone())
let alice_state = Alice::Done(AliceEndState::BtcPunished);
let alice_swap = Swap::Alice(alice_state.clone());
let alice_swap_id = Uuid::new_v4();
db.insert_latest_state(alice_swap_id, alice_swap)
.await
.expect("Failed to save second state");
.expect("Failed to save alice state 1");
let state_2 = Swap::Bob(Bob::Done(BobEndState::SafelyAborted));
let swap_id_2 = Uuid::new_v4();
db.insert_latest_state(swap_id_2, state_2.clone())
let alice_swaps = db.all_alice().unwrap();
assert_eq!(alice_swaps.len(), 1);
assert!(alice_swaps.contains(&(alice_swap_id, alice_state)));
let bob_state = Bob::Done(BobEndState::SafelyAborted);
let bob_swap = Swap::Bob(bob_state);
let bob_swap_id = Uuid::new_v4();
db.insert_latest_state(bob_swap_id, bob_swap)
.await
.expect("Failed to save first state");
.expect("Failed to save bob state 1");
let swaps = db.all().unwrap();
let err = db.all_alice().unwrap_err();
assert_eq!(swaps.len(), 2);
assert!(swaps.contains(&(swap_id_1, state_1)));
assert!(swaps.contains(&(swap_id_2, state_2)));
assert_eq!(err.downcast_ref::<NotAlice>().unwrap(), &NotAlice);
}
#[tokio::test]
async fn all_swaps_as_bob() {
let db_dir = tempfile::tempdir().unwrap();
let db = Database::open(db_dir.path()).unwrap();
let bob_state = Bob::Done(BobEndState::SafelyAborted);
let bob_swap = Swap::Bob(bob_state.clone());
let bob_swap_id = Uuid::new_v4();
db.insert_latest_state(bob_swap_id, bob_swap)
.await
.expect("Failed to save bob state 1");
let bob_swaps = db.all_bob().unwrap();
assert_eq!(bob_swaps.len(), 1);
assert!(bob_swaps.contains(&(bob_swap_id, bob_state)));
let alice_state = Alice::Done(AliceEndState::BtcPunished);
let alice_swap = Swap::Alice(alice_state);
let alice_swap_id = Uuid::new_v4();
db.insert_latest_state(alice_swap_id, alice_swap)
.await
.expect("Failed to save alice state 1");
let err = db.all_bob().unwrap_err();
assert_eq!(err.downcast_ref::<NotBob>().unwrap(), &NotBob);
}
}