Update ThreadUtils.java

Updating threading 'engine' for virtual threading
This commit is contained in:
XMRZombie 2025-04-29 00:54:34 +00:00 committed by GitHub
parent 81eaeb6df0
commit 03d32f4a7c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -16,41 +16,44 @@
*/
package haveno.common;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
public class ThreadUtils {
private static final Map<String, ExecutorService> EXECUTORS = new HashMap<>();
private static final Map<String, Thread> THREADS = new HashMap<>();
private static final int POOL_SIZE = 10;
private static final ExecutorService POOL = Executors.newFixedThreadPool(POOL_SIZE);
private static final Map<String, Thread> VIRTUAL_THREADS = new ConcurrentHashMap<>();
/**
* Execute the given command in a thread with the given id.
*
* Execute the given command in a virtual thread with the given id.
*
* @param command the command to execute
* @param threadId the thread id
*/
public static Future<?> execute(Runnable command, String threadId) {
synchronized (EXECUTORS) {
if (!EXECUTORS.containsKey(threadId)) EXECUTORS.put(threadId, Executors.newFixedThreadPool(1));
return EXECUTORS.get(threadId).submit(() -> {
synchronized (THREADS) {
THREADS.put(threadId, Thread.currentThread());
}
Thread.currentThread().setName(threadId);
CompletableFuture<Void> future = new CompletableFuture<>();
Thread virtualThread = Thread.ofVirtual().name(threadId).start(() -> {
try {
command.run();
});
}
future.complete(null);
} catch (Exception e) {
future.completeExceptionally(e);
} finally {
VIRTUAL_THREADS.remove(threadId);
}
});
VIRTUAL_THREADS.put(threadId, virtualThread);
return future;
}
/**
@ -67,85 +70,152 @@ public class ThreadUtils {
}
}
/**
* Shuts down the virtual thread with the given id.
*
* @param threadId the thread id
*/
public static void shutDown(String threadId) {
shutDown(threadId, null);
}
/**
* Shuts down the virtual thread with the given id and an optional timeout.
*
* @param threadId the thread id
* @param timeoutMs the timeout in milliseconds
*/
public static void shutDown(String threadId, Long timeoutMs) {
if (timeoutMs == null) timeoutMs = Long.MAX_VALUE;
ExecutorService pool = null;
synchronized (EXECUTORS) {
pool = EXECUTORS.get(threadId);
}
if (pool == null) return; // thread not found
pool.shutdown();
try {
if (!pool.awaitTermination(timeoutMs, TimeUnit.MILLISECONDS)) pool.shutdownNow();
} catch (InterruptedException e) {
pool.shutdownNow();
throw new RuntimeException(e);
} finally {
remove(threadId);
Thread thread = VIRTUAL_THREADS.get(threadId);
if (thread == null) return; // thread not found
thread.interrupt();
if (timeoutMs != null) {
try {
thread.join(timeoutMs);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
}
}
/**
* Removes the virtual thread with the given id from the tracking map.
*
* @param threadId the thread id
*/
public static void remove(String threadId) {
synchronized (EXECUTORS) {
EXECUTORS.remove(threadId);
}
synchronized (THREADS) {
THREADS.remove(threadId);
}
VIRTUAL_THREADS.remove(threadId);
}
// TODO: consolidate and cleanup apis
/**
* Submit a single task to be executed in a virtual thread.
*
* @param task the task to execute
* @return a Future representing the pending completion of the task
*/
public static Future<?> submitToPool(Runnable task) {
return submitToPool(Arrays.asList(task)).get(0);
}
/**
* Submit a single callable task to be executed in a virtual thread.
*
* @param task the task to execute
* @return a Future representing the pending completion of the task
*/
public static <T> Future<T> submitToPool(Callable<T> task) {
CompletableFuture<T> future = new CompletableFuture<>();
execute(() -> {
try {
T result = task.call();
future.complete(result);
} catch (Exception e) {
future.completeExceptionally(e);
}
}, "pool-task-" + task.hashCode());
return future;
}
/**
* Submit a list of tasks to be executed in virtual threads.
*
* @param tasks the tasks to execute
* @return a list of Futures representing the pending completion of the tasks
*/
public static List<Future<?>> submitToPool(List<Runnable> tasks) {
List<Future<?>> futures = new ArrayList<>();
for (Runnable task : tasks) futures.add(POOL.submit(task));
for (Runnable task : tasks) {
futures.add(execute(task, "pool-task-" + task.hashCode()));
}
return futures;
}
/**
* Await the completion of a single task.
*
* @param task the task to execute
* @return a Future representing the pending completion of the task
*/
public static Future<?> awaitTask(Runnable task) {
return awaitTask(task, null);
}
/**
* Await the completion of a single task with an optional timeout.
*
* @param task the task to execute
* @param timeoutMs the timeout in milliseconds
* @return a Future representing the pending completion of the task
*/
public static Future<?> awaitTask(Runnable task, Long timeoutMs) {
return awaitTasks(Arrays.asList(task), 1, timeoutMs).get(0);
}
/**
* Await the completion of a collection of tasks.
*
* @param tasks the tasks to execute
* @return a list of Futures representing the pending completion of the tasks
*/
public static List<Future<?>> awaitTasks(Collection<Runnable> tasks) {
return awaitTasks(tasks, tasks.size());
}
/**
* Await the completion of a collection of tasks with a specified maximum concurrency.
*
* @param tasks the tasks to execute
* @param maxConcurrency the maximum number of concurrent tasks
* @return a list of Futures representing the pending completion of the tasks
*/
public static List<Future<?>> awaitTasks(Collection<Runnable> tasks, int maxConcurrency) {
return awaitTasks(tasks, maxConcurrency, null);
}
/**
* Await the completion of a collection of tasks with a specified maximum concurrency and optional timeout.
*
* @param tasks the tasks to execute
* @param maxConcurrency the maximum number of concurrent tasks
* @param timeoutMs the timeout in milliseconds
* @return a list of Futures representing the pending completion of the tasks
*/
public static List<Future<?>> awaitTasks(Collection<Runnable> tasks, int maxConcurrency, Long timeoutMs) {
if (timeoutMs == null) timeoutMs = Long.MAX_VALUE;
if (tasks.isEmpty()) return new ArrayList<>();
ExecutorService executorService = Executors.newFixedThreadPool(tasks.size());
try {
List<Future<?>> futures = new ArrayList<>();
for (Runnable task : tasks) futures.add(executorService.submit(task, null));
for (Future<?> future : futures) future.get(timeoutMs, TimeUnit.MILLISECONDS);
return futures;
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
executorService.shutdownNow();
}
}
private static boolean isCurrentThread(Thread thread, String threadId) {
synchronized (THREADS) {
if (!THREADS.containsKey(threadId)) return false;
return thread == THREADS.get(threadId);
List<Future<?>> futures = new ArrayList<>();
for (Runnable task : tasks) {
futures.add(execute(task, "await-task-" + task.hashCode()));
}
for (Future<?> future : futures) {
try {
future.get(timeoutMs, TimeUnit.MILLISECONDS);
} catch (InterruptedException | ExecutionException | TimeoutException e) {
throw new RuntimeException(e);
}
}
return futures;
}
}