Improve database type safety

The database is now bound to a type eg. alice::State or bob::State.
The caller cannot expect to retrieve a type that is different to
the type that was stored.
This commit is contained in:
rishflab 2020-10-22 13:34:01 +11:00
parent 8eda051087
commit e3b68a3864
4 changed files with 28 additions and 24 deletions

View File

@ -82,7 +82,7 @@ pub async fn next_state<
} }
#[allow(clippy::large_enum_variant)] #[allow(clippy::large_enum_variant)]
#[derive(Debug)] #[derive(Debug, Deserialize, Serialize)]
pub enum State { pub enum State {
State0(State0), State0(State0),
State1(State1), State1(State1),

View File

@ -82,7 +82,7 @@ pub async fn next_state<
} }
} }
#[derive(Debug)] #[derive(Debug, Deserialize, Serialize)]
pub enum State { pub enum State {
State0(State0), State0(State0),
State1(State1), State1(State1),

View File

@ -18,6 +18,7 @@ mod tests {
use monero_harness::Monero; use monero_harness::Monero;
use rand::rngs::OsRng; use rand::rngs::OsRng;
use crate::harness::storage::Database;
use std::{convert::TryInto, path::Path}; use std::{convert::TryInto, path::Path};
use testcontainers::clients::Cli; use testcontainers::clients::Cli;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
@ -251,8 +252,10 @@ mod tests {
let cli = Cli::default(); let cli = Cli::default();
let (monero, _container) = Monero::new(&cli); let (monero, _container) = Monero::new(&cli);
let bitcoind = init_bitcoind(&cli).await; let bitcoind = init_bitcoind(&cli).await;
let alice_db = harness::storage::Database::open(Path::new(ALICE_TEST_DB_FOLDER)).unwrap(); let alice_db: Database<alice::State> =
let bob_db = harness::storage::Database::open(Path::new(BOB_TEST_DB_FOLDER)).unwrap(); harness::storage::Database::open(Path::new(ALICE_TEST_DB_FOLDER)).unwrap();
let bob_db: Database<bob::State> =
harness::storage::Database::open(Path::new(BOB_TEST_DB_FOLDER)).unwrap();
let ( let (
alice_state0, alice_state0,
@ -281,29 +284,26 @@ mod tests {
.await .await
.unwrap(); .unwrap();
let alice_state5: alice::State5 = alice_state.try_into().unwrap();
let bob_state3: bob::State3 = bob_state.try_into().unwrap();
// save state to db // save state to db
alice_db.insert_latest_state(&alice_state5).await.unwrap(); alice_db.insert_latest_state(&alice_state).await.unwrap();
bob_db.insert_latest_state(&bob_state3).await.unwrap(); bob_db.insert_latest_state(&bob_state).await.unwrap();
}; };
let (alice_state6, bob_state5) = { let (alice_state6, bob_state5) = {
// recover state from db // recover state from db
let alice_state5: alice::State5 = alice_db.get_latest_state().unwrap(); let alice_state = alice_db.get_latest_state().unwrap();
let bob_state3: bob::State3 = bob_db.get_latest_state().unwrap(); let bob_state = bob_db.get_latest_state().unwrap();
let (alice_state, bob_state) = future::try_join( let (alice_state, bob_state) = future::try_join(
run_alice_until( run_alice_until(
&mut alice_node, &mut alice_node,
alice_state5.into(), alice_state,
harness::alice::is_state6, harness::alice::is_state6,
&mut OsRng, &mut OsRng,
), ),
run_bob_until( run_bob_until(
&mut bob_node, &mut bob_node,
bob_state3.into(), bob_state,
harness::bob::is_state5, harness::bob::is_state5,
&mut OsRng, &mut OsRng,
), ),

View File

@ -2,24 +2,31 @@ use anyhow::{anyhow, Context, Result};
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::path::Path; use std::path::Path;
pub struct Database { pub struct Database<T>
where
T: Serialize + DeserializeOwned,
{
db: sled::Db, db: sled::Db,
_marker: std::marker::PhantomData<T>,
} }
impl Database { impl<T> Database<T>
where
T: Serialize + DeserializeOwned,
{
const LAST_STATE_KEY: &'static str = "latest_state"; const LAST_STATE_KEY: &'static str = "latest_state";
pub fn open(path: &Path) -> Result<Self> { pub fn open(path: &Path) -> Result<Self> {
let db = let db =
sled::open(path).with_context(|| format!("Could not open the DB at {:?}", path))?; sled::open(path).with_context(|| format!("Could not open the DB at {:?}", path))?;
Ok(Database { db }) Ok(Database {
db,
_marker: Default::default(),
})
} }
pub async fn insert_latest_state<T>(&self, state: &T) -> Result<()> pub async fn insert_latest_state(&self, state: &T) -> Result<()> {
where
T: Serialize + DeserializeOwned,
{
let key = serialize(&Self::LAST_STATE_KEY)?; let key = serialize(&Self::LAST_STATE_KEY)?;
let new_value = serialize(&state).context("Could not serialize new state value")?; let new_value = serialize(&state).context("Could not serialize new state value")?;
@ -37,10 +44,7 @@ impl Database {
.context("Could not flush db") .context("Could not flush db")
} }
pub fn get_latest_state<T>(&self) -> anyhow::Result<T> pub fn get_latest_state(&self) -> anyhow::Result<T> {
where
T: DeserializeOwned,
{
let key = serialize(&Self::LAST_STATE_KEY)?; let key = serialize(&Self::LAST_STATE_KEY)?;
let encoded = self let encoded = self