diff --git a/proxy/src/main/java/com/velocitypowered/proxy/connection/backend/BackendPlaySessionHandler.java b/proxy/src/main/java/com/velocitypowered/proxy/connection/backend/BackendPlaySessionHandler.java index dad35321a..a2afbae6a 100644 --- a/proxy/src/main/java/com/velocitypowered/proxy/connection/backend/BackendPlaySessionHandler.java +++ b/proxy/src/main/java/com/velocitypowered/proxy/connection/backend/BackendPlaySessionHandler.java @@ -86,10 +86,19 @@ public class BackendPlaySessionHandler implements MinecraftSessionHandler { @Override public boolean handle(PluginMessage packet) { - if (!canForwardPluginMessage(packet)) { + if (!serverConn.getPlayer().canForwardPluginMessage(serverConn.ensureConnected() + .getProtocolVersion(), packet)) { return true; } + // We need to specially handle REGISTER and UNREGISTER packets. Later on, we'll write them to + // the client. + if (PluginMessageUtil.isRegister(packet)) { + serverConn.getPlayer().getKnownChannels().addAll(PluginMessageUtil.getChannels(packet)); + } else if (PluginMessageUtil.isUnregister(packet)) { + serverConn.getPlayer().getKnownChannels().removeAll(PluginMessageUtil.getChannels(packet)); + } + if (PluginMessageUtil.isMcBrand(packet)) { PluginMessage rewritten = PluginMessageUtil.rewriteMinecraftBrand(packet, server.getVersion()); @@ -177,22 +186,4 @@ public class BackendPlaySessionHandler implements MinecraftSessionHandler { serverConn.getPlayer().disconnect(ConnectionMessages.UNEXPECTED_DISCONNECT); } } - - private boolean canForwardPluginMessage(PluginMessage message) { - MinecraftConnection mc = serverConn.getConnection(); - if (mc == null) { - return false; - } - boolean minecraftOrFmlMessage; - if (mc.getProtocolVersion().compareTo(ProtocolVersion.MINECRAFT_1_12_2) <= 0) { - String channel = message.getChannel(); - minecraftOrFmlMessage = channel.startsWith("MC|") || channel - .startsWith(LegacyForgeConstants.FORGE_LEGACY_HANDSHAKE_CHANNEL); - } else { - minecraftOrFmlMessage = message.getChannel().startsWith("minecraft:"); - } - return minecraftOrFmlMessage - || playerSessionHandler.getKnownChannels().contains(message.getChannel()) - || server.getChannelRegistrar().registered(message.getChannel()); - } } diff --git a/proxy/src/main/java/com/velocitypowered/proxy/connection/backend/TransitionSessionHandler.java b/proxy/src/main/java/com/velocitypowered/proxy/connection/backend/TransitionSessionHandler.java index 5ba01172e..b38d96ff8 100644 --- a/proxy/src/main/java/com/velocitypowered/proxy/connection/backend/TransitionSessionHandler.java +++ b/proxy/src/main/java/com/velocitypowered/proxy/connection/backend/TransitionSessionHandler.java @@ -19,6 +19,7 @@ import com.velocitypowered.proxy.protocol.packet.Disconnect; import com.velocitypowered.proxy.protocol.packet.JoinGame; import com.velocitypowered.proxy.protocol.packet.KeepAlive; import com.velocitypowered.proxy.protocol.packet.PluginMessage; +import com.velocitypowered.proxy.protocol.util.PluginMessageUtil; import java.io.IOException; import java.util.Collection; import java.util.concurrent.CompletableFuture; @@ -130,10 +131,17 @@ public class TransitionSessionHandler implements MinecraftSessionHandler { @Override public boolean handle(PluginMessage packet) { - if (!canForwardPluginMessage(packet)) { + if (!serverConn.getPlayer().canForwardPluginMessage(serverConn.ensureConnected() + .getProtocolVersion(), packet)) { return true; } + if (PluginMessageUtil.isRegister(packet)) { + serverConn.getPlayer().getKnownChannels().addAll(PluginMessageUtil.getChannels(packet)); + } else if (PluginMessageUtil.isUnregister(packet)) { + serverConn.getPlayer().getKnownChannels().removeAll(PluginMessageUtil.getChannels(packet)); + } + // We always need to handle plugin messages, for Forge compatibility. if (serverConn.getPhase().handle(serverConn, serverConn.getPlayer(), packet)) { // Handled, but check the server connection phase. @@ -160,35 +168,4 @@ public class TransitionSessionHandler implements MinecraftSessionHandler { resultFuture .completeExceptionally(new IOException("Unexpectedly disconnected from remote server")); } - - private Collection getClientKnownPluginChannels() { - MinecraftSessionHandler handler = serverConn.getPlayer().getMinecraftConnection() - .getSessionHandler(); - - if (handler instanceof InitialConnectSessionHandler) { - return ((InitialConnectSessionHandler) handler).getKnownChannels(); - } else if (handler instanceof ClientPlaySessionHandler) { - return ((ClientPlaySessionHandler) handler).getKnownChannels(); - } else { - return ImmutableList.of(); - } - } - - private boolean canForwardPluginMessage(PluginMessage message) { - MinecraftConnection mc = serverConn.getConnection(); - if (mc == null) { - return false; - } - boolean minecraftOrFmlMessage; - if (mc.getProtocolVersion().compareTo(ProtocolVersion.MINECRAFT_1_12_2) <= 0) { - String channel = message.getChannel(); - minecraftOrFmlMessage = channel.startsWith("MC|") || channel - .startsWith(LegacyForgeConstants.FORGE_LEGACY_HANDSHAKE_CHANNEL); - } else { - minecraftOrFmlMessage = message.getChannel().startsWith("minecraft:"); - } - return minecraftOrFmlMessage - || server.getChannelRegistrar().registered(message.getChannel()) - || getClientKnownPluginChannels().contains(message.getChannel()); - } } diff --git a/proxy/src/main/java/com/velocitypowered/proxy/connection/client/ClientPlaySessionHandler.java b/proxy/src/main/java/com/velocitypowered/proxy/connection/client/ClientPlaySessionHandler.java index 45b8674ee..0fbc0cbc6 100644 --- a/proxy/src/main/java/com/velocitypowered/proxy/connection/client/ClientPlaySessionHandler.java +++ b/proxy/src/main/java/com/velocitypowered/proxy/connection/client/ClientPlaySessionHandler.java @@ -50,12 +50,10 @@ import org.checkerframework.checker.nullness.qual.Nullable; public class ClientPlaySessionHandler implements MinecraftSessionHandler { private static final Logger logger = LogManager.getLogger(ClientPlaySessionHandler.class); - static final int MAX_PLUGIN_CHANNELS = 1024; private final ConnectedPlayer player; private boolean spawned = false; private final List serverBossBars = new ArrayList<>(); - private final Set knownChannels = new HashSet<>(); private final Queue loginPluginMessages = new ArrayDeque<>(); private final VelocityServer server; private @Nullable TabCompleteRequest legacyCommandTabComplete; @@ -68,19 +66,16 @@ public class ClientPlaySessionHandler implements MinecraftSessionHandler { public ClientPlaySessionHandler(VelocityServer server, ConnectedPlayer player) { this.player = player; this.server = server; - - if (player.getMinecraftConnection().getSessionHandler() - instanceof InitialConnectSessionHandler) { - this.knownChannels.addAll(((InitialConnectSessionHandler) player.getMinecraftConnection() - .getSessionHandler()).getKnownChannels()); - } } @Override public void activated() { + Collection channels = server.getChannelRegistrar().getChannelsForProtocol(player + .getProtocolVersion()); PluginMessage register = PluginMessageUtil.constructChannelsPacket(player.getProtocolVersion(), - server.getChannelRegistrar().getChannelsForProtocol(player.getProtocolVersion())); + channels); player.getMinecraftConnection().write(register); + player.getKnownChannels().addAll(channels); } @Override @@ -215,25 +210,10 @@ public class ClientPlaySessionHandler implements MinecraftSessionHandler { logger.warn("A plugin message was received while the backend server was not " + "ready. Channel: {}. Packet discarded.", packet.getChannel()); } else if (PluginMessageUtil.isRegister(packet)) { - List actuallyRegistered = new ArrayList<>(); - List channels = PluginMessageUtil.getChannels(packet); - for (String channel : channels) { - if (knownChannels.size() >= MAX_PLUGIN_CHANNELS && !knownChannels.contains(channel)) { - throw new IllegalStateException("Too many plugin message channels registered"); - } - if (knownChannels.add(channel)) { - actuallyRegistered.add(channel); - } - } - - if (!actuallyRegistered.isEmpty()) { - PluginMessage newRegisterPacket = PluginMessageUtil.constructChannelsPacket(backendConn - .getProtocolVersion(), actuallyRegistered); - backendConn.write(newRegisterPacket); - } + player.getKnownChannels().addAll(PluginMessageUtil.getChannels(packet)); + backendConn.write(packet); } else if (PluginMessageUtil.isUnregister(packet)) { - List channels = PluginMessageUtil.getChannels(packet); - knownChannels.removeAll(channels); + player.getKnownChannels().removeAll(PluginMessageUtil.getChannels(packet)); backendConn.write(packet); } else if (PluginMessageUtil.isMcBrand(packet)) { backendConn.write(PluginMessageUtil.rewriteMinecraftBrand(packet, server.getVersion())); @@ -385,14 +365,9 @@ public class ClientPlaySessionHandler implements MinecraftSessionHandler { // Tell the server about this client's plugin message channels. ProtocolVersion serverVersion = serverMc.getProtocolVersion(); - Collection toRegister = new HashSet<>(knownChannels); - if (serverVersion.compareTo(MINECRAFT_1_13) >= 0) { - toRegister.addAll(server.getChannelRegistrar().getModernChannelIds()); - } else { - toRegister.addAll(server.getChannelRegistrar().getIdsForLegacyConnections()); - } - if (!toRegister.isEmpty()) { - serverMc.delayedWrite(PluginMessageUtil.constructChannelsPacket(serverVersion, toRegister)); + if (!player.getKnownChannels().isEmpty()) { + serverMc.delayedWrite(PluginMessageUtil.constructChannelsPacket(serverVersion, + player.getKnownChannels())); } // If we had plugin messages queued during login/FML handshake, send them now. @@ -415,10 +390,6 @@ public class ClientPlaySessionHandler implements MinecraftSessionHandler { return serverBossBars; } - public Set getKnownChannels() { - return knownChannels; - } - /** * Handles additional tab complete for 1.12 and lower clients. * diff --git a/proxy/src/main/java/com/velocitypowered/proxy/connection/client/ConnectedPlayer.java b/proxy/src/main/java/com/velocitypowered/proxy/connection/client/ConnectedPlayer.java index 68bfd0798..d8c7ee290 100644 --- a/proxy/src/main/java/com/velocitypowered/proxy/connection/client/ConnectedPlayer.java +++ b/proxy/src/main/java/com/velocitypowered/proxy/connection/client/ConnectedPlayer.java @@ -31,6 +31,7 @@ import com.velocitypowered.proxy.VelocityServer; import com.velocitypowered.proxy.connection.MinecraftConnection; import com.velocitypowered.proxy.connection.MinecraftConnectionAssociation; import com.velocitypowered.proxy.connection.backend.VelocityServerConnection; +import com.velocitypowered.proxy.connection.forge.legacy.LegacyForgeConstants; import com.velocitypowered.proxy.connection.util.ConnectionMessages; import com.velocitypowered.proxy.connection.util.ConnectionRequestResults; import com.velocitypowered.proxy.connection.util.ConnectionRequestResults.Impl; @@ -42,11 +43,14 @@ import com.velocitypowered.proxy.protocol.packet.KeepAlive; import com.velocitypowered.proxy.protocol.packet.PluginMessage; import com.velocitypowered.proxy.protocol.packet.ResourcePackRequest; import com.velocitypowered.proxy.protocol.packet.TitlePacket; +import com.velocitypowered.proxy.protocol.util.PluginMessageUtil; import com.velocitypowered.proxy.server.VelocityRegisteredServer; import com.velocitypowered.proxy.tablist.VelocityTabList; import com.velocitypowered.proxy.util.VelocityMessages; +import com.velocitypowered.proxy.util.collect.CappedSet; import io.netty.buffer.ByteBufUtil; import java.net.InetSocketAddress; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Optional; @@ -68,12 +72,16 @@ import org.checkerframework.checker.nullness.qual.Nullable; public class ConnectedPlayer implements MinecraftConnectionAssociation, Player { + private static final int MAX_PLUGIN_CHANNELS = 1024; private static final PlainComponentSerializer PASS_THRU_TRANSLATE = new PlainComponentSerializer( c -> "", TranslatableComponent::key); static final PermissionProvider DEFAULT_PERMISSIONS = s -> PermissionFunction.ALWAYS_UNDEFINED; private static final Logger logger = LogManager.getLogger(ConnectedPlayer.class); + /** + * The actual Minecraft connection. This is actually a wrapper object around the Netty channel. + */ private final MinecraftConnection minecraftConnection; private final @Nullable InetSocketAddress virtualHost; private GameProfile profile; @@ -87,6 +95,7 @@ public class ConnectedPlayer implements MinecraftConnectionAssociation, Player { private final VelocityTabList tabList; private final VelocityServer server; private ClientConnectionPhase connectionPhase; + private final Collection knownChannels; private @MonotonicNonNull List serversToTry = null; @@ -99,6 +108,7 @@ public class ConnectedPlayer implements MinecraftConnectionAssociation, Player { this.virtualHost = virtualHost; this.permissionFunction = PermissionFunction.ALWAYS_UNDEFINED; this.connectionPhase = minecraftConnection.getType().getInitialClientPhase(); + this.knownChannels = CappedSet.create(MAX_PLUGIN_CHANNELS); } @Override @@ -631,6 +641,44 @@ public class ConnectedPlayer implements MinecraftConnectionAssociation, Player { this.connectionPhase = connectionPhase; } + /** + * Return all the plugin message channels "known" to the client. + * @return the channels + */ + public Collection getKnownChannels() { + return knownChannels; + } + + /** + * Determines whether or not we can forward a plugin message onto the client. + * @param version the Minecraft protocol version + * @param message the plugin message to forward to the client + * @return {@code true} if the message can be forwarded, {@code false} otherwise + */ + public boolean canForwardPluginMessage(ProtocolVersion version, PluginMessage message) { + boolean minecraftOrFmlMessage; + + // We should _always_ pass on new channels the server wishes to register (or unregister) with + // us. + if (PluginMessageUtil.isRegister(message) || PluginMessageUtil.isUnregister(message)) { + return true; + } + + // By default, all internal Minecraft and Forge channels are forwarded from the server. + if (version.compareTo(ProtocolVersion.MINECRAFT_1_12_2) <= 0) { + String channel = message.getChannel(); + minecraftOrFmlMessage = channel.startsWith("MC|") || channel + .startsWith(LegacyForgeConstants.FORGE_LEGACY_HANDSHAKE_CHANNEL); + } else { + minecraftOrFmlMessage = message.getChannel().startsWith("minecraft:"); + } + + // Otherwise, we need to see if the player already knows this channel or it's known by the + // proxy. + return minecraftOrFmlMessage || knownChannels.contains(message.getChannel()) + || server.getChannelRegistrar().registered(message.getChannel()); + } + private class ConnectionRequestBuilderImpl implements ConnectionRequestBuilder { private final RegisteredServer toConnect; diff --git a/proxy/src/main/java/com/velocitypowered/proxy/connection/client/InitialConnectSessionHandler.java b/proxy/src/main/java/com/velocitypowered/proxy/connection/client/InitialConnectSessionHandler.java index c681f5d2b..59c2e9255 100644 --- a/proxy/src/main/java/com/velocitypowered/proxy/connection/client/InitialConnectSessionHandler.java +++ b/proxy/src/main/java/com/velocitypowered/proxy/connection/client/InitialConnectSessionHandler.java @@ -1,24 +1,16 @@ package com.velocitypowered.proxy.connection.client; -import static com.velocitypowered.proxy.connection.client.ClientPlaySessionHandler.MAX_PLUGIN_CHANNELS; - import com.velocitypowered.proxy.connection.MinecraftSessionHandler; import com.velocitypowered.proxy.connection.backend.VelocityServerConnection; import com.velocitypowered.proxy.protocol.packet.PluginMessage; import com.velocitypowered.proxy.protocol.util.PluginMessageUtil; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; public class InitialConnectSessionHandler implements MinecraftSessionHandler { private final ConnectedPlayer player; - private final Set knownChannels; InitialConnectSessionHandler(ConnectedPlayer player) { this.player = player; - this.knownChannels = new HashSet<>(); } @Override @@ -30,29 +22,11 @@ public class InitialConnectSessionHandler implements MinecraftSessionHandler { } if (PluginMessageUtil.isRegister(packet)) { - List actuallyRegistered = new ArrayList<>(); - List channels = PluginMessageUtil.getChannels(packet); - for (String channel : channels) { - if (knownChannels.size() >= MAX_PLUGIN_CHANNELS && !knownChannels.contains(channel)) { - throw new IllegalStateException("Too many plugin message channels registered"); - } - if (knownChannels.add(channel)) { - actuallyRegistered.add(channel); - } - } - - if (!actuallyRegistered.isEmpty()) { - PluginMessage newRegisterPacket = PluginMessageUtil.constructChannelsPacket(serverConn - .ensureConnected().getProtocolVersion(), actuallyRegistered); - serverConn.ensureConnected().write(newRegisterPacket); - } + player.getKnownChannels().addAll(PluginMessageUtil.getChannels(packet)); } else if (PluginMessageUtil.isUnregister(packet)) { - List channels = PluginMessageUtil.getChannels(packet); - knownChannels.removeAll(channels); - serverConn.ensureConnected().write(packet); - } else { - serverConn.ensureConnected().write(packet); + player.getKnownChannels().removeAll(PluginMessageUtil.getChannels(packet)); } + serverConn.ensureConnected().write(packet); } return true; } @@ -62,8 +36,4 @@ public class InitialConnectSessionHandler implements MinecraftSessionHandler { // the user cancelled the login process player.teardown(); } - - public Set getKnownChannels() { - return knownChannels; - } } diff --git a/proxy/src/main/java/com/velocitypowered/proxy/network/http/NettyHttpClient.java b/proxy/src/main/java/com/velocitypowered/proxy/network/http/NettyHttpClient.java index fa66b533d..f64639957 100644 --- a/proxy/src/main/java/com/velocitypowered/proxy/network/http/NettyHttpClient.java +++ b/proxy/src/main/java/com/velocitypowered/proxy/network/http/NettyHttpClient.java @@ -15,7 +15,6 @@ import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.SslHandler; import java.net.InetSocketAddress; import java.net.URL; -import java.util.Objects; import java.util.concurrent.CompletableFuture; import javax.net.ssl.SSLEngine; diff --git a/proxy/src/main/java/com/velocitypowered/proxy/protocol/packet/KeepAlive.java b/proxy/src/main/java/com/velocitypowered/proxy/protocol/packet/KeepAlive.java index e4b581617..ac6810a65 100644 --- a/proxy/src/main/java/com/velocitypowered/proxy/protocol/packet/KeepAlive.java +++ b/proxy/src/main/java/com/velocitypowered/proxy/protocol/packet/KeepAlive.java @@ -32,7 +32,7 @@ public class KeepAlive implements MinecraftPacket { } else { randomId = ProtocolUtils.readVarInt(buf); } -} + } @Override public void encode(ByteBuf buf, ProtocolUtils.Direction direction, ProtocolVersion version) { diff --git a/proxy/src/main/java/com/velocitypowered/proxy/util/collect/CappedSet.java b/proxy/src/main/java/com/velocitypowered/proxy/util/collect/CappedSet.java new file mode 100644 index 000000000..a0f775f9e --- /dev/null +++ b/proxy/src/main/java/com/velocitypowered/proxy/util/collect/CappedSet.java @@ -0,0 +1,53 @@ +package com.velocitypowered.proxy.util.collect; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ForwardingSet; + +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; + +/** + * An unsynchronized collection that puts an upper bound on the size of the collection. + */ +public class CappedSet extends ForwardingSet { + + private final Set delegate; + private final int upperSize; + + private CappedSet(Set delegate, int upperSize) { + this.delegate = delegate; + this.upperSize = upperSize; + } + + /** + * Creates a capped collection backed by a {@link HashSet}. + * @param maxSize the maximum size of the collection + * @param the type of elements in the collection + * @return the new collection + */ + public static Set create(int maxSize) { + return new CappedSet<>(new HashSet<>(), maxSize); + } + + @Override + protected Set delegate() { + return delegate; + } + + @Override + public boolean add(T element) { + if (this.delegate.size() >= upperSize) { + Preconditions.checkState(this.delegate.contains(element), + "collection is too large (%s >= %s)", + this.delegate.size(), this.upperSize); + return false; + } + return this.delegate.add(element); + } + + @Override + public boolean addAll(Collection collection) { + return this.standardAddAll(collection); + } +} diff --git a/proxy/src/test/java/com/velocitypowered/proxy/util/collect/CappedSetTest.java b/proxy/src/test/java/com/velocitypowered/proxy/util/collect/CappedSetTest.java new file mode 100644 index 000000000..11df2237d --- /dev/null +++ b/proxy/src/test/java/com/velocitypowered/proxy/util/collect/CappedSetTest.java @@ -0,0 +1,54 @@ +package com.velocitypowered.proxy.util.collect; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.common.collect.ImmutableSet; +import java.util.Collection; +import java.util.Set; +import org.junit.jupiter.api.Test; + +class CappedSetTest { + + @Test + void basicVerification() { + Collection coll = CappedSet.create(1); + assertTrue(coll.add("coffee"), "did not add single item"); + assertThrows(IllegalStateException.class, () -> coll.add("tea"), + "item was added to collection although it is too full"); + assertEquals(1, coll.size(), "collection grew in size unexpectedly"); + } + + @Test + void testAddAll() { + Set doesFill1 = ImmutableSet.of("coffee", "tea"); + Set doesFill2 = ImmutableSet.of("chocolate"); + Set overfill = ImmutableSet.of("Coke", "Pepsi"); + + Collection coll = CappedSet.create(3); + assertTrue(coll.addAll(doesFill1), "did not add items"); + assertTrue(coll.addAll(doesFill2), "did not add items"); + assertThrows(IllegalStateException.class, () -> coll.addAll(overfill), + "items added to collection although it is too full"); + assertEquals(3, coll.size(), "collection grew in size unexpectedly"); + } + + @Test + void handlesSetBehaviorCorrectly() { + Set doesFill1 = ImmutableSet.of("coffee", "tea"); + Set doesFill2 = ImmutableSet.of("coffee", "chocolate"); + Set overfill = ImmutableSet.of("coffee", "Coke", "Pepsi"); + + Collection coll = CappedSet.create(3); + assertTrue(coll.addAll(doesFill1), "did not add items"); + assertTrue(coll.addAll(doesFill2), "did not add items"); + assertThrows(IllegalStateException.class, () -> coll.addAll(overfill), + "items added to collection although it is too full"); + + assertFalse(coll.addAll(doesFill1), "added items?!?"); + + assertEquals(3, coll.size(), "collection grew in size unexpectedly"); + } +} \ No newline at end of file