diff --git a/api/src/main/java/com/viaversion/viaversion/api/connection/ProtocolInfo.java b/api/src/main/java/com/viaversion/viaversion/api/connection/ProtocolInfo.java index 5aa83a46c..1d25712b5 100644 --- a/api/src/main/java/com/viaversion/viaversion/api/connection/ProtocolInfo.java +++ b/api/src/main/java/com/viaversion/viaversion/api/connection/ProtocolInfo.java @@ -23,6 +23,7 @@ package com.viaversion.viaversion.api.connection; import com.viaversion.viaversion.api.protocol.ProtocolPipeline; +import com.viaversion.viaversion.api.protocol.packet.Direction; import com.viaversion.viaversion.api.protocol.packet.State; import java.util.UUID; import org.checkerframework.checker.nullness.qual.Nullable; @@ -33,10 +34,63 @@ public interface ProtocolInfo { * Returns the protocol state the user is currently in. * * @return protocol state + * @deprecated server and client can be in different states, use {@link #getClientState()} or {@link #getServerState()} */ - State getState(); + @Deprecated/*(forRemoval = true)*/ + default State getState() { + return this.getServerState(); + } - void setState(State state); + /** + * Returns the protocol state the client is currently in. + * + * @return the client protocol state + */ + State getClientState(); + + /** + * Returns the protocol state the server is currently in. + * + * @return the server protocol state + */ + State getServerState(); + + /** + * Returns the protocol state for the given direction. + * + * @param direction protocol direction + * @return state for the given direction + */ + default State getState(final Direction direction) { + // Return the state the packet is coming from + return direction == Direction.CLIENTBOUND ? this.getServerState() : this.getClientState(); + } + + /** + * Sets both client and server state. + * + * @param state the new protocol state + * @see #setClientState(State) + * @see #setServerState(State) + */ + default void setState(final State state) { + this.setClientState(state); + this.setServerState(state); + } + + /** + * Sets the client protocol state. + * + * @param clientState the new client protocol state + */ + void setClientState(State clientState); + + /** + * Sets the server protocol state. + * + * @param serverState the new server protocol state + */ + void setServerState(State serverState); /** * Returns the user's protocol version, or -1 if not set. diff --git a/api/src/main/java/com/viaversion/viaversion/api/protocol/AbstractProtocol.java b/api/src/main/java/com/viaversion/viaversion/api/protocol/AbstractProtocol.java index 6eefd2aea..726768b6a 100644 --- a/api/src/main/java/com/viaversion/viaversion/api/protocol/AbstractProtocol.java +++ b/api/src/main/java/com/viaversion/viaversion/api/protocol/AbstractProtocol.java @@ -96,19 +96,7 @@ public abstract class AbstractProtocol wrapper.user().getProtocolInfo().setState(State.CONFIGURATION)); - } - - final ServerboundPacketType finishConfigurationPacket = finishConfigurationPacket(); - if (finishConfigurationPacket != null) { - final int id = finishConfigurationPacket.getId(); - registerServerbound(State.CONFIGURATION, id, id, wrapper -> wrapper.user().getProtocolInfo().setState(State.PLAY)); - } + registerConfigurationChangeHandlers(); // Register the rest of the ids with no handlers if necessary if (unmappedClientboundPacketType != null && mappedClientboundPacketType != null @@ -131,6 +119,33 @@ public abstract class AbstractProtocol void registerPacketIdChanges( Map> unmappedPacketTypes, Map> mappedPacketTypes, @@ -228,7 +243,18 @@ public abstract class AbstractProtocol> packetTypes = packetTypesProvider.unmappedClientboundPacketTypes(); + final PacketTypeMap packetTypeMap = packetTypes.get(State.PLAY); + return packetTypeMap != null ? packetTypeMap.typeByName("START_CONFIGURATION") : null; + } + + protected @Nullable ServerboundPacketType serverboundFinishConfigurationPacket() { + // To be overridden + return null; + } + + protected @Nullable ClientboundPacketType clientboundFinishConfigurationPacket() { // To be overridden return null; } @@ -429,6 +455,14 @@ public abstract class AbstractProtocol wrapper.user().getProtocolInfo().setClientState(state); + } + + protected PacketHandler setServerStateHandler(final State state) { + return wrapper -> wrapper.user().getProtocolInfo().setClientState(state); + } + @Override public PacketTypesProvider getPacketTypesProvider() { return packetTypesProvider; diff --git a/api/src/main/java/com/viaversion/viaversion/api/protocol/ProtocolPipeline.java b/api/src/main/java/com/viaversion/viaversion/api/protocol/ProtocolPipeline.java index dd588de96..bd43a9fcc 100644 --- a/api/src/main/java/com/viaversion/viaversion/api/protocol/ProtocolPipeline.java +++ b/api/src/main/java/com/viaversion/viaversion/api/protocol/ProtocolPipeline.java @@ -73,6 +73,8 @@ public interface ProtocolPipeline extends SimpleProtocol { */ List pipes(); + List reversedPipes(); + /** * Returns whether this pipe has protocols that are not base protocols, as given by {@link Protocol#isBaseProtocol()}. * diff --git a/api/src/main/java/com/viaversion/viaversion/api/protocol/packet/State.java b/api/src/main/java/com/viaversion/viaversion/api/protocol/packet/State.java index 1122c764b..9b5ee0329 100644 --- a/api/src/main/java/com/viaversion/viaversion/api/protocol/packet/State.java +++ b/api/src/main/java/com/viaversion/viaversion/api/protocol/packet/State.java @@ -27,6 +27,6 @@ public enum State { HANDSHAKE, STATUS, LOGIN, - PLAY, - CONFIGURATION + CONFIGURATION, + PLAY } diff --git a/common/src/main/java/com/viaversion/viaversion/connection/ProtocolInfoImpl.java b/common/src/main/java/com/viaversion/viaversion/connection/ProtocolInfoImpl.java index 34922fa4a..508ac6387 100644 --- a/common/src/main/java/com/viaversion/viaversion/connection/ProtocolInfoImpl.java +++ b/common/src/main/java/com/viaversion/viaversion/connection/ProtocolInfoImpl.java @@ -17,6 +17,7 @@ */ package com.viaversion.viaversion.connection; +import com.viaversion.viaversion.api.Via; import com.viaversion.viaversion.api.connection.ProtocolInfo; import com.viaversion.viaversion.api.connection.UserConnection; import com.viaversion.viaversion.api.protocol.ProtocolPipeline; @@ -26,25 +27,42 @@ import java.util.UUID; public class ProtocolInfoImpl implements ProtocolInfo { private final UserConnection connection; - private State state = State.HANDSHAKE; + private State clientState = State.HANDSHAKE; + private State serverState = State.HANDSHAKE; private int protocolVersion = -1; private int serverProtocolVersion = -1; private String username; private UUID uuid; private ProtocolPipeline pipeline; - public ProtocolInfoImpl(UserConnection connection) { + public ProtocolInfoImpl(final UserConnection connection) { this.connection = connection; } @Override - public State getState() { - return state; + public State getClientState() { + return clientState; } @Override - public void setState(State state) { - this.state = state; + public void setClientState(final State clientState) { + if (Via.getManager().debugHandler().enabled()) { + Via.getPlatform().getLogger().info("Client state changed from " + this.clientState + " to " + clientState + " for " + connection.getProtocolInfo().getUuid()); + } + this.clientState = clientState; + } + + @Override + public State getServerState() { + return serverState; + } + + @Override + public void setServerState(final State serverState) { + if (Via.getManager().debugHandler().enabled()) { + Via.getPlatform().getLogger().info("Server state changed from " + this.serverState + " to " + serverState + " for " + connection.getProtocolInfo().getUuid()); + } + this.serverState = serverState; } @Override @@ -108,7 +126,8 @@ public class ProtocolInfoImpl implements ProtocolInfo { @Override public String toString() { return "ProtocolInfo{" + - "state=" + state + + "clientState=" + clientState + + ", serverState=" + serverState + ", protocolVersion=" + protocolVersion + ", serverProtocolVersion=" + serverProtocolVersion + ", username='" + username + '\'' + diff --git a/common/src/main/java/com/viaversion/viaversion/connection/UserConnectionImpl.java b/common/src/main/java/com/viaversion/viaversion/connection/UserConnectionImpl.java index 444d7feed..7d3e9d196 100644 --- a/common/src/main/java/com/viaversion/viaversion/connection/UserConnectionImpl.java +++ b/common/src/main/java/com/viaversion/viaversion/connection/UserConnectionImpl.java @@ -27,6 +27,7 @@ import com.viaversion.viaversion.api.protocol.Protocol; import com.viaversion.viaversion.api.protocol.packet.Direction; import com.viaversion.viaversion.api.protocol.packet.PacketTracker; import com.viaversion.viaversion.api.protocol.packet.PacketWrapper; +import com.viaversion.viaversion.api.protocol.packet.State; import com.viaversion.viaversion.api.type.Type; import com.viaversion.viaversion.exception.CancelException; import com.viaversion.viaversion.protocol.packet.PacketWrapperImpl; @@ -324,8 +325,9 @@ public class UserConnectionImpl implements UserConnection { } PacketWrapper wrapper = new PacketWrapperImpl(id, buf, this); + State state = protocolInfo.getState(direction); try { - protocolInfo.getPipeline().transform(direction, protocolInfo.getState(), wrapper); + protocolInfo.getPipeline().transform(direction, state, wrapper); } catch (CancelException ex) { throw cancelSupplier.apply(ex); } diff --git a/common/src/main/java/com/viaversion/viaversion/protocol/ProtocolPipelineImpl.java b/common/src/main/java/com/viaversion/viaversion/protocol/ProtocolPipelineImpl.java index c3082c7ee..de4884920 100644 --- a/common/src/main/java/com/viaversion/viaversion/protocol/ProtocolPipelineImpl.java +++ b/common/src/main/java/com/viaversion/viaversion/protocol/ProtocolPipelineImpl.java @@ -30,6 +30,7 @@ import com.viaversion.viaversion.api.protocol.packet.PacketWrapper; import com.viaversion.viaversion.api.protocol.packet.State; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -43,6 +44,7 @@ public class ProtocolPipelineImpl extends AbstractSimpleProtocol implements Prot * Protocol list ordered from client to server transforation with the base protocols at the end. */ private List protocolList; + private List reversedProtocolList; private Set> protocolSet; public ProtocolPipelineImpl(UserConnection userConnection) { @@ -54,12 +56,14 @@ public class ProtocolPipelineImpl extends AbstractSimpleProtocol implements Prot @Override protected void registerPackets() { protocolList = new CopyOnWriteArrayList<>(); + reversedProtocolList = new CopyOnWriteArrayList<>(); // Create a backing set for faster contains calls with larger pipes protocolSet = Sets.newSetFromMap(new ConcurrentHashMap<>()); // This is a pipeline so we register basic pipes - Protocol baseProtocol = Via.getManager().getProtocolManager().getBaseProtocol(); + final Protocol baseProtocol = Via.getManager().getProtocolManager().getBaseProtocol(); protocolList.add(baseProtocol); + reversedProtocolList.add(baseProtocol); protocolSet.add(baseProtocol.getClass()); } @@ -71,12 +75,22 @@ public class ProtocolPipelineImpl extends AbstractSimpleProtocol implements Prot @Override public void add(Protocol protocol) { protocolList.add(protocol); + protocolSet.add(protocol.getClass()); protocol.init(userConnection); if (!protocol.isBaseProtocol()) { - moveBaseProtocolsToTail(); + moveBaseProtocolsToTail(protocolList); } + + setReversedProtocolList(); + } + + private void setReversedProtocolList() { + final List reversedProtocolList = new ArrayList<>(protocolList); + Collections.reverse(this.reversedProtocolList); + moveBaseProtocolsToTail(reversedProtocolList); + this.reversedProtocolList = new CopyOnWriteArrayList<>(reversedProtocolList); } @Override @@ -87,25 +101,26 @@ public class ProtocolPipelineImpl extends AbstractSimpleProtocol implements Prot this.protocolSet.add(protocol.getClass()); } - moveBaseProtocolsToTail(); + moveBaseProtocolsToTail(protocolList); + setReversedProtocolList(); } - private void moveBaseProtocolsToTail() { - // Move base Protocols to the end, so the login packets can be modified by other protocols - List baseProtocols = null; - for (Protocol protocol : protocolList) { + private List filterBaseProtocols(final List protocols) { + final List baseProtocols = new ArrayList<>(); + for (final Protocol protocol : protocolList) { if (protocol.isBaseProtocol()) { - if (baseProtocols == null) { - baseProtocols = new ArrayList<>(); - } - baseProtocols.add(protocol); } } + return baseProtocols; + } - if (baseProtocols != null) { - protocolList.removeAll(baseProtocols); - protocolList.addAll(baseProtocols); + private void moveBaseProtocolsToTail(final List protocols) { + // Move base Protocols to the end, so the login packets can be modified by other protocols + final List baseProtocols = filterBaseProtocols(protocols); + if (!baseProtocols.isEmpty()) { + protocols.removeAll(baseProtocols); + protocols.addAll(baseProtocols); } } @@ -119,7 +134,7 @@ public class ProtocolPipelineImpl extends AbstractSimpleProtocol implements Prot } // Apply protocols - packetWrapper.apply(direction, state, 0, protocolList, direction == Direction.CLIENTBOUND); + packetWrapper.apply(direction, state, 0, protocolListFor(direction), true); super.transform(direction, state, packetWrapper); if (debugHandler.enabled() && debugHandler.logPostPacketTransform() && debugHandler.shouldLog(packetWrapper, direction)) { @@ -127,6 +142,10 @@ public class ProtocolPipelineImpl extends AbstractSimpleProtocol implements Prot } } + private List protocolListFor(final Direction direction) { + return direction == Direction.CLIENTBOUND ? reversedProtocolList : protocolList; + } + private void logPacket(Direction direction, State state, PacketWrapper packetWrapper, int originalID) { // Debug packet int clientProtocol = userConnection.getProtocolInfo().getProtocolVersion(); @@ -169,6 +188,11 @@ public class ProtocolPipelineImpl extends AbstractSimpleProtocol implements Prot return protocolList; } + @Override + public List reversedPipes() { + return reversedProtocolList; + } + @Override public boolean hasNonBaseProtocols() { for (Protocol protocol : protocolList) { diff --git a/common/src/main/java/com/viaversion/viaversion/protocol/packet/PacketWrapperImpl.java b/common/src/main/java/com/viaversion/viaversion/protocol/packet/PacketWrapperImpl.java index f6c3c504e..7af754c8f 100644 --- a/common/src/main/java/com/viaversion/viaversion/protocol/packet/PacketWrapperImpl.java +++ b/common/src/main/java/com/viaversion/viaversion/protocol/packet/PacketWrapperImpl.java @@ -19,6 +19,7 @@ package com.viaversion.viaversion.protocol.packet; import com.google.common.base.Preconditions; import com.viaversion.viaversion.api.Via; +import com.viaversion.viaversion.api.connection.ProtocolInfo; import com.viaversion.viaversion.api.connection.UserConnection; import com.viaversion.viaversion.api.protocol.Protocol; import com.viaversion.viaversion.api.protocol.packet.Direction; @@ -309,8 +310,10 @@ public class PacketWrapperImpl implements PacketWrapper { */ private ByteBuf constructPacket(Class packetProtocol, boolean skipCurrentPipeline, Direction direction) throws Exception { // Apply current pipeline - for outgoing protocol, the collection will be reversed in the apply method - Protocol[] protocols = user().getProtocolInfo().getPipeline().pipes().toArray(PROTOCOL_ARRAY); - boolean reverse = direction == Direction.CLIENTBOUND; + final ProtocolInfo protocolInfo = user().getProtocolInfo(); + final boolean reverse = direction == Direction.CLIENTBOUND; + final List pipes = reverse ? protocolInfo.getPipeline().reversedPipes() : protocolInfo.getPipeline().pipes(); + final Protocol[] protocols = pipes.toArray(PROTOCOL_ARRAY); int index = -1; for (int i = 0; i < protocols.length; i++) { if (protocols[i].getClass() == packetProtocol) { @@ -332,8 +335,8 @@ public class PacketWrapperImpl implements PacketWrapper { resetReader(); // Apply other protocols - apply(direction, user().getProtocolInfo().getState(), index, protocols, reverse); - ByteBuf output = inputBuffer == null ? user().getChannel().alloc().buffer() : inputBuffer.alloc().buffer(); + apply(direction, protocolInfo.getState(direction), index, protocols, true); + final ByteBuf output = inputBuffer == null ? user().getChannel().alloc().buffer() : inputBuffer.alloc().buffer(); try { writeToBuffer(output); return output.retain(); @@ -402,15 +405,22 @@ public class PacketWrapperImpl implements PacketWrapper { private PacketWrapperImpl apply(Direction direction, State state, int index, Protocol[] pipeline, boolean reverse) throws Exception { // Reset the reader after every transformation for the packetWrapper, so it can be recycled across packets + State updatedState = state; // The state might change while transforming, so we need to check for that if (reverse) { for (int i = index; i >= 0; i--) { - pipeline[i].transform(direction, state, this); + pipeline[i].transform(direction, updatedState, this); resetReader(); + if (this.packetType != null) { + updatedState = this.packetType.state(); + } } } else { for (int i = index; i < pipeline.length; i++) { - pipeline[i].transform(direction, state, this); + pipeline[i].transform(direction, updatedState, this); resetReader(); + if (this.packetType != null) { + updatedState = this.packetType.state(); + } } } return this; diff --git a/common/src/main/java/com/viaversion/viaversion/protocols/base/BaseProtocol1_7.java b/common/src/main/java/com/viaversion/viaversion/protocols/base/BaseProtocol1_7.java index 979756085..a689c89e9 100644 --- a/common/src/main/java/com/viaversion/viaversion/protocols/base/BaseProtocol1_7.java +++ b/common/src/main/java/com/viaversion/viaversion/protocols/base/BaseProtocol1_7.java @@ -33,12 +33,10 @@ import com.viaversion.viaversion.api.protocol.version.VersionProvider; import com.viaversion.viaversion.api.type.Type; import com.viaversion.viaversion.protocol.ProtocolManagerImpl; import com.viaversion.viaversion.protocol.ServerProtocolVersionSingleton; -import com.viaversion.viaversion.protocols.protocol1_20_2to1_20.packet.ServerboundConfigurationPackets1_20_2; import com.viaversion.viaversion.protocols.protocol1_9to1_8.Protocol1_9To1_8; import com.viaversion.viaversion.util.ChatColorUtil; import com.viaversion.viaversion.util.GsonUtil; import io.netty.channel.ChannelFuture; - import java.util.List; import java.util.UUID; import java.util.logging.Level; @@ -121,8 +119,9 @@ public class BaseProtocol1_7 extends AbstractProtocol { // Login Success Packet registerClientbound(ClientboundLoginPackets.GAME_PROFILE, wrapper -> { ProtocolInfo info = wrapper.user().getProtocolInfo(); - if (info.getProtocolVersion() < ProtocolVersion.v1_20_2.getVersion()) { - info.setState(State.PLAY); + info.setServerState(State.PLAY); + if (info.getProtocolVersion() < ProtocolVersion.v1_20_2.getVersion()) { // 1.20.2+ clients will send a login ack first + info.setClientState(State.PLAY); } UUID uuid = passthroughLoginUUID(wrapper); @@ -165,7 +164,10 @@ public class BaseProtocol1_7 extends AbstractProtocol { } }); - registerServerbound(ServerboundLoginPackets.LOGIN_ACKNOWLEDGED, wrapper -> wrapper.user().getProtocolInfo().setState(State.CONFIGURATION)); + registerServerbound(ServerboundLoginPackets.LOGIN_ACKNOWLEDGED, wrapper -> { + final ProtocolInfo info = wrapper.user().getProtocolInfo(); + info.setClientState(State.CONFIGURATION); + }); } @Override diff --git a/common/src/main/java/com/viaversion/viaversion/protocols/protocol1_20_2to1_20/Protocol1_20_2To1_20.java b/common/src/main/java/com/viaversion/viaversion/protocols/protocol1_20_2to1_20/Protocol1_20_2To1_20.java index 72b24a405..c5d7c04e2 100644 --- a/common/src/main/java/com/viaversion/viaversion/protocols/protocol1_20_2to1_20/Protocol1_20_2To1_20.java +++ b/common/src/main/java/com/viaversion/viaversion/protocols/protocol1_20_2to1_20/Protocol1_20_2To1_20.java @@ -20,6 +20,7 @@ package com.viaversion.viaversion.protocols.protocol1_20_2to1_20; import com.github.steveice10.opennbt.tag.builtin.CompoundTag; import com.google.gson.JsonElement; import com.viaversion.viaversion.api.Via; +import com.viaversion.viaversion.api.connection.ProtocolInfo; import com.viaversion.viaversion.api.connection.UserConnection; import com.viaversion.viaversion.api.data.MappingData; import com.viaversion.viaversion.api.data.MappingDataBase; @@ -114,10 +115,6 @@ public final class Protocol1_20_2To1_20 extends AbstractProtocol { wrapper.user().get(ConfigurationState.class).setBridgePhase(BridgePhase.PROFILE_SENT); - - // Set the state according to what the server expects. All packets between now and when the client - // switches to PLAY as well will be discarded after being dealt with. - wrapper.user().getProtocolInfo().setState(State.PLAY); }); registerServerbound(State.LOGIN, ServerboundLoginPackets.LOGIN_ACKNOWLEDGED.getId(), -1, wrapper -> { @@ -131,6 +128,8 @@ public final class Protocol1_20_2To1_20 extends AbstractProtocol { wrapper.cancel(); + wrapper.user().getProtocolInfo().setClientState(State.PLAY); + final ConfigurationState configurationState = wrapper.user().get(ConfigurationState.class); configurationState.setBridgePhase(BridgePhase.NONE); configurationState.sendQueuedPackets(wrapper.user()); @@ -153,6 +152,7 @@ public final class Protocol1_20_2To1_20 extends AbstractProtocol