diff --git a/proxy/src/main/java/com/velocitypowered/proxy/protocol/StateRegistry.java b/proxy/src/main/java/com/velocitypowered/proxy/protocol/StateRegistry.java index 58cf30d19..fd5c268b2 100644 --- a/proxy/src/main/java/com/velocitypowered/proxy/protocol/StateRegistry.java +++ b/proxy/src/main/java/com/velocitypowered/proxy/protocol/StateRegistry.java @@ -1,23 +1,15 @@ package com.velocitypowered.proxy.protocol; -import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_10; -import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_11; -import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_11_1; +import static com.google.common.collect.Iterables.getLast; import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_12; import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_12_1; -import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_12_2; import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_13; -import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_13_1; -import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_13_2; import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_14; -import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_14_1; -import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_14_2; import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_8; import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_9; -import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_9_1; -import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_9_2; import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_9_4; import static com.velocitypowered.api.network.ProtocolVersion.MINIMUM_VERSION; +import static com.velocitypowered.api.network.ProtocolVersion.SUPPORTED_VERSIONS; import static com.velocitypowered.proxy.protocol.ProtocolUtils.Direction; import com.velocitypowered.api.network.ProtocolVersion; @@ -53,7 +45,6 @@ import io.netty.util.collection.IntObjectMap; import it.unimi.dsi.fastutil.objects.Object2IntMap; import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; -import java.util.Collection; import java.util.Collections; import java.util.EnumMap; import java.util.EnumSet; @@ -68,20 +59,20 @@ public enum StateRegistry { HANDSHAKE { { serverbound.register(Handshake.class, Handshake::new, - genericMappings(0x00)); + map(0x00, MINECRAFT_1_8, false)); } }, STATUS { { serverbound.register(StatusRequest.class, () -> StatusRequest.INSTANCE, - genericMappings(0x00)); + map(0x00, MINECRAFT_1_8, false)); serverbound.register(StatusPing.class, StatusPing::new, - genericMappings(0x01)); + map(0x01, MINECRAFT_1_8, false)); clientbound.register(StatusResponse.class, StatusResponse::new, - genericMappings(0x00)); + map(0x00, MINECRAFT_1_8, false)); clientbound.register(StatusPing.class, StatusPing::new, - genericMappings(0x01)); + map(0x01, MINECRAFT_1_8, false)); } }, PLAY { @@ -101,14 +92,12 @@ public enum StateRegistry { map(0x02, MINECRAFT_1_9, false), map(0x03, MINECRAFT_1_12, false), map(0x02, MINECRAFT_1_12_1, false), - map(0x02, MINECRAFT_1_13, false), map(0x03, MINECRAFT_1_14, false)); serverbound.register(ClientSettings.class, ClientSettings::new, map(0x15, MINECRAFT_1_8, false), map(0x04, MINECRAFT_1_9, false), map(0x05, MINECRAFT_1_12, false), map(0x04, MINECRAFT_1_12_1, false), - map(0x04, MINECRAFT_1_13, false), map(0x05, MINECRAFT_1_14, false)); serverbound.register(PluginMessage.class, PluginMessage::new, map(0x17, MINECRAFT_1_8, false), @@ -132,47 +121,35 @@ public enum StateRegistry { map(0x1F, MINECRAFT_1_14, false)); clientbound.register(BossBar.class, BossBar::new, - map(0x0C, MINECRAFT_1_9, false), - map(0x0C, MINECRAFT_1_12, false), - map(0x0C, MINECRAFT_1_13, false), - map(0x0C, MINECRAFT_1_14, false)); + map(0x0C, MINECRAFT_1_9, false)); clientbound.register(Chat.class, Chat::new, map(0x02, MINECRAFT_1_8, true), map(0x0F, MINECRAFT_1_9, true), - map(0x0F, MINECRAFT_1_12, true), - map(0x0E, MINECRAFT_1_13, true), - map(0x0E, MINECRAFT_1_14, false)); + map(0x0E, MINECRAFT_1_13, true)); clientbound.register(TabCompleteResponse.class, TabCompleteResponse::new, map(0x3A, MINECRAFT_1_8, false), map(0x0E, MINECRAFT_1_9, false), - map(0x0E, MINECRAFT_1_12, false), - map(0x10, MINECRAFT_1_13, false), - map(0x10, MINECRAFT_1_14, false)); + map(0x10, MINECRAFT_1_13, false)); clientbound.register(AvailableCommands.class, AvailableCommands::new, - map(0x11, MINECRAFT_1_13, false), - map(0x11, MINECRAFT_1_14, false)); + map(0x11, MINECRAFT_1_13, false)); clientbound.register(PluginMessage.class, PluginMessage::new, map(0x3F, MINECRAFT_1_8, false), map(0x18, MINECRAFT_1_9, false), - map(0x18, MINECRAFT_1_12, false), map(0x19, MINECRAFT_1_13, false), map(0x18, MINECRAFT_1_14, false)); clientbound.register(Disconnect.class, Disconnect::new, map(0x40, MINECRAFT_1_8, false), map(0x1A, MINECRAFT_1_9, false), - map(0x1A, MINECRAFT_1_12, false), map(0x1B, MINECRAFT_1_13, false), map(0x1A, MINECRAFT_1_14, false)); clientbound.register(KeepAlive.class, KeepAlive::new, map(0x00, MINECRAFT_1_8, false), map(0x1F, MINECRAFT_1_9, false), - map(0x1F, MINECRAFT_1_12, false), map(0x21, MINECRAFT_1_13, false), map(0x20, MINECRAFT_1_14, false)); clientbound.register(JoinGame.class, JoinGame::new, map(0x01, MINECRAFT_1_8, false), map(0x23, MINECRAFT_1_9, false), - map(0x23, MINECRAFT_1_12, false), map(0x25, MINECRAFT_1_13, false), map(0x25, MINECRAFT_1_14, false)); clientbound.register(Respawn.class, Respawn::new, @@ -207,7 +184,6 @@ public enum StateRegistry { clientbound.register(PlayerListItem.class, PlayerListItem::new, map(0x38, MINECRAFT_1_8, false), map(0x2D, MINECRAFT_1_9, false), - map(0x2D, MINECRAFT_1_12, false), map(0x2E, MINECRAFT_1_12_1, false), map(0x30, MINECRAFT_1_13, false), map(0x33, MINECRAFT_1_14, false)); @@ -216,24 +192,21 @@ public enum StateRegistry { LOGIN { { serverbound.register(ServerLogin.class, ServerLogin::new, - genericMappings(0x00)); + map(0x00, MINECRAFT_1_8, false)); serverbound.register(EncryptionResponse.class, EncryptionResponse::new, - genericMappings(0x01)); + map(0x01, MINECRAFT_1_8, false)); serverbound.register(LoginPluginResponse.class, LoginPluginResponse::new, - map(0x02, MINECRAFT_1_13, false), - map(0x02, MINECRAFT_1_14, false)); - + map(0x02, MINECRAFT_1_13, false)); clientbound.register(Disconnect.class, Disconnect::new, - genericMappings(0x00)); + map(0x00, MINECRAFT_1_8, false)); clientbound.register(EncryptionRequest.class, EncryptionRequest::new, - genericMappings(0x01)); + map(0x01, MINECRAFT_1_8, false)); clientbound.register(ServerLoginSuccess.class, ServerLoginSuccess::new, - genericMappings(0x02)); + map(0x02, MINECRAFT_1_8, false)); clientbound.register(SetCompression.class, SetCompression::new, - genericMappings(0x03)); + map(0x03, MINECRAFT_1_8, false)); clientbound.register(LoginPluginMessage.class, LoginPluginMessage::new, - map(0x04, MINECRAFT_1_13, false), - map(0x04, MINECRAFT_1_14, false)); + map(0x04, MINECRAFT_1_13, false)); } }; @@ -244,20 +217,6 @@ public enum StateRegistry { public static class PacketRegistry { - private static final Map> LINKED_PROTOCOL_VERSIONS - = new EnumMap<>(ProtocolVersion.class); - - static { - LINKED_PROTOCOL_VERSIONS.put(MINECRAFT_1_9, EnumSet.of(MINECRAFT_1_9_1, MINECRAFT_1_9_2, - MINECRAFT_1_9_4)); - LINKED_PROTOCOL_VERSIONS.put(MINECRAFT_1_9_4, EnumSet.of(MINECRAFT_1_10, MINECRAFT_1_11, - MINECRAFT_1_11_1)); - LINKED_PROTOCOL_VERSIONS.put(MINECRAFT_1_12, EnumSet.of(MINECRAFT_1_12_1)); - LINKED_PROTOCOL_VERSIONS.put(MINECRAFT_1_12_1, EnumSet.of(MINECRAFT_1_12_2)); - LINKED_PROTOCOL_VERSIONS.put(MINECRAFT_1_13, EnumSet.of(MINECRAFT_1_13_1, MINECRAFT_1_13_2)); - LINKED_PROTOCOL_VERSIONS.put(MINECRAFT_1_14, EnumSet.of(MINECRAFT_1_14_1, MINECRAFT_1_14_2)); - } - private final Direction direction; private final Map versions; private boolean fallback = true; @@ -292,28 +251,42 @@ public enum StateRegistry { throw new IllegalArgumentException("At least one mapping must be provided."); } - for (final PacketMapping mapping : mappings) { - ProtocolRegistry registry = this.versions.get(mapping.protocolVersion); - if (registry == null) { - throw new IllegalArgumentException("Unknown protocol version " + mapping.protocolVersion); - } - if (!mapping.encodeOnly) { - registry.packetIdToSupplier.put(mapping.id, packetSupplier); - } - registry.packetClassToId.put(clazz, mapping.id); + for (int i = 0; i < mappings.length; i++) { + PacketMapping current = mappings[i]; + PacketMapping next = (i + 1 < mappings.length) ? mappings[i + 1] : current; + ProtocolVersion from = current.protocolVersion; + ProtocolVersion to = current == next ? getLast(SUPPORTED_VERSIONS) : next.protocolVersion; - Collection linked = LINKED_PROTOCOL_VERSIONS.get(mapping.protocolVersion); - if (linked != null) { - links: - for (ProtocolVersion linkedVersion : linked) { - // Make sure that later mappings override this one. - for (PacketMapping m : mappings) { - if (linkedVersion == m.protocolVersion) { - continue links; - } - } - register(clazz, packetSupplier, map(mapping.id, linkedVersion, mapping.encodeOnly)); + if (from.compareTo(to) >= 0 && from != getLast(SUPPORTED_VERSIONS)) { + throw new IllegalArgumentException(String.format( + "Next mapping version (%s) should be lower then current (%s)", to, from)); + } + + for (ProtocolVersion protocol : EnumSet.range(from, to)) { + if (protocol == to && next != current) { + break; } + ProtocolRegistry registry = this.versions.get(protocol); + if (registry == null) { + throw new IllegalArgumentException("Unknown protocol version " + + current.protocolVersion); + } + + if (registry.packetIdToSupplier.containsKey(current.id)) { + throw new IllegalArgumentException("Can not register class " + clazz.getSimpleName() + + " with id " + current.id + " for " + registry.version + + " because another packet is already registered"); + } + + if (registry.packetClassToId.containsKey(clazz)) { + throw new IllegalArgumentException(clazz.getSimpleName() + + " is already registered for version " + registry.version); + } + + if (!current.encodeOnly) { + registry.packetIdToSupplier.put(current.id, packetSupplier); + } + registry.packetClassToId.put(clazz, current.id); } } } @@ -409,8 +382,8 @@ public enum StateRegistry { /** * Creates a PacketMapping using the provided arguments. * - * @param id Packet Id - * @param version Protocol version + * @param id Packet Id + * @param version Protocol version * @param encodeOnly When true packet decoding will be disabled * @return PacketMapping with the provided arguments */ @@ -418,13 +391,4 @@ public enum StateRegistry { return new PacketMapping(id, version, encodeOnly); } - private static PacketMapping[] genericMappings(int id) { - return new PacketMapping[] { - map(id, MINECRAFT_1_8, false), - map(id, MINECRAFT_1_9, false), - map(id, MINECRAFT_1_12, false), - map(id, MINECRAFT_1_13, false), - map(id, MINECRAFT_1_14, false) - }; - } } diff --git a/proxy/src/test/java/com/velocitypowered/proxy/protocol/PacketRegistryTest.java b/proxy/src/test/java/com/velocitypowered/proxy/protocol/PacketRegistryTest.java index b943bd31a..e1c40091a 100644 --- a/proxy/src/test/java/com/velocitypowered/proxy/protocol/PacketRegistryTest.java +++ b/proxy/src/test/java/com/velocitypowered/proxy/protocol/PacketRegistryTest.java @@ -1,14 +1,23 @@ package com.velocitypowered.proxy.protocol; +import static com.google.common.collect.Iterables.getLast; +import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_11; import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_12; import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_12_1; import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_12_2; +import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_13; +import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_14_2; +import static com.velocitypowered.api.network.ProtocolVersion.MINECRAFT_1_8; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import com.velocitypowered.api.network.ProtocolVersion; import com.velocitypowered.proxy.protocol.packet.Handshake; +import com.velocitypowered.proxy.protocol.packet.StatusPing; + import org.junit.jupiter.api.Test; class PacketRegistryTest { @@ -17,6 +26,7 @@ class PacketRegistryTest { StateRegistry.PacketRegistry registry = new StateRegistry.PacketRegistry( ProtocolUtils.Direction.CLIENTBOUND); registry.register(Handshake.class, Handshake::new, + new StateRegistry.PacketMapping(0x01, MINECRAFT_1_8, false), new StateRegistry.PacketMapping(0x00, MINECRAFT_1_12, false)); return registry; } @@ -40,6 +50,12 @@ class PacketRegistryTest { assertEquals(Handshake.class, packet.getClass(), "Registry returned wrong class"); assertEquals(0, registry.getProtocolRegistry(MINECRAFT_1_12_1).getPacketId(packet), "Registry did not return the correct packet ID"); + assertEquals(0, registry.getProtocolRegistry(MINECRAFT_1_14_2).getPacketId(packet), + "Registry did not return the correct packet ID"); + assertEquals(1, registry.getProtocolRegistry(MINECRAFT_1_11).getPacketId(packet), + "Registry did not return the correct packet ID"); + assertNull(registry.getProtocolRegistry(MINECRAFT_1_14_2).createPacket(0x01), + "Registry should return a null"); } @Test @@ -52,18 +68,61 @@ class PacketRegistryTest { () -> registry.getProtocolRegistry(ProtocolVersion.UNKNOWN).getPacketId(new Handshake())); } + @Test + void failOnWrongOrder() { + StateRegistry.PacketRegistry registry = new StateRegistry.PacketRegistry( + ProtocolUtils.Direction.CLIENTBOUND); + assertThrows(IllegalArgumentException.class, + () -> registry.register(Handshake.class, Handshake::new, + new StateRegistry.PacketMapping(0x01, MINECRAFT_1_13, false), + new StateRegistry.PacketMapping(0x00, MINECRAFT_1_8, false))); + assertThrows(IllegalArgumentException.class, + () -> registry.register(Handshake.class, Handshake::new, + new StateRegistry.PacketMapping(0x01, MINECRAFT_1_13, false), + new StateRegistry.PacketMapping(0x01, MINECRAFT_1_13, false))); + } + + @Test + void failOnDuplicate() { + StateRegistry.PacketRegistry registry = new StateRegistry.PacketRegistry( + ProtocolUtils.Direction.CLIENTBOUND); + registry.register(Handshake.class, Handshake::new, + new StateRegistry.PacketMapping(0x00, MINECRAFT_1_8, false)); + assertThrows(IllegalArgumentException.class, + () -> registry.register(Handshake.class, Handshake::new, + new StateRegistry.PacketMapping(0x01, MINECRAFT_1_12, false))); + assertThrows(IllegalArgumentException.class, + () -> registry.register(StatusPing.class, StatusPing::new, + new StateRegistry.PacketMapping(0x00, MINECRAFT_1_13, false))); + } + + @Test + void shouldNotFailWhenRegisterLatestProtocolVersion() { + StateRegistry.PacketRegistry registry = new StateRegistry.PacketRegistry( + ProtocolUtils.Direction.CLIENTBOUND); + assertDoesNotThrow(() -> registry.register(Handshake.class, Handshake::new, + new StateRegistry.PacketMapping(0x00, MINECRAFT_1_8, false), + new StateRegistry.PacketMapping(0x01, getLast(ProtocolVersion.SUPPORTED_VERSIONS), + false))); + } + @Test void registrySuppliesCorrectPacketsByProtocol() { StateRegistry.PacketRegistry registry = new StateRegistry.PacketRegistry( ProtocolUtils.Direction.CLIENTBOUND); registry.register(Handshake.class, Handshake::new, new StateRegistry.PacketMapping(0x00, MINECRAFT_1_12, false), - new StateRegistry.PacketMapping(0x01, MINECRAFT_1_12_1, false)); + new StateRegistry.PacketMapping(0x01, MINECRAFT_1_12_1, false), + new StateRegistry.PacketMapping(0x02, MINECRAFT_1_13, false)); assertEquals(Handshake.class, registry.getProtocolRegistry(MINECRAFT_1_12).createPacket(0x00).getClass()); assertEquals(Handshake.class, registry.getProtocolRegistry(MINECRAFT_1_12_1).createPacket(0x01).getClass()); assertEquals(Handshake.class, registry.getProtocolRegistry(MINECRAFT_1_12_2).createPacket(0x01).getClass()); + assertEquals(Handshake.class, + registry.getProtocolRegistry(MINECRAFT_1_13).createPacket(0x02).getClass()); + assertEquals(Handshake.class, + registry.getProtocolRegistry(MINECRAFT_1_14_2).createPacket(0x02).getClass()); } } \ No newline at end of file