diff --git a/common/src/main/java/haveno/common/ThreadUtils.java b/common/src/main/java/haveno/common/ThreadUtils.java index e463e83154..f5defe1119 100644 --- a/common/src/main/java/haveno/common/ThreadUtils.java +++ b/common/src/main/java/haveno/common/ThreadUtils.java @@ -16,41 +16,44 @@ */ package haveno.common; + import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; public class ThreadUtils { - - private static final Map EXECUTORS = new HashMap<>(); - private static final Map THREADS = new HashMap<>(); - private static final int POOL_SIZE = 10; - private static final ExecutorService POOL = Executors.newFixedThreadPool(POOL_SIZE); + + private static final Map VIRTUAL_THREADS = new ConcurrentHashMap<>(); /** - * Execute the given command in a thread with the given id. - * + * Execute the given command in a virtual thread with the given id. + * * @param command the command to execute * @param threadId the thread id */ public static Future execute(Runnable command, String threadId) { - synchronized (EXECUTORS) { - if (!EXECUTORS.containsKey(threadId)) EXECUTORS.put(threadId, Executors.newFixedThreadPool(1)); - return EXECUTORS.get(threadId).submit(() -> { - synchronized (THREADS) { - THREADS.put(threadId, Thread.currentThread()); - } - Thread.currentThread().setName(threadId); + CompletableFuture future = new CompletableFuture<>(); + Thread virtualThread = Thread.ofVirtual().name(threadId).start(() -> { + try { command.run(); - }); - } + future.complete(null); + } catch (Exception e) { + future.completeExceptionally(e); + } finally { + VIRTUAL_THREADS.remove(threadId); + } + }); + VIRTUAL_THREADS.put(threadId, virtualThread); + return future; } /** @@ -67,85 +70,152 @@ public class ThreadUtils { } } + /** + * Shuts down the virtual thread with the given id. + * + * @param threadId the thread id + */ public static void shutDown(String threadId) { shutDown(threadId, null); } + /** + * Shuts down the virtual thread with the given id and an optional timeout. + * + * @param threadId the thread id + * @param timeoutMs the timeout in milliseconds + */ public static void shutDown(String threadId, Long timeoutMs) { - if (timeoutMs == null) timeoutMs = Long.MAX_VALUE; - ExecutorService pool = null; - synchronized (EXECUTORS) { - pool = EXECUTORS.get(threadId); - } - if (pool == null) return; // thread not found - pool.shutdown(); - try { - if (!pool.awaitTermination(timeoutMs, TimeUnit.MILLISECONDS)) pool.shutdownNow(); - } catch (InterruptedException e) { - pool.shutdownNow(); - throw new RuntimeException(e); - } finally { - remove(threadId); + Thread thread = VIRTUAL_THREADS.get(threadId); + if (thread == null) return; // thread not found + thread.interrupt(); + if (timeoutMs != null) { + try { + thread.join(timeoutMs); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } } } + /** + * Removes the virtual thread with the given id from the tracking map. + * + * @param threadId the thread id + */ public static void remove(String threadId) { - synchronized (EXECUTORS) { - EXECUTORS.remove(threadId); - } - synchronized (THREADS) { - THREADS.remove(threadId); - } + VIRTUAL_THREADS.remove(threadId); } - // TODO: consolidate and cleanup apis - + /** + * Submit a single task to be executed in a virtual thread. + * + * @param task the task to execute + * @return a Future representing the pending completion of the task + */ public static Future submitToPool(Runnable task) { return submitToPool(Arrays.asList(task)).get(0); } + /** + * Submit a single callable task to be executed in a virtual thread. + * + * @param task the task to execute + * @return a Future representing the pending completion of the task + */ + public static Future submitToPool(Callable task) { + CompletableFuture future = new CompletableFuture<>(); + execute(() -> { + try { + T result = task.call(); + future.complete(result); + } catch (Exception e) { + future.completeExceptionally(e); + } + }, "pool-task-" + task.hashCode()); + return future; + } + + /** + * Submit a list of tasks to be executed in virtual threads. + * + * @param tasks the tasks to execute + * @return a list of Futures representing the pending completion of the tasks + */ public static List> submitToPool(List tasks) { List> futures = new ArrayList<>(); - for (Runnable task : tasks) futures.add(POOL.submit(task)); + for (Runnable task : tasks) { + futures.add(execute(task, "pool-task-" + task.hashCode())); + } return futures; } + /** + * Await the completion of a single task. + * + * @param task the task to execute + * @return a Future representing the pending completion of the task + */ public static Future awaitTask(Runnable task) { return awaitTask(task, null); } + /** + * Await the completion of a single task with an optional timeout. + * + * @param task the task to execute + * @param timeoutMs the timeout in milliseconds + * @return a Future representing the pending completion of the task + */ public static Future awaitTask(Runnable task, Long timeoutMs) { return awaitTasks(Arrays.asList(task), 1, timeoutMs).get(0); } + /** + * Await the completion of a collection of tasks. + * + * @param tasks the tasks to execute + * @return a list of Futures representing the pending completion of the tasks + */ public static List> awaitTasks(Collection tasks) { return awaitTasks(tasks, tasks.size()); } + /** + * Await the completion of a collection of tasks with a specified maximum concurrency. + * + * @param tasks the tasks to execute + * @param maxConcurrency the maximum number of concurrent tasks + * @return a list of Futures representing the pending completion of the tasks + */ public static List> awaitTasks(Collection tasks, int maxConcurrency) { return awaitTasks(tasks, maxConcurrency, null); } + /** + * Await the completion of a collection of tasks with a specified maximum concurrency and optional timeout. + * + * @param tasks the tasks to execute + * @param maxConcurrency the maximum number of concurrent tasks + * @param timeoutMs the timeout in milliseconds + * @return a list of Futures representing the pending completion of the tasks + */ public static List> awaitTasks(Collection tasks, int maxConcurrency, Long timeoutMs) { if (timeoutMs == null) timeoutMs = Long.MAX_VALUE; if (tasks.isEmpty()) return new ArrayList<>(); - ExecutorService executorService = Executors.newFixedThreadPool(tasks.size()); - try { - List> futures = new ArrayList<>(); - for (Runnable task : tasks) futures.add(executorService.submit(task, null)); - for (Future future : futures) future.get(timeoutMs, TimeUnit.MILLISECONDS); - return futures; - } catch (Exception e) { - throw new RuntimeException(e); - } finally { - executorService.shutdownNow(); - } - } - private static boolean isCurrentThread(Thread thread, String threadId) { - synchronized (THREADS) { - if (!THREADS.containsKey(threadId)) return false; - return thread == THREADS.get(threadId); + List> futures = new ArrayList<>(); + for (Runnable task : tasks) { + futures.add(execute(task, "await-task-" + task.hashCode())); } + for (Future future : futures) { + try { + future.get(timeoutMs, TimeUnit.MILLISECONDS); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + throw new RuntimeException(e); + } + } + return futures; } }