diff --git a/CHANGELOG.md b/CHANGELOG.md index de92298f..97df1a33 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +- GUI + CLI + ASB: The Monero RPC pool now caches TCP and Tor streams + ## [3.0.0-beta.5] - 2025-08-04 - GUI + CLI: Fixed a potential race condition where if the user closed the app while the Bitcoin was in the process of being published, manual recovery would be required to get to a recoverable state. diff --git a/Cargo.lock b/Cargo.lock index 6e4de5d4..7637187f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -789,14 +789,14 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "axum" -version = "0.7.9" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" dependencies = [ - "async-trait", "axum-core", "axum-macros", "bytes", + "form_urlencoded", "futures-util", "http 1.3.1", "http-body 1.0.1", @@ -824,13 +824,12 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.4.5" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6" dependencies = [ - "async-trait", "bytes", - "futures-util", + "futures-core", "http 1.3.1", "http-body 1.0.1", "http-body-util", @@ -845,9 +844,9 @@ dependencies = [ [[package]] name = "axum-macros" -version = "0.4.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57d123550fa8d071b7255cb0cc04dc302baa6c8c4a79f55701552684d8399bce" +checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c" dependencies = [ "proc-macro2", "quote", @@ -2079,6 +2078,19 @@ dependencies = [ "itertools 0.10.5", ] +[[package]] +name = "crossbeam" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + [[package]] name = "crossbeam-channel" version = "0.5.15" @@ -5968,9 +5980,9 @@ checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5" [[package]] name = "matchit" -version = "0.7.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" [[package]] name = "md-5" @@ -6243,6 +6255,7 @@ dependencies = [ "axum", "chrono", "clap 4.5.42", + "crossbeam", "futures", "http-body-util", "hyper 1.6.0", @@ -6252,15 +6265,17 @@ dependencies = [ "native-tls", "rand 0.8.5", "regex", + "reqwest 0.11.27", "serde", "serde_json", "sqlx", + "swap-serde", "tokio", "tokio-native-tls", "tokio-test", "tor-rtcompat", - "tower 0.4.13", - "tower-http 0.5.2", + "tower 0.5.2", + "tower-http 0.6.6", "tracing", "tracing-subscriber", "typeshare", @@ -12465,22 +12480,6 @@ dependencies = [ "uuid", ] -[[package]] -name = "tower-http" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" -dependencies = [ - "bitflags 2.9.1", - "bytes", - "http 1.3.1", - "http-body 1.0.1", - "http-body-util", - "pin-project-lite", - "tower-layer", - "tower-service", -] - [[package]] name = "tower-http" version = "0.6.6" diff --git a/monero-rpc-pool/Cargo.toml b/monero-rpc-pool/Cargo.toml index 9900844b..f56f5c2d 100644 --- a/monero-rpc-pool/Cargo.toml +++ b/monero-rpc-pool/Cargo.toml @@ -8,36 +8,64 @@ edition = "2021" name = "monero-rpc-pool" path = "src/main.rs" +[[bin]] +name = "stress-test" +path = "src/bin/stress_test.rs" +required-features = ["stress-test"] + +[features] +stress-test = ["reqwest"] + [dependencies] +# Core utilities anyhow = { workspace = true } -axum = { version = "0.7", features = ["macros"] } -chrono = { version = "0.4", features = ["serde"] } -clap = { version = "4.0", features = ["derive"] } futures = { workspace = true } -monero = { workspace = true } -monero-rpc = { path = "../monero-rpc" } rand = { workspace = true } regex = "1.0" -serde = { workspace = true } -serde_json = { workspace = true } -sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "sqlite", "chrono", "migrate"] } -tokio = { workspace = true, features = ["full"] } -tower = "0.4" -tower-http = { version = "0.5", features = ["cors"] } -tracing = { workspace = true } -tracing-subscriber = { workspace = true } -typeshare = { workspace = true } url = "2.0" uuid = { workspace = true } -arti-client = { workspace = true, features = ["tokio"] } -tor-rtcompat = { workspace = true, features = ["tokio", "rustls"] } +# CLI and logging +clap = { version = "4.0", features = ["derive"] } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } +# Async runtime +crossbeam = "0.8.4" +tokio = { workspace = true, features = ["full"] } + +# Serialization +chrono = { version = "0.4", features = ["serde"] } +serde = { workspace = true } +serde_json = { workspace = true } +typeshare = { workspace = true } + +# Database +sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "sqlite", "chrono", "migrate"] } + +# Web framework and HTTP +axum = { version = "0.8.4", features = ["macros"] } http-body-util = "0.1" hyper = { version = "1", features = ["full"] } hyper-util = { version = "0.1", features = ["full"] } +tower = "0.5.2" +tower-http = { version = "0.6.6", features = ["cors"] } + +# TLS/Security native-tls = "0.2" tokio-native-tls = "0.3" +# Tor networking +arti-client = { workspace = true, features = ["tokio"] } +tor-rtcompat = { workspace = true, features = ["tokio", "rustls"] } + +# Monero/Project specific +monero = { workspace = true } +monero-rpc = { path = "../monero-rpc" } +swap-serde = { path = "../swap-serde" } + +# Optional dependencies (for features) +reqwest = { version = "0.11", features = ["json"], optional = true } + [dev-dependencies] tokio-test = "0.4" diff --git a/monero-rpc-pool/src/bin/stress_test.rs b/monero-rpc-pool/src/bin/stress_test.rs new file mode 100644 index 00000000..69b4d265 --- /dev/null +++ b/monero-rpc-pool/src/bin/stress_test.rs @@ -0,0 +1,209 @@ +use arti_client::{TorClient, TorClientConfig}; +use clap::Parser; +use monero::Network; +use monero_rpc_pool::{config::Config, create_app_with_receiver, database::parse_network}; +use reqwest; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::time::sleep; +use tor_rtcompat::tokio::TokioRustlsRuntime; + +#[derive(Parser)] +#[command(name = "stress-test")] +#[command(about = "Stress test the Monero RPC Pool")] +#[command(version)] +struct Args { + #[arg(short, long, default_value = "60")] + #[arg(help = "Duration to run the test in seconds")] + duration: u64, + + #[arg(short, long, default_value = "10")] + #[arg(help = "Number of concurrent requests")] + concurrency: usize, + + #[arg(short, long, default_value = "mainnet")] + #[arg(help = "Network to use (mainnet, stagenet, testnet)")] + #[arg(value_parser = parse_network)] + network: Network, + + #[arg(long)] + #[arg(help = "Enable Tor routing")] + tor: bool, + + #[arg(short, long)] + #[arg(help = "Enable verbose logging")] + verbose: bool, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let args = Args::parse(); + + if args.verbose { + tracing_subscriber::fmt() + .with_env_filter("debug") + .with_target(false) + .init(); + } + + println!("Stress Testing Monero RPC Pool"); + println!(" Duration: {}s", args.duration); + println!(" Concurrency: {}", args.concurrency); + println!(" Network: {}", args.network); + println!(" Tor: {}", args.tor); + println!(); + + // Setup Tor client if requested + let tor_client = if args.tor { + println!("Setting up Tor client..."); + let config = TorClientConfig::default(); + let runtime = TokioRustlsRuntime::current().expect("We are always running with tokio"); + + let client = TorClient::with_runtime(runtime) + .config(config) + .create_unbootstrapped_async() + .await?; + + let client = std::sync::Arc::new(client); + + let client_clone = client.clone(); + client_clone + .bootstrap() + .await + .expect("Failed to bootstrap Tor client"); + + Some(client) + } else { + None + }; + + // Start the pool server + println!("Starting RPC pool server..."); + let config = + Config::new_random_port_with_tor_client(std::env::temp_dir(), tor_client, args.network); + let (app, _status_receiver, _background_handle) = create_app_with_receiver(config).await?; + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + let pool_url = format!("http://{}", addr); + + // Start the server in the background + tokio::spawn(async move { + if let Err(e) = axum::serve(listener, app).await { + eprintln!("Server error: {}", e); + } + }); + + // Give the server a moment to start + sleep(Duration::from_millis(500)).await; + + let client = reqwest::Client::new(); + + let start_time = Instant::now(); + let test_duration = Duration::from_secs(args.duration); + + // Use atomic counters shared between all workers + let success_count = Arc::new(AtomicU64::new(0)); + let error_count = Arc::new(AtomicU64::new(0)); + let total_response_time_nanos = Arc::new(AtomicU64::new(0)); + let should_stop = Arc::new(AtomicBool::new(false)); + + println!("Running for {} seconds...", args.duration); + + // Spawn workers that continuously make requests + let mut tasks = Vec::new(); + for _ in 0..args.concurrency { + let client = client.clone(); + let url = format!("{}/get_info", pool_url); + let success_count = success_count.clone(); + let error_count = error_count.clone(); + let total_response_time_nanos = total_response_time_nanos.clone(); + let should_stop = should_stop.clone(); + + tasks.push(tokio::spawn(async move { + while !should_stop.load(Ordering::Relaxed) { + let request_start = Instant::now(); + + match client.get(&url).send().await { + Ok(response) => { + if response.status().is_success() { + success_count.fetch_add(1, Ordering::Relaxed); + } else { + error_count.fetch_add(1, Ordering::Relaxed); + } + let elapsed_nanos = request_start.elapsed().as_nanos() as u64; + total_response_time_nanos.fetch_add(elapsed_nanos, Ordering::Relaxed); + } + Err(_) => { + error_count.fetch_add(1, Ordering::Relaxed); + } + } + + // Small delay to prevent overwhelming the server + sleep(Duration::from_millis(10)).await; + } + })); + } + + // Show progress while workers run and signal stop when duration is reached + let should_stop_clone = should_stop.clone(); + let progress_task = tokio::spawn(async move { + while start_time.elapsed() < test_duration { + let elapsed = start_time.elapsed().as_secs(); + let remaining = args.duration.saturating_sub(elapsed); + print!("\rRunning... {}s remaining", remaining); + std::io::Write::flush(&mut std::io::stdout()).unwrap(); + sleep(Duration::from_secs(1)).await; + } + // Signal all workers to stop + should_stop_clone.store(true, Ordering::Relaxed); + }); + + // Wait for the test duration to complete + let _ = progress_task.await; + + // Wait a moment for workers to see the stop signal and finish their current requests + sleep(Duration::from_millis(100)).await; + + // Cancel any remaining worker tasks + for task in &tasks { + task.abort(); + } + + // Wait for tasks to finish + for task in tasks { + let _ = task.await; + } + + // Final results + println!("\r "); // Clear progress line + println!(); + + let final_success_count = success_count.load(Ordering::Relaxed); + let final_error_count = error_count.load(Ordering::Relaxed); + let final_total_response_time_nanos = total_response_time_nanos.load(Ordering::Relaxed); + + println!("Stress Test Results:"); + println!(" Total successful requests: {}", final_success_count); + println!(" Total failed requests: {}", final_error_count); + println!( + " Total requests: {}", + final_success_count + final_error_count + ); + + let total_requests = final_success_count + final_error_count; + if total_requests > 0 { + let success_rate = (final_success_count as f64 / total_requests as f64) * 100.0; + println!(" Success rate: {:.2}%", success_rate); + + let avg_response_time_nanos = final_total_response_time_nanos / total_requests; + let avg_response_time = Duration::from_nanos(avg_response_time_nanos); + println!(" Average response time: {:?}", avg_response_time); + + let requests_per_second = total_requests as f64 / args.duration as f64; + println!(" Requests per second: {:.2}", requests_per_second); + } + + Ok(()) +} diff --git a/monero-rpc-pool/src/config.rs b/monero-rpc-pool/src/config.rs index e1375162..2580bc3c 100644 --- a/monero-rpc-pool/src/config.rs +++ b/monero-rpc-pool/src/config.rs @@ -1,3 +1,4 @@ +use monero::Network; use std::path::PathBuf; use crate::TorClientArc; @@ -8,6 +9,7 @@ pub struct Config { pub port: u16, pub data_dir: PathBuf, pub tor_client: Option, + pub network: Network, } impl std::fmt::Debug for Config { @@ -17,13 +19,14 @@ impl std::fmt::Debug for Config { .field("port", &self.port) .field("data_dir", &self.data_dir) .field("tor_client", &self.tor_client.is_some()) + .field("network", &self.network) .finish() } } impl Config { - pub fn new_with_port(host: String, port: u16, data_dir: PathBuf) -> Self { - Self::new_with_port_and_tor_client(host, port, data_dir, None) + pub fn new_with_port(host: String, port: u16, data_dir: PathBuf, network: Network) -> Self { + Self::new_with_port_and_tor_client(host, port, data_dir, None, network) } pub fn new_with_port_and_tor_client( @@ -31,23 +34,32 @@ impl Config { port: u16, data_dir: PathBuf, tor_client: impl Into>, + network: Network, ) -> Self { Self { host, port, data_dir, tor_client: tor_client.into(), + network, } } - pub fn new_random_port(data_dir: PathBuf) -> Self { - Self::new_random_port_with_tor_client(data_dir, None) + pub fn new_random_port(data_dir: PathBuf, network: Network) -> Self { + Self::new_random_port_with_tor_client(data_dir, None, network) } pub fn new_random_port_with_tor_client( data_dir: PathBuf, tor_client: impl Into>, + network: Network, ) -> Self { - Self::new_with_port_and_tor_client("127.0.0.1".to_string(), 0, data_dir, tor_client) + Self::new_with_port_and_tor_client( + "127.0.0.1".to_string(), + 0, + data_dir, + tor_client, + network, + ) } } diff --git a/monero-rpc-pool/src/connection_pool.rs b/monero-rpc-pool/src/connection_pool.rs new file mode 100644 index 00000000..0be53d70 --- /dev/null +++ b/monero-rpc-pool/src/connection_pool.rs @@ -0,0 +1,178 @@ +//! Very small HTTP/1 connection pool for both clearnet (TCP) and Tor streams. +//! +//! After investigation we learned that pooling **raw** sockets is not useful +//! because once Hyper finishes a `Connection` the socket is closed. The correct +//! thing to cache is the HTTP client pair returned by +//! `hyper::client::conn::http1::handshake` – specifically the +//! `SendRequest` handle. +//! +//! A `SendRequest` can serve multiple sequential requests as long as the +//! `Connection` future that Hyper gives us keeps running in the background. +//! Therefore `ConnectionPool` stores those senders and a separate background +//! task drives the corresponding `Connection` until the peer closes it. When +//! that happens any future `send_request` will error and we will drop that entry +//! from the pool automatically. +//! +//! The internal data-structure: +//! +//! ```text +//! Arc>>>>>>> +//! ``` +//! +//! Locking strategy +//! ---------------- +//! * **Outer `RwLock`** – protects the HashMap (rare contention). +//! * **Per-host `RwLock`** – protects the Vec for that host. +//! * **`Mutex` around each `SendRequest`** – guarantees only one request at a +//! time per connection. +//! +//! The `GuardedSender` returned by `ConnectionPool::get()` derefs to +//! `SendRequest`. Once the guard is dropped the mutex unlocks and the +//! connection is again available. + +use std::collections::HashMap; +use std::sync::Arc; + +use axum::body::Body; +use tokio::sync::{Mutex, OwnedMutexGuard, RwLock}; + +/// Key for the map – `(scheme, host, port, via_tor)`. +pub type StreamKey = (String, String, i64, bool); + +/// Alias for hyper's HTTP/1 sender. +pub type HttpSender = hyper::client::conn::http1::SendRequest; + +/// Connection pool. +#[derive(Clone, Default)] +pub struct ConnectionPool { + inner: Arc>>>>>>>, +} + +/// Guard returned by `get()`. Derefs to the underlying `SendRequest` so callers +/// can invoke `send_request()` directly. +pub struct GuardedSender { + guard: OwnedMutexGuard, + pool: ConnectionPool, + key: StreamKey, + sender_arc: Arc>, +} + +impl std::ops::Deref for GuardedSender { + type Target = HttpSender; + fn deref(&self) -> &Self::Target { + &self.guard + } +} +impl std::ops::DerefMut for GuardedSender { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.guard + } +} + +impl GuardedSender { + /// Mark this sender as failed and remove it from the pool. + pub async fn mark_failed(self) { + // Dropping the guard releases the mutex, then we remove from pool + drop(self.guard); + self.pool.remove_sender(&self.key, &self.sender_arc).await; + } +} + +impl ConnectionPool { + pub fn new() -> Self { + Self { + inner: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Try to fetch an idle connection. Returns `None` if all are busy or the + /// host has no pool yet. + pub async fn try_get(&self, key: &StreamKey) -> Option { + let map = self.inner.read().await; + let vec_lock = map.get(key)?.clone(); + drop(map); + + let vec = vec_lock.write().await; + let total_connections = vec.len(); + let mut busy_connections = 0; + + for sender_mutex in vec.iter() { + if let Ok(guard) = sender_mutex.clone().try_lock_owned() { + tracing::debug!( + "Reusing connection for {}://{}:{} (via_tor={}). Pool stats: {}/{} connections available", + key.0, key.1, key.2, key.3, total_connections - busy_connections, total_connections + ); + return Some(GuardedSender { + guard, + pool: self.clone(), + key: key.clone(), + sender_arc: sender_mutex.clone(), + }); + } else { + busy_connections += 1; + } + } + + tracing::debug!( + "No idle connections for {}://{}:{} (via_tor={}). Pool stats: 0/{} connections available", + key.0, key.1, key.2, key.3, total_connections + ); + None + } + + /// Insert `sender` into the pool and return an *exclusive* handle ready to + /// send the first request. + pub async fn insert_and_lock(&self, key: StreamKey, sender: HttpSender) -> GuardedSender { + let sender_mutex = Arc::new(Mutex::new(sender)); + let key_clone = key.clone(); + let sender_mutex_clone = sender_mutex.clone(); + + { + let mut map = self.inner.write().await; + let vec_lock = map + .entry(key) + .or_insert_with(|| Arc::new(RwLock::new(Vec::new()))) + .clone(); + let mut vec = vec_lock.write().await; + vec.push(sender_mutex.clone()); + } + + let guard = sender_mutex.lock_owned().await; + + // Log the new connection count after insertion + let map_read = self.inner.read().await; + if let Some(vec_lock) = map_read.get(&key_clone) { + let vec = vec_lock.read().await; + tracing::debug!( + "Created new connection for {}://{}:{} (via_tor={}). Pool stats: 1/{} connections available", + key_clone.0, key_clone.1, key_clone.2, key_clone.3, vec.len() + ); + } + drop(map_read); + + GuardedSender { + guard, + pool: self.clone(), + key: key_clone, + sender_arc: sender_mutex_clone, + } + } + + /// Remove a specific sender from the pool (used when connection fails). + pub async fn remove_sender(&self, key: &StreamKey, sender_arc: &Arc>) { + if let Some(vec_lock) = self.inner.read().await.get(key).cloned() { + let mut vec = vec_lock.write().await; + let old_count = vec.len(); + vec.retain(|arc_mutex| !Arc::ptr_eq(arc_mutex, sender_arc)); + let new_count = vec.len(); + + if old_count != new_count { + tracing::debug!( + "Removed failed connection for {}://{}:{} (via_tor={}). Pool stats: {}/{} connections remaining", + key.0, key.1, key.2, key.3, new_count, new_count + ); + } + } + } +} diff --git a/monero-rpc-pool/src/database.rs b/monero-rpc-pool/src/database.rs index 81a54c56..0d08260a 100644 --- a/monero-rpc-pool/src/database.rs +++ b/monero-rpc-pool/src/database.rs @@ -2,9 +2,32 @@ use std::path::PathBuf; use crate::types::{NodeAddress, NodeHealthStats, NodeMetadata, NodeRecord}; use anyhow::Result; +use monero::Network; use sqlx::SqlitePool; use tracing::{info, warn}; +/// Convert a string to a Network enum +pub fn parse_network(s: &str) -> Result { + match s.to_lowercase().as_str() { + "mainnet" => Ok(Network::Mainnet), + "stagenet" => Ok(Network::Stagenet), + "testnet" => Ok(Network::Testnet), + _ => anyhow::bail!( + "Invalid network: {}. Must be mainnet, stagenet, or testnet", + s + ), + } +} + +/// Convert a Network enum to a string for database storage +pub fn network_to_string(network: &Network) -> &'static str { + match network { + Network::Mainnet => "mainnet", + Network::Stagenet => "stagenet", + Network::Testnet => "testnet", + } +} + #[derive(Clone)] pub struct Database { pub pool: SqlitePool, @@ -134,7 +157,8 @@ impl Database { .parse() .unwrap_or_else(|_| chrono::Utc::now()); - let metadata = NodeMetadata::new(row.id, row.network, first_seen_at); + let network = parse_network(&row.network).unwrap_or(Network::Mainnet); + let metadata = NodeMetadata::new(row.id, network, first_seen_at); let health = NodeHealthStats { success_count: row.success_count, failure_count: row.failure_count, diff --git a/monero-rpc-pool/src/lib.rs b/monero-rpc-pool/src/lib.rs index 82e15649..1dc39e54 100644 --- a/monero-rpc-pool/src/lib.rs +++ b/monero-rpc-pool/src/lib.rs @@ -6,7 +6,6 @@ use axum::{ routing::{any, get}, Router, }; -use monero::Network; use tokio::task::JoinHandle; use tor_rtcompat::tokio::TokioRustlsRuntime; @@ -16,21 +15,8 @@ use tracing::{error, info}; /// Type alias for the Tor client used throughout the crate pub type TorClientArc = Arc>; -pub trait ToNetworkString { - fn to_network_string(&self) -> String; -} - -impl ToNetworkString for Network { - fn to_network_string(&self) -> String { - match self { - Network::Mainnet => "mainnet".to_string(), - Network::Stagenet => "stagenet".to_string(), - Network::Testnet => "testnet".to_string(), - } - } -} - pub mod config; +pub mod connection_pool; pub mod database; pub mod pool; pub mod proxy; @@ -45,6 +31,7 @@ use proxy::{proxy_handler, stats_handler}; pub struct AppState { pub node_pool: Arc, pub tor_client: Option, + pub connection_pool: crate::connection_pool::ConnectionPool, } /// Manages background tasks for the RPC pool @@ -71,9 +58,8 @@ impl Into for ServerInfo { } } -async fn create_app_with_receiver( +pub async fn create_app_with_receiver( config: Config, - network: Network, ) -> Result<( Router, tokio::sync::broadcast::Receiver, @@ -82,9 +68,8 @@ async fn create_app_with_receiver( // Initialize database let db = Database::new(config.data_dir.clone()).await?; - // Initialize node pool with network - let network_str = network.to_network_string(); - let (node_pool, status_receiver) = NodePool::new(db.clone(), network_str.clone()); + // Initialize node pool with network from config + let (node_pool, status_receiver) = NodePool::new(db.clone(), config.network.clone()); let node_pool = Arc::new(node_pool); // Publish initial status immediately to ensure first event is sent @@ -93,7 +78,7 @@ async fn create_app_with_receiver( } // Send status updates every 10 seconds - let mut interval = tokio::time::interval(std::time::Duration::from_secs(10)); + let mut interval = tokio::time::interval(std::time::Duration::from_secs(2)); let node_pool_for_health_check = node_pool.clone(); let status_update_handle = tokio::spawn(async move { loop { @@ -112,6 +97,7 @@ async fn create_app_with_receiver( let app_state = AppState { node_pool, tor_client: config.tor_client, + connection_pool: crate::connection_pool::ConnectionPool::new(), }; // Build the app @@ -124,25 +110,13 @@ async fn create_app_with_receiver( Ok((app, status_receiver, pool_handle)) } -pub async fn create_app(config: Config, network: Network) -> Result { - let (app, _, _pool_handle) = create_app_with_receiver(config, network).await?; - // Note: pool_handle is dropped here, so tasks will be aborted when this function returns - // This is intentional for the simple create_app use case +pub async fn create_app(config: Config) -> Result { + let (app, _, _pool_handle) = create_app_with_receiver(config).await?; Ok(app) } -/// Create an app with a custom data directory for the database -pub async fn create_app_with_data_dir( - config: Config, - network: Network, - data_dir: std::path::PathBuf, -) -> Result { - let config_with_data_dir = Config::new_with_port(config.host, config.port, data_dir); - create_app(config_with_data_dir, network).await -} - -pub async fn run_server(config: Config, network: Network) -> Result<()> { - let app = create_app(config.clone(), network).await?; +pub async fn run_server(config: Config) -> Result<()> { + let app = create_app(config.clone()).await?; let bind_address = format!("{}:{}", config.host, config.port); info!("Starting server on {}", bind_address); @@ -155,27 +129,23 @@ pub async fn run_server(config: Config, network: Network) -> Result<()> { } /// Run a server with a custom data directory -pub async fn run_server_with_data_dir( - config: Config, - network: Network, - data_dir: std::path::PathBuf, -) -> Result<()> { - let config_with_data_dir = Config::new_with_port(config.host, config.port, data_dir); - run_server(config_with_data_dir, network).await +pub async fn run_server_with_data_dir(config: Config, data_dir: std::path::PathBuf) -> Result<()> { + let config_with_data_dir = + Config::new_with_port(config.host, config.port, data_dir, config.network); + run_server(config_with_data_dir).await } /// Start a server with a random port for library usage /// Returns the server info with the actual port used, a receiver for pool status updates, and pool handle pub async fn start_server_with_random_port( config: Config, - network: Network, ) -> Result<( ServerInfo, tokio::sync::broadcast::Receiver, PoolHandle, )> { let host = config.host.clone(); - let (app, status_receiver, pool_handle) = create_app_with_receiver(config, network).await?; + let (app, status_receiver, pool_handle) = create_app_with_receiver(config).await?; // Bind to port 0 to get a random available port let listener = tokio::net::TcpListener::bind(format!("{}:0", host)).await?; diff --git a/monero-rpc-pool/src/main.rs b/monero-rpc-pool/src/main.rs index 1bf0e366..1e3befdf 100644 --- a/monero-rpc-pool/src/main.rs +++ b/monero-rpc-pool/src/main.rs @@ -1,32 +1,11 @@ use arti_client::{TorClient, TorClientConfig}; use clap::Parser; -use monero_rpc_pool::{config::Config, run_server}; +use monero_rpc_pool::{config::Config, database::parse_network, run_server}; use tracing::info; use tracing_subscriber::{self, EnvFilter}; use monero::Network; -fn parse_network(s: &str) -> Result { - match s.to_lowercase().as_str() { - "mainnet" => Ok(Network::Mainnet), - "stagenet" => Ok(Network::Stagenet), - "testnet" => Ok(Network::Testnet), - _ => Err(format!( - "Invalid network: {}. Must be mainnet, stagenet, or testnet", - s - )), - } -} - -// TODO: Replace with Display impl for Network -fn network_to_string(network: &Network) -> String { - match network { - Network::Mainnet => "mainnet".to_string(), - Network::Stagenet => "stagenet".to_string(), - Network::Testnet => "testnet".to_string(), - } -} - #[derive(Parser)] #[command(name = "monero-rpc-pool")] #[command(about = "A load-balancing HTTP proxy for Monero RPC nodes")] @@ -100,16 +79,17 @@ async fn main() -> Result<(), Box> { args.port, std::env::temp_dir().join("monero-rpc-pool"), tor_client, + args.network, ); info!( host = config.host, port = config.port, - network = network_to_string(&args.network), + network = ?args.network, "Starting Monero RPC Pool" ); - if let Err(e) = run_server(config, args.network).await { + if let Err(e) = run_server(config).await { eprintln!("Server error: {}", e); std::process::exit(1); } diff --git a/monero-rpc-pool/src/pool.rs b/monero-rpc-pool/src/pool.rs index c033ab38..67fd25b1 100644 --- a/monero-rpc-pool/src/pool.rs +++ b/monero-rpc-pool/src/pool.rs @@ -1,12 +1,12 @@ use anyhow::{Context, Result}; -use std::collections::VecDeque; -use std::sync::{Arc, Mutex}; +use crossbeam::deque::{Injector, Steal}; +use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::broadcast; use tracing::warn; use typeshare::typeshare; -use crate::database::Database; +use crate::database::{network_to_string, Database}; use crate::types::NodeAddress; #[derive(Debug, Clone, serde::Serialize)] @@ -30,16 +30,15 @@ pub struct ReliableNodeInfo { pub avg_latency_ms: Option, } -#[derive(Debug)] +#[derive(Debug, Clone)] struct BandwidthEntry { timestamp: Instant, bytes: u64, } #[derive(Debug)] -struct BandwidthTracker { - entries: VecDeque, - window_duration: Duration, +pub struct BandwidthTracker { + entries: Injector, } impl BandwidthTracker { @@ -47,38 +46,50 @@ impl BandwidthTracker { fn new() -> Self { Self { - entries: VecDeque::new(), - window_duration: Self::WINDOW_DURATION, + entries: Injector::new(), } } - fn record_bytes(&mut self, bytes: u64) { + pub fn record_bytes(&self, bytes: u64) { let now = Instant::now(); - self.entries.push_back(BandwidthEntry { + self.entries.push(BandwidthEntry { timestamp: now, bytes, }); - - // Clean up old entries - let cutoff = now - self.window_duration; - while let Some(front) = self.entries.front() { - if front.timestamp < cutoff { - self.entries.pop_front(); - } else { - break; - } - } } fn get_kb_per_sec(&self) -> f64 { - if self.entries.len() < 5 { + let now = Instant::now(); + let cutoff = now - Self::WINDOW_DURATION; + + // Collect valid entries from the injector + let mut valid_entries = Vec::new(); + let mut total_bytes = 0u64; + + // Drain all entries, keeping only recent ones + loop { + match self.entries.steal() { + Steal::Success(entry) => { + if entry.timestamp >= cutoff { + total_bytes += entry.bytes; + valid_entries.push(entry); + } + } + Steal::Empty | Steal::Retry => break, + } + } + + // Put back the valid entries + for entry in valid_entries.iter() { + self.entries.push(entry.clone()); + } + + if valid_entries.len() < 5 { return 0.0; } - let total_bytes: u64 = self.entries.iter().map(|e| e.bytes).sum(); - let now = Instant::now(); - let oldest_time = self.entries.front().unwrap().timestamp; - let duration_secs = (now - oldest_time).as_secs_f64(); + let oldest_time = valid_entries.iter().map(|e| e.timestamp).min().unwrap(); + let duration_secs = now.duration_since(oldest_time).as_secs_f64(); if duration_secs > 0.0 { (total_bytes as f64 / 1024.0) / duration_secs @@ -90,19 +101,19 @@ impl BandwidthTracker { pub struct NodePool { db: Database, - network: String, + network: monero::Network, status_sender: broadcast::Sender, - bandwidth_tracker: Arc>, + bandwidth_tracker: Arc, } impl NodePool { - pub fn new(db: Database, network: String) -> (Self, broadcast::Receiver) { + pub fn new(db: Database, network: monero::Network) -> (Self, broadcast::Receiver) { let (status_sender, status_receiver) = broadcast::channel(100); let pool = Self { db, network, status_sender, - bandwidth_tracker: Arc::new(Mutex::new(BandwidthTracker::new())), + bandwidth_tracker: Arc::new(BandwidthTracker::new()), }; (pool, status_receiver) } @@ -128,9 +139,11 @@ impl NodePool { } pub fn record_bandwidth(&self, bytes: u64) { - if let Ok(mut tracker) = self.bandwidth_tracker.lock() { - tracker.record_bytes(bytes); - } + self.bandwidth_tracker.record_bytes(bytes); + } + + pub fn get_bandwidth_tracker(&self) -> Arc { + self.bandwidth_tracker.clone() } pub async fn publish_status_update(&self) -> Result<()> { @@ -138,24 +151,19 @@ impl NodePool { if let Err(e) = self.status_sender.send(status.clone()) { warn!("Failed to send status update: {}", e); - } else { - tracing::debug!(?status, "Sent status update"); } Ok(()) } pub async fn get_current_status(&self) -> Result { - let (total, reachable, _reliable) = self.db.get_node_stats(&self.network).await?; - let reliable_nodes = self.db.get_reliable_nodes(&self.network).await?; + let network_str = network_to_string(&self.network); + let (total, reachable, _reliable) = self.db.get_node_stats(network_str).await?; + let reliable_nodes = self.db.get_reliable_nodes(network_str).await?; let (successful_checks, unsuccessful_checks) = - self.db.get_health_check_stats(&self.network).await?; + self.db.get_health_check_stats(network_str).await?; - let bandwidth_kb_per_sec = if let Ok(tracker) = self.bandwidth_tracker.lock() { - tracker.get_kb_per_sec() - } else { - 0.0 - }; + let bandwidth_kb_per_sec = self.bandwidth_tracker.get_kb_per_sec(); let top_reliable_nodes = reliable_nodes .into_iter() @@ -184,13 +192,13 @@ impl NodePool { tracing::debug!( "Getting top reliable nodes for network {} (target: {})", - self.network, + network_to_string(&self.network), limit ); let available_nodes = self .db - .get_top_nodes_by_recent_success(&self.network, limit as i64) + .get_top_nodes_by_recent_success(network_to_string(&self.network), limit as i64) .await .context("Failed to get top nodes by recent success")?; @@ -230,7 +238,7 @@ impl NodePool { tracing::debug!( "Pool size: {} nodes for network {} (target: {})", selected_nodes.len(), - self.network, + network_to_string(&self.network), limit ); diff --git a/monero-rpc-pool/src/proxy.rs b/monero-rpc-pool/src/proxy.rs index 7e29f483..45a35b4b 100644 --- a/monero-rpc-pool/src/proxy.rs +++ b/monero-rpc-pool/src/proxy.rs @@ -4,8 +4,11 @@ use axum::{ http::{request::Parts, response, StatusCode}, response::Response, }; +use futures::{stream::Stream, StreamExt}; use http_body_util::BodyExt; use hyper_util::rt::TokioIo; +use std::pin::Pin; +use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; use tokio_native_tls::native_tls::TlsConnector; @@ -99,7 +102,7 @@ async fn proxy_to_multiple_nodes( // Start timing the request let latency = std::time::Instant::now(); - let response = match proxy_to_single_node(request.clone(), &node, state.tor_client.clone()) + let response = match proxy_to_single_node(state, request.clone(), &node) .instrument(info_span!( "connection", node = node_uri, @@ -117,33 +120,36 @@ async fn proxy_to_multiple_nodes( // Calculate the latency let latency = latency.elapsed().as_millis() as f64; - // Convert response to cloneable to avoid consumption issues - let cloneable_response = CloneableResponse::from_response(response) - .await - .map_err(|e| { - HandlerError::CloneRequestError(format!("Failed to buffer response: {}", e)) - })?; + // Convert response to streamable to check first 1KB for errors + let streamable_response = StreamableResponse::from_response_with_tracking( + response, + Some(state.node_pool.clone()), + ) + .await + .map_err(|e| { + HandlerError::CloneRequestError(format!("Failed to buffer response: {}", e)) + })?; - let error = match cloneable_response.get_jsonrpc_error() { + let error = match streamable_response.get_jsonrpc_error() { Some(error) => { // Check if we have already got two previous JSON-RPC errors // If we did, we assume there is a reason for it - // We return the response as is. + // We return the response as is (streaming). if collected_errors .iter() .filter(|(_, error)| matches!(error, HandlerError::JsonRpcError(_))) .count() >= 2 { - return Ok(cloneable_response.into_response()); + return Ok(streamable_response.into_response()); } Some(HandlerError::JsonRpcError(error)) } - None if cloneable_response.status().is_client_error() - || cloneable_response.status().is_server_error() => + None if streamable_response.status().is_client_error() + || streamable_response.status().is_server_error() => { - Some(HandlerError::HttpError(cloneable_response.status())) + Some(HandlerError::HttpError(streamable_response.status())) } _ => None, }; @@ -153,16 +159,11 @@ async fn proxy_to_multiple_nodes( push_error(&mut collected_errors, node, error); } None => { - let response_size_bytes = cloneable_response.body.len() as u64; - tracing::debug!( - "Proxy request to {} succeeded with size {}kb", - node_uri, - (response_size_bytes as f64 / 1024.0) + tracing::trace!( + "Proxy request to {} succeeded, streaming response", + node_uri ); - // Record bandwidth usage - state.node_pool.record_bandwidth(response_size_bytes); - // Only record errors if we have gotten a successful response // This helps prevent logging errors if its our likely our fault (no internet) for (node, _) in collected_errors.iter() { @@ -172,8 +173,8 @@ async fn proxy_to_multiple_nodes( // Record the success with actual latency record_success(&state, &node.0, &node.1, node.2, latency).await; - // Finally return the successful response - return Ok(cloneable_response.into_response()); + // Finally return the successful streaming response + return Ok(streamable_response.into_response()); } } } @@ -213,94 +214,93 @@ async fn maybe_wrap_with_tls( /// Important: Does NOT error if the response is a HTTP error or a JSON-RPC error /// The caller is responsible for checking the response status and body for errors async fn proxy_to_single_node( + state: &crate::AppState, request: CloneableRequest, node: &(String, String, i64), - tor_client: Option, ) -> Result { + use crate::connection_pool::GuardedSender; + if request.clearnet_whitelisted() { - tracing::debug!("Request is whitelisted, sending over clearnet"); + tracing::trace!("Request is whitelisted, sending over clearnet"); } - let response = match tor_client { - // If Tor client is ready for traffic, use it - Some(tor_client) - if tor_client.bootstrap_status().ready_for_traffic() - // If the request is whitelisted, we don't want to use Tor - && !request.clearnet_whitelisted() => + let use_tor = match &state.tor_client { + Some(tc) + if tc.bootstrap_status().ready_for_traffic() && !request.clearnet_whitelisted() => { + true + } + _ => false, + }; + + let key = (node.0.clone(), node.1.clone(), node.2, use_tor); + + // Try to reuse an idle HTTP connection first. + let mut guarded_sender: Option = state.connection_pool.try_get(&key).await; + + if guarded_sender.is_none() { + // Need to build a new TCP/Tor stream. + let boxed_stream = if use_tor { + let tor_client = state.tor_client.as_ref().ok_or_else(|| { + SingleRequestError::ConnectionError("Tor requested but client missing".into()) + })?; let stream = tor_client .connect(format!("{}:{}", node.1, node.2)) .await .map_err(|e| SingleRequestError::ConnectionError(e.to_string()))?; - - // Wrap with TLS if using HTTPS - let stream = maybe_wrap_with_tls(stream, &node.0, &node.1).await?; - - let (mut sender, conn) = hyper::client::conn::http1::handshake(TokioIo::new(stream)) - .await - .map_err(|e| SingleRequestError::ConnectionError(e.to_string()))?; - - tracing::debug!( - "Connected to node via Tor{}", - if node.0 == "https" { " with TLS" } else { "" } - ); - - tokio::task::spawn(async move { - if let Err(err) = conn.await { - println!("Connection failed: {:?}", err); - } - }); - - // Forward the request to the node - // No need to rewrite the URI because the request.uri() is relative - sender - .send_request(request.to_request()) - .await - .map_err(|e| SingleRequestError::SendRequestError(e.to_string()))? - } - // Otherwise send over clearnet - _ => { + maybe_wrap_with_tls(stream, &node.0, &node.1).await? + } else { let stream = TcpStream::connect(format!("{}:{}", node.1, node.2)) .await .map_err(|e| SingleRequestError::ConnectionError(e.to_string()))?; + maybe_wrap_with_tls(stream, &node.0, &node.1).await? + }; - // Wrap with TLS if using HTTPS - let stream = maybe_wrap_with_tls(stream, &node.0, &node.1).await?; + // Build an HTTP/1 connection over the stream. + let (sender, conn) = hyper::client::conn::http1::handshake(TokioIo::new(boxed_stream)) + .await + .map_err(|e| SingleRequestError::ConnectionError(e.to_string()))?; - let (mut sender, conn) = hyper::client::conn::http1::handshake(TokioIo::new(stream)) - .await - .map_err(|e| SingleRequestError::ConnectionError(e.to_string()))?; + // Drive the connection in the background. + tokio::spawn(async move { + let _ = conn.await; // Just drive the connection, errors handled per-request + }); - tracing::debug!( - "Connected to node via clearnet{}", - if node.0 == "https" { " with TLS" } else { "" } - ); + // Insert into pool and obtain exclusive access for this request. + guarded_sender = Some( + state + .connection_pool + .insert_and_lock(key.clone(), sender) + .await, + ); - tokio::task::spawn(async move { - if let Err(err) = conn.await { - println!("Connection failed: {:?}", err); - } - }); + tracing::trace!( + "Established new connection via {}{}", + if use_tor { "Tor" } else { "clearnet" }, + if node.0 == "https" { " with TLS" } else { "" } + ); + } - sender - .send_request(request.to_request()) - .await - .map_err(|e| SingleRequestError::SendRequestError(e.to_string()))? + let mut guarded_sender = guarded_sender.expect("sender must be set"); + + // Forward the request to the node. URI stays relative, so no rewrite. + let response = match guarded_sender.send_request(request.to_request()).await { + Ok(response) => response, + Err(e) => { + // Connection failed, remove it from the pool + guarded_sender.mark_failed().await; + return Err(SingleRequestError::SendRequestError(e.to_string())); } }; // Convert hyper Response to axum Response let (parts, body) = response.into_parts(); - let body_bytes = body - .collect() - .await - .map_err(|e| SingleRequestError::CollectResponseError(e.to_string()))? - .to_bytes(); - let axum_body = Body::from(body_bytes); + let stream = body + .into_data_stream() + .map(|result| result.map_err(|e| axum::Error::new(e))); + let axum_body = Body::from_stream(stream); - let response = Response::from_parts(parts, axum_body); - - Ok(response) + Ok(Response::from_parts(parts, axum_body)) } fn get_jsonrpc_error(body: &[u8]) -> Option { @@ -341,6 +341,49 @@ pub struct CloneableRequest { pub body: Vec, } +/// A response that buffers the first 1KB for error checking and keeps the rest as a stream +pub struct StreamableResponse { + parts: response::Parts, + first_chunk: Vec, + remaining_stream: Option, axum::Error>> + Send>>>, +} + +/// A wrapper stream that tracks bandwidth usage +struct BandwidthTrackingStream { + inner: S, + bandwidth_tracker: Arc, +} + +impl BandwidthTrackingStream { + fn new(inner: S, bandwidth_tracker: Arc) -> Self { + Self { + inner, + bandwidth_tracker, + } + } +} + +impl Stream for BandwidthTrackingStream +where + S: Stream, axum::Error>> + Unpin, +{ + type Item = Result, axum::Error>; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let result = Pin::new(&mut self.inner).poll_next(cx); + + if let std::task::Poll::Ready(Some(Ok(ref chunk))) = result { + let chunk_size = chunk.len() as u64; + self.bandwidth_tracker.record_bytes(chunk_size); + } + + result + } +} + /// A cloneable response that buffers the body in memory #[derive(Clone)] pub struct CloneableResponse { @@ -388,6 +431,117 @@ impl CloneableRequest { } } +impl StreamableResponse { + const ERROR_CHECK_SIZE: usize = 1024; // 1KB + + /// Convert a streaming response with bandwidth tracking + pub async fn from_response_with_tracking( + response: Response, + node_pool: Option>, + ) -> Result { + let (parts, body) = response.into_parts(); + let mut body_stream = body.into_data_stream(); + + let mut first_chunk = Vec::new(); + let mut remaining_chunks = Vec::new(); + let mut total_read = 0; + + // Collect chunks until we have at least 1KB for error checking + while total_read < Self::ERROR_CHECK_SIZE { + match body_stream.next().await { + Some(Ok(chunk)) => { + let chunk_bytes = chunk.to_vec(); + let needed = Self::ERROR_CHECK_SIZE - total_read; + + if chunk_bytes.len() <= needed { + // Entire chunk goes to first_chunk + first_chunk.extend_from_slice(&chunk_bytes); + total_read += chunk_bytes.len(); + } else { + // Split the chunk + first_chunk.extend_from_slice(&chunk_bytes[..needed]); + remaining_chunks.push(chunk_bytes[needed..].to_vec()); + total_read += needed; + break; + } + } + Some(Err(e)) => return Err(e), + None => break, // End of stream + } + } + + // Track bandwidth for the first chunk if we have a node pool + if let Some(ref node_pool) = node_pool { + node_pool.record_bandwidth(first_chunk.len() as u64); + } + + // Create stream for remaining data + let remaining_stream = + if !remaining_chunks.is_empty() || total_read >= Self::ERROR_CHECK_SIZE { + let initial_chunks = remaining_chunks.into_iter().map(Ok); + let rest_stream = body_stream.map(|result| { + result + .map(|chunk| chunk.to_vec()) + .map_err(|e| axum::Error::new(e)) + }); + let combined_stream = futures::stream::iter(initial_chunks).chain(rest_stream); + + // Wrap with bandwidth tracking if we have a node pool + let final_stream: Pin, axum::Error>> + Send>> = + if let Some(node_pool) = node_pool.clone() { + let bandwidth_tracker = node_pool.get_bandwidth_tracker(); + Box::pin(BandwidthTrackingStream::new( + combined_stream, + bandwidth_tracker, + )) + } else { + Box::pin(combined_stream) + }; + + Some(final_stream) + } else { + None + }; + + Ok(StreamableResponse { + parts, + first_chunk, + remaining_stream, + }) + } + + /// Get the status code + pub fn status(&self) -> StatusCode { + self.parts.status + } + + /// Check for JSON-RPC errors in the first chunk + pub fn get_jsonrpc_error(&self) -> Option { + get_jsonrpc_error(&self.first_chunk) + } + + /// Convert to a streaming response + pub fn into_response(self) -> Response { + let body = if let Some(remaining_stream) = self.remaining_stream { + // Create a stream that starts with the first chunk, then continues with the rest + let first_chunk_stream = + futures::stream::once(futures::future::ready(Ok(self.first_chunk))); + let combined_stream = first_chunk_stream.chain(remaining_stream); + Body::from_stream(combined_stream) + } else { + // Only the first chunk exists + Body::from(self.first_chunk) + }; + + Response::from_parts(self.parts, body) + } + + /// Get the size of the response (first chunk only, for bandwidth tracking) + pub fn first_chunk_size(&self) -> usize { + self.first_chunk.len() + } +} + impl CloneableResponse { /// Convert a streaming response into a cloneable one by buffering the body pub async fn from_response(response: Response) -> Result { @@ -468,7 +622,6 @@ enum HandlerError { enum SingleRequestError { ConnectionError(String), SendRequestError(String), - CollectResponseError(String), } impl std::fmt::Display for HandlerError { @@ -500,9 +653,6 @@ impl std::fmt::Display for SingleRequestError { match self { SingleRequestError::ConnectionError(msg) => write!(f, "Connection error: {}", msg), SingleRequestError::SendRequestError(msg) => write!(f, "Send request error: {}", msg), - SingleRequestError::CollectResponseError(msg) => { - write!(f, "Collect response error: {}", msg) - } } } } diff --git a/monero-rpc-pool/src/types.rs b/monero-rpc-pool/src/types.rs index df536d16..b3b8c947 100644 --- a/monero-rpc-pool/src/types.rs +++ b/monero-rpc-pool/src/types.rs @@ -1,4 +1,5 @@ use chrono::{DateTime, Utc}; +use monero::Network; use serde::{Deserialize, Serialize}; use std::fmt; @@ -28,12 +29,13 @@ impl fmt::Display for NodeAddress { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct NodeMetadata { pub id: i64, - pub network: String, // "mainnet", "stagenet", or "testnet" + #[serde(with = "swap_serde::monero::network")] + pub network: Network, pub first_seen_at: DateTime, } impl NodeMetadata { - pub fn new(id: i64, network: String, first_seen_at: DateTime) -> Self { + pub fn new(id: i64, network: Network, first_seen_at: DateTime) -> Self { Self { id, network, diff --git a/src-gui/src/renderer/components/pages/monero/components/WalletOverview.tsx b/src-gui/src/renderer/components/pages/monero/components/WalletOverview.tsx index 7430685d..f2ebd27f 100644 --- a/src-gui/src/renderer/components/pages/monero/components/WalletOverview.tsx +++ b/src-gui/src/renderer/components/pages/monero/components/WalletOverview.tsx @@ -3,17 +3,88 @@ import { useAppSelector } from "store/hooks"; import { PiconeroAmount } from "../../../other/Units"; import { FiatPiconeroAmount } from "../../../other/Units"; import StateIndicator from "./StateIndicator"; +import humanizeDuration from "humanize-duration"; +import { GetMoneroSyncProgressResponse } from "models/tauriModel"; + +interface TimeEstimationResult { + blocksLeft: number; + hasDirectKnowledge: boolean; + isStuck: boolean; + formattedTimeRemaining: string | null; +} + +const AVG_MONERO_BLOCK_SIZE_KB = 130; + +function useSyncTimeEstimation( + syncProgress: GetMoneroSyncProgressResponse | undefined, +): TimeEstimationResult | null { + const poolStatus = useAppSelector((state) => state.pool.status); + const restoreHeight = useAppSelector( + (state) => state.wallet.state.restoreHeight, + ); + + if (restoreHeight == null || poolStatus == null) { + return null; + } + + const currentBlock = syncProgress?.current_block ?? 0; + const targetBlock = syncProgress?.target_block ?? 0; + const restoreBlock = restoreHeight.height; + + // For blocks before the restore height we only need to download the header + const fastBlocksLeft = + currentBlock < restoreBlock + ? Math.max(0, Math.min(restoreBlock, targetBlock) - currentBlock) + : 0; + + // For blocks after (or equal to) the restore height we need the full block data + const fullBlocksLeft = Math.max( + 0, + targetBlock - Math.max(currentBlock, restoreBlock), + ); + + const blocksLeft = fastBlocksLeft + fullBlocksLeft; + + // Treat blocksLeft = 1 as if we have no direct knowledge + const hasDirectKnowledge = blocksLeft != null && blocksLeft > 1; + + const isStuck = + poolStatus?.bandwidth_kb_per_sec != null && + poolStatus.bandwidth_kb_per_sec < 1; + + // A full blocks is 130kb, we assume a header is 2% of that + const estimatedDownloadLeftSize = + fullBlocksLeft * AVG_MONERO_BLOCK_SIZE_KB + + (fastBlocksLeft * AVG_MONERO_BLOCK_SIZE_KB) / 50; + + const estimatedTimeRemaining = + hasDirectKnowledge && + poolStatus?.bandwidth_kb_per_sec != null && + poolStatus.bandwidth_kb_per_sec > 0 + ? estimatedDownloadLeftSize / poolStatus.bandwidth_kb_per_sec + : null; + + const formattedTimeRemaining = estimatedTimeRemaining + ? humanizeDuration(estimatedTimeRemaining * 1000, { + round: true, + largest: 1, + }) + : null; + + return { + blocksLeft, + hasDirectKnowledge, + isStuck, + formattedTimeRemaining, + }; +} interface WalletOverviewProps { balance?: { unlocked_balance: string; total_balance: string; }; - syncProgress?: { - current_block: number; - target_block: number; - progress_percentage: number; - }; + syncProgress?: GetMoneroSyncProgressResponse; } // Component for displaying wallet address and balance @@ -26,15 +97,12 @@ export default function WalletOverview({ ); const poolStatus = useAppSelector((state) => state.pool.status); + const timeEstimation = useSyncTimeEstimation(syncProgress); const pendingBalance = parseFloat(balance.total_balance) - parseFloat(balance.unlocked_balance); const isSyncing = syncProgress && syncProgress.progress_percentage < 100; - const blocksLeft = syncProgress?.target_block - syncProgress?.current_block; - - // Treat blocksLeft = 1 as if we have no direct knowledge - const hasDirectKnowledge = blocksLeft != null && blocksLeft > 1; // syncProgress.progress_percentage is not good to display // assuming we have an old wallet, eventually we will always only use the last few cm of the progress bar @@ -61,36 +129,23 @@ export default function WalletOverview({ ), ); - const isStuck = - poolStatus?.bandwidth_kb_per_sec != null && - poolStatus.bandwidth_kb_per_sec < 0.01; - - // Calculate estimated time remaining for sync - const formatTimeRemaining = (seconds: number): string => { - if (seconds < 60) return `${Math.round(seconds)}s`; - if (seconds < 3600) return `${Math.round(seconds / 60)}m`; - if (seconds < 86400) return `${Math.round(seconds / 3600)}h`; - return `${Math.round(seconds / 86400)}d`; - }; - - const estimatedTimeRemaining = - hasDirectKnowledge && - poolStatus?.bandwidth_kb_per_sec != null && - poolStatus.bandwidth_kb_per_sec > 0 - ? (blocksLeft * 130) / poolStatus.bandwidth_kb_per_sec // blocks * 130kb / kb_per_sec = seconds - : null; - return ( {syncProgress && syncProgress.progress_percentage < 100 && ( @@ -174,7 +229,8 @@ export default function WalletOverview({ display: "flex", flexDirection: "column", alignItems: "flex-end", - gap: 2, + justifyContent: "space-between", + minHeight: "100%", }} > - {isSyncing && hasDirectKnowledge && ( + {isSyncing && timeEstimation?.hasDirectKnowledge && ( - {blocksLeft?.toLocaleString()} blocks left + {timeEstimation.blocksLeft?.toLocaleString()} blocks left )} - {poolStatus && isSyncing && !isStuck && ( + {poolStatus && isSyncing && !timeEstimation?.isStuck && ( <> - {estimatedTimeRemaining && !isStuck && ( - <>{formatTimeRemaining(estimatedTimeRemaining)} left - )}{" "} - / {poolStatus.bandwidth_kb_per_sec?.toFixed(1) ?? "0.0"} KB/s + {timeEstimation?.formattedTimeRemaining && + !timeEstimation?.isStuck && ( + <> + {timeEstimation.formattedTimeRemaining} left /{" "} + {poolStatus.bandwidth_kb_per_sec?.toFixed(1) ?? "0.0"}{" "} + KB/s + + )} )} diff --git a/src-gui/src/renderer/components/pages/swap/swap/done/BitcoinRefundedPage.tsx b/src-gui/src/renderer/components/pages/swap/swap/done/BitcoinRefundedPage.tsx index 0c82e522..866a6fcc 100644 --- a/src-gui/src/renderer/components/pages/swap/swap/done/BitcoinRefundedPage.tsx +++ b/src-gui/src/renderer/components/pages/swap/swap/done/BitcoinRefundedPage.tsx @@ -69,8 +69,9 @@ function MultiBitcoinRefundedPage({ <> Unfortunately, the swap was not successful. However, rest assured that - all your Bitcoin has been refunded to the specified address. The swap - process is now complete, and you are free to exit the application. + all your Bitcoin has been refunded to the specified address.{" "} + {btc_refund_finalized && + "The swap process is now complete, and you are free to exit the application."} } export async function getRestoreHeight(): Promise { - return await invokeNoArgs("get_restore_height"); + const restoreHeight = + await invokeNoArgs("get_restore_height"); + store.dispatch(setRestoreHeight(restoreHeight)); + return restoreHeight; } export async function setMoneroRestoreHeight( @@ -489,25 +493,31 @@ export async function getMoneroSyncProgress(): Promise { + // Returns the wallet's seed phrase as a single string. Backend must expose the `get_monero_seed` command. + return await invokeNoArgs("get_monero_seed"); +} + // Wallet management functions that handle Redux dispatching export async function initializeMoneroWallet() { try { - const [ - addressResponse, - balanceResponse, - syncProgressResponse, - historyResponse, - ] = await Promise.all([ - getMoneroMainAddress(), - getMoneroBalance(), - getMoneroSyncProgress(), - getMoneroHistory(), + await Promise.all([ + getMoneroMainAddress().then((response) => { + store.dispatch(setMainAddress(response.address)); + }), + getMoneroBalance().then((response) => { + store.dispatch(setBalance(response)); + }), + getMoneroSyncProgress().then((response) => { + store.dispatch(setSyncProgress(response)); + }), + getMoneroHistory().then((response) => { + store.dispatch(setHistory(response)); + }), + getRestoreHeight().then((response) => { + store.dispatch(setRestoreHeight(response)); + }), ]); - - store.dispatch(setMainAddress(addressResponse.address)); - store.dispatch(setBalance(balanceResponse)); - store.dispatch(setSyncProgress(syncProgressResponse)); - store.dispatch(setHistory(historyResponse)); } catch (err) { console.error("Failed to fetch Monero wallet data:", err); } @@ -527,13 +537,12 @@ export async function sendMoneroTransaction( }) .catch((refreshErr) => { console.error("Failed to refresh wallet data after send:", refreshErr); - // Could emit a toast notification here }); return response; } catch (err) { console.error("Failed to send Monero:", err); - throw err; // ✅ Re-throw so caller can handle appropriately + throw err; } } diff --git a/src-gui/src/store/features/walletSlice.ts b/src-gui/src/store/features/walletSlice.ts index 71178b73..73be6319 100644 --- a/src-gui/src/store/features/walletSlice.ts +++ b/src-gui/src/store/features/walletSlice.ts @@ -3,15 +3,16 @@ import { GetMoneroBalanceResponse, GetMoneroHistoryResponse, GetMoneroSyncProgressResponse, + GetRestoreHeightResponse, } from "models/tauriModel"; interface WalletState { - // Wallet data mainAddress: string | null; balance: GetMoneroBalanceResponse | null; syncProgress: GetMoneroSyncProgressResponse | null; history: GetMoneroHistoryResponse | null; lowestCurrentBlock: number | null; + restoreHeight: GetRestoreHeightResponse | null; } export interface WalletSlice { @@ -20,12 +21,12 @@ export interface WalletSlice { const initialState: WalletSlice = { state: { - // Wallet data mainAddress: null, balance: null, syncProgress: null, history: null, lowestCurrentBlock: null, + restoreHeight: null, }, }; @@ -59,6 +60,9 @@ export const walletSlice = createSlice({ setHistory(slice, action: PayloadAction) { slice.state.history = action.payload; }, + setRestoreHeight(slice, action: PayloadAction) { + slice.state.restoreHeight = action.payload; + }, // Reset actions resetWalletState(slice) { slice.state = initialState.state; @@ -72,6 +76,7 @@ export const { setSyncProgress, setHistory, resetWalletState, + setRestoreHeight, } = walletSlice.actions; export default walletSlice.reducer; diff --git a/swap/src/bin/asb.rs b/swap/src/bin/asb.rs index ec3fd61d..b611c19c 100644 --- a/swap/src/bin/asb.rs +++ b/swap/src/bin/asb.rs @@ -499,8 +499,8 @@ async fn init_monero_wallet( monero_rpc_pool::start_server_with_random_port( monero_rpc_pool::config::Config::new_random_port( config.data.dir.join("monero-rpc-pool"), + env_config.monero_network, ), - env_config.monero_network, ) .await .context("Failed to start Monero RPC Pool for ASB")?; diff --git a/swap/src/cli/api.rs b/swap/src/cli/api.rs index 174d0173..843eac4a 100644 --- a/swap/src/cli/api.rs +++ b/swap/src/cli/api.rs @@ -345,11 +345,11 @@ impl ContextBuilder { } else { None }, + match self.is_testnet { + true => monero::Network::Stagenet, + false => monero::Network::Mainnet, + }, ), - match self.is_testnet { - true => crate::monero::Network::Stagenet, - false => crate::monero::Network::Mainnet, - }, ) .await?;