diff --git a/swap/src/api.rs b/swap/src/api.rs index 0cf889dd..54b229e9 100644 --- a/swap/src/api.rs +++ b/swap/src/api.rs @@ -8,11 +8,14 @@ use crate::protocol::Database; use crate::seed::Seed; use crate::{bitcoin, cli, monero}; use anyhow::{bail, Context as AnyContext, Error, Result}; +use futures::future::try_join_all; use std::fmt; +use std::future::Future; use std::net::SocketAddr; use std::path::PathBuf; use std::sync::{Arc, Once}; -use tokio::sync::{broadcast, broadcast::Sender, RwLock}; +use tokio::sync::{broadcast, broadcast::Sender, Mutex, RwLock}; +use tokio::task::JoinHandle; use url::Url; static START: Once = Once::new(); @@ -32,6 +35,38 @@ pub struct Config { use uuid::Uuid; +pub struct PendingTaskList(Mutex>>); + +impl PendingTaskList { + pub fn new() -> Self { + Self(Mutex::new(Vec::new())) + } + + pub async fn spawn(&self, future: F) + where + F: Future + Send + 'static, + T: Send + 'static, + { + let handle = tokio::spawn(async move { + let _ = future.await; + }); + + self.0.lock().await.push(handle); + } + + pub async fn wait_for_tasks(&self) -> Result<()> { + let tasks = { + // Scope for the lock, to avoid holding it for the entire duration of the async block + let mut guard = self.0.lock().await; + guard.drain(..).collect::>() + }; + + try_join_all(tasks).await?; + + Ok(()) + } +} + pub struct SwapLock { current_swap: RwLock>, suspension_trigger: Sender<()>, @@ -135,6 +170,7 @@ pub struct Context { monero_rpc_process: Option, pub swap_lock: Arc, pub config: Config, + pub tasks: Arc, } #[allow(clippy::too_many_arguments)] @@ -208,6 +244,7 @@ impl Context { data_dir, }, swap_lock: Arc::new(SwapLock::new()), + tasks: Arc::new(PendingTaskList::new()), }; Ok(context) @@ -231,6 +268,7 @@ impl Context { .expect("Could not open sqlite database"), monero_rpc_process: None, swap_lock: Arc::new(SwapLock::new()), + tasks: Arc::new(PendingTaskList::new()), } } } diff --git a/swap/src/api/request.rs b/swap/src/api/request.rs index f6eebf2d..8ecc9f25 100644 --- a/swap/src/api/request.rs +++ b/swap/src/api/request.rs @@ -403,7 +403,7 @@ impl Request { } }; - tokio::spawn(async move { + context.tasks.clone().spawn(async move { tokio::select! { biased; _ = context.swap_lock.listen_for_swap_force_suspension() => { @@ -481,7 +481,7 @@ impl Request { .await .expect("Could not release swap lock"); Ok::<_, anyhow::Error>(()) - }.in_current_span()); + }.in_current_span()).await; Ok(json!({ "swapId": swap_id.to_string(), @@ -555,7 +555,7 @@ impl Request { ) .await?; - tokio::spawn( + context.tasks.clone().spawn( async move { let handle = tokio::spawn(event_loop.run().in_current_span()); tokio::select! { @@ -596,7 +596,7 @@ impl Request { Ok::<(), anyhow::Error>(()) } .in_current_span(), - ); + ).await; Ok(json!({ "result": "ok", })) diff --git a/swap/src/bin/swap.rs b/swap/src/bin/swap.rs index a3c2a30f..931bcabf 100644 --- a/swap/src/bin/swap.rs +++ b/swap/src/bin/swap.rs @@ -31,7 +31,8 @@ async fn main() -> Result<()> { if let Err(e) = check_latest_version(env!("CARGO_PKG_VERSION")).await { eprintln!("{}", e); } - let _result = request.call(Arc::clone(&context)).await?; + request.call(context.clone()).await?; + context.tasks.wait_for_tasks().await?; Ok(()) }