diff --git a/common/src/main/java/haveno/common/ThreadUtils.java b/common/src/main/java/haveno/common/ThreadUtils.java index f5defe1119..a945014bc1 100644 --- a/common/src/main/java/haveno/common/ThreadUtils.java +++ b/common/src/main/java/haveno/common/ThreadUtils.java @@ -21,7 +21,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; -import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; @@ -29,10 +28,17 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; - +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + public class ThreadUtils { - - private static final Map VIRTUAL_THREADS = new ConcurrentHashMap<>(); + + private static final Logger logger = LoggerFactory.getLogger(ThreadUtils.class); + private static final ConcurrentHashMap VIRTUAL_THREADS = new ConcurrentHashMap<>(); + private static final AtomicInteger THREAD_COUNTER = new AtomicInteger(0); + private static final long DEFAULT_TIMEOUT_MS = 5000; // Default timeout for operations /** * Execute the given command in a virtual thread with the given id. @@ -47,11 +53,11 @@ public class ThreadUtils { command.run(); future.complete(null); } catch (Exception e) { + logger.error("Exception in thread: {} - {}", threadId, e.getMessage(), e); future.completeExceptionally(e); - } finally { - VIRTUAL_THREADS.remove(threadId); } }); + VIRTUAL_THREADS.put(threadId, virtualThread); return future; } @@ -65,65 +71,51 @@ public class ThreadUtils { public static void await(Runnable command, String threadId) { try { execute(command, threadId).get(); - } catch (Exception e) { - throw new RuntimeException(e); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + logger.error("Interrupted while awaiting command execution in thread: {} - {}", threadId, e.getMessage(), e); + throw new RuntimeException("Interrupted while awaiting command execution", e); + } catch (ExecutionException e) { + logger.error("Execution exception while awaiting command execution in thread: {} - {}", threadId, e.getMessage(), e); + throw new RuntimeException("Execution exception while awaiting command execution", e); } } - /** - * Shuts down the virtual thread with the given id. - * - * @param threadId the thread id - */ + public static void shutDown(String threadId) { shutDown(threadId, null); } - /** - * Shuts down the virtual thread with the given id and an optional timeout. - * - * @param threadId the thread id - * @param timeoutMs the timeout in milliseconds - */ public static void shutDown(String threadId, Long timeoutMs) { - Thread thread = VIRTUAL_THREADS.get(threadId); - if (thread == null) return; // thread not found + Thread thread = VIRTUAL_THREADS.remove(threadId); + if (thread == null) { + logger.warn("Thread not found: {}", threadId); + return; // thread not found + } thread.interrupt(); if (timeoutMs != null) { try { thread.join(timeoutMs); } catch (InterruptedException e) { Thread.currentThread().interrupt(); - throw new RuntimeException(e); + logger.error("Interrupted while waiting for thread to shut down: {} - {}", threadId, e.getMessage(), e); + throw new RuntimeException("Interrupted while waiting for thread to shut down", e); } } + logger.info("Shut down thread: {}", threadId); } - /** - * Removes the virtual thread with the given id from the tracking map. - * - * @param threadId the thread id - */ + public static void remove(String threadId) { VIRTUAL_THREADS.remove(threadId); } - /** - * Submit a single task to be executed in a virtual thread. - * - * @param task the task to execute - * @return a Future representing the pending completion of the task - */ + // TODO: consolidate and cleanup apis + public static Future submitToPool(Runnable task) { return submitToPool(Arrays.asList(task)).get(0); } - /** - * Submit a single callable task to be executed in a virtual thread. - * - * @param task the task to execute - * @return a Future representing the pending completion of the task - */ public static Future submitToPool(Callable task) { CompletableFuture future = new CompletableFuture<>(); execute(() -> { @@ -131,91 +123,73 @@ public class ThreadUtils { T result = task.call(); future.complete(result); } catch (Exception e) { + logger.error("Exception in callable task - {}", e.getMessage(), e); future.completeExceptionally(e); } - }, "pool-task-" + task.hashCode()); + }, "pool-task-" + THREAD_COUNTER.incrementAndGet()); return future; } - /** - * Submit a list of tasks to be executed in virtual threads. - * - * @param tasks the tasks to execute - * @return a list of Futures representing the pending completion of the tasks - */ public static List> submitToPool(List tasks) { List> futures = new ArrayList<>(); for (Runnable task : tasks) { - futures.add(execute(task, "pool-task-" + task.hashCode())); + futures.add(execute(task, "pool-task-" + THREAD_COUNTER.incrementAndGet())); } return futures; } - /** - * Await the completion of a single task. - * - * @param task the task to execute - * @return a Future representing the pending completion of the task - */ public static Future awaitTask(Runnable task) { return awaitTask(task, null); } - /** - * Await the completion of a single task with an optional timeout. - * - * @param task the task to execute - * @param timeoutMs the timeout in milliseconds - * @return a Future representing the pending completion of the task - */ public static Future awaitTask(Runnable task, Long timeoutMs) { return awaitTasks(Arrays.asList(task), 1, timeoutMs).get(0); } - /** - * Await the completion of a collection of tasks. - * - * @param tasks the tasks to execute - * @return a list of Futures representing the pending completion of the tasks - */ public static List> awaitTasks(Collection tasks) { return awaitTasks(tasks, tasks.size()); } - /** - * Await the completion of a collection of tasks with a specified maximum concurrency. - * - * @param tasks the tasks to execute - * @param maxConcurrency the maximum number of concurrent tasks - * @return a list of Futures representing the pending completion of the tasks - */ public static List> awaitTasks(Collection tasks, int maxConcurrency) { return awaitTasks(tasks, maxConcurrency, null); } - /** - * Await the completion of a collection of tasks with a specified maximum concurrency and optional timeout. - * - * @param tasks the tasks to execute - * @param maxConcurrency the maximum number of concurrent tasks - * @param timeoutMs the timeout in milliseconds - * @return a list of Futures representing the pending completion of the tasks - */ public static List> awaitTasks(Collection tasks, int maxConcurrency, Long timeoutMs) { - if (timeoutMs == null) timeoutMs = Long.MAX_VALUE; + if (timeoutMs == null) timeoutMs = DEFAULT_TIMEOUT_MS; if (tasks.isEmpty()) return new ArrayList<>(); List> futures = new ArrayList<>(); - for (Runnable task : tasks) { - futures.add(execute(task, "await-task-" + task.hashCode())); - } - for (Future future : futures) { - try { - future.get(timeoutMs, TimeUnit.MILLISECONDS); - } catch (InterruptedException | ExecutionException | TimeoutException e) { - throw new RuntimeException(e); + AtomicReference> remainingTasks = new AtomicReference<>(new ArrayList<>(tasks)); + + while (true) { + List batchTasks = remainingTasks.get().subList(0, Math.min(maxConcurrency, remainingTasks.get().size())); + if (batchTasks.isEmpty()) break; + + List> batchFutures = new ArrayList<>(); + for (Runnable task : batchTasks) { + batchFutures.add(execute(task, "await-task-" + THREAD_COUNTER.incrementAndGet())); } + + for (Future future : batchFutures) { + try { + future.get(timeoutMs, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + logger.error("Interrupted while awaiting task completion - {}", e.getMessage(), e); + throw new RuntimeException("Interrupted while awaiting task completion", e); + } catch (ExecutionException e) { + logger.error("Execution exception while awaiting task completion - {}", e.getMessage(), e); + throw new RuntimeException("Execution exception while awaiting task completion", e); + } catch (TimeoutException e) { + logger.error("Timeout while awaiting task completion - {}", e.getMessage(), e); + throw new RuntimeException("Timeout while awaiting task completion", e); + } + } + + futures.addAll(batchFutures); + remainingTasks.set(remainingTasks.get().subList(Math.min(maxConcurrency, remainingTasks.get().size()), remainingTasks.get().size())); } + return futures; } }