mirror of
https://gitlab.com/veilid/veilid.git
synced 2024-12-24 23:09:25 -05:00
refactor for cooperative cancellation
This commit is contained in:
parent
bcc1bfc1a3
commit
180628beef
13
Cargo.lock
generated
13
Cargo.lock
generated
@ -4410,6 +4410,18 @@ version = "1.1.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "stop-token"
|
||||||
|
version = "0.7.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "af91f480ee899ab2d9f8435bfdfc14d08a5754bd9d3fef1f1a1c23336aad6c8b"
|
||||||
|
dependencies = [
|
||||||
|
"async-channel",
|
||||||
|
"cfg-if 1.0.0",
|
||||||
|
"futures-core",
|
||||||
|
"pin-project-lite",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "strsim"
|
name = "strsim"
|
||||||
version = "0.10.0"
|
version = "0.10.0"
|
||||||
@ -5209,6 +5221,7 @@ dependencies = [
|
|||||||
"simplelog",
|
"simplelog",
|
||||||
"socket2",
|
"socket2",
|
||||||
"static_assertions",
|
"static_assertions",
|
||||||
|
"stop-token",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-error",
|
"tracing-error",
|
||||||
|
@ -41,6 +41,7 @@ flume = { version = "^0", features = ["async"] }
|
|||||||
enumset = { version= "^1", features = ["serde"] }
|
enumset = { version= "^1", features = ["serde"] }
|
||||||
backtrace = { version = "^0", optional = true }
|
backtrace = { version = "^0", optional = true }
|
||||||
owo-colors = "^3"
|
owo-colors = "^3"
|
||||||
|
stop-token = "^0"
|
||||||
|
|
||||||
ed25519-dalek = { version = "^1", default_features = false, features = ["alloc", "u64_backend"] }
|
ed25519-dalek = { version = "^1", default_features = false, features = ["alloc", "u64_backend"] }
|
||||||
x25519-dalek = { package = "x25519-dalek-ng", version = "^1", default_features = false, features = ["u64_backend"] }
|
x25519-dalek = { package = "x25519-dalek-ng", version = "^1", default_features = false, features = ["u64_backend"] }
|
||||||
|
@ -109,7 +109,7 @@ pub struct AttachmentManagerInner {
|
|||||||
maintain_peers: bool,
|
maintain_peers: bool,
|
||||||
attach_timestamp: Option<u64>,
|
attach_timestamp: Option<u64>,
|
||||||
update_callback: Option<UpdateCallback>,
|
update_callback: Option<UpdateCallback>,
|
||||||
attachment_maintainer_jh: Option<JoinHandle<()>>,
|
attachment_maintainer_jh: Option<MustJoinHandle<()>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -306,8 +306,9 @@ impl AttachmentManager {
|
|||||||
// Create long-running connection maintenance routine
|
// Create long-running connection maintenance routine
|
||||||
let this = self.clone();
|
let this = self.clone();
|
||||||
self.inner.lock().maintain_peers = true;
|
self.inner.lock().maintain_peers = true;
|
||||||
self.inner.lock().attachment_maintainer_jh =
|
self.inner.lock().attachment_maintainer_jh = Some(MustJoinHandle::new(intf::spawn(
|
||||||
Some(intf::spawn(this.attachment_maintainer()));
|
this.attachment_maintainer(),
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(level = "trace", skip(self))]
|
#[instrument(level = "trace", skip(self))]
|
||||||
|
@ -9,11 +9,12 @@ use network_connection::*;
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct ConnectionManagerInner {
|
struct ConnectionManagerInner {
|
||||||
connection_table: ConnectionTable,
|
connection_table: ConnectionTable,
|
||||||
|
stop_source: Option<StopSource>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ConnectionManagerArc {
|
struct ConnectionManagerArc {
|
||||||
network_manager: NetworkManager,
|
network_manager: NetworkManager,
|
||||||
inner: AsyncMutex<ConnectionManagerInner>,
|
inner: AsyncMutex<Option<ConnectionManagerInner>>,
|
||||||
}
|
}
|
||||||
impl core::fmt::Debug for ConnectionManagerArc {
|
impl core::fmt::Debug for ConnectionManagerArc {
|
||||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||||
@ -31,14 +32,14 @@ pub struct ConnectionManager {
|
|||||||
impl ConnectionManager {
|
impl ConnectionManager {
|
||||||
fn new_inner(config: VeilidConfig) -> ConnectionManagerInner {
|
fn new_inner(config: VeilidConfig) -> ConnectionManagerInner {
|
||||||
ConnectionManagerInner {
|
ConnectionManagerInner {
|
||||||
|
stop_source: Some(StopSource::new()),
|
||||||
connection_table: ConnectionTable::new(config),
|
connection_table: ConnectionTable::new(config),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fn new_arc(network_manager: NetworkManager) -> ConnectionManagerArc {
|
fn new_arc(network_manager: NetworkManager) -> ConnectionManagerArc {
|
||||||
let config = network_manager.config();
|
|
||||||
ConnectionManagerArc {
|
ConnectionManagerArc {
|
||||||
network_manager,
|
network_manager,
|
||||||
inner: AsyncMutex::new(Self::new_inner(config)),
|
inner: AsyncMutex::new(None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub fn new(network_manager: NetworkManager) -> Self {
|
pub fn new(network_manager: NetworkManager) -> Self {
|
||||||
@ -53,12 +54,32 @@ impl ConnectionManager {
|
|||||||
|
|
||||||
pub async fn startup(&self) {
|
pub async fn startup(&self) {
|
||||||
trace!("startup connection manager");
|
trace!("startup connection manager");
|
||||||
//let mut inner = self.arc.inner.lock().await;
|
let mut inner = self.arc.inner.lock().await;
|
||||||
|
if inner.is_some() {
|
||||||
|
panic!("shouldn't start connection manager twice without shutting it down first");
|
||||||
|
}
|
||||||
|
|
||||||
|
*inner = Some(Self::new_inner(self.network_manager().config()));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn shutdown(&self) {
|
pub async fn shutdown(&self) {
|
||||||
// Drops connection table, which drops all connections in it
|
// Remove the inner from the lock
|
||||||
*self.arc.inner.lock().await = Self::new_inner(self.arc.network_manager.config());
|
let mut inner = {
|
||||||
|
let mut inner_lock = self.arc.inner.lock().await;
|
||||||
|
let inner = match inner_lock.take() {
|
||||||
|
Some(v) => v,
|
||||||
|
None => {
|
||||||
|
panic!("not started");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
inner
|
||||||
|
};
|
||||||
|
|
||||||
|
// Stop all the connections
|
||||||
|
drop(inner.stop_source.take());
|
||||||
|
|
||||||
|
// Wait for the connections to complete
|
||||||
|
inner.connection_table.join().await;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a network connection if one already is established
|
// Returns a network connection if one already is established
|
||||||
@ -67,6 +88,12 @@ impl ConnectionManager {
|
|||||||
descriptor: ConnectionDescriptor,
|
descriptor: ConnectionDescriptor,
|
||||||
) -> Option<ConnectionHandle> {
|
) -> Option<ConnectionHandle> {
|
||||||
let mut inner = self.arc.inner.lock().await;
|
let mut inner = self.arc.inner.lock().await;
|
||||||
|
let inner = match &mut *inner {
|
||||||
|
Some(v) => v,
|
||||||
|
None => {
|
||||||
|
panic!("not started");
|
||||||
|
}
|
||||||
|
};
|
||||||
inner.connection_table.get_connection(descriptor)
|
inner.connection_table.get_connection(descriptor)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,24 +108,18 @@ impl ConnectionManager {
|
|||||||
log_net!("on_new_protocol_network_connection: {:?}", conn);
|
log_net!("on_new_protocol_network_connection: {:?}", conn);
|
||||||
|
|
||||||
// Wrap with NetworkConnection object to start the connection processing loop
|
// Wrap with NetworkConnection object to start the connection processing loop
|
||||||
let conn = NetworkConnection::from_protocol(self.clone(), conn);
|
let stop_token = match &inner.stop_source {
|
||||||
|
Some(ss) => ss.token(),
|
||||||
|
None => return Err("not creating connection because we are stopping".to_owned()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let conn = NetworkConnection::from_protocol(self.clone(), stop_token, conn);
|
||||||
let handle = conn.get_handle();
|
let handle = conn.get_handle();
|
||||||
// Add to the connection table
|
// Add to the connection table
|
||||||
inner.connection_table.add_connection(conn)?;
|
inner.connection_table.add_connection(conn)?;
|
||||||
Ok(handle)
|
Ok(handle)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Called by low-level network when any connection-oriented protocol connection appears
|
|
||||||
// either from incoming connections.
|
|
||||||
pub(super) async fn on_accepted_protocol_network_connection(
|
|
||||||
&self,
|
|
||||||
conn: ProtocolNetworkConnection,
|
|
||||||
) -> Result<(), String> {
|
|
||||||
let mut inner = self.arc.inner.lock().await;
|
|
||||||
self.on_new_protocol_network_connection(&mut *inner, conn)
|
|
||||||
.map(drop)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Called when we want to create a new connection or get the current one that already exists
|
// Called when we want to create a new connection or get the current one that already exists
|
||||||
// This will kill off any connections that are in conflict with the new connection to be made
|
// This will kill off any connections that are in conflict with the new connection to be made
|
||||||
// in order to make room for the new connection in the system's connection table
|
// in order to make room for the new connection in the system's connection table
|
||||||
@ -107,6 +128,14 @@ impl ConnectionManager {
|
|||||||
local_addr: Option<SocketAddr>,
|
local_addr: Option<SocketAddr>,
|
||||||
dial_info: DialInfo,
|
dial_info: DialInfo,
|
||||||
) -> Result<ConnectionHandle, String> {
|
) -> Result<ConnectionHandle, String> {
|
||||||
|
let mut inner = self.arc.inner.lock().await;
|
||||||
|
let inner = match &mut *inner {
|
||||||
|
Some(v) => v,
|
||||||
|
None => {
|
||||||
|
panic!("not started");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
log_net!(
|
log_net!(
|
||||||
"== get_or_create_connection local_addr={:?} dial_info={:?}",
|
"== get_or_create_connection local_addr={:?} dial_info={:?}",
|
||||||
local_addr.green(),
|
local_addr.green(),
|
||||||
@ -123,7 +152,6 @@ impl ConnectionManager {
|
|||||||
|
|
||||||
// If any connection to this remote exists that has the same protocol, return it
|
// If any connection to this remote exists that has the same protocol, return it
|
||||||
// Any connection will do, we don't have to match the local address
|
// Any connection will do, we don't have to match the local address
|
||||||
let mut inner = self.arc.inner.lock().await;
|
|
||||||
|
|
||||||
if let Some(conn) = inner
|
if let Some(conn) = inner
|
||||||
.connection_table
|
.connection_table
|
||||||
@ -197,10 +225,39 @@ impl ConnectionManager {
|
|||||||
self.on_new_protocol_network_connection(&mut *inner, conn)
|
self.on_new_protocol_network_connection(&mut *inner, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// Callbacks
|
||||||
|
|
||||||
|
// Called by low-level network when any connection-oriented protocol connection appears
|
||||||
|
// either from incoming connections.
|
||||||
|
pub(super) async fn on_accepted_protocol_network_connection(
|
||||||
|
&self,
|
||||||
|
conn: ProtocolNetworkConnection,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
let mut inner = self.arc.inner.lock().await;
|
||||||
|
let inner = match &mut *inner {
|
||||||
|
Some(v) => v,
|
||||||
|
None => {
|
||||||
|
// If we are shutting down, just drop this and return
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
self.on_new_protocol_network_connection(inner, conn)
|
||||||
|
.map(drop)
|
||||||
|
}
|
||||||
|
|
||||||
// Callback from network connection receive loop when it exits
|
// Callback from network connection receive loop when it exits
|
||||||
// cleans up the entry in the connection table
|
// cleans up the entry in the connection table
|
||||||
pub(super) async fn report_connection_finished(&self, descriptor: ConnectionDescriptor) {
|
pub(super) async fn report_connection_finished(&self, descriptor: ConnectionDescriptor) {
|
||||||
let mut inner = self.arc.inner.lock().await;
|
let mut inner = self.arc.inner.lock().await;
|
||||||
|
let inner = match &mut *inner {
|
||||||
|
Some(v) => v,
|
||||||
|
None => {
|
||||||
|
// If we're shutting down, do nothing here
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
if let Err(e) = inner.connection_table.remove_connection(descriptor) {
|
if let Err(e) = inner.connection_table.remove_connection(descriptor) {
|
||||||
log_net!(error e);
|
log_net!(error e);
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
use alloc::collections::btree_map::Entry;
|
use alloc::collections::btree_map::Entry;
|
||||||
|
use futures_util::StreamExt;
|
||||||
use hashlink::LruCache;
|
use hashlink::LruCache;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -41,6 +42,16 @@ impl ConnectionTable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn join(&mut self) {
|
||||||
|
let mut unord = FuturesUnordered::new();
|
||||||
|
for table in &mut self.conn_by_descriptor {
|
||||||
|
for (_, v) in table.drain() {
|
||||||
|
unord.push(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
while unord.next().await.is_some() {}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn add_connection(&mut self, conn: NetworkConnection) -> Result<(), String> {
|
pub fn add_connection(&mut self, conn: NetworkConnection) -> Result<(), String> {
|
||||||
let descriptor = conn.connection_descriptor();
|
let descriptor = conn.connection_descriptor();
|
||||||
let ip_addr = descriptor.remote_address().to_ip_addr();
|
let ip_addr = descriptor.remote_address().to_ip_addr();
|
||||||
|
@ -171,8 +171,8 @@ impl NetworkManager {
|
|||||||
let this2 = this.clone();
|
let this2 = this.clone();
|
||||||
this.unlocked_inner
|
this.unlocked_inner
|
||||||
.rolling_transfers_task
|
.rolling_transfers_task
|
||||||
.set_routine(move |l, t| {
|
.set_routine(move |s, l, t| {
|
||||||
Box::pin(this2.clone().rolling_transfers_task_routine(l, t))
|
Box::pin(this2.clone().rolling_transfers_task_routine(s, l, t))
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
// Set relay management tick task
|
// Set relay management tick task
|
||||||
@ -180,8 +180,8 @@ impl NetworkManager {
|
|||||||
let this2 = this.clone();
|
let this2 = this.clone();
|
||||||
this.unlocked_inner
|
this.unlocked_inner
|
||||||
.relay_management_task
|
.relay_management_task
|
||||||
.set_routine(move |l, t| {
|
.set_routine(move |s, l, t| {
|
||||||
Box::pin(this2.clone().relay_management_task_routine(l, t))
|
Box::pin(this2.clone().relay_management_task_routine(s, l, t))
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
this
|
this
|
||||||
@ -275,10 +275,10 @@ impl NetworkManager {
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Start network components
|
// Start network components
|
||||||
|
connection_manager.startup().await;
|
||||||
|
net.startup().await?;
|
||||||
rpc_processor.startup().await?;
|
rpc_processor.startup().await?;
|
||||||
receipt_manager.startup().await?;
|
receipt_manager.startup().await?;
|
||||||
net.startup().await?;
|
|
||||||
connection_manager.startup().await;
|
|
||||||
|
|
||||||
trace!("NetworkManager::internal_startup end");
|
trace!("NetworkManager::internal_startup end");
|
||||||
|
|
||||||
@ -302,20 +302,20 @@ impl NetworkManager {
|
|||||||
trace!("NetworkManager::shutdown begin");
|
trace!("NetworkManager::shutdown begin");
|
||||||
|
|
||||||
// Cancel all tasks
|
// Cancel all tasks
|
||||||
if let Err(e) = self.unlocked_inner.rolling_transfers_task.cancel().await {
|
if let Err(e) = self.unlocked_inner.rolling_transfers_task.stop().await {
|
||||||
warn!("rolling_transfers_task not cancelled: {}", e);
|
warn!("rolling_transfers_task not stopped: {}", e);
|
||||||
}
|
}
|
||||||
if let Err(e) = self.unlocked_inner.relay_management_task.cancel().await {
|
if let Err(e) = self.unlocked_inner.relay_management_task.stop().await {
|
||||||
warn!("relay_management_task not cancelled: {}", e);
|
warn!("relay_management_task not stopped: {}", e);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown network components if they started up
|
// Shutdown network components if they started up
|
||||||
let components = self.inner.lock().components.clone();
|
let components = self.inner.lock().components.clone();
|
||||||
if let Some(components) = components {
|
if let Some(components) = components {
|
||||||
components.connection_manager.shutdown().await;
|
|
||||||
components.net.shutdown().await;
|
|
||||||
components.receipt_manager.shutdown().await;
|
components.receipt_manager.shutdown().await;
|
||||||
components.rpc_processor.shutdown().await;
|
components.rpc_processor.shutdown().await;
|
||||||
|
components.net.shutdown().await;
|
||||||
|
components.connection_manager.shutdown().await;
|
||||||
}
|
}
|
||||||
|
|
||||||
// reset the state
|
// reset the state
|
||||||
@ -1202,7 +1202,12 @@ impl NetworkManager {
|
|||||||
|
|
||||||
// Keep relays assigned and accessible
|
// Keep relays assigned and accessible
|
||||||
#[instrument(level = "trace", skip(self), err)]
|
#[instrument(level = "trace", skip(self), err)]
|
||||||
async fn relay_management_task_routine(self, _last_ts: u64, cur_ts: u64) -> Result<(), String> {
|
async fn relay_management_task_routine(
|
||||||
|
self,
|
||||||
|
stop_token: StopToken,
|
||||||
|
_last_ts: u64,
|
||||||
|
cur_ts: u64,
|
||||||
|
) -> Result<(), String> {
|
||||||
// log_net!("--- network manager relay_management task");
|
// log_net!("--- network manager relay_management task");
|
||||||
|
|
||||||
// Get our node's current node info and network class and do the right thing
|
// Get our node's current node info and network class and do the right thing
|
||||||
@ -1255,7 +1260,12 @@ impl NetworkManager {
|
|||||||
|
|
||||||
// Compute transfer statistics for the low level network
|
// Compute transfer statistics for the low level network
|
||||||
#[instrument(level = "trace", skip(self), err)]
|
#[instrument(level = "trace", skip(self), err)]
|
||||||
async fn rolling_transfers_task_routine(self, last_ts: u64, cur_ts: u64) -> Result<(), String> {
|
async fn rolling_transfers_task_routine(
|
||||||
|
self,
|
||||||
|
stop_token: StopToken,
|
||||||
|
last_ts: u64,
|
||||||
|
cur_ts: u64,
|
||||||
|
) -> Result<(), String> {
|
||||||
// log_net!("--- network manager rolling_transfers task");
|
// log_net!("--- network manager rolling_transfers task");
|
||||||
{
|
{
|
||||||
let inner = &mut *self.inner.lock();
|
let inner = &mut *self.inner.lock();
|
||||||
|
@ -42,7 +42,8 @@ struct NetworkInner {
|
|||||||
protocol_config: Option<ProtocolConfig>,
|
protocol_config: Option<ProtocolConfig>,
|
||||||
static_public_dialinfo: ProtocolSet,
|
static_public_dialinfo: ProtocolSet,
|
||||||
network_class: Option<NetworkClass>,
|
network_class: Option<NetworkClass>,
|
||||||
join_handles: Vec<JoinHandle<()>>,
|
join_handles: Vec<MustJoinHandle<()>>,
|
||||||
|
stop_source: Option<StopSource>,
|
||||||
udp_port: u16,
|
udp_port: u16,
|
||||||
tcp_port: u16,
|
tcp_port: u16,
|
||||||
ws_port: u16,
|
ws_port: u16,
|
||||||
@ -82,6 +83,7 @@ impl Network {
|
|||||||
static_public_dialinfo: ProtocolSet::empty(),
|
static_public_dialinfo: ProtocolSet::empty(),
|
||||||
network_class: None,
|
network_class: None,
|
||||||
join_handles: Vec::new(),
|
join_handles: Vec::new(),
|
||||||
|
stop_source: None,
|
||||||
udp_port: 0u16,
|
udp_port: 0u16,
|
||||||
tcp_port: 0u16,
|
tcp_port: 0u16,
|
||||||
ws_port: 0u16,
|
ws_port: 0u16,
|
||||||
@ -115,8 +117,8 @@ impl Network {
|
|||||||
let this2 = this.clone();
|
let this2 = this.clone();
|
||||||
this.unlocked_inner
|
this.unlocked_inner
|
||||||
.update_network_class_task
|
.update_network_class_task
|
||||||
.set_routine(move |l, t| {
|
.set_routine(move |s, l, t| {
|
||||||
Box::pin(this2.clone().update_network_class_task_routine(l, t))
|
Box::pin(this2.clone().update_network_class_task_routine(s, l, t))
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -200,7 +202,7 @@ impl Network {
|
|||||||
Ok(config)
|
Ok(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_to_join_handles(&self, jh: JoinHandle<()>) {
|
fn add_to_join_handles(&self, jh: MustJoinHandle<()>) {
|
||||||
let mut inner = self.inner.lock();
|
let mut inner = self.inner.lock();
|
||||||
inner.join_handles.push(jh);
|
inner.join_handles.push(jh);
|
||||||
}
|
}
|
||||||
@ -506,17 +508,28 @@ impl Network {
|
|||||||
let network_manager = self.network_manager();
|
let network_manager = self.network_manager();
|
||||||
let routing_table = self.routing_table();
|
let routing_table = self.routing_table();
|
||||||
|
|
||||||
// Cancel all tasks
|
// Stop all tasks
|
||||||
if let Err(e) = self.unlocked_inner.update_network_class_task.cancel().await {
|
if let Err(e) = self.unlocked_inner.update_network_class_task.stop().await {
|
||||||
warn!("update_network_class_task not cancelled: {}", e);
|
error!("update_network_class_task not cancelled: {}", e);
|
||||||
}
|
}
|
||||||
|
let mut unord = FuturesUnordered::new();
|
||||||
|
{
|
||||||
|
let mut inner = self.inner.lock();
|
||||||
|
// Drop the stop
|
||||||
|
drop(inner.stop_source.take());
|
||||||
|
// take the join handles out
|
||||||
|
for h in inner.join_handles.drain(..) {
|
||||||
|
unord.push(h);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Wait for everything to stop
|
||||||
|
while unord.next().await.is_some() {}
|
||||||
|
|
||||||
// Drop all dial info
|
// Drop all dial info
|
||||||
routing_table.clear_dial_info_details(RoutingDomain::PublicInternet);
|
routing_table.clear_dial_info_details(RoutingDomain::PublicInternet);
|
||||||
routing_table.clear_dial_info_details(RoutingDomain::LocalNetwork);
|
routing_table.clear_dial_info_details(RoutingDomain::LocalNetwork);
|
||||||
|
|
||||||
// Reset state including network class
|
// Reset state including network class
|
||||||
// Cancels all async background tasks by dropping join handles
|
|
||||||
*self.inner.lock() = Self::new_inner(network_manager);
|
*self.inner.lock() = Self::new_inner(network_manager);
|
||||||
|
|
||||||
info!("network stopped");
|
info!("network stopped");
|
||||||
|
@ -465,7 +465,12 @@ impl Network {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(level = "trace", skip(self), err)]
|
#[instrument(level = "trace", skip(self), err)]
|
||||||
pub async fn update_network_class_task_routine(self, _l: u64, _t: u64) -> Result<(), String> {
|
pub async fn update_network_class_task_routine(
|
||||||
|
self,
|
||||||
|
stop_token: StopToken,
|
||||||
|
_l: u64,
|
||||||
|
_t: u64,
|
||||||
|
) -> Result<(), String> {
|
||||||
// Ensure we aren't trying to update this without clearing it first
|
// Ensure we aren't trying to update this without clearing it first
|
||||||
let old_network_class = self.inner.lock().network_class;
|
let old_network_class = self.inner.lock().network_class;
|
||||||
assert_eq!(old_network_class, None);
|
assert_eq!(old_network_class, None);
|
||||||
|
@ -2,6 +2,7 @@ use super::*;
|
|||||||
use crate::intf::*;
|
use crate::intf::*;
|
||||||
use async_tls::TlsAcceptor;
|
use async_tls::TlsAcceptor;
|
||||||
use sockets::*;
|
use sockets::*;
|
||||||
|
use stop_token::future::FutureExt;
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
@ -91,48 +92,24 @@ impl Network {
|
|||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn spawn_socket_listener(&self, addr: SocketAddr) -> Result<(), String> {
|
async fn tcp_acceptor(
|
||||||
// Get config
|
self,
|
||||||
let (connection_initial_timeout, tls_connection_initial_timeout) = {
|
tcp_stream: async_std::io::Result<TcpStream>,
|
||||||
let c = self.config.get();
|
listener_state: Arc<RwLock<ListenerState>>,
|
||||||
(
|
connection_manager: ConnectionManager,
|
||||||
ms_to_us(c.network.connection_initial_timeout_ms),
|
connection_initial_timeout: u64,
|
||||||
ms_to_us(c.network.tls.connection_initial_timeout_ms),
|
tls_connection_initial_timeout: u64,
|
||||||
)
|
) {
|
||||||
|
let tcp_stream = match tcp_stream {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => {
|
||||||
|
// If this happened our low-level listener socket probably died
|
||||||
|
// so it's time to restart the network
|
||||||
|
self.inner.lock().network_needs_restart = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create a reusable socket with no linger time, and no delay
|
|
||||||
let socket = new_bound_shared_tcp_socket(addr)?;
|
|
||||||
// Listen on the socket
|
|
||||||
socket
|
|
||||||
.listen(128)
|
|
||||||
.map_err(|e| format!("Couldn't listen on TCP socket: {}", e))?;
|
|
||||||
|
|
||||||
// Make an async tcplistener from the socket2 socket
|
|
||||||
let std_listener: std::net::TcpListener = socket.into();
|
|
||||||
let listener = TcpListener::from(std_listener);
|
|
||||||
|
|
||||||
debug!("spawn_socket_listener: binding successful to {}", addr);
|
|
||||||
|
|
||||||
// Create protocol handler records
|
|
||||||
let listener_state = Arc::new(RwLock::new(ListenerState::new()));
|
|
||||||
self.inner
|
|
||||||
.lock()
|
|
||||||
.listener_states
|
|
||||||
.insert(addr, listener_state.clone());
|
|
||||||
|
|
||||||
// Spawn the socket task
|
|
||||||
let this = self.clone();
|
|
||||||
let connection_manager = self.connection_manager();
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////
|
|
||||||
let jh = spawn(async move {
|
|
||||||
// moves listener object in and get incoming iterator
|
|
||||||
// when this task exists, the listener will close the socket
|
|
||||||
listener
|
|
||||||
.incoming()
|
|
||||||
.for_each_concurrent(None, |tcp_stream| async {
|
|
||||||
let tcp_stream = tcp_stream.unwrap();
|
|
||||||
let listener_state = listener_state.clone();
|
let listener_state = listener_state.clone();
|
||||||
let connection_manager = connection_manager.clone();
|
let connection_manager = connection_manager.clone();
|
||||||
|
|
||||||
@ -175,7 +152,7 @@ impl Network {
|
|||||||
let ls = listener_state.read().clone();
|
let ls = listener_state.read().clone();
|
||||||
|
|
||||||
let conn = if ls.tls_acceptor.is_some() && first_packet[0] == 0x16 {
|
let conn = if ls.tls_acceptor.is_some() && first_packet[0] == 0x16 {
|
||||||
this.try_tls_handlers(
|
self.try_tls_handlers(
|
||||||
ls.tls_acceptor.as_ref().unwrap(),
|
ls.tls_acceptor.as_ref().unwrap(),
|
||||||
ps,
|
ps,
|
||||||
tcp_stream,
|
tcp_stream,
|
||||||
@ -185,7 +162,7 @@ impl Network {
|
|||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
} else {
|
} else {
|
||||||
this.try_handlers(ps, tcp_stream, addr, &ls.protocol_accept_handlers)
|
self.try_handlers(ps, tcp_stream, addr, &ls.protocol_accept_handlers)
|
||||||
.await
|
.await
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -213,21 +190,74 @@ impl Network {
|
|||||||
{
|
{
|
||||||
log_net!(error "failed to register new connection: {}", e);
|
log_net!(error "failed to register new connection: {}", e);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn spawn_socket_listener(&self, addr: SocketAddr) -> Result<(), String> {
|
||||||
|
// Get config
|
||||||
|
let (connection_initial_timeout, tls_connection_initial_timeout) = {
|
||||||
|
let c = self.config.get();
|
||||||
|
(
|
||||||
|
ms_to_us(c.network.connection_initial_timeout_ms),
|
||||||
|
ms_to_us(c.network.tls.connection_initial_timeout_ms),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create a reusable socket with no linger time, and no delay
|
||||||
|
let socket = new_bound_shared_tcp_socket(addr)?;
|
||||||
|
// Listen on the socket
|
||||||
|
socket
|
||||||
|
.listen(128)
|
||||||
|
.map_err(|e| format!("Couldn't listen on TCP socket: {}", e))?;
|
||||||
|
|
||||||
|
// Make an async tcplistener from the socket2 socket
|
||||||
|
let std_listener: std::net::TcpListener = socket.into();
|
||||||
|
let listener = TcpListener::from(std_listener);
|
||||||
|
|
||||||
|
debug!("spawn_socket_listener: binding successful to {}", addr);
|
||||||
|
|
||||||
|
// Create protocol handler records
|
||||||
|
let listener_state = Arc::new(RwLock::new(ListenerState::new()));
|
||||||
|
self.inner
|
||||||
|
.lock()
|
||||||
|
.listener_states
|
||||||
|
.insert(addr, listener_state.clone());
|
||||||
|
|
||||||
|
// Spawn the socket task
|
||||||
|
let this = self.clone();
|
||||||
|
let stop_token = self.inner.lock().stop_source.as_ref().unwrap().token();
|
||||||
|
let connection_manager = self.connection_manager();
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////
|
||||||
|
let jh = spawn(async move {
|
||||||
|
// moves listener object in and get incoming iterator
|
||||||
|
// when this task exists, the listener will close the socket
|
||||||
|
let _ = listener
|
||||||
|
.incoming()
|
||||||
|
.for_each_concurrent(None, |tcp_stream| {
|
||||||
|
let this = this.clone();
|
||||||
|
let listener_state = listener_state.clone();
|
||||||
|
let connection_manager = connection_manager.clone();
|
||||||
|
Self::tcp_acceptor(
|
||||||
|
this,
|
||||||
|
tcp_stream,
|
||||||
|
listener_state,
|
||||||
|
connection_manager,
|
||||||
|
connection_initial_timeout,
|
||||||
|
tls_connection_initial_timeout,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
|
.timeout_at(stop_token)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
log_net!(debug "exited incoming loop for {}", addr);
|
log_net!(debug "exited incoming loop for {}", addr);
|
||||||
// Remove our listener state from this address if we're stopping
|
// Remove our listener state from this address if we're stopping
|
||||||
this.inner.lock().listener_states.remove(&addr);
|
this.inner.lock().listener_states.remove(&addr);
|
||||||
log_net!(debug "listener state removed for {}", addr);
|
log_net!(debug "listener state removed for {}", addr);
|
||||||
|
|
||||||
// If this happened our low-level listener socket probably died
|
|
||||||
// so it's time to restart the network
|
|
||||||
this.inner.lock().network_needs_restart = true;
|
|
||||||
});
|
});
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
// Add to join handles
|
// Add to join handles
|
||||||
self.add_to_join_handles(jh);
|
self.add_to_join_handles(MustJoinHandle::new(jh));
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
use sockets::*;
|
use sockets::*;
|
||||||
|
use stop_token::future::FutureExt;
|
||||||
|
|
||||||
impl Network {
|
impl Network {
|
||||||
pub(super) async fn create_udp_listener_tasks(&self) -> Result<(), String> {
|
pub(super) async fn create_udp_listener_tasks(&self) -> Result<(), String> {
|
||||||
@ -43,13 +44,21 @@ impl Network {
|
|||||||
// Spawn a local async task for each socket
|
// Spawn a local async task for each socket
|
||||||
let mut protocol_handlers_unordered = FuturesUnordered::new();
|
let mut protocol_handlers_unordered = FuturesUnordered::new();
|
||||||
let network_manager = this.network_manager();
|
let network_manager = this.network_manager();
|
||||||
|
let stop_token = this.inner.lock().stop_source.as_ref().unwrap().token();
|
||||||
|
|
||||||
for ph in protocol_handlers {
|
for ph in protocol_handlers {
|
||||||
let network_manager = network_manager.clone();
|
let network_manager = network_manager.clone();
|
||||||
|
let stop_token = stop_token.clone();
|
||||||
let jh = spawn_local(async move {
|
let jh = spawn_local(async move {
|
||||||
let mut data = vec![0u8; 65536];
|
let mut data = vec![0u8; 65536];
|
||||||
|
|
||||||
while let Ok((size, descriptor)) = ph.recv_message(&mut data).await {
|
loop {
|
||||||
|
match ph
|
||||||
|
.recv_message(&mut data)
|
||||||
|
.timeout_at(stop_token.clone())
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(Ok((size, descriptor))) => {
|
||||||
// XXX: Limit the number of packets from the same IP address?
|
// XXX: Limit the number of packets from the same IP address?
|
||||||
log_net!("UDP packet: {:?}", descriptor);
|
log_net!("UDP packet: {:?}", descriptor);
|
||||||
|
|
||||||
@ -67,23 +76,43 @@ impl Network {
|
|||||||
log_net!(error "failed to process received udp envelope: {}", e);
|
log_net!(error "failed to process received udp envelope: {}", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Ok(Err(_)) => {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
protocol_handlers_unordered.push(jh);
|
protocol_handlers_unordered.push(jh);
|
||||||
}
|
}
|
||||||
// Now we wait for any join handle to exit,
|
// Now we wait for join handles to exit,
|
||||||
// which would indicate an error needing
|
// if any error out it indicates an error needing
|
||||||
// us to completely restart the network
|
// us to completely restart the network
|
||||||
let _ = protocol_handlers_unordered.next().await;
|
loop {
|
||||||
|
match protocol_handlers_unordered.next().await {
|
||||||
|
Some(v) => {
|
||||||
|
// true = stopped, false = errored
|
||||||
|
if !v {
|
||||||
|
// If any protocol handler fails, our socket died and we need to restart the network
|
||||||
|
this.inner.lock().network_needs_restart = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
// All protocol handlers exited
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
trace!("UDP listener task stopped");
|
trace!("UDP listener task stopped");
|
||||||
// If this loop fails, our socket died and we need to restart the network
|
|
||||||
this.inner.lock().network_needs_restart = true;
|
|
||||||
});
|
});
|
||||||
////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
// Add to join handle
|
// Add to join handle
|
||||||
self.add_to_join_handles(jh);
|
self.add_to_join_handles(MustJoinHandle::new(jh));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
use futures_util::{FutureExt, StreamExt};
|
use futures_util::{FutureExt, StreamExt};
|
||||||
|
use stop_token::prelude::*;
|
||||||
|
|
||||||
cfg_if::cfg_if! {
|
cfg_if::cfg_if! {
|
||||||
if #[cfg(target_arch = "wasm32")] {
|
if #[cfg(target_arch = "wasm32")] {
|
||||||
@ -84,7 +85,7 @@ pub struct NetworkConnectionStats {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct NetworkConnection {
|
pub struct NetworkConnection {
|
||||||
descriptor: ConnectionDescriptor,
|
descriptor: ConnectionDescriptor,
|
||||||
_processor: Option<JoinHandle<()>>,
|
processor: Option<MustJoinHandle<()>>,
|
||||||
established_time: u64,
|
established_time: u64,
|
||||||
stats: Arc<Mutex<NetworkConnectionStats>>,
|
stats: Arc<Mutex<NetworkConnectionStats>>,
|
||||||
sender: flume::Sender<Vec<u8>>,
|
sender: flume::Sender<Vec<u8>>,
|
||||||
@ -97,7 +98,7 @@ impl NetworkConnection {
|
|||||||
|
|
||||||
Self {
|
Self {
|
||||||
descriptor,
|
descriptor,
|
||||||
_processor: None,
|
processor: None,
|
||||||
established_time: intf::get_timestamp(),
|
established_time: intf::get_timestamp(),
|
||||||
stats: Arc::new(Mutex::new(NetworkConnectionStats {
|
stats: Arc::new(Mutex::new(NetworkConnectionStats {
|
||||||
last_message_sent_time: None,
|
last_message_sent_time: None,
|
||||||
@ -109,6 +110,7 @@ impl NetworkConnection {
|
|||||||
|
|
||||||
pub(super) fn from_protocol(
|
pub(super) fn from_protocol(
|
||||||
connection_manager: ConnectionManager,
|
connection_manager: ConnectionManager,
|
||||||
|
stop_token: StopToken,
|
||||||
protocol_connection: ProtocolNetworkConnection,
|
protocol_connection: ProtocolNetworkConnection,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Get timeout
|
// Get timeout
|
||||||
@ -132,19 +134,20 @@ impl NetworkConnection {
|
|||||||
}));
|
}));
|
||||||
|
|
||||||
// Spawn connection processor and pass in protocol connection
|
// Spawn connection processor and pass in protocol connection
|
||||||
let processor = intf::spawn_local(Self::process_connection(
|
let processor = MustJoinHandle::new(intf::spawn_local(Self::process_connection(
|
||||||
connection_manager,
|
connection_manager,
|
||||||
|
stop_token,
|
||||||
descriptor.clone(),
|
descriptor.clone(),
|
||||||
receiver,
|
receiver,
|
||||||
protocol_connection,
|
protocol_connection,
|
||||||
inactivity_timeout,
|
inactivity_timeout,
|
||||||
stats.clone(),
|
stats.clone(),
|
||||||
));
|
)));
|
||||||
|
|
||||||
// Return the connection
|
// Return the connection
|
||||||
Self {
|
Self {
|
||||||
descriptor,
|
descriptor,
|
||||||
_processor: Some(processor),
|
processor: Some(processor),
|
||||||
established_time: intf::get_timestamp(),
|
established_time: intf::get_timestamp(),
|
||||||
stats,
|
stats,
|
||||||
sender,
|
sender,
|
||||||
@ -197,6 +200,7 @@ impl NetworkConnection {
|
|||||||
// Connection receiver loop
|
// Connection receiver loop
|
||||||
fn process_connection(
|
fn process_connection(
|
||||||
connection_manager: ConnectionManager,
|
connection_manager: ConnectionManager,
|
||||||
|
stop_token: StopToken,
|
||||||
descriptor: ConnectionDescriptor,
|
descriptor: ConnectionDescriptor,
|
||||||
receiver: flume::Receiver<Vec<u8>>,
|
receiver: flume::Receiver<Vec<u8>>,
|
||||||
protocol_connection: ProtocolNetworkConnection,
|
protocol_connection: ProtocolNetworkConnection,
|
||||||
@ -289,26 +293,28 @@ impl NetworkConnection {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process futures
|
// Process futures
|
||||||
match unord.next().await {
|
match unord.next().timeout_at(stop_token.clone()).await {
|
||||||
Some(RecvLoopAction::Send) => {
|
Ok(Some(RecvLoopAction::Send)) => {
|
||||||
// Don't reset inactivity timer if we're only sending
|
// Don't reset inactivity timer if we're only sending
|
||||||
|
|
||||||
need_sender = true;
|
need_sender = true;
|
||||||
}
|
}
|
||||||
Some(RecvLoopAction::Recv) => {
|
Ok(Some(RecvLoopAction::Recv)) => {
|
||||||
// Reset inactivity timer since we got something from this connection
|
// Reset inactivity timer since we got something from this connection
|
||||||
timer.set(new_timer());
|
timer.set(new_timer());
|
||||||
|
|
||||||
need_receiver = true;
|
need_receiver = true;
|
||||||
}
|
}
|
||||||
Some(RecvLoopAction::Finish) | Some(RecvLoopAction::Timeout) => {
|
Ok(Some(RecvLoopAction::Finish) | Some(RecvLoopAction::Timeout)) => {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
Ok(None) => {
|
||||||
None => {
|
|
||||||
// Should not happen
|
// Should not happen
|
||||||
unreachable!();
|
unreachable!();
|
||||||
}
|
}
|
||||||
|
Err(_) => {
|
||||||
|
// Stop token
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -317,9 +323,23 @@ impl NetworkConnection {
|
|||||||
descriptor.green()
|
descriptor.green()
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Let the connection manager know the receive loop exited
|
||||||
connection_manager
|
connection_manager
|
||||||
.report_connection_finished(descriptor)
|
.report_connection_finished(descriptor)
|
||||||
.await
|
.await;
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Resolves ready when the connection loop has terminated
|
||||||
|
impl Future for NetworkConnection {
|
||||||
|
type Output = ();
|
||||||
|
|
||||||
|
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
|
||||||
|
if let Some(mut processor) = self.processor.as_mut() {
|
||||||
|
Pin::new(&mut processor).poll(cx)
|
||||||
|
} else {
|
||||||
|
task::Poll::Ready(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -4,6 +4,7 @@ use dht::*;
|
|||||||
use futures_util::stream::{FuturesUnordered, StreamExt};
|
use futures_util::stream::{FuturesUnordered, StreamExt};
|
||||||
use network_manager::*;
|
use network_manager::*;
|
||||||
use routing_table::*;
|
use routing_table::*;
|
||||||
|
use stop_token::future::FutureExt;
|
||||||
use xx::*;
|
use xx::*;
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||||
@ -170,7 +171,8 @@ pub struct ReceiptManagerInner {
|
|||||||
network_manager: NetworkManager,
|
network_manager: NetworkManager,
|
||||||
records_by_nonce: BTreeMap<ReceiptNonce, Arc<Mutex<ReceiptRecord>>>,
|
records_by_nonce: BTreeMap<ReceiptNonce, Arc<Mutex<ReceiptRecord>>>,
|
||||||
next_oldest_ts: Option<u64>,
|
next_oldest_ts: Option<u64>,
|
||||||
timeout_task: SingleFuture<()>,
|
stop_source: Option<StopSource>,
|
||||||
|
timeout_task: MustJoinSingleFuture<()>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -184,7 +186,8 @@ impl ReceiptManager {
|
|||||||
network_manager,
|
network_manager,
|
||||||
records_by_nonce: BTreeMap::new(),
|
records_by_nonce: BTreeMap::new(),
|
||||||
next_oldest_ts: None,
|
next_oldest_ts: None,
|
||||||
timeout_task: SingleFuture::new(),
|
stop_source: None,
|
||||||
|
timeout_task: MustJoinSingleFuture::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -201,13 +204,14 @@ impl ReceiptManager {
|
|||||||
pub async fn startup(&self) -> Result<(), String> {
|
pub async fn startup(&self) -> Result<(), String> {
|
||||||
trace!("startup receipt manager");
|
trace!("startup receipt manager");
|
||||||
// Retrieve config
|
// Retrieve config
|
||||||
/*
|
|
||||||
{
|
{
|
||||||
let config = self.core().config();
|
// let config = self.core().config();
|
||||||
let c = config.get();
|
// let c = config.get();
|
||||||
let mut inner = self.inner.lock();
|
let mut inner = self.inner.lock();
|
||||||
|
inner.stop_source = Some(StopSource::new());
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -235,7 +239,7 @@ impl ReceiptManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(level = "trace", skip(self))]
|
#[instrument(level = "trace", skip(self))]
|
||||||
pub async fn timeout_task_routine(self, now: u64) {
|
pub async fn timeout_task_routine(self, now: u64, stop_token: StopToken) {
|
||||||
// Go through all receipts and build a list of expired nonces
|
// Go through all receipts and build a list of expired nonces
|
||||||
let mut new_next_oldest_ts: Option<u64> = None;
|
let mut new_next_oldest_ts: Option<u64> = None;
|
||||||
let mut expired_records = Vec::new();
|
let mut expired_records = Vec::new();
|
||||||
@ -276,13 +280,25 @@ impl ReceiptManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Wait on all the multi-call callbacks
|
// Wait on all the multi-call callbacks
|
||||||
while callbacks.next().await.is_some() {}
|
loop {
|
||||||
|
match callbacks.next().timeout_at(stop_token.clone()).await {
|
||||||
|
Ok(Some(_)) => {}
|
||||||
|
Ok(None) | Err(_) => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn tick(&self) -> Result<(), String> {
|
pub async fn tick(&self) -> Result<(), String> {
|
||||||
let (next_oldest_ts, timeout_task) = {
|
let (next_oldest_ts, timeout_task, stop_token) = {
|
||||||
let inner = self.inner.lock();
|
let inner = self.inner.lock();
|
||||||
(inner.next_oldest_ts, inner.timeout_task.clone())
|
let stop_token = match inner.stop_source.as_ref() {
|
||||||
|
Some(ss) => ss.token(),
|
||||||
|
None => {
|
||||||
|
// Do nothing if we're shutting down
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
(inner.next_oldest_ts, inner.timeout_task.clone(), stop_token)
|
||||||
};
|
};
|
||||||
let now = intf::get_timestamp();
|
let now = intf::get_timestamp();
|
||||||
// If we have at least one timestamp to expire, lets do it
|
// If we have at least one timestamp to expire, lets do it
|
||||||
@ -290,7 +306,7 @@ impl ReceiptManager {
|
|||||||
if now >= next_oldest_ts {
|
if now >= next_oldest_ts {
|
||||||
// Single-spawn the timeout task routine
|
// Single-spawn the timeout task routine
|
||||||
let _ = timeout_task
|
let _ = timeout_task
|
||||||
.single_spawn(self.clone().timeout_task_routine(now))
|
.single_spawn(self.clone().timeout_task_routine(now, stop_token))
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -299,6 +315,20 @@ impl ReceiptManager {
|
|||||||
|
|
||||||
pub async fn shutdown(&self) {
|
pub async fn shutdown(&self) {
|
||||||
let network_manager = self.network_manager();
|
let network_manager = self.network_manager();
|
||||||
|
|
||||||
|
// Stop all tasks
|
||||||
|
let timeout_task = {
|
||||||
|
let mut inner = self.inner.lock();
|
||||||
|
// Drop the stop
|
||||||
|
drop(inner.stop_source.take());
|
||||||
|
inner.timeout_task.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Wait for everything to stop
|
||||||
|
if !timeout_task.join().await.is_ok() {
|
||||||
|
panic!("joining timeout task failed");
|
||||||
|
}
|
||||||
|
|
||||||
*self.inner.lock() = Self::new_inner(network_manager);
|
*self.inner.lock() = Self::new_inner(network_manager);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -410,9 +440,16 @@ impl ReceiptManager {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Increment return count
|
// Increment return count
|
||||||
let callback_future = {
|
let (callback_future, stop_token) = {
|
||||||
// Look up the receipt record from the nonce
|
// Look up the receipt record from the nonce
|
||||||
let mut inner = self.inner.lock();
|
let mut inner = self.inner.lock();
|
||||||
|
let stop_token = match inner.stop_source.as_ref() {
|
||||||
|
Some(ss) => ss.token(),
|
||||||
|
None => {
|
||||||
|
// If we're stopping do nothing here
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
let record = match inner.records_by_nonce.get(&receipt_nonce) {
|
let record = match inner.records_by_nonce.get(&receipt_nonce) {
|
||||||
Some(r) => r.clone(),
|
Some(r) => r.clone(),
|
||||||
None => {
|
None => {
|
||||||
@ -438,12 +475,12 @@ impl ReceiptManager {
|
|||||||
|
|
||||||
Self::update_next_oldest_timestamp(&mut *inner);
|
Self::update_next_oldest_timestamp(&mut *inner);
|
||||||
}
|
}
|
||||||
callback_future
|
(callback_future, stop_token)
|
||||||
};
|
};
|
||||||
|
|
||||||
// Issue the callback
|
// Issue the callback
|
||||||
if let Some(callback_future) = callback_future {
|
if let Some(callback_future) = callback_future {
|
||||||
callback_future.await;
|
let _ = callback_future.timeout_at(stop_token).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -71,7 +71,7 @@ struct RoutingTableUnlockedInner {
|
|||||||
bootstrap_task: TickTask,
|
bootstrap_task: TickTask,
|
||||||
peer_minimum_refresh_task: TickTask,
|
peer_minimum_refresh_task: TickTask,
|
||||||
ping_validator_task: TickTask,
|
ping_validator_task: TickTask,
|
||||||
node_info_update_single_future: SingleFuture<()>,
|
node_info_update_single_future: MustJoinSingleFuture<()>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -103,7 +103,7 @@ impl RoutingTable {
|
|||||||
bootstrap_task: TickTask::new(1),
|
bootstrap_task: TickTask::new(1),
|
||||||
peer_minimum_refresh_task: TickTask::new_ms(c.network.dht.min_peer_refresh_time_ms),
|
peer_minimum_refresh_task: TickTask::new_ms(c.network.dht.min_peer_refresh_time_ms),
|
||||||
ping_validator_task: TickTask::new(1),
|
ping_validator_task: TickTask::new(1),
|
||||||
node_info_update_single_future: SingleFuture::new(),
|
node_info_update_single_future: MustJoinSingleFuture::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub fn new(network_manager: NetworkManager) -> Self {
|
pub fn new(network_manager: NetworkManager) -> Self {
|
||||||
@ -118,8 +118,8 @@ impl RoutingTable {
|
|||||||
let this2 = this.clone();
|
let this2 = this.clone();
|
||||||
this.unlocked_inner
|
this.unlocked_inner
|
||||||
.rolling_transfers_task
|
.rolling_transfers_task
|
||||||
.set_routine(move |l, t| {
|
.set_routine(move |s, l, t| {
|
||||||
Box::pin(this2.clone().rolling_transfers_task_routine(l, t))
|
Box::pin(this2.clone().rolling_transfers_task_routine(s, l, t))
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
// Set bootstrap tick task
|
// Set bootstrap tick task
|
||||||
@ -127,15 +127,15 @@ impl RoutingTable {
|
|||||||
let this2 = this.clone();
|
let this2 = this.clone();
|
||||||
this.unlocked_inner
|
this.unlocked_inner
|
||||||
.bootstrap_task
|
.bootstrap_task
|
||||||
.set_routine(move |_l, _t| Box::pin(this2.clone().bootstrap_task_routine()));
|
.set_routine(move |s, _l, _t| Box::pin(this2.clone().bootstrap_task_routine(s)));
|
||||||
}
|
}
|
||||||
// Set peer minimum refresh tick task
|
// Set peer minimum refresh tick task
|
||||||
{
|
{
|
||||||
let this2 = this.clone();
|
let this2 = this.clone();
|
||||||
this.unlocked_inner
|
this.unlocked_inner
|
||||||
.peer_minimum_refresh_task
|
.peer_minimum_refresh_task
|
||||||
.set_routine(move |_l, _t| {
|
.set_routine(move |s, _l, _t| {
|
||||||
Box::pin(this2.clone().peer_minimum_refresh_task_routine())
|
Box::pin(this2.clone().peer_minimum_refresh_task_routine(s))
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
// Set ping validator tick task
|
// Set ping validator tick task
|
||||||
@ -143,7 +143,9 @@ impl RoutingTable {
|
|||||||
let this2 = this.clone();
|
let this2 = this.clone();
|
||||||
this.unlocked_inner
|
this.unlocked_inner
|
||||||
.ping_validator_task
|
.ping_validator_task
|
||||||
.set_routine(move |l, t| Box::pin(this2.clone().ping_validator_task_routine(l, t)));
|
.set_routine(move |s, l, t| {
|
||||||
|
Box::pin(this2.clone().ping_validator_task_routine(s, l, t))
|
||||||
|
});
|
||||||
}
|
}
|
||||||
this
|
this
|
||||||
}
|
}
|
||||||
@ -373,26 +375,26 @@ impl RoutingTable {
|
|||||||
|
|
||||||
pub async fn terminate(&self) {
|
pub async fn terminate(&self) {
|
||||||
// Cancel all tasks being ticked
|
// Cancel all tasks being ticked
|
||||||
if let Err(e) = self.unlocked_inner.rolling_transfers_task.cancel().await {
|
if let Err(e) = self.unlocked_inner.rolling_transfers_task.stop().await {
|
||||||
warn!("rolling_transfers_task not cancelled: {}", e);
|
error!("rolling_transfers_task not stopped: {}", e);
|
||||||
}
|
}
|
||||||
if let Err(e) = self.unlocked_inner.bootstrap_task.cancel().await {
|
if let Err(e) = self.unlocked_inner.bootstrap_task.stop().await {
|
||||||
warn!("bootstrap_task not cancelled: {}", e);
|
error!("bootstrap_task not stopped: {}", e);
|
||||||
}
|
}
|
||||||
if let Err(e) = self.unlocked_inner.peer_minimum_refresh_task.cancel().await {
|
if let Err(e) = self.unlocked_inner.peer_minimum_refresh_task.stop().await {
|
||||||
warn!("peer_minimum_refresh_task not cancelled: {}", e);
|
error!("peer_minimum_refresh_task not stopped: {}", e);
|
||||||
}
|
}
|
||||||
if let Err(e) = self.unlocked_inner.ping_validator_task.cancel().await {
|
if let Err(e) = self.unlocked_inner.ping_validator_task.stop().await {
|
||||||
warn!("ping_validator_task not cancelled: {}", e);
|
error!("ping_validator_task not stopped: {}", e);
|
||||||
}
|
}
|
||||||
if self
|
if self
|
||||||
.unlocked_inner
|
.unlocked_inner
|
||||||
.node_info_update_single_future
|
.node_info_update_single_future
|
||||||
.cancel()
|
.join()
|
||||||
.await
|
.await
|
||||||
.is_err()
|
.is_err()
|
||||||
{
|
{
|
||||||
warn!("node_info_update_single_future not cancelled");
|
error!("node_info_update_single_future not stopped");
|
||||||
}
|
}
|
||||||
|
|
||||||
*self.inner.lock() = Self::new_inner(self.network_manager());
|
*self.inner.lock() = Self::new_inner(self.network_manager());
|
||||||
@ -990,7 +992,7 @@ impl RoutingTable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(level = "trace", skip(self), err)]
|
#[instrument(level = "trace", skip(self), err)]
|
||||||
async fn bootstrap_task_routine(self) -> Result<(), String> {
|
async fn bootstrap_task_routine(self, stop_token: StopToken) -> Result<(), String> {
|
||||||
let (bootstrap, bootstrap_nodes) = {
|
let (bootstrap, bootstrap_nodes) = {
|
||||||
let c = self.config.get();
|
let c = self.config.get();
|
||||||
(
|
(
|
||||||
@ -1093,7 +1095,7 @@ impl RoutingTable {
|
|||||||
// Ask our remaining peers to give us more peers before we go
|
// Ask our remaining peers to give us more peers before we go
|
||||||
// back to the bootstrap servers to keep us from bothering them too much
|
// back to the bootstrap servers to keep us from bothering them too much
|
||||||
#[instrument(level = "trace", skip(self), err)]
|
#[instrument(level = "trace", skip(self), err)]
|
||||||
async fn peer_minimum_refresh_task_routine(self) -> Result<(), String> {
|
async fn peer_minimum_refresh_task_routine(self, stop_token: StopToken) -> Result<(), String> {
|
||||||
// get list of all peers we know about, even the unreliable ones, and ask them to find nodes close to our node too
|
// get list of all peers we know about, even the unreliable ones, and ask them to find nodes close to our node too
|
||||||
let noderefs = {
|
let noderefs = {
|
||||||
let mut inner = self.inner.lock();
|
let mut inner = self.inner.lock();
|
||||||
@ -1125,7 +1127,12 @@ impl RoutingTable {
|
|||||||
// Ping each node in the routing table if they need to be pinged
|
// Ping each node in the routing table if they need to be pinged
|
||||||
// to determine their reliability
|
// to determine their reliability
|
||||||
#[instrument(level = "trace", skip(self), err)]
|
#[instrument(level = "trace", skip(self), err)]
|
||||||
async fn ping_validator_task_routine(self, _last_ts: u64, cur_ts: u64) -> Result<(), String> {
|
async fn ping_validator_task_routine(
|
||||||
|
self,
|
||||||
|
stop_token: StopToken,
|
||||||
|
_last_ts: u64,
|
||||||
|
cur_ts: u64,
|
||||||
|
) -> Result<(), String> {
|
||||||
// log_rtab!("--- ping_validator task");
|
// log_rtab!("--- ping_validator task");
|
||||||
|
|
||||||
let rpc = self.rpc_processor();
|
let rpc = self.rpc_processor();
|
||||||
@ -1144,7 +1151,9 @@ impl RoutingTable {
|
|||||||
nr,
|
nr,
|
||||||
e.state_debug_info(cur_ts)
|
e.state_debug_info(cur_ts)
|
||||||
);
|
);
|
||||||
unord.push(intf::spawn_local(rpc.clone().rpc_call_status(nr)));
|
unord.push(MustJoinHandle::new(intf::spawn_local(
|
||||||
|
rpc.clone().rpc_call_status(nr),
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
Option::<()>::None
|
Option::<()>::None
|
||||||
});
|
});
|
||||||
@ -1158,7 +1167,12 @@ impl RoutingTable {
|
|||||||
|
|
||||||
// Compute transfer statistics to determine how 'fast' a node is
|
// Compute transfer statistics to determine how 'fast' a node is
|
||||||
#[instrument(level = "trace", skip(self), err)]
|
#[instrument(level = "trace", skip(self), err)]
|
||||||
async fn rolling_transfers_task_routine(self, last_ts: u64, cur_ts: u64) -> Result<(), String> {
|
async fn rolling_transfers_task_routine(
|
||||||
|
self,
|
||||||
|
stop_token: StopToken,
|
||||||
|
last_ts: u64,
|
||||||
|
cur_ts: u64,
|
||||||
|
) -> Result<(), String> {
|
||||||
// log_rtab!("--- rolling_transfers task");
|
// log_rtab!("--- rolling_transfers task");
|
||||||
let inner = &mut *self.inner.lock();
|
let inner = &mut *self.inner.lock();
|
||||||
|
|
||||||
|
@ -10,9 +10,11 @@ use crate::intf::*;
|
|||||||
use crate::xx::*;
|
use crate::xx::*;
|
||||||
use capnp::message::ReaderSegments;
|
use capnp::message::ReaderSegments;
|
||||||
use coders::*;
|
use coders::*;
|
||||||
|
use futures_util::StreamExt;
|
||||||
use network_manager::*;
|
use network_manager::*;
|
||||||
use receipt_manager::*;
|
use receipt_manager::*;
|
||||||
use routing_table::*;
|
use routing_table::*;
|
||||||
|
use stop_token::future::FutureExt;
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
@ -167,7 +169,8 @@ pub struct RPCProcessorInner {
|
|||||||
timeout: u64,
|
timeout: u64,
|
||||||
max_route_hop_count: usize,
|
max_route_hop_count: usize,
|
||||||
waiting_rpc_table: BTreeMap<OperationId, EventualValue<RPCMessageReader>>,
|
waiting_rpc_table: BTreeMap<OperationId, EventualValue<RPCMessageReader>>,
|
||||||
worker_join_handles: Vec<JoinHandle<()>>,
|
stop_source: Option<StopSource>,
|
||||||
|
worker_join_handles: Vec<MustJoinHandle<()>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -189,6 +192,7 @@ impl RPCProcessor {
|
|||||||
timeout: 10000000,
|
timeout: 10000000,
|
||||||
max_route_hop_count: 7,
|
max_route_hop_count: 7,
|
||||||
waiting_rpc_table: BTreeMap::new(),
|
waiting_rpc_table: BTreeMap::new(),
|
||||||
|
stop_source: None,
|
||||||
worker_join_handles: Vec::new(),
|
worker_join_handles: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1368,8 +1372,8 @@ impl RPCProcessor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn rpc_worker(self, receiver: flume::Receiver<RPCMessage>) {
|
async fn rpc_worker(self, stop_token: StopToken, receiver: flume::Receiver<RPCMessage>) {
|
||||||
while let Ok(msg) = receiver.recv_async().await {
|
while let Ok(Ok(msg)) = receiver.recv_async().timeout_at(stop_token.clone()).await {
|
||||||
let _ = self
|
let _ = self
|
||||||
.process_rpc_message(msg)
|
.process_rpc_message(msg)
|
||||||
.await
|
.await
|
||||||
@ -1409,20 +1413,37 @@ impl RPCProcessor {
|
|||||||
inner.max_route_hop_count = max_route_hop_count;
|
inner.max_route_hop_count = max_route_hop_count;
|
||||||
let channel = flume::bounded(queue_size as usize);
|
let channel = flume::bounded(queue_size as usize);
|
||||||
inner.send_channel = Some(channel.0.clone());
|
inner.send_channel = Some(channel.0.clone());
|
||||||
|
inner.stop_source = Some(StopSource::new());
|
||||||
|
|
||||||
// spin up N workers
|
// spin up N workers
|
||||||
trace!("Spinning up {} RPC workers", concurrency);
|
trace!("Spinning up {} RPC workers", concurrency);
|
||||||
for _ in 0..concurrency {
|
for _ in 0..concurrency {
|
||||||
let this = self.clone();
|
let this = self.clone();
|
||||||
let receiver = channel.1.clone();
|
let receiver = channel.1.clone();
|
||||||
let jh = spawn(Self::rpc_worker(this, receiver));
|
let jh = spawn(Self::rpc_worker(this, inner.stop_source.as_ref().unwrap().token(), receiver));
|
||||||
inner.worker_join_handles.push(jh);
|
inner.worker_join_handles.push(MustJoinHandle::new(jh));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn shutdown(&self) {
|
pub async fn shutdown(&self) {
|
||||||
|
// Stop the rpc workers
|
||||||
|
let mut unord = FuturesUnordered::new();
|
||||||
|
{
|
||||||
|
let mut inner = self.inner.lock();
|
||||||
|
// drop the stop
|
||||||
|
drop(inner.stop_source.take());
|
||||||
|
// take the join handles out
|
||||||
|
for h in inner.worker_join_handles.drain(..) {
|
||||||
|
unord.push(h);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for them to complete
|
||||||
|
while unord.next().await.is_some() {}
|
||||||
|
|
||||||
|
// Release the rpc processor
|
||||||
*self.inner.lock() = Self::new_inner(self.network_manager());
|
*self.inner.lock() = Self::new_inner(self.network_manager());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -545,6 +545,43 @@ pub async fn test_single_future() {
|
|||||||
assert_eq!(sf.check().await, Ok(None));
|
assert_eq!(sf.check().await, Ok(None));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn test_must_join_single_future() {
|
||||||
|
info!("testing must join single future");
|
||||||
|
let sf = MustJoinSingleFuture::<u32>::new();
|
||||||
|
assert_eq!(sf.check().await, Ok(None));
|
||||||
|
assert_eq!(
|
||||||
|
sf.single_spawn(async {
|
||||||
|
intf::sleep(2000).await;
|
||||||
|
69
|
||||||
|
})
|
||||||
|
.await,
|
||||||
|
Ok(None)
|
||||||
|
);
|
||||||
|
assert_eq!(sf.check().await, Ok(None));
|
||||||
|
assert_eq!(sf.single_spawn(async { panic!() }).await, Ok(None));
|
||||||
|
assert_eq!(sf.join().await, Ok(Some(69)));
|
||||||
|
assert_eq!(
|
||||||
|
sf.single_spawn(async {
|
||||||
|
intf::sleep(1000).await;
|
||||||
|
37
|
||||||
|
})
|
||||||
|
.await,
|
||||||
|
Ok(None)
|
||||||
|
);
|
||||||
|
intf::sleep(2000).await;
|
||||||
|
assert_eq!(
|
||||||
|
sf.single_spawn(async {
|
||||||
|
intf::sleep(1000).await;
|
||||||
|
27
|
||||||
|
})
|
||||||
|
.await,
|
||||||
|
Ok(Some(37))
|
||||||
|
);
|
||||||
|
intf::sleep(2000).await;
|
||||||
|
assert_eq!(sf.join().await, Ok(Some(27)));
|
||||||
|
assert_eq!(sf.check().await, Ok(None));
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn test_tools() {
|
pub async fn test_tools() {
|
||||||
info!("testing retry_falloff_log");
|
info!("testing retry_falloff_log");
|
||||||
let mut last_us = 0u64;
|
let mut last_us = 0u64;
|
||||||
@ -568,6 +605,7 @@ pub async fn test_all() {
|
|||||||
#[cfg(not(target_arch = "wasm32"))]
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
test_network_interfaces().await;
|
test_network_interfaces().await;
|
||||||
test_single_future().await;
|
test_single_future().await;
|
||||||
|
test_must_join_single_future().await;
|
||||||
test_eventual().await;
|
test_eventual().await;
|
||||||
test_eventual_value().await;
|
test_eventual_value().await;
|
||||||
test_eventual_value_clone().await;
|
test_eventual_value_clone().await;
|
||||||
|
@ -8,6 +8,8 @@ mod eventual_value_clone;
|
|||||||
mod ip_addr_port;
|
mod ip_addr_port;
|
||||||
mod ip_extra;
|
mod ip_extra;
|
||||||
mod log_thru;
|
mod log_thru;
|
||||||
|
mod must_join_handle;
|
||||||
|
mod must_join_single_future;
|
||||||
mod mutable_future;
|
mod mutable_future;
|
||||||
mod single_future;
|
mod single_future;
|
||||||
mod single_shot_eventual;
|
mod single_shot_eventual;
|
||||||
@ -25,6 +27,7 @@ pub use owo_colors::OwoColorize;
|
|||||||
pub use parking_lot::*;
|
pub use parking_lot::*;
|
||||||
pub use split_url::*;
|
pub use split_url::*;
|
||||||
pub use static_assertions::*;
|
pub use static_assertions::*;
|
||||||
|
pub use stop_token::*;
|
||||||
pub use tracing::*;
|
pub use tracing::*;
|
||||||
|
|
||||||
pub type PinBox<T> = Pin<Box<T>>;
|
pub type PinBox<T> = Pin<Box<T>>;
|
||||||
@ -105,6 +108,8 @@ pub use eventual_value::*;
|
|||||||
pub use eventual_value_clone::*;
|
pub use eventual_value_clone::*;
|
||||||
pub use ip_addr_port::*;
|
pub use ip_addr_port::*;
|
||||||
pub use ip_extra::*;
|
pub use ip_extra::*;
|
||||||
|
pub use must_join_handle::*;
|
||||||
|
pub use must_join_single_future::*;
|
||||||
pub use mutable_future::*;
|
pub use mutable_future::*;
|
||||||
pub use single_future::*;
|
pub use single_future::*;
|
||||||
pub use single_shot_eventual::*;
|
pub use single_shot_eventual::*;
|
||||||
|
43
veilid-core/src/xx/must_join_handle.rs
Normal file
43
veilid-core/src/xx/must_join_handle.rs
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
use async_executors::JoinHandle;
|
||||||
|
use core::future::Future;
|
||||||
|
use core::pin::Pin;
|
||||||
|
use core::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use core::task::{Context, Poll};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct MustJoinHandle<T> {
|
||||||
|
join_handle: JoinHandle<T>,
|
||||||
|
completed: AtomicBool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> MustJoinHandle<T> {
|
||||||
|
pub fn new(join_handle: JoinHandle<T>) -> Self {
|
||||||
|
Self {
|
||||||
|
join_handle,
|
||||||
|
completed: AtomicBool::new(false),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Drop for MustJoinHandle<T> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
// panic if we haven't completed
|
||||||
|
if !self.completed.load(Ordering::Relaxed) {
|
||||||
|
panic!("MustJoinHandle was not completed upon drop. Add cooperative cancellation where appropriate to ensure this is completed before drop.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: 'static> Future for MustJoinHandle<T> {
|
||||||
|
type Output = T;
|
||||||
|
|
||||||
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
|
match Pin::new(&mut self.join_handle).poll(cx) {
|
||||||
|
Poll::Ready(t) => {
|
||||||
|
self.completed.store(true, Ordering::Relaxed);
|
||||||
|
Poll::Ready(t)
|
||||||
|
}
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
213
veilid-core/src/xx/must_join_single_future.rs
Normal file
213
veilid-core/src/xx/must_join_single_future.rs
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
use super::*;
|
||||||
|
use crate::intf::*;
|
||||||
|
use cfg_if::*;
|
||||||
|
use core::task::Poll;
|
||||||
|
use futures_util::poll;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct MustJoinSingleFutureInner<T>
|
||||||
|
where
|
||||||
|
T: 'static,
|
||||||
|
{
|
||||||
|
locked: bool,
|
||||||
|
join_handle: Option<MustJoinHandle<T>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Spawns a single background processing task idempotently, possibly returning the return value of the previously executed background task
|
||||||
|
/// This does not queue, just ensures that no more than a single copy of the task is running at a time, but allowing tasks to be retriggered
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MustJoinSingleFuture<T>
|
||||||
|
where
|
||||||
|
T: 'static,
|
||||||
|
{
|
||||||
|
inner: Arc<Mutex<MustJoinSingleFutureInner<T>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Default for MustJoinSingleFuture<T>
|
||||||
|
where
|
||||||
|
T: 'static,
|
||||||
|
{
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> MustJoinSingleFuture<T>
|
||||||
|
where
|
||||||
|
T: 'static,
|
||||||
|
{
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
inner: Arc::new(Mutex::new(MustJoinSingleFutureInner {
|
||||||
|
locked: false,
|
||||||
|
join_handle: None,
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_lock(&self) -> Result<Option<MustJoinHandle<T>>, ()> {
|
||||||
|
let mut inner = self.inner.lock();
|
||||||
|
if inner.locked {
|
||||||
|
// If already locked error out
|
||||||
|
return Err(());
|
||||||
|
}
|
||||||
|
inner.locked = true;
|
||||||
|
// If we got the lock, return what we have for a join handle if anything
|
||||||
|
Ok(inner.join_handle.take())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unlock(&self, jh: Option<MustJoinHandle<T>>) {
|
||||||
|
let mut inner = self.inner.lock();
|
||||||
|
assert!(inner.locked);
|
||||||
|
assert!(inner.join_handle.is_none());
|
||||||
|
inner.locked = false;
|
||||||
|
inner.join_handle = jh;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the result
|
||||||
|
pub async fn check(&self) -> Result<Option<T>, ()> {
|
||||||
|
let mut out: Option<T> = None;
|
||||||
|
|
||||||
|
// See if we have a result we can return
|
||||||
|
let maybe_jh = match self.try_lock() {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => {
|
||||||
|
// If we are already polling somewhere else, don't hand back a result
|
||||||
|
return Err(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if maybe_jh.is_some() {
|
||||||
|
let mut jh = maybe_jh.unwrap();
|
||||||
|
|
||||||
|
// See if we finished, if so, return the value of the last execution
|
||||||
|
if let Poll::Ready(r) = poll!(&mut jh) {
|
||||||
|
out = Some(r);
|
||||||
|
// Task finished, unlock with nothing
|
||||||
|
self.unlock(None);
|
||||||
|
} else {
|
||||||
|
// Still running put the join handle back so we can check on it later
|
||||||
|
self.unlock(Some(jh));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No task, unlock with nothing
|
||||||
|
self.unlock(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the prior result if we have one
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the result
|
||||||
|
pub async fn join(&self) -> Result<Option<T>, ()> {
|
||||||
|
let mut out: Option<T> = None;
|
||||||
|
|
||||||
|
// See if we have a result we can return
|
||||||
|
let maybe_jh = match self.try_lock() {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => {
|
||||||
|
// If we are already polling somewhere else,
|
||||||
|
// that's an error because you can only join
|
||||||
|
// these things once
|
||||||
|
return Err(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if maybe_jh.is_some() {
|
||||||
|
let jh = maybe_jh.unwrap();
|
||||||
|
// Wait for return value of the last execution
|
||||||
|
out = Some(jh.await);
|
||||||
|
// Task finished, unlock with nothing
|
||||||
|
} else {
|
||||||
|
// No task, unlock with nothing
|
||||||
|
}
|
||||||
|
self.unlock(None);
|
||||||
|
|
||||||
|
// Return the prior result if we have one
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Possibly spawn the future possibly returning the value of the last execution
|
||||||
|
cfg_if! {
|
||||||
|
if #[cfg(target_arch = "wasm32")] {
|
||||||
|
pub async fn single_spawn(
|
||||||
|
&self,
|
||||||
|
future: impl Future<Output = T> + 'static,
|
||||||
|
) -> Result<Option<T>, ()> {
|
||||||
|
let mut out: Option<T> = None;
|
||||||
|
|
||||||
|
// See if we have a result we can return
|
||||||
|
let maybe_jh = match self.try_lock() {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => {
|
||||||
|
// If we are already polling somewhere else, don't hand back a result
|
||||||
|
return Err(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let mut run = true;
|
||||||
|
|
||||||
|
if maybe_jh.is_some() {
|
||||||
|
let mut jh = maybe_jh.unwrap();
|
||||||
|
|
||||||
|
// See if we finished, if so, return the value of the last execution
|
||||||
|
if let Poll::Ready(r) = poll!(&mut jh) {
|
||||||
|
out = Some(r);
|
||||||
|
// Task finished, unlock with a new task
|
||||||
|
} else {
|
||||||
|
// Still running, don't run again, unlock with the current join handle
|
||||||
|
run = false;
|
||||||
|
self.unlock(Some(jh));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run if we should do that
|
||||||
|
if run {
|
||||||
|
self.unlock(Some(MustJoinHandle::new(spawn_local(future))));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the prior result if we have one
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cfg_if! {
|
||||||
|
if #[cfg(not(target_arch = "wasm32"))] {
|
||||||
|
impl<T> MustJoinSingleFuture<T>
|
||||||
|
where
|
||||||
|
T: 'static + Send,
|
||||||
|
{
|
||||||
|
pub async fn single_spawn(
|
||||||
|
&self,
|
||||||
|
future: impl Future<Output = T> + Send + 'static,
|
||||||
|
) -> Result<Option<T>, ()> {
|
||||||
|
let mut out: Option<T> = None;
|
||||||
|
// See if we have a result we can return
|
||||||
|
let maybe_jh = match self.try_lock() {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => {
|
||||||
|
// If we are already polling somewhere else, don't hand back a result
|
||||||
|
return Err(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let mut run = true;
|
||||||
|
if maybe_jh.is_some() {
|
||||||
|
let mut jh = maybe_jh.unwrap();
|
||||||
|
// See if we finished, if so, return the value of the last execution
|
||||||
|
if let Poll::Ready(r) = poll!(&mut jh) {
|
||||||
|
out = Some(r);
|
||||||
|
// Task finished, unlock with a new task
|
||||||
|
} else {
|
||||||
|
// Still running, don't run again, unlock with the current join handle
|
||||||
|
run = false;
|
||||||
|
self.unlock(Some(jh));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Run if we should do that
|
||||||
|
if run {
|
||||||
|
self.unlock(Some(MustJoinHandle::new(spawn(future))));
|
||||||
|
}
|
||||||
|
// Return the prior result if we have one
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -6,10 +6,10 @@ use once_cell::sync::OnceCell;
|
|||||||
cfg_if! {
|
cfg_if! {
|
||||||
if #[cfg(target_arch = "wasm32")] {
|
if #[cfg(target_arch = "wasm32")] {
|
||||||
type TickTaskRoutine =
|
type TickTaskRoutine =
|
||||||
dyn Fn(u64, u64) -> PinBoxFuture<Result<(), String>> + 'static;
|
dyn Fn(StopToken, u64, u64) -> PinBoxFuture<Result<(), String>> + 'static;
|
||||||
} else {
|
} else {
|
||||||
type TickTaskRoutine =
|
type TickTaskRoutine =
|
||||||
dyn Fn(u64, u64) -> SendPinBoxFuture<Result<(), String>> + Send + Sync + 'static;
|
dyn Fn(StopToken, u64, u64) -> SendPinBoxFuture<Result<(), String>> + Send + Sync + 'static;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -20,7 +20,8 @@ pub struct TickTask {
|
|||||||
last_timestamp_us: AtomicU64,
|
last_timestamp_us: AtomicU64,
|
||||||
tick_period_us: u64,
|
tick_period_us: u64,
|
||||||
routine: OnceCell<Box<TickTaskRoutine>>,
|
routine: OnceCell<Box<TickTaskRoutine>>,
|
||||||
single_future: SingleFuture<Result<(), String>>,
|
stop_source: AsyncMutex<Option<StopSource>>,
|
||||||
|
single_future: MustJoinSingleFuture<Result<(), String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TickTask {
|
impl TickTask {
|
||||||
@ -29,7 +30,8 @@ impl TickTask {
|
|||||||
last_timestamp_us: AtomicU64::new(0),
|
last_timestamp_us: AtomicU64::new(0),
|
||||||
tick_period_us,
|
tick_period_us,
|
||||||
routine: OnceCell::new(),
|
routine: OnceCell::new(),
|
||||||
single_future: SingleFuture::new(),
|
stop_source: AsyncMutex::new(None),
|
||||||
|
single_future: MustJoinSingleFuture::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub fn new_ms(tick_period_ms: u32) -> Self {
|
pub fn new_ms(tick_period_ms: u32) -> Self {
|
||||||
@ -37,7 +39,8 @@ impl TickTask {
|
|||||||
last_timestamp_us: AtomicU64::new(0),
|
last_timestamp_us: AtomicU64::new(0),
|
||||||
tick_period_us: (tick_period_ms as u64) * 1000u64,
|
tick_period_us: (tick_period_ms as u64) * 1000u64,
|
||||||
routine: OnceCell::new(),
|
routine: OnceCell::new(),
|
||||||
single_future: SingleFuture::new(),
|
stop_source: AsyncMutex::new(None),
|
||||||
|
single_future: MustJoinSingleFuture::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub fn new(tick_period_sec: u32) -> Self {
|
pub fn new(tick_period_sec: u32) -> Self {
|
||||||
@ -45,7 +48,8 @@ impl TickTask {
|
|||||||
last_timestamp_us: AtomicU64::new(0),
|
last_timestamp_us: AtomicU64::new(0),
|
||||||
tick_period_us: (tick_period_sec as u64) * 1000000u64,
|
tick_period_us: (tick_period_sec as u64) * 1000000u64,
|
||||||
routine: OnceCell::new(),
|
routine: OnceCell::new(),
|
||||||
single_future: SingleFuture::new(),
|
stop_source: AsyncMutex::new(None),
|
||||||
|
single_future: MustJoinSingleFuture::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -53,22 +57,31 @@ impl TickTask {
|
|||||||
if #[cfg(target_arch = "wasm32")] {
|
if #[cfg(target_arch = "wasm32")] {
|
||||||
pub fn set_routine(
|
pub fn set_routine(
|
||||||
&self,
|
&self,
|
||||||
routine: impl Fn(u64, u64) -> PinBoxFuture<Result<(), String>> + 'static,
|
routine: impl Fn(StopToken, u64, u64) -> PinBoxFuture<Result<(), String>> + 'static,
|
||||||
) {
|
) {
|
||||||
self.routine.set(Box::new(routine)).map_err(drop).unwrap();
|
self.routine.set(Box::new(routine)).map_err(drop).unwrap();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
pub fn set_routine(
|
pub fn set_routine(
|
||||||
&self,
|
&self,
|
||||||
routine: impl Fn(u64, u64) -> SendPinBoxFuture<Result<(), String>> + Send + Sync + 'static,
|
routine: impl Fn(StopToken, u64, u64) -> SendPinBoxFuture<Result<(), String>> + Send + Sync + 'static,
|
||||||
) {
|
) {
|
||||||
self.routine.set(Box::new(routine)).map_err(drop).unwrap();
|
self.routine.set(Box::new(routine)).map_err(drop).unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn cancel(&self) -> Result<(), String> {
|
pub async fn stop(&self) -> Result<(), String> {
|
||||||
match self.single_future.cancel().await {
|
// drop the stop source if we have one
|
||||||
|
let opt_stop_source = &mut *self.stop_source.lock().await;
|
||||||
|
if opt_stop_source.is_none() {
|
||||||
|
// already stopped, just return
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
*opt_stop_source = None;
|
||||||
|
|
||||||
|
// wait for completion of the tick task
|
||||||
|
match self.single_future.join().await {
|
||||||
Ok(Some(Err(err))) => Err(err),
|
Ok(Some(Err(err))) => Err(err),
|
||||||
_ => Ok(()),
|
_ => Ok(()),
|
||||||
}
|
}
|
||||||
@ -80,27 +93,35 @@ impl TickTask {
|
|||||||
|
|
||||||
if last_timestamp_us == 0u64 || (now - last_timestamp_us) >= self.tick_period_us {
|
if last_timestamp_us == 0u64 || (now - last_timestamp_us) >= self.tick_period_us {
|
||||||
// Run the singlefuture
|
// Run the singlefuture
|
||||||
|
let opt_stop_source = &mut *self.stop_source.lock().await;
|
||||||
|
let stop_source = StopSource::new();
|
||||||
match self
|
match self
|
||||||
.single_future
|
.single_future
|
||||||
.single_spawn(self.routine.get().unwrap()(last_timestamp_us, now))
|
.single_spawn(self.routine.get().unwrap()(
|
||||||
|
stop_source.token(),
|
||||||
|
last_timestamp_us,
|
||||||
|
now,
|
||||||
|
))
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(Some(Err(err))) => {
|
// Single future ran this tick
|
||||||
// If the last execution errored out then we should pass that error up
|
Ok(Some(ret)) => {
|
||||||
|
// Set new timer
|
||||||
self.last_timestamp_us.store(now, Ordering::Release);
|
self.last_timestamp_us.store(now, Ordering::Release);
|
||||||
return Err(err);
|
// Save new stopper
|
||||||
|
*opt_stop_source = Some(stop_source);
|
||||||
|
ret
|
||||||
}
|
}
|
||||||
|
// Single future did not run this tick
|
||||||
Ok(None) | Err(()) => {
|
Ok(None) | Err(()) => {
|
||||||
// If the execution didn't happen this time because it was already running
|
// If the execution didn't happen this time because it was already running
|
||||||
// then we should try again the next tick and not reset the timestamp so we try as soon as possible
|
// then we should try again the next tick and not reset the timestamp so we try as soon as possible
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
// Execution happened, next execution attempt should happen only after tick period
|
|
||||||
self.last_timestamp_us.store(now, Ordering::Release);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// It's not time yet
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user