diff --git a/common/src/main/java/com/viaversion/viaversion/protocol/RedirectProtocolVersion.java b/common/src/main/java/com/viaversion/viaversion/protocol/RedirectProtocolVersion.java index dc28c6b48..8bb41efa6 100644 --- a/common/src/main/java/com/viaversion/viaversion/protocol/RedirectProtocolVersion.java +++ b/common/src/main/java/com/viaversion/viaversion/protocol/RedirectProtocolVersion.java @@ -24,9 +24,10 @@ import java.util.Comparator; import org.checkerframework.checker.nullness.qual.Nullable; /** - * A {@link ProtocolVersion} with the version type {@link VersionType#SPECIAL} that compares equal to the given - * origin version. The origin version will also be used in {@link com.viaversion.viaversion.protocols.base.InitialBaseProtocol} - * to determine the correct base protocol. + * Intended API class for protocol versions with the version type {@link VersionType#SPECIAL}. + *

+ * Compares equal to the given origin version and allows base protocol determination via {@link #getBaseProtocolVersion()} + * which can be null for special cases where there is no base protocol. */ public class RedirectProtocolVersion extends ProtocolVersion { @@ -56,4 +57,11 @@ public class RedirectProtocolVersion extends ProtocolVersion { public ProtocolVersion getOrigin() { return origin; } + + /** + * @return the protocol version used to determine the base protocol, null in case there is no base protocol. + */ + public @Nullable ProtocolVersion getBaseProtocolVersion() { + return origin; + } } diff --git a/common/src/main/java/com/viaversion/viaversion/protocols/base/InitialBaseProtocol.java b/common/src/main/java/com/viaversion/viaversion/protocols/base/InitialBaseProtocol.java index 4bfcace2a..648da9406 100644 --- a/common/src/main/java/com/viaversion/viaversion/protocols/base/InitialBaseProtocol.java +++ b/common/src/main/java/com/viaversion/viaversion/protocols/base/InitialBaseProtocol.java @@ -96,17 +96,22 @@ public class InitialBaseProtocol extends AbstractProtocol protocolPath = protocolManager.getProtocolPath(info.protocolVersion(), serverProtocol); - // Add Base Protocol ProtocolPipeline pipeline = info.getPipeline(); - // Special versions might compare equal to normal versions and would break this getter, - // platforms either need to use the RedirectProtocolVersion API or add the base protocols manually + // Save manually added protocols for later + List alreadyAdded = new ArrayList<>(pipeline.pipes()); + + // Special versions might compare equal to normal versions and would the normal lookup, + // platforms can use the RedirectProtocolVersion API or need to manually handle their base protocols. + ProtocolVersion baseProtocolVersion = null; if (serverProtocol.getVersionType() != VersionType.SPECIAL) { - for (final Protocol protocol : protocolManager.getBaseProtocols(serverProtocol)) { - pipeline.add(protocol); - } + baseProtocolVersion = serverProtocol; } else if (serverProtocol instanceof RedirectProtocolVersion version) { - for (final Protocol protocol : protocolManager.getBaseProtocols(version.getOrigin())) { + baseProtocolVersion = version.getBaseProtocolVersion(); + } + if (baseProtocolVersion != null) { + // Add base protocols + for (final Protocol protocol : protocolManager.getBaseProtocols(baseProtocolVersion)) { pipeline.add(protocol); } } @@ -131,7 +136,7 @@ public class InitialBaseProtocol extends AbstractProtocol protocols = new ArrayList<>(pipeline.pipes()); - protocols.remove(this); + protocols.removeAll(alreadyAdded); // Skip all manually added protocols to prevent double handling wrapper.apply(Direction.SERVERBOUND, State.HANDSHAKE, protocols); } catch (CancelException e) { wrapper.cancel();