diff --git a/proxy/src/main/java/com/velocitypowered/proxy/connection/backend/BackendPlaySessionHandler.java b/proxy/src/main/java/com/velocitypowered/proxy/connection/backend/BackendPlaySessionHandler.java index dad35321a..e6d75c1e6 100644 --- a/proxy/src/main/java/com/velocitypowered/proxy/connection/backend/BackendPlaySessionHandler.java +++ b/proxy/src/main/java/com/velocitypowered/proxy/connection/backend/BackendPlaySessionHandler.java @@ -86,7 +86,7 @@ public class BackendPlaySessionHandler implements MinecraftSessionHandler { @Override public boolean handle(PluginMessage packet) { - if (!canForwardPluginMessage(packet)) { + if (!serverConn.getPlayer().canForwardPluginMessage(packet)) { return true; } @@ -177,22 +177,4 @@ public class BackendPlaySessionHandler implements MinecraftSessionHandler { serverConn.getPlayer().disconnect(ConnectionMessages.UNEXPECTED_DISCONNECT); } } - - private boolean canForwardPluginMessage(PluginMessage message) { - MinecraftConnection mc = serverConn.getConnection(); - if (mc == null) { - return false; - } - boolean minecraftOrFmlMessage; - if (mc.getProtocolVersion().compareTo(ProtocolVersion.MINECRAFT_1_12_2) <= 0) { - String channel = message.getChannel(); - minecraftOrFmlMessage = channel.startsWith("MC|") || channel - .startsWith(LegacyForgeConstants.FORGE_LEGACY_HANDSHAKE_CHANNEL); - } else { - minecraftOrFmlMessage = message.getChannel().startsWith("minecraft:"); - } - return minecraftOrFmlMessage - || playerSessionHandler.getKnownChannels().contains(message.getChannel()) - || server.getChannelRegistrar().registered(message.getChannel()); - } } diff --git a/proxy/src/main/java/com/velocitypowered/proxy/connection/backend/TransitionSessionHandler.java b/proxy/src/main/java/com/velocitypowered/proxy/connection/backend/TransitionSessionHandler.java index 5ba01172e..302c70449 100644 --- a/proxy/src/main/java/com/velocitypowered/proxy/connection/backend/TransitionSessionHandler.java +++ b/proxy/src/main/java/com/velocitypowered/proxy/connection/backend/TransitionSessionHandler.java @@ -130,7 +130,7 @@ public class TransitionSessionHandler implements MinecraftSessionHandler { @Override public boolean handle(PluginMessage packet) { - if (!canForwardPluginMessage(packet)) { + if (!serverConn.getPlayer().canForwardPluginMessage(packet)) { return true; } @@ -160,35 +160,4 @@ public class TransitionSessionHandler implements MinecraftSessionHandler { resultFuture .completeExceptionally(new IOException("Unexpectedly disconnected from remote server")); } - - private Collection getClientKnownPluginChannels() { - MinecraftSessionHandler handler = serverConn.getPlayer().getMinecraftConnection() - .getSessionHandler(); - - if (handler instanceof InitialConnectSessionHandler) { - return ((InitialConnectSessionHandler) handler).getKnownChannels(); - } else if (handler instanceof ClientPlaySessionHandler) { - return ((ClientPlaySessionHandler) handler).getKnownChannels(); - } else { - return ImmutableList.of(); - } - } - - private boolean canForwardPluginMessage(PluginMessage message) { - MinecraftConnection mc = serverConn.getConnection(); - if (mc == null) { - return false; - } - boolean minecraftOrFmlMessage; - if (mc.getProtocolVersion().compareTo(ProtocolVersion.MINECRAFT_1_12_2) <= 0) { - String channel = message.getChannel(); - minecraftOrFmlMessage = channel.startsWith("MC|") || channel - .startsWith(LegacyForgeConstants.FORGE_LEGACY_HANDSHAKE_CHANNEL); - } else { - minecraftOrFmlMessage = message.getChannel().startsWith("minecraft:"); - } - return minecraftOrFmlMessage - || server.getChannelRegistrar().registered(message.getChannel()) - || getClientKnownPluginChannels().contains(message.getChannel()); - } } diff --git a/proxy/src/main/java/com/velocitypowered/proxy/connection/client/ClientPlaySessionHandler.java b/proxy/src/main/java/com/velocitypowered/proxy/connection/client/ClientPlaySessionHandler.java index 45b8674ee..0fbc0cbc6 100644 --- a/proxy/src/main/java/com/velocitypowered/proxy/connection/client/ClientPlaySessionHandler.java +++ b/proxy/src/main/java/com/velocitypowered/proxy/connection/client/ClientPlaySessionHandler.java @@ -50,12 +50,10 @@ import org.checkerframework.checker.nullness.qual.Nullable; public class ClientPlaySessionHandler implements MinecraftSessionHandler { private static final Logger logger = LogManager.getLogger(ClientPlaySessionHandler.class); - static final int MAX_PLUGIN_CHANNELS = 1024; private final ConnectedPlayer player; private boolean spawned = false; private final List serverBossBars = new ArrayList<>(); - private final Set knownChannels = new HashSet<>(); private final Queue loginPluginMessages = new ArrayDeque<>(); private final VelocityServer server; private @Nullable TabCompleteRequest legacyCommandTabComplete; @@ -68,19 +66,16 @@ public class ClientPlaySessionHandler implements MinecraftSessionHandler { public ClientPlaySessionHandler(VelocityServer server, ConnectedPlayer player) { this.player = player; this.server = server; - - if (player.getMinecraftConnection().getSessionHandler() - instanceof InitialConnectSessionHandler) { - this.knownChannels.addAll(((InitialConnectSessionHandler) player.getMinecraftConnection() - .getSessionHandler()).getKnownChannels()); - } } @Override public void activated() { + Collection channels = server.getChannelRegistrar().getChannelsForProtocol(player + .getProtocolVersion()); PluginMessage register = PluginMessageUtil.constructChannelsPacket(player.getProtocolVersion(), - server.getChannelRegistrar().getChannelsForProtocol(player.getProtocolVersion())); + channels); player.getMinecraftConnection().write(register); + player.getKnownChannels().addAll(channels); } @Override @@ -215,25 +210,10 @@ public class ClientPlaySessionHandler implements MinecraftSessionHandler { logger.warn("A plugin message was received while the backend server was not " + "ready. Channel: {}. Packet discarded.", packet.getChannel()); } else if (PluginMessageUtil.isRegister(packet)) { - List actuallyRegistered = new ArrayList<>(); - List channels = PluginMessageUtil.getChannels(packet); - for (String channel : channels) { - if (knownChannels.size() >= MAX_PLUGIN_CHANNELS && !knownChannels.contains(channel)) { - throw new IllegalStateException("Too many plugin message channels registered"); - } - if (knownChannels.add(channel)) { - actuallyRegistered.add(channel); - } - } - - if (!actuallyRegistered.isEmpty()) { - PluginMessage newRegisterPacket = PluginMessageUtil.constructChannelsPacket(backendConn - .getProtocolVersion(), actuallyRegistered); - backendConn.write(newRegisterPacket); - } + player.getKnownChannels().addAll(PluginMessageUtil.getChannels(packet)); + backendConn.write(packet); } else if (PluginMessageUtil.isUnregister(packet)) { - List channels = PluginMessageUtil.getChannels(packet); - knownChannels.removeAll(channels); + player.getKnownChannels().removeAll(PluginMessageUtil.getChannels(packet)); backendConn.write(packet); } else if (PluginMessageUtil.isMcBrand(packet)) { backendConn.write(PluginMessageUtil.rewriteMinecraftBrand(packet, server.getVersion())); @@ -385,14 +365,9 @@ public class ClientPlaySessionHandler implements MinecraftSessionHandler { // Tell the server about this client's plugin message channels. ProtocolVersion serverVersion = serverMc.getProtocolVersion(); - Collection toRegister = new HashSet<>(knownChannels); - if (serverVersion.compareTo(MINECRAFT_1_13) >= 0) { - toRegister.addAll(server.getChannelRegistrar().getModernChannelIds()); - } else { - toRegister.addAll(server.getChannelRegistrar().getIdsForLegacyConnections()); - } - if (!toRegister.isEmpty()) { - serverMc.delayedWrite(PluginMessageUtil.constructChannelsPacket(serverVersion, toRegister)); + if (!player.getKnownChannels().isEmpty()) { + serverMc.delayedWrite(PluginMessageUtil.constructChannelsPacket(serverVersion, + player.getKnownChannels())); } // If we had plugin messages queued during login/FML handshake, send them now. @@ -415,10 +390,6 @@ public class ClientPlaySessionHandler implements MinecraftSessionHandler { return serverBossBars; } - public Set getKnownChannels() { - return knownChannels; - } - /** * Handles additional tab complete for 1.12 and lower clients. * diff --git a/proxy/src/main/java/com/velocitypowered/proxy/connection/client/ConnectedPlayer.java b/proxy/src/main/java/com/velocitypowered/proxy/connection/client/ConnectedPlayer.java index 68bfd0798..837217778 100644 --- a/proxy/src/main/java/com/velocitypowered/proxy/connection/client/ConnectedPlayer.java +++ b/proxy/src/main/java/com/velocitypowered/proxy/connection/client/ConnectedPlayer.java @@ -31,6 +31,7 @@ import com.velocitypowered.proxy.VelocityServer; import com.velocitypowered.proxy.connection.MinecraftConnection; import com.velocitypowered.proxy.connection.MinecraftConnectionAssociation; import com.velocitypowered.proxy.connection.backend.VelocityServerConnection; +import com.velocitypowered.proxy.connection.forge.legacy.LegacyForgeConstants; import com.velocitypowered.proxy.connection.util.ConnectionMessages; import com.velocitypowered.proxy.connection.util.ConnectionRequestResults; import com.velocitypowered.proxy.connection.util.ConnectionRequestResults.Impl; @@ -45,12 +46,10 @@ import com.velocitypowered.proxy.protocol.packet.TitlePacket; import com.velocitypowered.proxy.server.VelocityRegisteredServer; import com.velocitypowered.proxy.tablist.VelocityTabList; import com.velocitypowered.proxy.util.VelocityMessages; +import com.velocitypowered.proxy.util.collect.CappedCollection; import io.netty.buffer.ByteBufUtil; import java.net.InetSocketAddress; -import java.util.Collections; -import java.util.List; -import java.util.Optional; -import java.util.UUID; +import java.util.*; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.ThreadLocalRandom; @@ -68,12 +67,16 @@ import org.checkerframework.checker.nullness.qual.Nullable; public class ConnectedPlayer implements MinecraftConnectionAssociation, Player { + private static final int MAX_PLUGIN_CHANNELS = 1024; private static final PlainComponentSerializer PASS_THRU_TRANSLATE = new PlainComponentSerializer( c -> "", TranslatableComponent::key); static final PermissionProvider DEFAULT_PERMISSIONS = s -> PermissionFunction.ALWAYS_UNDEFINED; private static final Logger logger = LogManager.getLogger(ConnectedPlayer.class); + /** + * The actual Minecraft connection. This is actually a wrapper object around the Netty channel. + */ private final MinecraftConnection minecraftConnection; private final @Nullable InetSocketAddress virtualHost; private GameProfile profile; @@ -87,6 +90,7 @@ public class ConnectedPlayer implements MinecraftConnectionAssociation, Player { private final VelocityTabList tabList; private final VelocityServer server; private ClientConnectionPhase connectionPhase; + private final Collection knownChannels; private @MonotonicNonNull List serversToTry = null; @@ -99,6 +103,7 @@ public class ConnectedPlayer implements MinecraftConnectionAssociation, Player { this.virtualHost = virtualHost; this.permissionFunction = PermissionFunction.ALWAYS_UNDEFINED; this.connectionPhase = minecraftConnection.getType().getInitialClientPhase(); + this.knownChannels = CappedCollection.newCappedSet(MAX_PLUGIN_CHANNELS); } @Override @@ -631,6 +636,36 @@ public class ConnectedPlayer implements MinecraftConnectionAssociation, Player { this.connectionPhase = connectionPhase; } + /** + * Return all the plugin message channels "known" to the client. + * @return the channels + */ + public Collection getKnownChannels() { + return knownChannels; + } + + /** + * Determines whether or not we can forward a plugin message onto the client. + * @param message the plugin message to forward to the client + * @return {@code true} if the message can be forwarded, {@code false} otherwise + */ + public boolean canForwardPluginMessage(PluginMessage message) { + // If we're forwarding a plugin message onto the client, that implies that we have a backend connection + // already. + MinecraftConnection mc = ensureBackendConnection(); + + boolean minecraftOrFmlMessage; + if (mc.getProtocolVersion().compareTo(ProtocolVersion.MINECRAFT_1_12_2) <= 0) { + String channel = message.getChannel(); + minecraftOrFmlMessage = channel.startsWith("MC|") || channel + .startsWith(LegacyForgeConstants.FORGE_LEGACY_HANDSHAKE_CHANNEL); + } else { + minecraftOrFmlMessage = message.getChannel().startsWith("minecraft:"); + } + return minecraftOrFmlMessage || knownChannels.contains(message.getChannel()) + || server.getChannelRegistrar().registered(message.getChannel()); + } + private class ConnectionRequestBuilderImpl implements ConnectionRequestBuilder { private final RegisteredServer toConnect; diff --git a/proxy/src/main/java/com/velocitypowered/proxy/connection/client/InitialConnectSessionHandler.java b/proxy/src/main/java/com/velocitypowered/proxy/connection/client/InitialConnectSessionHandler.java index c681f5d2b..59c2e9255 100644 --- a/proxy/src/main/java/com/velocitypowered/proxy/connection/client/InitialConnectSessionHandler.java +++ b/proxy/src/main/java/com/velocitypowered/proxy/connection/client/InitialConnectSessionHandler.java @@ -1,24 +1,16 @@ package com.velocitypowered.proxy.connection.client; -import static com.velocitypowered.proxy.connection.client.ClientPlaySessionHandler.MAX_PLUGIN_CHANNELS; - import com.velocitypowered.proxy.connection.MinecraftSessionHandler; import com.velocitypowered.proxy.connection.backend.VelocityServerConnection; import com.velocitypowered.proxy.protocol.packet.PluginMessage; import com.velocitypowered.proxy.protocol.util.PluginMessageUtil; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; public class InitialConnectSessionHandler implements MinecraftSessionHandler { private final ConnectedPlayer player; - private final Set knownChannels; InitialConnectSessionHandler(ConnectedPlayer player) { this.player = player; - this.knownChannels = new HashSet<>(); } @Override @@ -30,29 +22,11 @@ public class InitialConnectSessionHandler implements MinecraftSessionHandler { } if (PluginMessageUtil.isRegister(packet)) { - List actuallyRegistered = new ArrayList<>(); - List channels = PluginMessageUtil.getChannels(packet); - for (String channel : channels) { - if (knownChannels.size() >= MAX_PLUGIN_CHANNELS && !knownChannels.contains(channel)) { - throw new IllegalStateException("Too many plugin message channels registered"); - } - if (knownChannels.add(channel)) { - actuallyRegistered.add(channel); - } - } - - if (!actuallyRegistered.isEmpty()) { - PluginMessage newRegisterPacket = PluginMessageUtil.constructChannelsPacket(serverConn - .ensureConnected().getProtocolVersion(), actuallyRegistered); - serverConn.ensureConnected().write(newRegisterPacket); - } + player.getKnownChannels().addAll(PluginMessageUtil.getChannels(packet)); } else if (PluginMessageUtil.isUnregister(packet)) { - List channels = PluginMessageUtil.getChannels(packet); - knownChannels.removeAll(channels); - serverConn.ensureConnected().write(packet); - } else { - serverConn.ensureConnected().write(packet); + player.getKnownChannels().removeAll(PluginMessageUtil.getChannels(packet)); } + serverConn.ensureConnected().write(packet); } return true; } @@ -62,8 +36,4 @@ public class InitialConnectSessionHandler implements MinecraftSessionHandler { // the user cancelled the login process player.teardown(); } - - public Set getKnownChannels() { - return knownChannels; - } } diff --git a/proxy/src/main/java/com/velocitypowered/proxy/util/collect/CappedCollection.java b/proxy/src/main/java/com/velocitypowered/proxy/util/collect/CappedCollection.java new file mode 100644 index 000000000..64f5e92a5 --- /dev/null +++ b/proxy/src/main/java/com/velocitypowered/proxy/util/collect/CappedCollection.java @@ -0,0 +1,51 @@ +package com.velocitypowered.proxy.util.collect; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ForwardingCollection; + +import java.util.Collection; +import java.util.HashSet; + +/** + * An unsynchronized collection that puts an upper bound on the size of the collection. + */ +public class CappedCollection extends ForwardingCollection { + + private final Collection delegate; + private final int upperSize; + + private CappedCollection(Collection delegate, int upperSize) { + this.delegate = delegate; + this.upperSize = upperSize; + } + + /** + * Creates a capped collection backed by a {@link HashSet}. + * @param maxSize the maximum size of the collection + * @param the type of elements in the collection + * @return the new collection + */ + public static Collection newCappedSet(int maxSize) { + return new CappedCollection<>(new HashSet<>(), maxSize); + } + + @Override + protected Collection delegate() { + return delegate; + } + + @Override + public boolean add(T element) { + Preconditions.checkState(this.delegate.size() + 1 <= upperSize, "collection is too large (%s)", + this.delegate.size()); + return this.delegate.add(element); + } + + @Override + public boolean addAll(Collection collection) { + Preconditions.checkState(this.delegate.size() + collection.size() <= upperSize, + "collection would grow too large (%s + %s > %s)", + this.delegate.size(), collection.size(), upperSize); + return this.standardAddAll(collection); + } +} diff --git a/proxy/src/test/java/com/velocitypowered/proxy/util/collect/CappedCollectionTest.java b/proxy/src/test/java/com/velocitypowered/proxy/util/collect/CappedCollectionTest.java new file mode 100644 index 000000000..634cadf2c --- /dev/null +++ b/proxy/src/test/java/com/velocitypowered/proxy/util/collect/CappedCollectionTest.java @@ -0,0 +1,35 @@ +package com.velocitypowered.proxy.util.collect; + +import static org.junit.jupiter.api.Assertions.*; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import java.util.Collection; +import java.util.Set; +import org.junit.jupiter.api.Test; + +class CappedCollectionTest { + + @Test + void basicVerification() { + Collection coll = CappedCollection.newCappedSet(1); + assertTrue(coll.add("coffee"), "did not add single item"); + assertThrows(IllegalStateException.class, () -> coll.add("tea"), + "item was added to collection although it is too full"); + assertEquals(1, coll.size(), "collection grew in size unexpectedly"); + } + + @Test + void testAddAll() { + Set doesFill1 = ImmutableSet.of("coffee", "tea"); + Set doesFill2 = ImmutableSet.of("chocolate"); + Set overfill = ImmutableSet.of("Coke", "Pepsi"); + + Collection coll = CappedCollection.newCappedSet(3); + assertTrue(coll.addAll(doesFill1), "did not add items"); + assertTrue(coll.addAll(doesFill2), "did not add items"); + assertThrows(IllegalStateException.class, () -> coll.addAll(overfill), + "items added to collection although it is too full"); + assertEquals(3, coll.size(), "collection grew in size unexpectedly"); + } +} \ No newline at end of file