From 9b4d50b2141cb567201e934b15206a7e6e04a10f Mon Sep 17 00:00:00 2001 From: KennyTV Date: Sun, 21 Mar 2021 20:25:52 +0100 Subject: [PATCH] Tidy up concurrent mapping loading --- .../api/protocol/ProtocolRegistry.java | 109 ++++++++++++------ 1 file changed, 75 insertions(+), 34 deletions(-) diff --git a/common/src/main/java/us/myles/ViaVersion/api/protocol/ProtocolRegistry.java b/common/src/main/java/us/myles/ViaVersion/api/protocol/ProtocolRegistry.java index f1b0633ed..716422b0e 100644 --- a/common/src/main/java/us/myles/ViaVersion/api/protocol/ProtocolRegistry.java +++ b/common/src/main/java/us/myles/ViaVersion/api/protocol/ProtocolRegistry.java @@ -1,5 +1,6 @@ package us.myles.ViaVersion.api.protocol; +import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import com.google.common.collect.Range; import com.google.common.util.concurrent.ThreadFactoryBuilder; @@ -57,6 +58,9 @@ import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.function.Function; public class ProtocolRegistry { public static final Protocol BASE_PROTOCOL = new BaseProtocol(); @@ -70,7 +74,7 @@ public class ProtocolRegistry { private static final List, Protocol>> baseProtocols = Lists.newCopyOnWriteArrayList(); private static final List registerList = new ArrayList<>(); - private static final Object MAPPING_LOADER_LOCK = new Object(); + private static final ReadWriteLock MAPPING_LOADER_LOCK = new ReentrantReadWriteLock(); private static Map, CompletableFuture> mappingLoaderFutures = new HashMap<>(); private static ThreadPoolExecutor mappingLoaderExecutor; private static boolean mappingsLoaded; @@ -353,9 +357,10 @@ public class ProtocolRegistry { if (mappingsLoaded) return; CompletableFuture future = getMappingLoaderFuture(protocolClass); - if (future == null) return; - - future.get(); + if (future != null) { + // Wait for completion + future.get(); + } } /** @@ -364,7 +369,8 @@ public class ProtocolRegistry { * @return true if the executor has now been shut down */ public static boolean checkForMappingCompletion() { - synchronized (MAPPING_LOADER_LOCK) { + MAPPING_LOADER_LOCK.readLock().lock(); + try { if (mappingsLoaded) return false; for (CompletableFuture future : mappingLoaderFutures.values()) { @@ -376,10 +382,68 @@ public class ProtocolRegistry { shutdownLoaderExecutor(); return true; + } finally { + MAPPING_LOADER_LOCK.readLock().unlock(); + } + } + + /** + * Executes the given runnable asynchronously, adding a {@link CompletableFuture} + * to the list of data to load bound to their protocols. + * + * @param protocolClass protocol class + * @param runnable runnable to be executed asynchronously + */ + public static void addMappingLoaderFuture(Class protocolClass, Runnable runnable) { + CompletableFuture future = CompletableFuture.runAsync(runnable, mappingLoaderExecutor).exceptionally(mappingLoaderThrowable(protocolClass)); + + MAPPING_LOADER_LOCK.writeLock().lock(); + try { + mappingLoaderFutures.put(protocolClass, future); + } finally { + MAPPING_LOADER_LOCK.writeLock().unlock(); + } + } + + /** + * Executes the given runnable asynchronously after the other protocol has finished its data loading, + * adding a {@link CompletableFuture} to the list of data to load bound to their protocols. + * + * @param protocolClass protocol class + * @param runnable runnable to be executed asynchronously + */ + public static void addMappingLoaderFuture(Class protocolClass, Class dependsOn, Runnable runnable) { + CompletableFuture future = getMappingLoaderFuture(dependsOn) + .whenCompleteAsync((v, throwable) -> runnable.run(), mappingLoaderExecutor).exceptionally(mappingLoaderThrowable(protocolClass)); + + MAPPING_LOADER_LOCK.writeLock().lock(); + try { + mappingLoaderFutures.put(protocolClass, future); + } finally { + MAPPING_LOADER_LOCK.writeLock().unlock(); + } + } + + /** + * Returns the data loading future bound to the protocol, or null if all loading is complete. + * The future may or may not have already been completed. + * + * @param protocolClass protocol class + * @return data loading future bound to the protocol, or null if all loading is complete + */ + @Nullable + public static CompletableFuture getMappingLoaderFuture(Class protocolClass) { + MAPPING_LOADER_LOCK.readLock().lock(); + try { + return mappingsLoaded ? null : mappingLoaderFutures.get(protocolClass); + } finally { + MAPPING_LOADER_LOCK.readLock().unlock(); } } private static void shutdownLoaderExecutor() { + Preconditions.checkArgument(!mappingsLoaded); + Via.getPlatform().getLogger().info("Finished mapping loading, shutting down loader executor!"); mappingsLoaded = true; mappingLoaderExecutor.shutdown(); @@ -391,34 +455,11 @@ public class ProtocolRegistry { } } - public static void addMappingLoaderFuture(Class protocolClass, Runnable runnable) { - synchronized (MAPPING_LOADER_LOCK) { - CompletableFuture future = CompletableFuture.runAsync(runnable, mappingLoaderExecutor).exceptionally(throwable -> { - Via.getPlatform().getLogger().severe("Error during mapping loading of " + protocolClass.getSimpleName()); - throwable.printStackTrace(); - return null; - }); - mappingLoaderFutures.put(protocolClass, future); - } - } - - public static void addMappingLoaderFuture(Class protocolClass, Class dependsOn, Runnable runnable) { - synchronized (MAPPING_LOADER_LOCK) { - CompletableFuture future = getMappingLoaderFuture(dependsOn) - .whenCompleteAsync((v, throwable) -> runnable.run(), mappingLoaderExecutor).exceptionally(throwable -> { - Via.getPlatform().getLogger().severe("Error during mapping loading of " + protocolClass.getSimpleName()); - throwable.printStackTrace(); - return null; - }); - mappingLoaderFutures.put(protocolClass, future); - } - } - - @Nullable - public static CompletableFuture getMappingLoaderFuture(Class protocolClass) { - synchronized (MAPPING_LOADER_LOCK) { - if (mappingsLoaded) return null; - return mappingLoaderFutures.get(protocolClass); - } + private static Function mappingLoaderThrowable(Class protocolClass) { + return throwable -> { + Via.getPlatform().getLogger().severe("Error during mapping loading of " + protocolClass.getSimpleName()); + throwable.printStackTrace(); + return null; + }; } }