Make the whole protocol timeout after 60s

This commit is contained in:
Thomas Eizinger 2021-06-24 19:47:09 +10:00
parent 90bb4a2f6e
commit 76035dc488
No known key found for this signature in database
GPG Key ID: 651AC83A6C6C8B96

View File

@ -19,6 +19,7 @@ use std::collections::VecDeque;
use std::fmt::Debug; use std::fmt::Debug;
use std::future; use std::future;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::time::Duration;
use uuid::Uuid; use uuid::Uuid;
use void::Void; use void::Void;
@ -34,8 +35,8 @@ pub enum OutEvent {
}, },
Error { Error {
peer_id: PeerId, peer_id: PeerId,
error: Error error: Error,
} },
} }
#[derive(Debug)] #[derive(Debug)]
@ -174,21 +175,20 @@ where
fn inject_event(&mut self, peer_id: PeerId, connection: ConnectionId, event: HandlerOutEvent) { fn inject_event(&mut self, peer_id: PeerId, connection: ConnectionId, event: HandlerOutEvent) {
match event { match event {
HandlerOutEvent::Initiated(send_wallet_snapshot) => { HandlerOutEvent::Initiated(send_wallet_snapshot) => {
self.events.push_back(OutEvent::Initiated { send_wallet_snapshot }) self.events.push_back(OutEvent::Initiated {
send_wallet_snapshot,
})
} }
HandlerOutEvent::Completed(Ok((swap_id, state3))) => { HandlerOutEvent::Completed(Ok((swap_id, state3))) => {
self.events.push_back(OutEvent::Completed { self.events.push_back(OutEvent::Completed {
peer_id, peer_id,
swap_id, swap_id,
state3 state3,
}) })
}, }
HandlerOutEvent::Completed(Err(error)) => { HandlerOutEvent::Completed(Err(error)) => {
self.events.push_back(OutEvent::Error { self.events.push_back(OutEvent::Error { peer_id, error })
peer_id, }
error
})
},
} }
} }
@ -267,6 +267,8 @@ pub struct Handler<LR> {
latest_rate: LR, latest_rate: LR,
resume_only: bool, resume_only: bool,
timeout: Duration,
} }
impl<LR> Handler<LR> { impl<LR> Handler<LR> {
@ -285,6 +287,7 @@ impl<LR> Handler<LR> {
env_config, env_config,
latest_rate, latest_rate,
resume_only, resume_only,
timeout: Duration::from_secs(60),
} }
} }
} }
@ -329,11 +332,14 @@ where
let latest_rate = self.latest_rate.latest_rate(); let latest_rate = self.latest_rate.latest_rate();
let env_config = self.env_config; let env_config = self.env_config;
// TODO: Put a timeout on the whole future let protocol = tokio::time::timeout(self.timeout, async move {
self.inbound_stream = OptionFuture::from(Some( let request = read_cbor_message::<SpotPriceRequest>(&mut substream)
async move { .await
let request = read_cbor_message::<SpotPriceRequest>(&mut substream).await.map_err(|e| Error::Io(e))?; .map_err(|e| Error::Io(e))?;
let wallet_snapshot = sender.send_receive(request.btc).await.map_err(|e| Error::WalletSnapshotFailed(anyhow!(e)))?; let wallet_snapshot = sender
.send_receive(request.btc)
.await
.map_err(|e| Error::WalletSnapshotFailed(anyhow!(e)))?;
// wrap all of these into another future so we can `return` from all the // wrap all of these into another future so we can `return` from all the
// different blocks // different blocks
@ -370,8 +376,7 @@ where
}); });
} }
let rate = let rate = latest_rate.map_err(|e| Error::LatestRateFetchFailed(Box::new(e)))?;
latest_rate.map_err(|e| Error::LatestRateFetchFailed(Box::new(e)))?;
let xmr = rate let xmr = rate
.sell_quote(btc) .sell_quote(btc)
.map_err(|e| Error::SellQuoteCalculationFailed(e))?; .map_err(|e| Error::SellQuoteCalculationFailed(e))?;
@ -388,7 +393,9 @@ where
let xmr = match validate.await { let xmr = match validate.await {
Ok(xmr) => { Ok(xmr) => {
write_cbor_message(&mut substream, SpotPriceResponse::Xmr(xmr)).await.map_err(|e| Error::Io(e))?; write_cbor_message(&mut substream, SpotPriceResponse::Xmr(xmr))
.await
.map_err(|e| Error::Io(e))?;
xmr xmr
} }
@ -420,7 +427,9 @@ where
.map_err(|e| Error::Io(e))?; .map_err(|e| Error::Io(e))?;
let (swap_id, state1) = state0.receive(message0).map_err(|e| Error::Io(e))?; let (swap_id, state1) = state0.receive(message0).map_err(|e| Error::Io(e))?;
write_cbor_message(&mut substream, state1.next_message()).await.map_err(|e| Error::Io(e))?; write_cbor_message(&mut substream, state1.next_message())
.await
.map_err(|e| Error::Io(e))?;
let message2 = read_cbor_message::<Message2>(&mut substream) let message2 = read_cbor_message::<Message2>(&mut substream)
.await .await
@ -431,7 +440,9 @@ where
.context("Failed to receive Message2") .context("Failed to receive Message2")
.map_err(|e| Error::Io(e))?; .map_err(|e| Error::Io(e))?;
write_cbor_message(&mut substream, state2.next_message()).await.map_err(|e| Error::Io(e))?; write_cbor_message(&mut substream, state2.next_message())
.await
.map_err(|e| Error::Io(e))?;
let message4 = read_cbor_message::<Message4>(&mut substream) let message4 = read_cbor_message::<Message4>(&mut substream)
.await .await
@ -443,6 +454,14 @@ where
.map_err(|e| Error::Io(e))?; .map_err(|e| Error::Io(e))?;
Ok((swap_id, state3)) Ok((swap_id, state3))
});
let max_seconds = self.timeout.as_secs();
self.inbound_stream = OptionFuture::from(Some(
async move {
protocol.await.map_err(|_| Error::Timeout {
seconds: max_seconds,
})?
} }
.boxed(), .boxed(),
)); ));
@ -540,6 +559,8 @@ mod protocol {
>; >;
} }
// TODO: Differentiate between errors that we send back and shit that happens on
// our side (IO, timeout)
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum Error { pub enum Error {
#[error("ASB is running in resume-only mode")] #[error("ASB is running in resume-only mode")]
@ -571,7 +592,9 @@ pub enum Error {
#[error("Io Error: {0}")] #[error("Io Error: {0}")]
Io(anyhow::Error), Io(anyhow::Error),
#[error("Failed to request wallet snapshot: {0}")] #[error("Failed to request wallet snapshot: {0}")]
WalletSnapshotFailed(anyhow::Error) WalletSnapshotFailed(anyhow::Error),
#[error("Failed to complete execution setup within {seconds}s")]
Timeout { seconds: u64 },
} }
impl Error { impl Error {
@ -596,9 +619,8 @@ impl Error {
Error::LatestRateFetchFailed(_) Error::LatestRateFetchFailed(_)
| Error::SellQuoteCalculationFailed(_) | Error::SellQuoteCalculationFailed(_)
| Error::WalletSnapshotFailed(_) | Error::WalletSnapshotFailed(_)
| Error::Io(_) => { | Error::Timeout { .. }
SpotPriceError::Other | Error::Io(_) => SpotPriceError::Other,
}
} }
} }
} }