From b95f076562db4792744264e4f9aa77fa586afefa Mon Sep 17 00:00:00 2001 From: kashike Date: Fri, 27 Jul 2018 20:13:23 -0700 Subject: [PATCH] Don't search through protocol versions all the time --- .../network/ConnectionManager.java | 4 +- .../connection/backend/ServerConnection.java | 4 +- .../proxy/protocol/ProtocolConstants.java | 4 +- .../proxy/protocol/StateRegistry.java | 130 +++++++++--------- .../protocol/netty/MinecraftDecoder.java | 11 +- .../protocol/netty/MinecraftEncoder.java | 15 +- .../proxy/protocol/packets/Chat.java | 4 +- .../proxy/protocol/PacketRegistryTest.java | 30 ++-- 8 files changed, 100 insertions(+), 102 deletions(-) diff --git a/src/main/java/com/velocitypowered/network/ConnectionManager.java b/src/main/java/com/velocitypowered/network/ConnectionManager.java index 5814eb428..9b8676db8 100644 --- a/src/main/java/com/velocitypowered/network/ConnectionManager.java +++ b/src/main/java/com/velocitypowered/network/ConnectionManager.java @@ -96,8 +96,8 @@ public final class ConnectionManager { .addLast(FRAME_DECODER, new MinecraftVarintFrameDecoder()) .addLast(LEGACY_PING_ENCODER, LegacyPingEncoder.INSTANCE) .addLast(FRAME_ENCODER, MinecraftVarintLengthEncoder.INSTANCE) - .addLast(MINECRAFT_DECODER, new MinecraftDecoder(ProtocolConstants.Direction.TO_SERVER)) - .addLast(MINECRAFT_ENCODER, new MinecraftEncoder(ProtocolConstants.Direction.TO_CLIENT)); + .addLast(MINECRAFT_DECODER, new MinecraftDecoder(ProtocolConstants.Direction.SERVERBOUND)) + .addLast(MINECRAFT_ENCODER, new MinecraftEncoder(ProtocolConstants.Direction.CLIENTBOUND)); final MinecraftConnection connection = new MinecraftConnection(ch); connection.setState(StateRegistry.HANDSHAKE); diff --git a/src/main/java/com/velocitypowered/proxy/connection/backend/ServerConnection.java b/src/main/java/com/velocitypowered/proxy/connection/backend/ServerConnection.java index 687cc1539..3a5517948 100644 --- a/src/main/java/com/velocitypowered/proxy/connection/backend/ServerConnection.java +++ b/src/main/java/com/velocitypowered/proxy/connection/backend/ServerConnection.java @@ -48,8 +48,8 @@ public class ServerConnection implements MinecraftConnectionAssociation { .addLast(READ_TIMEOUT, new ReadTimeoutHandler(SERVER_READ_TIMEOUT_SECONDS, TimeUnit.SECONDS)) .addLast(FRAME_DECODER, new MinecraftVarintFrameDecoder()) .addLast(FRAME_ENCODER, MinecraftVarintLengthEncoder.INSTANCE) - .addLast(MINECRAFT_DECODER, new MinecraftDecoder(ProtocolConstants.Direction.TO_CLIENT)) - .addLast(MINECRAFT_ENCODER, new MinecraftEncoder(ProtocolConstants.Direction.TO_SERVER)); + .addLast(MINECRAFT_DECODER, new MinecraftDecoder(ProtocolConstants.Direction.CLIENTBOUND)) + .addLast(MINECRAFT_ENCODER, new MinecraftEncoder(ProtocolConstants.Direction.SERVERBOUND)); MinecraftConnection connection = new MinecraftConnection(ch); connection.setState(StateRegistry.HANDSHAKE); diff --git a/src/main/java/com/velocitypowered/proxy/protocol/ProtocolConstants.java b/src/main/java/com/velocitypowered/proxy/protocol/ProtocolConstants.java index 3201ac14d..768c17fb0 100644 --- a/src/main/java/com/velocitypowered/proxy/protocol/ProtocolConstants.java +++ b/src/main/java/com/velocitypowered/proxy/protocol/ProtocolConstants.java @@ -14,7 +14,7 @@ public enum ProtocolConstants { ; } public enum Direction { - TO_SERVER, - TO_CLIENT + SERVERBOUND, + CLIENTBOUND } } diff --git a/src/main/java/com/velocitypowered/proxy/protocol/StateRegistry.java b/src/main/java/com/velocitypowered/proxy/protocol/StateRegistry.java index 44bc06f64..d4e52fe95 100644 --- a/src/main/java/com/velocitypowered/proxy/protocol/StateRegistry.java +++ b/src/main/java/com/velocitypowered/proxy/protocol/StateRegistry.java @@ -1,13 +1,10 @@ package com.velocitypowered.proxy.protocol; -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; import com.velocitypowered.proxy.protocol.packets.*; import io.netty.util.collection.IntObjectHashMap; import io.netty.util.collection.IntObjectMap; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Objects; import java.util.function.Supplier; @@ -17,124 +14,127 @@ import static com.velocitypowered.proxy.protocol.ProtocolConstants.MINECRAFT_1_1 public enum StateRegistry { HANDSHAKE { { - TO_SERVER.register(Handshake.class, Handshake::new, + SERVERBOUND.register(Handshake.class, Handshake::new, generic(0x00)); } }, STATUS { { - TO_SERVER.register(StatusRequest.class, StatusRequest::new, + SERVERBOUND.register(StatusRequest.class, StatusRequest::new, generic(0x00)); - TO_SERVER.register(Ping.class, Ping::new, + SERVERBOUND.register(Ping.class, Ping::new, generic(0x01)); - TO_CLIENT.register(StatusResponse.class, StatusResponse::new, + CLIENTBOUND.register(StatusResponse.class, StatusResponse::new, generic(0x00)); - TO_CLIENT.register(Ping.class, Ping::new, + CLIENTBOUND.register(Ping.class, Ping::new, generic(0x01)); } }, PLAY { { - TO_SERVER.register(Chat.class, Chat::new, + SERVERBOUND.register(Chat.class, Chat::new, map(0x02, MINECRAFT_1_12)); - TO_SERVER.register(Ping.class, Ping::new, + SERVERBOUND.register(Ping.class, Ping::new, map(0x0b, MINECRAFT_1_12)); - TO_CLIENT.register(Chat.class, Chat::new, + CLIENTBOUND.register(Chat.class, Chat::new, map(0x0F, MINECRAFT_1_12)); - TO_CLIENT.register(Disconnect.class, Disconnect::new, + CLIENTBOUND.register(Disconnect.class, Disconnect::new, map(0x1A, MINECRAFT_1_12)); - TO_CLIENT.register(Ping.class, Ping::new, + CLIENTBOUND.register(Ping.class, Ping::new, map(0x1F, MINECRAFT_1_12)); - TO_CLIENT.register(JoinGame.class, JoinGame::new, + CLIENTBOUND.register(JoinGame.class, JoinGame::new, map(0x23, MINECRAFT_1_12)); - TO_CLIENT.register(Respawn.class, Respawn::new, + CLIENTBOUND.register(Respawn.class, Respawn::new, map(0x35, MINECRAFT_1_12)); } }, LOGIN { { - TO_SERVER.register(ServerLogin.class, ServerLogin::new, + SERVERBOUND.register(ServerLogin.class, ServerLogin::new, generic(0x00)); - TO_SERVER.register(EncryptionResponse.class, EncryptionResponse::new, + SERVERBOUND.register(EncryptionResponse.class, EncryptionResponse::new, generic(0x01)); - TO_CLIENT.register(Disconnect.class, Disconnect::new, + CLIENTBOUND.register(Disconnect.class, Disconnect::new, generic(0x00)); - TO_CLIENT.register(EncryptionRequest.class, EncryptionRequest::new, + CLIENTBOUND.register(EncryptionRequest.class, EncryptionRequest::new, generic(0x01)); - TO_CLIENT.register(ServerLoginSuccess.class, ServerLoginSuccess::new, + CLIENTBOUND.register(ServerLoginSuccess.class, ServerLoginSuccess::new, generic(0x02)); - TO_CLIENT.register(SetCompression.class, SetCompression::new, + CLIENTBOUND.register(SetCompression.class, SetCompression::new, generic(0x03)); } }; - public final PacketRegistry TO_CLIENT = new PacketRegistry(ProtocolConstants.Direction.TO_CLIENT, this); - public final PacketRegistry TO_SERVER = new PacketRegistry(ProtocolConstants.Direction.TO_SERVER, this); + public final PacketRegistry CLIENTBOUND = new PacketRegistry(ProtocolConstants.Direction.CLIENTBOUND); + public final PacketRegistry SERVERBOUND = new PacketRegistry(ProtocolConstants.Direction.SERVERBOUND); public static class PacketRegistry { private final ProtocolConstants.Direction direction; - private final StateRegistry state; - private final IntObjectMap>> byProtocolVersionToProtocolIds = new IntObjectHashMap<>(); - private final Map, List> idMappers = new HashMap<>(); + private final IntObjectMap versions = new IntObjectHashMap<>(); - public PacketRegistry(ProtocolConstants.Direction direction, StateRegistry state) { + public PacketRegistry(ProtocolConstants.Direction direction) { this.direction = direction; - this.state = state; + } + + public ProtocolVersion getVersion(final int version) { + ProtocolVersion result = null; + for (final IntObjectMap.PrimitiveEntry entry : this.versions.entries()) { + if (entry.key() <= version) { + result = entry.value(); + } + } + if (result == null) { + throw new IllegalArgumentException("Could not find data for protocol version " + version); + } + return result; } public

void register(Class

clazz, Supplier

packetSupplier, PacketMapping... mappings) { if (mappings.length == 0) { throw new IllegalArgumentException("At least one mapping must be provided."); } - for (PacketMapping mapping : mappings) { - IntObjectMap> ids = byProtocolVersionToProtocolIds.get(mapping.protocolVersion); - if (ids == null) { - byProtocolVersionToProtocolIds.put(mapping.protocolVersion, ids = new IntObjectHashMap<>()); + + for (final PacketMapping mapping : mappings) { + ProtocolVersion version = this.versions.get(mapping.protocolVersion); + if (version == null) { + version = new ProtocolVersion(mapping.protocolVersion); + this.versions.put(mapping.protocolVersion, version); } - ids.put(mapping.id, packetSupplier); + version.packetIdToSupplier.put(mapping.id, packetSupplier); + version.packetClassToId.put(clazz, mapping.id); } - idMappers.put(clazz, ImmutableList.copyOf(mappings)); } - public MinecraftPacket createPacket(int id, int protocolVersion) { - IntObjectMap> bestLookup = null; - for (IntObjectMap.PrimitiveEntry>> entry : byProtocolVersionToProtocolIds.entries()) { - if (entry.key() <= protocolVersion) { - bestLookup = entry.value(); - } - } - if (bestLookup == null) { - return null; - } - Supplier supplier = bestLookup.get(id); - if (supplier == null) { - return null; - } - return supplier.get(); - } + public class ProtocolVersion { + public final int id; + final IntObjectMap> packetIdToSupplier = new IntObjectHashMap<>(); + final Map, Integer> packetClassToId = new HashMap<>(); - public int getId(MinecraftPacket packet, int protocolVersion) { - Preconditions.checkNotNull(packet, "packet"); + ProtocolVersion(final int id) { + this.id = id; + } - List mappings = idMappers.get(packet.getClass()); - if (mappings == null || mappings.isEmpty()) { - throw new IllegalArgumentException("Supplied packet " + packet.getClass().getName() + - " doesn't have any mappings. Direction " + direction + " State " + state); - } - int useId = -1; - for (PacketMapping mapping : mappings) { - if (mapping.protocolVersion <= protocolVersion) { - useId = mapping.id; + public MinecraftPacket createPacket(final int id) { + final Supplier supplier = this.packetIdToSupplier.get(id); + if (supplier == null) { + return null; } + return supplier.get(); } - if (useId == -1) { - throw new IllegalArgumentException("Unable to find a mapping for " + packet.getClass().getName() - + " Version " + protocolVersion + " Direction " + direction + " State " + state); + + public int getPacketId(final MinecraftPacket packet) { + final Integer id = this.packetClassToId.get(packet.getClass()); + if (id == null) { + throw new IllegalArgumentException(String.format( + "Unable to find id for packet of type %s in %s protocol %s", + packet.getClass().getName(), PacketRegistry.this.direction, this.id + )); + } + return id; } - return useId; } } diff --git a/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftDecoder.java b/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftDecoder.java index 8f0f1dfaf..df493a771 100644 --- a/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftDecoder.java +++ b/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftDecoder.java @@ -12,7 +12,7 @@ import java.util.List; public class MinecraftDecoder extends MessageToMessageDecoder { private StateRegistry state; private final ProtocolConstants.Direction direction; - private int protocolVersion; + private StateRegistry.PacketRegistry.ProtocolVersion protocolVersion; public MinecraftDecoder(ProtocolConstants.Direction direction) { this.state = StateRegistry.HANDSHAKE; @@ -28,14 +28,13 @@ public class MinecraftDecoder extends MessageToMessageDecoder { ByteBuf slice = msg.slice().retain(); int packetId = ProtocolUtils.readVarInt(msg); - StateRegistry.PacketRegistry mappings = direction == ProtocolConstants.Direction.TO_CLIENT ? state.TO_CLIENT : state.TO_SERVER; - MinecraftPacket packet = mappings.createPacket(packetId, protocolVersion); + MinecraftPacket packet = this.protocolVersion.createPacket(packetId); if (packet == null) { msg.skipBytes(msg.readableBytes()); out.add(new PacketWrapper(null, slice)); } else { try { - packet.decode(msg, direction, protocolVersion); + packet.decode(msg, direction, protocolVersion.id); } catch (Exception e) { throw new CorruptedFrameException("Error decoding " + packet.getClass() + " Direction " + direction + " Protocol " + protocolVersion + " State " + state + " ID " + Integer.toHexString(packetId), e); @@ -44,12 +43,12 @@ public class MinecraftDecoder extends MessageToMessageDecoder { } } - public int getProtocolVersion() { + public StateRegistry.PacketRegistry.ProtocolVersion getProtocolVersion() { return protocolVersion; } public void setProtocolVersion(int protocolVersion) { - this.protocolVersion = protocolVersion; + this.protocolVersion = (this.direction == ProtocolConstants.Direction.CLIENTBOUND ? this.state.CLIENTBOUND : this.state.SERVERBOUND).getVersion(protocolVersion); } public StateRegistry getState() { diff --git a/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftEncoder.java b/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftEncoder.java index f90662a1a..7cc6e346b 100644 --- a/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftEncoder.java +++ b/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftEncoder.java @@ -12,7 +12,7 @@ import io.netty.handler.codec.MessageToByteEncoder; public class MinecraftEncoder extends MessageToByteEncoder { private StateRegistry state; private final ProtocolConstants.Direction direction; - private int protocolVersion; + private StateRegistry.PacketRegistry.ProtocolVersion protocolVersion; public MinecraftEncoder(ProtocolConstants.Direction direction) { this.state = StateRegistry.HANDSHAKE; @@ -20,19 +20,18 @@ public class MinecraftEncoder extends MessageToByteEncoder { } @Override - protected void encode(ChannelHandlerContext ctx, MinecraftPacket msg, ByteBuf out) throws Exception { - StateRegistry.PacketRegistry mappings = direction == ProtocolConstants.Direction.TO_CLIENT ? state.TO_CLIENT : state.TO_SERVER; - int packetId = mappings.getId(msg, protocolVersion); + protected void encode(ChannelHandlerContext ctx, MinecraftPacket msg, ByteBuf out) { + int packetId = this.protocolVersion.getPacketId(msg); ProtocolUtils.writeVarInt(out, packetId); - msg.encode(out, direction, protocolVersion); + msg.encode(out, direction, protocolVersion.id); } - public int getProtocolVersion() { + public StateRegistry.PacketRegistry.ProtocolVersion getProtocolVersion() { return protocolVersion; } - public void setProtocolVersion(int protocolVersion) { - this.protocolVersion = protocolVersion; + public void setProtocolVersion(final int protocolVersion) { + this.protocolVersion = (this.direction == ProtocolConstants.Direction.CLIENTBOUND ? this.state.CLIENTBOUND : this.state.SERVERBOUND).getVersion(protocolVersion); } public StateRegistry getState() { diff --git a/src/main/java/com/velocitypowered/proxy/protocol/packets/Chat.java b/src/main/java/com/velocitypowered/proxy/protocol/packets/Chat.java index 0ccc949ed..6f4c049ba 100644 --- a/src/main/java/com/velocitypowered/proxy/protocol/packets/Chat.java +++ b/src/main/java/com/velocitypowered/proxy/protocol/packets/Chat.java @@ -47,7 +47,7 @@ public class Chat implements MinecraftPacket { @Override public void decode(ByteBuf buf, ProtocolConstants.Direction direction, int protocolVersion) { message = ProtocolUtils.readString(buf); - if (direction == ProtocolConstants.Direction.TO_CLIENT) { + if (direction == ProtocolConstants.Direction.CLIENTBOUND) { position = buf.readByte(); } } @@ -55,7 +55,7 @@ public class Chat implements MinecraftPacket { @Override public void encode(ByteBuf buf, ProtocolConstants.Direction direction, int protocolVersion) { ProtocolUtils.writeString(buf, message); - if (direction == ProtocolConstants.Direction.TO_CLIENT) { + if (direction == ProtocolConstants.Direction.CLIENTBOUND) { buf.writeByte(position); } } diff --git a/src/test/java/com/velocitypowered/proxy/protocol/PacketRegistryTest.java b/src/test/java/com/velocitypowered/proxy/protocol/PacketRegistryTest.java index c5705221c..a6c8b0c0f 100644 --- a/src/test/java/com/velocitypowered/proxy/protocol/PacketRegistryTest.java +++ b/src/test/java/com/velocitypowered/proxy/protocol/PacketRegistryTest.java @@ -8,7 +8,7 @@ import static org.junit.jupiter.api.Assertions.*; class PacketRegistryTest { private StateRegistry.PacketRegistry setupRegistry() { - StateRegistry.PacketRegistry registry = new StateRegistry.PacketRegistry(ProtocolConstants.Direction.TO_CLIENT, StateRegistry.HANDSHAKE); + StateRegistry.PacketRegistry registry = new StateRegistry.PacketRegistry(ProtocolConstants.Direction.CLIENTBOUND); registry.register(Handshake.class, Handshake::new, new StateRegistry.PacketMapping(0x00, 1)); registry.register(Ping.class, Ping::new, new StateRegistry.PacketMapping(0x01, 1), new StateRegistry.PacketMapping(0x02, 5)); @@ -18,47 +18,47 @@ class PacketRegistryTest { @Test void packetRegistryWorks() { StateRegistry.PacketRegistry registry = setupRegistry(); - MinecraftPacket packet = registry.createPacket(0, 1); + MinecraftPacket packet = registry.getVersion(1).createPacket(0); assertNotNull(packet, "Packet was not found in registry"); assertEquals(Handshake.class, packet.getClass(), "Registry returned wrong class"); - assertEquals(0, registry.getId(packet, 1), "Registry did not return the correct packet ID"); + assertEquals(0, registry.getVersion(1).getPacketId(packet), "Registry did not return the correct packet ID"); } @Test void packetRegistryRevertsToBestOldVersion() { StateRegistry.PacketRegistry registry = setupRegistry(); - MinecraftPacket packet = registry.createPacket(0, 2); + MinecraftPacket packet = registry.getVersion(2).createPacket(0); assertNotNull(packet, "Packet was not found in registry"); assertEquals(Handshake.class, packet.getClass(), "Registry returned wrong class"); - assertEquals(0, registry.getId(packet, 2), "Registry did not return the correct packet ID"); + assertEquals(0, registry.getVersion(2).getPacketId(packet), "Registry did not return the correct packet ID"); } @Test void packetRegistryDoesntProvideNewPacketsForOld() { StateRegistry.PacketRegistry registry = setupRegistry(); - assertNull(registry.createPacket(0, 0), "Packet was found in registry despite being too new"); + assertNull(registry.getVersion(0).createPacket(0), "Packet was found in registry despite being too new"); - assertThrows(IllegalArgumentException.class, () -> registry.getId(new Handshake(), 0), "Registry provided new packets for an old protocol version"); + assertThrows(IllegalArgumentException.class, () -> registry.getVersion(0).getPacketId(new Handshake()), "Registry provided new packets for an old protocol version"); } @Test void failOnNoMappings() { - StateRegistry.PacketRegistry registry = new StateRegistry.PacketRegistry(ProtocolConstants.Direction.TO_CLIENT, StateRegistry.HANDSHAKE); + StateRegistry.PacketRegistry registry = new StateRegistry.PacketRegistry(ProtocolConstants.Direction.CLIENTBOUND); assertThrows(IllegalArgumentException.class, () -> registry.register(Handshake.class, Handshake::new)); - assertThrows(IllegalArgumentException.class, () -> registry.getId(new Handshake(), 0)); + assertThrows(IllegalArgumentException.class, () -> registry.getVersion(0).getPacketId(new Handshake())); } @Test void packetRegistryProvidesCorrectVersionsForMultipleMappings() { StateRegistry.PacketRegistry registry = setupRegistry(); - assertNotNull(registry.createPacket(1, 1), "Packet was not found in registry despite being being registered with ID 1 and version 1"); - assertNotNull(registry.createPacket(1, 2), "Packet was not found in registry despite being being registered with ID 1 and version 1 (we are looking up version 2)"); - assertNotNull(registry.createPacket(2, 5), "Packet was not found in registry despite being being registered with ID 2 and version 5"); - assertNotNull(registry.createPacket(2, 6), "Packet was not found in registry despite being being registered with ID 2 and version 5 (we are looking up version 6)"); + assertNotNull(registry.getVersion(1).createPacket(1), "Packet was not found in registry despite being being registered with ID 1 and version 1"); + assertNotNull(registry.getVersion(2).createPacket(1), "Packet was not found in registry despite being being registered with ID 1 and version 1 (we are looking up version 2)"); + assertNotNull(registry.getVersion(5).createPacket(2), "Packet was not found in registry despite being being registered with ID 2 and version 5"); + assertNotNull(registry.getVersion(6).createPacket(2), "Packet was not found in registry despite being being registered with ID 2 and version 5 (we are looking up version 6)"); - assertEquals(1, registry.getId(new Ping(), 1), "Wrong ID provided from registry"); - assertEquals(2, registry.getId(new Ping(), 5), "Wrong ID provided from registry"); + assertEquals(1, registry.getVersion(1).getPacketId(new Ping()), "Wrong ID provided from registry"); + assertEquals(2, registry.getVersion(5).getPacketId(new Ping()), "Wrong ID provided from registry"); } } \ No newline at end of file