use uuid::Uuid; use sqlx::sqlite::Sqlite; use anyhow::Result; use chrono::Utc; use crate::protocol::{State, Database}; use std::path::PathBuf; use libp2p::{PeerId, Multiaddr}; use crate::monero::Address; use sqlx::{SqlitePool, Pool}; use async_trait::async_trait; use crate::database::Swap; use std::str::FromStr; use anyhow::Context; use std::collections::HashMap; pub struct SqliteDatabase { pool: Pool } impl SqliteDatabase { pub async fn open(path: PathBuf) -> Result where Self: std::marker::Sized { let path_str = format!("sqlite:{}", path.as_path().display()); let pool = SqlitePool::connect(&path_str) .await?; Ok(Self { pool }) } pub async fn run_migrations(&mut self) -> anyhow::Result<()> { sqlx::migrate!("./migrations").run(&self.pool).await?; Ok(()) } } #[async_trait] impl Database for SqliteDatabase { async fn insert_peer_id(&self, swap_id: Uuid, peer_id: PeerId) -> Result<()> { let mut conn = self.pool.acquire().await?; let swap_id = swap_id.to_string(); let peer_id = peer_id.to_string(); sqlx::query!( r#" insert into peers ( swap_id, peer_id ) values (?, ?); "#, swap_id, peer_id ).execute(&mut conn).await?; Ok(()) } async fn get_peer_id(&self, swap_id: Uuid) -> Result { let mut conn = self.pool.acquire().await?; let swap_id = swap_id.to_string(); let row = sqlx::query!(r#" SELECT peer_id FROM peers WHERE swap_id = ? "#, swap_id ).fetch_one(&mut conn). await?; let peer_id = PeerId::from_str(&row.peer_id)?; Ok(peer_id) } async fn insert_monero_address(&self, swap_id: Uuid, address: Address) -> Result<()> { let mut conn = self.pool.acquire().await?; let swap_id = swap_id.to_string(); let address = address.to_string(); sqlx::query!( r#" insert into monero_addresses ( swap_id, address ) values (?, ?); "#, swap_id, address ).execute(&mut conn).await?; Ok(()) } async fn get_monero_address(&self, swap_id: Uuid) -> Result
{ let mut conn = self.pool.acquire().await?; let swap_id = swap_id.to_string(); let row = sqlx::query!(r#" SELECT address FROM monero_addresses WHERE swap_id = ? "#, swap_id ).fetch_one(&mut conn). await?; let address = row.address.parse()?; Ok(address) } async fn insert_address(&self, peer_id: PeerId, address: Multiaddr) -> Result<()> { let mut conn = self.pool.acquire().await?; let peer_id = peer_id.to_string(); let address = address.to_string(); sqlx::query!( r#" insert into peer_addresses ( peer_id, address ) values (?, ?); "#, peer_id, address ).execute(&mut conn).await?; Ok(()) } async fn get_addresses(&self, peer_id: PeerId) -> Result> { let mut conn = self.pool.acquire().await?; let peer_id = peer_id.to_string(); let rows = sqlx::query!(r#" SELECT address FROM peer_addresses WHERE peer_id = ? "#, peer_id, ).fetch_all(&mut conn). await?; let addresses = rows.iter() .map(|row| { let multiaddr = Multiaddr::from_str(&row.address)?; Ok(multiaddr) }) .collect::>>(); addresses } async fn insert_latest_state(&self, swap_id: Uuid, state: State) -> Result<()> { let mut conn = self.pool.acquire().await?; let entered_at = Utc::now(); let swap_id = swap_id.to_string(); let swap = serde_json::to_string(&Swap::from(state))?; let entered_at = entered_at.to_string(); sqlx::query!( r#" insert into swap_states ( swap_id, entered_at, state ) values (?, ?, ?); "#, swap_id, entered_at, swap ).execute(&mut conn).await?; Ok(()) } async fn get_state(&self, swap_id: Uuid) -> Result { let mut conn = self.pool.acquire().await?; let swap_id = swap_id.to_string(); let row = sqlx::query!( r#" SELECT state FROM swap_states WHERE swap_id = ? ORDER BY id desc LIMIT 1; "#, swap_id ).fetch_all(&mut conn).await?; let row = row.first().context(format!("No state in database for swap: {}", swap_id))?; let swap: Swap = serde_json::from_str(&row.state)?; Ok(swap.into()) } async fn all(&self) -> Result> { let mut conn = self.pool.acquire().await?; let rows = sqlx::query!( r#" SELECT swap_id, state FROM ( SELECT max(id), swap_id, state FROM swap_states GROUP BY swap_id ) "# ).fetch_all(&mut conn).await?; let result = rows.iter().map(|row|{ let swap_id = Uuid::from_str(&row.swap_id)?; let state = match serde_json::from_str::(&row.state) { Ok(a) => Ok(State::from(a)), Err(e) => Err(e) }?; Ok((swap_id, state)) }).collect::>>(); result } async fn unfinished(&self, unfinished: fn(State) -> bool) -> Result> { Ok(self.all() .await? .into_iter() .filter(|(_swap_id, state)| unfinished(state.clone())) .collect()) } } #[cfg(test)] mod tests { use super::*; use std::fs::File; use tempfile::tempdir; use crate::protocol::alice::AliceState; use crate::protocol::bob::BobState; #[tokio::test] async fn test_insert_and_load_state() { let db = setup_test_db().await.unwrap(); let state_1 = State::Alice(AliceState::BtcRedeemed); let swap_id_1 = Uuid::new_v4(); db.insert_latest_state( swap_id_1, state_1).await.unwrap(); let state_1 = State::Alice(AliceState::BtcRedeemed); db.insert_latest_state( swap_id_1, state_1.clone()).await.unwrap(); let state_1_loaded = db.get_state( swap_id_1).await.unwrap(); assert_eq!(state_1, state_1_loaded); } #[tokio::test] async fn test_retrieve_all_latest_states() { let db = setup_test_db().await.unwrap(); let state_1 = State::Alice(AliceState::BtcRedeemed); let state_2 = State::Alice(AliceState::BtcPunished); let state_3 = State::Alice(AliceState::SafelyAborted); let state_4 = State::Bob(BobState::SafelyAborted); let swap_id_1 = Uuid::new_v4(); let swap_id_2 = Uuid::new_v4(); db.insert_latest_state(swap_id_1, state_1.clone()).await.unwrap(); db.insert_latest_state(swap_id_1, state_2.clone()).await.unwrap(); db.insert_latest_state(swap_id_1, state_3.clone()).await.unwrap(); db.insert_latest_state(swap_id_2, state_4.clone()).await.unwrap(); let latest_loaded = db.all().await.unwrap(); assert_eq!(latest_loaded.len(), 2); assert!(latest_loaded.contains(&(swap_id_1, state_3))); assert!(latest_loaded.contains(&(swap_id_2, state_4))); assert!(!latest_loaded.contains(&(swap_id_1, state_1))); assert!(!latest_loaded.contains(&(swap_id_1, state_2))); } #[tokio::test] async fn test_insert_load_monero_address() -> Result<()> { let db = setup_test_db().await?; let swap_id = Uuid::new_v4(); let monero_address = "53gEuGZUhP9JMEBZoGaFNzhwEgiG7hwQdMCqFxiyiTeFPmkbt1mAoNybEUvYBKHcnrSgxnVWgZsTvRBaHBNXPa8tHiCU51a".parse()?; db.insert_monero_address(swap_id, monero_address).await?; let loaded_monero_address = db.get_monero_address(swap_id).await?; assert_eq!(monero_address, loaded_monero_address); Ok(()) } #[tokio::test] async fn test_insert_and_load_multiaddr() -> Result<()> { let db = setup_test_db().await?; let peer_id = PeerId::random(); let multiaddr1 = "/ip4/127.0.0.1".parse::()?; let multiaddr2 = "/ip4/127.0.0.2".parse::()?; db.insert_address(peer_id, multiaddr1.clone()).await?; db.insert_address(peer_id, multiaddr2.clone()).await?; let loaded_multiaddr = db.get_addresses(peer_id).await?; assert!(loaded_multiaddr.contains(&multiaddr1)); assert!(loaded_multiaddr.contains(&multiaddr2)); assert_eq!(loaded_multiaddr.len(), 2); Ok(()) } #[tokio::test] async fn test_insert_and_load_peer_id() -> Result<()> { let db = setup_test_db().await?; let peer_id = PeerId::random(); let multiaddr1 = "/ip4/127.0.0.1".parse::()?; let multiaddr2 = "/ip4/127.0.0.2".parse::()?; db.insert_address(peer_id, multiaddr1.clone()).await?; db.insert_address(peer_id, multiaddr2.clone()).await?; let loaded_multiaddr = db.get_addresses(peer_id).await?; assert!(loaded_multiaddr.contains(&multiaddr1)); assert!(loaded_multiaddr.contains(&multiaddr2)); assert_eq!(loaded_multiaddr.len(), 2); Ok(()) } async fn setup_test_db() -> Result { let temp_db = tempdir().unwrap().into_path().join("tempdb"); // file has to exist in order to connect with sqlite File::create(temp_db.clone()).unwrap(); let mut db = SqliteDatabase::open(temp_db).await?; db.run_migrations().await.unwrap(); Ok(db) } }