Remove generics from Database

This commit is contained in:
Lucas Soriano del Pino 2020-11-03 15:26:47 +11:00 committed by rishflab
parent 02075c2a1d
commit f9cfc2abe3
6 changed files with 103 additions and 111 deletions

View file

@ -45,7 +45,7 @@ use xmr_btc::{
pub async fn swap( pub async fn swap(
bitcoin_wallet: Arc<bitcoin::Wallet>, bitcoin_wallet: Arc<bitcoin::Wallet>,
monero_wallet: Arc<monero::Wallet>, monero_wallet: Arc<monero::Wallet>,
db: Database<storage::Alice>, db: Database,
listen: Multiaddr, listen: Multiaddr,
transport: SwapTransport, transport: SwapTransport,
behaviour: Alice, behaviour: Alice,
@ -177,7 +177,7 @@ pub async fn swap(
}; };
let swap_id = Uuid::new_v4(); let swap_id = Uuid::new_v4();
db.insert_latest_state(swap_id, &storage::Alice::Handshaken(state3.clone())) db.insert_latest_state(swap_id, storage::Alice::Handshaken(state3.clone()).into())
.await?; .await?;
info!("Handshake complete, we now have State3 for Alice."); info!("Handshake complete, we now have State3 for Alice.");
@ -205,14 +205,14 @@ pub async fn swap(
public_spend_key, public_spend_key,
public_view_key, public_view_key,
}) => { }) => {
db.insert_latest_state(swap_id, &storage::Alice::BtcLocked(state3.clone())) db.insert_latest_state(swap_id, storage::Alice::BtcLocked(state3.clone()).into())
.await?; .await?;
let (transfer_proof, _) = monero_wallet let (transfer_proof, _) = monero_wallet
.transfer(public_spend_key, public_view_key, amount) .transfer(public_spend_key, public_view_key, amount)
.await?; .await?;
db.insert_latest_state(swap_id, &storage::Alice::XmrLocked(state3.clone())) db.insert_latest_state(swap_id, storage::Alice::XmrLocked(state3.clone()).into())
.await?; .await?;
let mut guard = network.as_ref().lock().await; let mut guard = network.as_ref().lock().await;
@ -221,10 +221,14 @@ pub async fn swap(
} }
GeneratorState::Yielded(Action::RedeemBtc(tx)) => { GeneratorState::Yielded(Action::RedeemBtc(tx)) => {
db.insert_latest_state(swap_id, &storage::Alice::BtcRedeemable { db.insert_latest_state(
swap_id,
storage::Alice::BtcRedeemable {
state: state3.clone(), state: state3.clone(),
redeem_tx: tx.clone(), redeem_tx: tx.clone(),
}) }
.into(),
)
.await?; .await?;
let _ = bitcoin_wallet.broadcast_signed_transaction(tx).await?; let _ = bitcoin_wallet.broadcast_signed_transaction(tx).await?;
@ -233,7 +237,10 @@ pub async fn swap(
let _ = bitcoin_wallet.broadcast_signed_transaction(tx).await?; let _ = bitcoin_wallet.broadcast_signed_transaction(tx).await?;
} }
GeneratorState::Yielded(Action::PunishBtc(tx)) => { GeneratorState::Yielded(Action::PunishBtc(tx)) => {
db.insert_latest_state(swap_id, &storage::Alice::BtcPunishable(state3.clone())) db.insert_latest_state(
swap_id,
storage::Alice::BtcPunishable(state3.clone()).into(),
)
.await?; .await?;
let _ = bitcoin_wallet.broadcast_signed_transaction(tx).await?; let _ = bitcoin_wallet.broadcast_signed_transaction(tx).await?;
@ -242,11 +249,15 @@ pub async fn swap(
spend_key, spend_key,
view_key, view_key,
}) => { }) => {
db.insert_latest_state(swap_id, &storage::Alice::BtcRefunded { db.insert_latest_state(
swap_id,
storage::Alice::BtcRefunded {
state: state3.clone(), state: state3.clone(),
spend_key, spend_key,
view_key, view_key,
}) }
.into(),
)
.await?; .await?;
monero_wallet monero_wallet
@ -254,7 +265,7 @@ pub async fn swap(
.await?; .await?;
} }
GeneratorState::Complete(()) => { GeneratorState::Complete(()) => {
db.insert_latest_state(swap_id, &storage::Alice::SwapComplete) db.insert_latest_state(swap_id, storage::Alice::SwapComplete.into())
.await?; .await?;
return Ok(()); return Ok(());

View file

@ -44,7 +44,7 @@ use xmr_btc::{
pub async fn swap( pub async fn swap(
bitcoin_wallet: Arc<bitcoin::Wallet>, bitcoin_wallet: Arc<bitcoin::Wallet>,
monero_wallet: Arc<monero::Wallet>, monero_wallet: Arc<monero::Wallet>,
db: Database<storage::Bob>, db: Database,
btc: u64, btc: u64,
addr: Multiaddr, addr: Multiaddr,
mut cmd_tx: Sender<Cmd>, mut cmd_tx: Sender<Cmd>,
@ -144,7 +144,7 @@ pub async fn swap(
}; };
let swap_id = Uuid::new_v4(); let swap_id = Uuid::new_v4();
db.insert_latest_state(swap_id, &storage::Bob::Handshaken(state2.clone())) db.insert_latest_state(swap_id, storage::Bob::Handshaken(state2.clone()).into())
.await?; .await?;
swarm.send_message2(alice.clone(), state2.next_message()); swarm.send_message2(alice.clone(), state2.next_message());
@ -172,11 +172,11 @@ pub async fn swap(
let _ = bitcoin_wallet let _ = bitcoin_wallet
.broadcast_signed_transaction(signed_tx_lock) .broadcast_signed_transaction(signed_tx_lock)
.await?; .await?;
db.insert_latest_state(swap_id, &storage::Bob::BtcLocked(state2.clone())) db.insert_latest_state(swap_id, storage::Bob::BtcLocked(state2.clone()).into())
.await?; .await?;
} }
GeneratorState::Yielded(bob::Action::SendBtcRedeemEncsig(tx_redeem_encsig)) => { GeneratorState::Yielded(bob::Action::SendBtcRedeemEncsig(tx_redeem_encsig)) => {
db.insert_latest_state(swap_id, &storage::Bob::XmrLocked(state2.clone())) db.insert_latest_state(swap_id, storage::Bob::XmrLocked(state2.clone()).into())
.await?; .await?;
let mut guard = network.as_ref().lock().await; let mut guard = network.as_ref().lock().await;
@ -196,7 +196,7 @@ pub async fn swap(
spend_key, spend_key,
view_key, view_key,
}) => { }) => {
db.insert_latest_state(swap_id, &storage::Bob::BtcRedeemed(state2.clone())) db.insert_latest_state(swap_id, storage::Bob::BtcRedeemed(state2.clone()).into())
.await?; .await?;
monero_wallet monero_wallet
@ -204,7 +204,7 @@ pub async fn swap(
.await?; .await?;
} }
GeneratorState::Yielded(bob::Action::CancelBtc(tx_cancel)) => { GeneratorState::Yielded(bob::Action::CancelBtc(tx_cancel)) => {
db.insert_latest_state(swap_id, &storage::Bob::BtcRefundable(state2.clone())) db.insert_latest_state(swap_id, storage::Bob::BtcRefundable(state2.clone()).into())
.await?; .await?;
let _ = bitcoin_wallet let _ = bitcoin_wallet
@ -212,7 +212,7 @@ pub async fn swap(
.await?; .await?;
} }
GeneratorState::Yielded(bob::Action::RefundBtc(tx_refund)) => { GeneratorState::Yielded(bob::Action::RefundBtc(tx_refund)) => {
db.insert_latest_state(swap_id, &storage::Bob::BtcRefundable(state2.clone())) db.insert_latest_state(swap_id, storage::Bob::BtcRefundable(state2.clone()).into())
.await?; .await?;
let _ = bitcoin_wallet let _ = bitcoin_wallet
@ -220,7 +220,7 @@ pub async fn swap(
.await?; .await?;
} }
GeneratorState::Complete(()) => { GeneratorState::Complete(()) => {
db.insert_latest_state(swap_id, &storage::Bob::SwapComplete) db.insert_latest_state(swap_id, storage::Bob::SwapComplete.into())
.await?; .await?;
return Ok(()); return Ok(());

View file

@ -156,7 +156,7 @@ async fn create_tor_service(
async fn swap_as_alice( async fn swap_as_alice(
bitcoin_wallet: Arc<swap::bitcoin::Wallet>, bitcoin_wallet: Arc<swap::bitcoin::Wallet>,
monero_wallet: Arc<swap::monero::Wallet>, monero_wallet: Arc<swap::monero::Wallet>,
db: Database<storage::Alice>, db: Database,
addr: Multiaddr, addr: Multiaddr,
transport: SwapTransport, transport: SwapTransport,
behaviour: Alice, behaviour: Alice,
@ -167,7 +167,7 @@ async fn swap_as_alice(
async fn swap_as_bob( async fn swap_as_bob(
bitcoin_wallet: Arc<swap::bitcoin::Wallet>, bitcoin_wallet: Arc<swap::bitcoin::Wallet>,
monero_wallet: Arc<swap::monero::Wallet>, monero_wallet: Arc<swap::monero::Wallet>,
db: Database<storage::Bob>, db: Database,
sats: u64, sats: u64,
alice: Multiaddr, alice: Multiaddr,
transport: SwapTransport, transport: SwapTransport,

View file

@ -5,7 +5,14 @@ use uuid::Uuid;
use xmr_btc::{alice, bob, monero, serde::monero_private_key}; use xmr_btc::{alice, bob, monero, serde::monero_private_key};
#[allow(clippy::large_enum_variant)] #[allow(clippy::large_enum_variant)]
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub enum Swap {
Alice(Alice),
Bob(Bob),
}
#[allow(clippy::large_enum_variant)]
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub enum Alice { pub enum Alice {
Handshaken(alice::State3), Handshaken(alice::State3),
BtcLocked(alice::State3), BtcLocked(alice::State3),
@ -24,7 +31,7 @@ pub enum Alice {
SwapComplete, SwapComplete,
} }
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub enum Bob { pub enum Bob {
Handshaken(bob::State2), Handshaken(bob::State2),
BtcLocked(bob::State2), BtcLocked(bob::State2),
@ -34,54 +41,54 @@ pub enum Bob {
SwapComplete, SwapComplete,
} }
pub struct Database<T> impl From<Alice> for Swap {
where fn from(from: Alice) -> Self {
T: Serialize + DeserializeOwned, Swap::Alice(from)
{ }
db: sled::Db,
_marker: std::marker::PhantomData<T>,
} }
impl<T> Database<T> impl From<Bob> for Swap {
where fn from(from: Bob) -> Self {
T: Serialize + DeserializeOwned, Swap::Bob(from)
{ }
}
pub struct Database(sled::Db);
impl Database {
pub fn open(path: &Path) -> Result<Self> { pub fn open(path: &Path) -> Result<Self> {
let db = let db =
sled::open(path).with_context(|| format!("Could not open the DB at {:?}", path))?; sled::open(path).with_context(|| format!("Could not open the DB at {:?}", path))?;
Ok(Database { Ok(Database(db))
db,
_marker: Default::default(),
})
} }
// TODO: Add method to update state // TODO: Add method to update state
pub async fn insert_latest_state(&self, swap_id: Uuid, state: &T) -> Result<()> { pub async fn insert_latest_state(&self, swap_id: Uuid, state: Swap) -> Result<()> {
let key = serialize(&swap_id)?; let key = serialize(&swap_id)?;
let new_value = serialize(&state).context("Could not serialize new state value")?; let new_value = serialize(&state).context("Could not serialize new state value")?;
let old_value = self.db.get(&key)?; let old_value = self.0.get(&key)?;
self.db self.0
.compare_and_swap(key, old_value, Some(new_value)) .compare_and_swap(key, old_value, Some(new_value))
.context("Could not write in the DB")? .context("Could not write in the DB")?
.context("Stored swap somehow changed, aborting saving")?; .context("Stored swap somehow changed, aborting saving")?;
// TODO: see if this can be done through sled config // TODO: see if this can be done through sled config
self.db self.0
.flush_async() .flush_async()
.await .await
.map(|_| ()) .map(|_| ())
.context("Could not flush db") .context("Could not flush db")
} }
pub fn get_latest_state(&self, swap_id: Uuid) -> anyhow::Result<T> { pub fn get_latest_state(&self, swap_id: Uuid) -> anyhow::Result<Swap> {
let key = serialize(&swap_id)?; let key = serialize(&swap_id)?;
let encoded = self let encoded = self
.db .0
.get(&key)? .get(&key)?
.ok_or_else(|| anyhow!("State does not exist {:?}", key))?; .ok_or_else(|| anyhow!("State does not exist {:?}", key))?;
@ -106,87 +113,61 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
#![allow(non_snake_case)]
use super::*; use super::*;
use bitcoin::SigHash;
use rand::rngs::OsRng;
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use xmr_btc::{cross_curve_dleq, monero, serde::monero_private_key};
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct TestState {
A: xmr_btc::bitcoin::PublicKey,
a: xmr_btc::bitcoin::SecretKey,
s_a: cross_curve_dleq::Scalar,
#[serde(with = "monero_private_key")]
s_b: monero::PrivateKey,
S_a_monero: ::monero::PublicKey,
S_a_bitcoin: xmr_btc::bitcoin::PublicKey,
v: xmr_btc::monero::PrivateViewKey,
#[serde(with = "::bitcoin::util::amount::serde::as_sat")]
btc: ::bitcoin::Amount,
xmr: xmr_btc::monero::Amount,
refund_timelock: u32,
refund_address: ::bitcoin::Address,
transaction: ::bitcoin::Transaction,
tx_punish_sig: xmr_btc::bitcoin::Signature,
}
#[tokio::test] #[tokio::test]
async fn recover_state_from_db() { async fn can_write_and_read_to_multiple_keys() {
let db_dir = tempfile::tempdir().unwrap(); let db_dir = tempfile::tempdir().unwrap();
let db = Database::open(db_dir.path()).unwrap(); let db = Database::open(db_dir.path()).unwrap();
let a = xmr_btc::bitcoin::SecretKey::new_random(&mut OsRng); let state_1 = Swap::Alice(Alice::SwapComplete);
let s_a = cross_curve_dleq::Scalar::random(&mut OsRng); let swap_id_1 = Uuid::new_v4();
let s_b = monero::PrivateKey::from_scalar(monero::Scalar::random(&mut OsRng)); db.insert_latest_state(swap_id_1, state_1.clone())
let v_a = xmr_btc::monero::PrivateViewKey::new_random(&mut OsRng); .await
let S_a_monero = monero::PublicKey::from_private_key(&monero::PrivateKey { .expect("Failed to save second state");
scalar: s_a.into_ed25519(),
});
let S_a_bitcoin = s_a.into_secp256k1().into();
let tx_punish_sig = a.sign(SigHash::default());
let state = TestState { let state_2 = Swap::Bob(Bob::SwapComplete);
A: a.public(), let swap_id_2 = Uuid::new_v4();
a, db.insert_latest_state(swap_id_2, state_2.clone())
s_b, .await
s_a, .expect("Failed to save first state");
S_a_monero,
S_a_bitcoin, let recovered_1 = db
v: v_a, .get_latest_state(swap_id_1)
btc: ::bitcoin::Amount::from_sat(100), .expect("Failed to recover first state");
xmr: xmr_btc::monero::Amount::from_piconero(1000),
refund_timelock: 0, let recovered_2 = db
refund_address: ::bitcoin::Address::from_str("1L5wSMgerhHg8GZGcsNmAx5EXMRXSKR3He") .get_latest_state(swap_id_2)
.unwrap(), .expect("Failed to recover second state");
transaction: ::bitcoin::Transaction {
version: 0, assert_eq!(recovered_1, state_1);
lock_time: 0, assert_eq!(recovered_2, state_2);
input: vec![::bitcoin::TxIn::default()], }
output: vec![::bitcoin::TxOut::default()],
}, #[tokio::test]
tx_punish_sig, async fn can_write_twice_to_one_key() {
}; let db_dir = tempfile::tempdir().unwrap();
let db = Database::open(db_dir.path()).unwrap();
let state = Swap::Alice(Alice::SwapComplete);
let swap_id = Uuid::new_v4(); let swap_id = Uuid::new_v4();
db.insert_latest_state(swap_id, &state) db.insert_latest_state(swap_id, state.clone())
.await .await
.expect("Failed to save state the first time"); .expect("Failed to save state the first time");
let recovered: TestState = db let recovered = db
.get_latest_state(swap_id) .get_latest_state(swap_id)
.expect("Failed to recover state the first time"); .expect("Failed to recover state the first time");
// We insert and recover twice to ensure database implementation allows the // We insert and recover twice to ensure database implementation allows the
// caller to write to an existing key // caller to write to an existing key
db.insert_latest_state(swap_id, &recovered) db.insert_latest_state(swap_id, recovered)
.await .await
.expect("Failed to save state the second time"); .expect("Failed to save state the second time");
let recovered: TestState = db let recovered = db
.get_latest_state(swap_id) .get_latest_state(swap_id)
.expect("Failed to recover state the second time"); .expect("Failed to recover state the second time");
assert_eq!(state, recovered); assert_eq!(recovered, state);
} }
} }

View file

@ -684,7 +684,7 @@ impl State2 {
} }
} }
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct State3 { pub struct State3 {
pub a: bitcoin::SecretKey, pub a: bitcoin::SecretKey,
pub B: bitcoin::PublicKey, pub B: bitcoin::PublicKey,

View file

@ -495,7 +495,7 @@ impl State1 {
} }
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct State2 { pub struct State2 {
pub A: bitcoin::PublicKey, pub A: bitcoin::PublicKey,
pub b: bitcoin::SecretKey, pub b: bitcoin::SecretKey,