Update ThreadUtils.java

- Utilizes Java's virtual threads (Thread.ofVirtual()), which are lightweight and designed for handling a large number of concurrent tasks with minimal overhead

- Names virtual threads directly during creation using Thread.ofVirtual().name(threadId).

- Manages virtual threads using a ConcurrentHashMap for better concurrency and thread-safety.

- Shuts down virtual threads by interrupting them and optionally waiting for them to join with a timeout. Removes threads from the ConcurrentHashMap.

- Controls concurrency by submitting tasks to virtual threads in batches, with a configurable maximum concurrency level.

- AtomicInteger and AtomicReference for thread-safe counter increments and task management, reducing the need for synchronized blocks
This commit is contained in:
XMRZombie 2025-05-02 22:44:54 +00:00 committed by GitHub
parent e16e2ebd03
commit 6d746c05de
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -21,7 +21,6 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
@ -29,10 +28,17 @@ import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class ThreadUtils {
private static final Map<String, Thread> VIRTUAL_THREADS = new ConcurrentHashMap<>();
private static final Logger logger = LoggerFactory.getLogger(ThreadUtils.class);
private static final ConcurrentHashMap<String, Thread> VIRTUAL_THREADS = new ConcurrentHashMap<>();
private static final AtomicInteger THREAD_COUNTER = new AtomicInteger(0);
private static final long DEFAULT_TIMEOUT_MS = 5000; // Default timeout for operations
/**
* Execute the given command in a virtual thread with the given id.
@ -47,11 +53,11 @@ public class ThreadUtils {
command.run();
future.complete(null);
} catch (Exception e) {
logger.error("Exception in thread: {} - {}", threadId, e.getMessage(), e);
future.completeExceptionally(e);
} finally {
VIRTUAL_THREADS.remove(threadId);
}
});
VIRTUAL_THREADS.put(threadId, virtualThread);
return future;
}
@ -65,65 +71,51 @@ public class ThreadUtils {
public static void await(Runnable command, String threadId) {
try {
execute(command, threadId).get();
} catch (Exception e) {
throw new RuntimeException(e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.error("Interrupted while awaiting command execution in thread: {} - {}", threadId, e.getMessage(), e);
throw new RuntimeException("Interrupted while awaiting command execution", e);
} catch (ExecutionException e) {
logger.error("Execution exception while awaiting command execution in thread: {} - {}", threadId, e.getMessage(), e);
throw new RuntimeException("Execution exception while awaiting command execution", e);
}
}
/**
* Shuts down the virtual thread with the given id.
*
* @param threadId the thread id
*/
public static void shutDown(String threadId) {
shutDown(threadId, null);
}
/**
* Shuts down the virtual thread with the given id and an optional timeout.
*
* @param threadId the thread id
* @param timeoutMs the timeout in milliseconds
*/
public static void shutDown(String threadId, Long timeoutMs) {
Thread thread = VIRTUAL_THREADS.get(threadId);
if (thread == null) return; // thread not found
Thread thread = VIRTUAL_THREADS.remove(threadId);
if (thread == null) {
logger.warn("Thread not found: {}", threadId);
return; // thread not found
}
thread.interrupt();
if (timeoutMs != null) {
try {
thread.join(timeoutMs);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
logger.error("Interrupted while waiting for thread to shut down: {} - {}", threadId, e.getMessage(), e);
throw new RuntimeException("Interrupted while waiting for thread to shut down", e);
}
}
logger.info("Shut down thread: {}", threadId);
}
/**
* Removes the virtual thread with the given id from the tracking map.
*
* @param threadId the thread id
*/
public static void remove(String threadId) {
VIRTUAL_THREADS.remove(threadId);
}
/**
* Submit a single task to be executed in a virtual thread.
*
* @param task the task to execute
* @return a Future representing the pending completion of the task
*/
// TODO: consolidate and cleanup apis
public static Future<?> submitToPool(Runnable task) {
return submitToPool(Arrays.asList(task)).get(0);
}
/**
* Submit a single callable task to be executed in a virtual thread.
*
* @param task the task to execute
* @return a Future representing the pending completion of the task
*/
public static <T> Future<T> submitToPool(Callable<T> task) {
CompletableFuture<T> future = new CompletableFuture<>();
execute(() -> {
@ -131,91 +123,73 @@ public class ThreadUtils {
T result = task.call();
future.complete(result);
} catch (Exception e) {
logger.error("Exception in callable task - {}", e.getMessage(), e);
future.completeExceptionally(e);
}
}, "pool-task-" + task.hashCode());
}, "pool-task-" + THREAD_COUNTER.incrementAndGet());
return future;
}
/**
* Submit a list of tasks to be executed in virtual threads.
*
* @param tasks the tasks to execute
* @return a list of Futures representing the pending completion of the tasks
*/
public static List<Future<?>> submitToPool(List<Runnable> tasks) {
List<Future<?>> futures = new ArrayList<>();
for (Runnable task : tasks) {
futures.add(execute(task, "pool-task-" + task.hashCode()));
futures.add(execute(task, "pool-task-" + THREAD_COUNTER.incrementAndGet()));
}
return futures;
}
/**
* Await the completion of a single task.
*
* @param task the task to execute
* @return a Future representing the pending completion of the task
*/
public static Future<?> awaitTask(Runnable task) {
return awaitTask(task, null);
}
/**
* Await the completion of a single task with an optional timeout.
*
* @param task the task to execute
* @param timeoutMs the timeout in milliseconds
* @return a Future representing the pending completion of the task
*/
public static Future<?> awaitTask(Runnable task, Long timeoutMs) {
return awaitTasks(Arrays.asList(task), 1, timeoutMs).get(0);
}
/**
* Await the completion of a collection of tasks.
*
* @param tasks the tasks to execute
* @return a list of Futures representing the pending completion of the tasks
*/
public static List<Future<?>> awaitTasks(Collection<Runnable> tasks) {
return awaitTasks(tasks, tasks.size());
}
/**
* Await the completion of a collection of tasks with a specified maximum concurrency.
*
* @param tasks the tasks to execute
* @param maxConcurrency the maximum number of concurrent tasks
* @return a list of Futures representing the pending completion of the tasks
*/
public static List<Future<?>> awaitTasks(Collection<Runnable> tasks, int maxConcurrency) {
return awaitTasks(tasks, maxConcurrency, null);
}
/**
* Await the completion of a collection of tasks with a specified maximum concurrency and optional timeout.
*
* @param tasks the tasks to execute
* @param maxConcurrency the maximum number of concurrent tasks
* @param timeoutMs the timeout in milliseconds
* @return a list of Futures representing the pending completion of the tasks
*/
public static List<Future<?>> awaitTasks(Collection<Runnable> tasks, int maxConcurrency, Long timeoutMs) {
if (timeoutMs == null) timeoutMs = Long.MAX_VALUE;
if (timeoutMs == null) timeoutMs = DEFAULT_TIMEOUT_MS;
if (tasks.isEmpty()) return new ArrayList<>();
List<Future<?>> futures = new ArrayList<>();
for (Runnable task : tasks) {
futures.add(execute(task, "await-task-" + task.hashCode()));
}
for (Future<?> future : futures) {
try {
future.get(timeoutMs, TimeUnit.MILLISECONDS);
} catch (InterruptedException | ExecutionException | TimeoutException e) {
throw new RuntimeException(e);
AtomicReference<List<Runnable>> remainingTasks = new AtomicReference<>(new ArrayList<>(tasks));
while (true) {
List<Runnable> batchTasks = remainingTasks.get().subList(0, Math.min(maxConcurrency, remainingTasks.get().size()));
if (batchTasks.isEmpty()) break;
List<Future<?>> batchFutures = new ArrayList<>();
for (Runnable task : batchTasks) {
batchFutures.add(execute(task, "await-task-" + THREAD_COUNTER.incrementAndGet()));
}
for (Future<?> future : batchFutures) {
try {
future.get(timeoutMs, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.error("Interrupted while awaiting task completion - {}", e.getMessage(), e);
throw new RuntimeException("Interrupted while awaiting task completion", e);
} catch (ExecutionException e) {
logger.error("Execution exception while awaiting task completion - {}", e.getMessage(), e);
throw new RuntimeException("Execution exception while awaiting task completion", e);
} catch (TimeoutException e) {
logger.error("Timeout while awaiting task completion - {}", e.getMessage(), e);
throw new RuntimeException("Timeout while awaiting task completion", e);
}
}
futures.addAll(batchFutures);
remainingTasks.set(remainingTasks.get().subList(Math.min(maxConcurrency, remainingTasks.get().size()), remainingTasks.get().size()));
}
return futures;
}
}