From 784806848d486fd3c06c3f63a8c54483361b9241 Mon Sep 17 00:00:00 2001 From: Andrew Steinborn Date: Thu, 5 Sep 2024 00:00:34 -0400 Subject: [PATCH] Do more speculative VarInt reading optimizations (#1418) --- .../proxy/protocol/ProtocolUtils.java | 75 ++++------ .../netty/MinecraftVarintFrameDecoder.java | 130 +++++++++++++----- .../protocol/netty/VarintByteDecoder.java | 68 --------- .../packet/LoginPluginMessagePacket.java | 2 +- .../proxy/protocol/ProtocolUtilsTest.java | 4 +- 5 files changed, 127 insertions(+), 152 deletions(-) delete mode 100644 proxy/src/main/java/com/velocitypowered/proxy/protocol/netty/VarintByteDecoder.java diff --git a/proxy/src/main/java/com/velocitypowered/proxy/protocol/ProtocolUtils.java b/proxy/src/main/java/com/velocitypowered/proxy/protocol/ProtocolUtils.java index 053dcd65f..797a58223 100644 --- a/proxy/src/main/java/com/velocitypowered/proxy/protocol/ProtocolUtils.java +++ b/proxy/src/main/java/com/velocitypowered/proxy/protocol/ProtocolUtils.java @@ -104,6 +104,7 @@ public enum ProtocolUtils { .build(); public static final int DEFAULT_MAX_STRING_SIZE = 65536; // 64KiB + private static final int MAXIMUM_VARINT_SIZE = 5; private static final BinaryTagType[] BINARY_TAG_TYPES = new BinaryTagType[] { BinaryTagTypes.END, BinaryTagTypes.BYTE, BinaryTagTypes.SHORT, BinaryTagTypes.INT, BinaryTagTypes.LONG, BinaryTagTypes.FLOAT, BinaryTagTypes.DOUBLE, @@ -111,13 +112,18 @@ public enum ProtocolUtils { BinaryTagTypes.COMPOUND, BinaryTagTypes.INT_ARRAY, BinaryTagTypes.LONG_ARRAY}; private static final QuietDecoderException BAD_VARINT_CACHED = new QuietDecoderException("Bad VarInt decoded"); - private static final int[] VARINT_EXACT_BYTE_LENGTHS = new int[33]; + private static final int[] VAR_INT_LENGTHS = new int[65]; static { for (int i = 0; i <= 32; ++i) { - VARINT_EXACT_BYTE_LENGTHS[i] = (int) Math.ceil((31d - (i - 1)) / 7d); + VAR_INT_LENGTHS[i] = (int) Math.ceil((31d - (i - 1)) / 7d); } - VARINT_EXACT_BYTE_LENGTHS[32] = 1; // Special case for the number 0. + VAR_INT_LENGTHS[32] = 1; // Special case for the number 0. + } + + private static DecoderException badVarint() { + return MinecraftDecoder.DEBUG ? new CorruptedFrameException("Bad VarInt decoded") + : BAD_VARINT_CACHED; } /** @@ -127,56 +133,29 @@ public enum ProtocolUtils { * @return the decoded VarInt */ public static int readVarInt(ByteBuf buf) { - int read = readVarIntSafely(buf); - if (read == Integer.MIN_VALUE) { - throw MinecraftDecoder.DEBUG ? new CorruptedFrameException("Bad VarInt decoded") - : BAD_VARINT_CACHED; + int readable = buf.readableBytes(); + if (readable == 0) { + // special case for empty buffer + throw badVarint(); } - return read; - } - /** - * Reads a Minecraft-style VarInt from the specified {@code buf}. The difference between this - * method and {@link #readVarInt(ByteBuf)} is that this function returns a sentinel value if the - * varint is invalid. - * - * @param buf the buffer to read from - * @return the decoded VarInt, or {@code Integer.MIN_VALUE} if the varint is invalid - */ - public static int readVarIntSafely(ByteBuf buf) { - int i = 0; - int maxRead = Math.min(5, buf.readableBytes()); - for (int j = 0; j < maxRead; j++) { - int k = buf.readByte(); + // we can read at least one byte, and this should be a common case + int k = buf.readByte(); + if ((k & 0x80) != 128) { + return k; + } + + // in case decoding one byte was not enough, use a loop to decode up to the next 4 bytes + int maxRead = Math.min(MAXIMUM_VARINT_SIZE, readable); + int i = k & 0x7F; + for (int j = 1; j < maxRead; j++) { + k = buf.readByte(); i |= (k & 0x7F) << j * 7; if ((k & 0x80) != 128) { return i; } } - return Integer.MIN_VALUE; - } - - /** - * Reads a Minecraft-style VarInt from the specified {@code buf}. The difference between this - * method and {@link #readVarInt(ByteBuf)} is that this function returns a sentinel value if the - * varint is invalid. - * - * @param buf the buffer to read from - * @return the decoded VarInt - * @throws DecoderException if the varint is invalid - */ - public static int readVarIntSafelyOrThrow(ByteBuf buf) { - int i = 0; - int maxRead = Math.min(5, buf.readableBytes()); - for (int j = 0; j < maxRead; j++) { - int k = buf.readByte(); - i |= (k & 0x7F) << j * 7; - if ((k & 0x80) != 128) { - return i; - } - } - throw MinecraftDecoder.DEBUG ? new CorruptedFrameException("Bad VarInt decoded") - : BAD_VARINT_CACHED; + throw badVarint(); } /** @@ -186,7 +165,7 @@ public enum ProtocolUtils { * @return the byte size of {@code value} if encoded as a VarInt */ public static int varIntBytes(int value) { - return VARINT_EXACT_BYTE_LENGTHS[Integer.numberOfLeadingZeros(value)]; + return VAR_INT_LENGTHS[Integer.numberOfLeadingZeros(value)]; } /** @@ -210,6 +189,8 @@ public enum ProtocolUtils { private static void writeVarIntFull(ByteBuf buf, int value) { // See https://steinborn.me/posts/performance/how-fast-can-you-write-a-varint/ + + // This essentially is an unrolled version of the "traditional" VarInt encoding. if ((value & (0xFFFFFFFF << 7)) == 0) { buf.writeByte(value); } else if ((value & (0xFFFFFFFF << 14)) == 0) { diff --git a/proxy/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftVarintFrameDecoder.java b/proxy/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftVarintFrameDecoder.java index 94baa2ffe..7bf7563ea 100644 --- a/proxy/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftVarintFrameDecoder.java +++ b/proxy/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftVarintFrameDecoder.java @@ -17,7 +17,8 @@ package com.velocitypowered.proxy.protocol.netty; -import com.velocitypowered.proxy.protocol.netty.VarintByteDecoder.DecodeResult; +import static io.netty.util.ByteProcessor.FIND_NON_NUL; + import com.velocitypowered.proxy.util.except.QuietDecoderException; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; @@ -29,53 +30,114 @@ import java.util.List; */ public class MinecraftVarintFrameDecoder extends ByteToMessageDecoder { - private static final QuietDecoderException BAD_LENGTH_CACHED = + private static final QuietDecoderException BAD_PACKET_LENGTH = new QuietDecoderException("Bad packet length"); - private static final QuietDecoderException VARINT_BIG_CACHED = + private static final QuietDecoderException VARINT_TOO_BIG = new QuietDecoderException("VarInt too big"); @Override - protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) + throws Exception { if (!ctx.channel().isActive()) { in.clear(); return; } - final VarintByteDecoder reader = new VarintByteDecoder(); - - int varintEnd = in.forEachByte(reader); - if (varintEnd == -1) { - // We tried to go beyond the end of the buffer. This is probably a good sign that the - // buffer was too short to hold a proper varint. - if (reader.getResult() == DecodeResult.RUN_OF_ZEROES) { - // Special case where the entire packet is just a run of zeroes. We ignore them all. - in.clear(); - } + // skip any runs of 0x00 we might find + int packetStart = in.forEachByte(FIND_NON_NUL); + if (packetStart == -1) { return; } + in.readerIndex(packetStart); - if (reader.getResult() == DecodeResult.RUN_OF_ZEROES) { - // this will return to the point where the next varint starts - in.readerIndex(varintEnd); - } else if (reader.getResult() == DecodeResult.SUCCESS) { - int readVarint = reader.getReadVarint(); - int bytesRead = reader.getBytesRead(); - if (readVarint < 0) { - in.clear(); - throw BAD_LENGTH_CACHED; - } else if (readVarint == 0) { - // skip over the empty packet(s) and ignore it - in.readerIndex(varintEnd + 1); + // try to read the length of the packet + in.markReaderIndex(); + int preIndex = in.readerIndex(); + int length = readRawVarInt21(in); + if (preIndex == in.readerIndex()) { + return; + } + if (length < 0) { + throw BAD_PACKET_LENGTH; + } + + // note that zero-length packets are ignored + if (length > 0) { + if (in.readableBytes() < length) { + in.resetReaderIndex(); } else { - int minimumRead = bytesRead + readVarint; - if (in.isReadable(minimumRead)) { - out.add(in.retainedSlice(varintEnd + 1, readVarint)); - in.skipBytes(minimumRead); - } + out.add(in.readRetainedSlice(length)); } - } else if (reader.getResult() == DecodeResult.TOO_BIG) { - in.clear(); - throw VARINT_BIG_CACHED; } } + + /** + * Reads a VarInt from the buffer of up to 21 bits in size. + * + * @param buffer the buffer to read from + * @return the VarInt decoded, {@code 0} if no varint could be read + * @throws QuietDecoderException if the VarInt is too big to be decoded + */ + private static int readRawVarInt21(ByteBuf buffer) { + if (buffer.readableBytes() < 4) { + // we don't have enough that we can read a potentially full varint, so fall back to + // the slow path. + return readRawVarintSmallBuf(buffer); + } + int wholeOrMore = buffer.getIntLE(buffer.readerIndex()); + + // take the last three bytes and check if any of them have the high bit set + int atStop = ~wholeOrMore & 0x808080; + if (atStop == 0) { + // all bytes have the high bit set, so the varint we are trying to decode is too wide + throw VARINT_TOO_BIG; + } + + int bitsToKeep = Integer.numberOfTrailingZeros(atStop) + 1; + buffer.skipBytes(bitsToKeep >> 3); + + // remove all bits we don't need to keep, a trick from + // https://github.com/netty/netty/pull/14050#issuecomment-2107750734: + // + // > The idea is that thisVarintMask has 0s above the first one of firstOneOnStop, and 1s at + // > and below it. For example if firstOneOnStop is 0x800080 (where the last 0x80 is the only + // > one that matters), then thisVarintMask is 0xFF. + // + // this is also documented in Hacker's Delight, section 2-1 "Manipulating Rightmost Bits" + int preservedBytes = wholeOrMore & (atStop ^ (atStop - 1)); + + // merge together using this trick: https://github.com/netty/netty/pull/14050#discussion_r1597896639 + preservedBytes = (preservedBytes & 0x007F007F) | ((preservedBytes & 0x00007F00) >> 1); + preservedBytes = (preservedBytes & 0x00003FFF) | ((preservedBytes & 0x3FFF0000) >> 2); + return preservedBytes; + } + + private static int readRawVarintSmallBuf(ByteBuf buffer) { + if (!buffer.isReadable()) { + return 0; + } + buffer.markReaderIndex(); + + byte tmp = buffer.readByte(); + if (tmp >= 0) { + return tmp; + } + int result = tmp & 0x7F; + if (!buffer.isReadable()) { + buffer.resetReaderIndex(); + return 0; + } + if ((tmp = buffer.readByte()) >= 0) { + return result | tmp << 7; + } + result |= (tmp & 0x7F) << 7; + if (!buffer.isReadable()) { + buffer.resetReaderIndex(); + return 0; + } + if ((tmp = buffer.readByte()) >= 0) { + return result | tmp << 14; + } + return result | (tmp & 0x7F) << 14; + } } diff --git a/proxy/src/main/java/com/velocitypowered/proxy/protocol/netty/VarintByteDecoder.java b/proxy/src/main/java/com/velocitypowered/proxy/protocol/netty/VarintByteDecoder.java deleted file mode 100644 index 06cec7350..000000000 --- a/proxy/src/main/java/com/velocitypowered/proxy/protocol/netty/VarintByteDecoder.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright (C) 2020-2021 Velocity Contributors - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ - -package com.velocitypowered.proxy.protocol.netty; - -import io.netty.util.ByteProcessor; - -class VarintByteDecoder implements ByteProcessor { - - private int readVarint; - private int bytesRead; - private DecodeResult result = DecodeResult.TOO_SHORT; - - @Override - public boolean process(byte k) { - if (k == 0 && bytesRead == 0) { - // tentatively say it's invalid, but there's a possibility of redemption - result = DecodeResult.RUN_OF_ZEROES; - return true; - } - if (result == DecodeResult.RUN_OF_ZEROES) { - return false; - } - readVarint |= (k & 0x7F) << bytesRead++ * 7; - if (bytesRead > 3) { - result = DecodeResult.TOO_BIG; - return false; - } - if ((k & 0x80) != 128) { - result = DecodeResult.SUCCESS; - return false; - } - return true; - } - - public int getReadVarint() { - return readVarint; - } - - public int getBytesRead() { - return bytesRead; - } - - public DecodeResult getResult() { - return result; - } - - public enum DecodeResult { - SUCCESS, - TOO_SHORT, - TOO_BIG, - RUN_OF_ZEROES - } -} diff --git a/proxy/src/main/java/com/velocitypowered/proxy/protocol/packet/LoginPluginMessagePacket.java b/proxy/src/main/java/com/velocitypowered/proxy/protocol/packet/LoginPluginMessagePacket.java index 213492cbf..682785eb2 100644 --- a/proxy/src/main/java/com/velocitypowered/proxy/protocol/packet/LoginPluginMessagePacket.java +++ b/proxy/src/main/java/com/velocitypowered/proxy/protocol/packet/LoginPluginMessagePacket.java @@ -63,7 +63,7 @@ public class LoginPluginMessagePacket extends DeferredByteBufHolder implements M @Override public void decode(ByteBuf buf, ProtocolUtils.Direction direction, ProtocolVersion version) { - this.id = ProtocolUtils.readVarIntSafelyOrThrow(buf); + this.id = ProtocolUtils.readVarInt(buf); this.channel = ProtocolUtils.readString(buf); if (buf.isReadable()) { this.replace(buf.readRetainedSlice(buf.readableBytes())); diff --git a/proxy/src/test/java/com/velocitypowered/proxy/protocol/ProtocolUtilsTest.java b/proxy/src/test/java/com/velocitypowered/proxy/protocol/ProtocolUtilsTest.java index f476de67c..1ed59cd83 100644 --- a/proxy/src/test/java/com/velocitypowered/proxy/protocol/ProtocolUtilsTest.java +++ b/proxy/src/test/java/com/velocitypowered/proxy/protocol/ProtocolUtilsTest.java @@ -70,7 +70,7 @@ public class ProtocolUtilsTest { private void writeReadTestOld(ByteBuf buf, int test) { buf.clear(); writeVarIntOld(buf, test); - assertEquals(test, ProtocolUtils.readVarIntSafely(buf)); + assertEquals(test, ProtocolUtils.readVarInt(buf)); } @Test @@ -103,7 +103,7 @@ public class ProtocolUtilsTest { "Encoding of " + i + " was invalid"); assertEquals(i, oldReadVarIntSafely(varintNew)); - assertEquals(i, ProtocolUtils.readVarIntSafely(varintOld)); + assertEquals(i, ProtocolUtils.readVarInt(varintOld)); varintNew.clear(); varintOld.clear();