Shadowsocks2022 Client Implementation Improvements (#2770)

* Shadowsocks2022 Client Implementation Improvements

1. Added UDP Replay Detection
2. Added UDP Processor State Cache
3. Added More Detailed Output for Time Difference Error
4. Replaced Mutex with RWMutex for reduced lock contention
5. Added per server session tracking of decryption cache and anti-replay window
6. Adjust server session track time
7. Increase per session buffer to 128
8. Fix client crash when EIH is not enabled
9. Fix client log contains non-human-friendly content
10.Remove mark and remove for trackedSessions
11. Fixed packet size uint16 overflow issue
This commit is contained in:
Xiaokang Wang (Shelikhoo) 2023-11-24 00:40:07 +00:00 committed by GitHub
parent 62428d8011
commit b6da3e86a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 254 additions and 44 deletions

View File

@ -91,6 +91,7 @@ type AuthenticationReader struct {
transferType protocol.TransferType transferType protocol.TransferType
padding PaddingLengthGenerator padding PaddingLengthGenerator
size uint16 size uint16
sizeOffset uint16
paddingLen uint16 paddingLen uint16
hasSize bool hasSize bool
done bool done bool
@ -104,6 +105,9 @@ func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, re
padding: paddingLen, padding: paddingLen,
sizeBytes: make([]byte, sizeParser.SizeBytes()), sizeBytes: make([]byte, sizeParser.SizeBytes()),
} }
if chunkSizeDecoderWithOffset, ok := sizeParser.(ChunkSizeDecoderWithOffset); ok {
r.sizeOffset = chunkSizeDecoderWithOffset.HasConstantOffset()
}
if breader, ok := reader.(*buf.BufferedReader); ok { if breader, ok := reader.(*buf.BufferedReader); ok {
r.reader = breader r.reader = breader
} else { } else {
@ -160,12 +164,14 @@ func (r *AuthenticationReader) readInternal(soft bool, mb *buf.MultiBuffer) erro
return err return err
} }
if size == uint16(r.auth.Overhead())+padding { if size+r.sizeOffset == uint16(r.auth.Overhead())+padding {
r.done = true r.done = true
return io.EOF return io.EOF
} }
if soft && int32(size) > r.reader.BufferedBytes() { effectiveSize := int32(size) + int32(r.sizeOffset)
if soft && effectiveSize > r.reader.BufferedBytes() {
r.size = size r.size = size
r.paddingLen = padding r.paddingLen = padding
r.hasSize = true r.hasSize = true
@ -173,7 +179,7 @@ func (r *AuthenticationReader) readInternal(soft bool, mb *buf.MultiBuffer) erro
} }
if size <= buf.Size { if size <= buf.Size {
b, err := r.readBuffer(int32(size), int32(padding)) b, err := r.readBuffer(effectiveSize, int32(padding))
if err != nil { if err != nil {
return nil return nil
} }
@ -181,16 +187,16 @@ func (r *AuthenticationReader) readInternal(soft bool, mb *buf.MultiBuffer) erro
return nil return nil
} }
payload := bytespool.Alloc(int32(size)) payload := bytespool.Alloc(effectiveSize)
defer bytespool.Free(payload) defer bytespool.Free(payload)
if _, err := io.ReadFull(r.reader, payload[:size]); err != nil { if _, err := io.ReadFull(r.reader, payload[:effectiveSize]); err != nil {
return err return err
} }
size -= padding effectiveSize -= int32(padding)
rb, err := r.auth.Open(payload[:0], payload[:size]) rb, err := r.auth.Open(payload[:0], payload[:effectiveSize])
if err != nil { if err != nil {
return err return err
} }

View File

@ -15,6 +15,13 @@ type ChunkSizeDecoder interface {
Decode([]byte) (uint16, error) Decode([]byte) (uint16, error)
} }
type ChunkSizeDecoderWithOffset interface {
ChunkSizeDecoder
// HasConstantOffset set the constant offset of Decode
// The effective size should be HasConstantOffset() + Decode(_).[0](uint64)
HasConstantOffset() uint16
}
// ChunkSizeEncoder is a utility class to encode size value into bytes. // ChunkSizeEncoder is a utility class to encode size value into bytes.
type ChunkSizeEncoder interface { type ChunkSizeEncoder interface {
SizeBytes() int32 SizeBytes() int32

2
go.mod
View File

@ -17,6 +17,7 @@ require (
github.com/miekg/dns v1.1.57 github.com/miekg/dns v1.1.57
github.com/mustafaturan/bus v1.0.2 github.com/mustafaturan/bus v1.0.2
github.com/pelletier/go-toml v1.9.5 github.com/pelletier/go-toml v1.9.5
github.com/pion/transport/v2 v2.2.1
github.com/pires/go-proxyproto v0.7.0 github.com/pires/go-proxyproto v0.7.0
github.com/quic-go/quic-go v0.40.0 github.com/quic-go/quic-go v0.40.0
github.com/refraction-networking/utls v1.5.4 github.com/refraction-networking/utls v1.5.4
@ -68,7 +69,6 @@ require (
github.com/pion/logging v0.2.2 // indirect github.com/pion/logging v0.2.2 // indirect
github.com/pion/randutil v0.1.0 // indirect github.com/pion/randutil v0.1.0 // indirect
github.com/pion/sctp v1.8.7 // indirect github.com/pion/sctp v1.8.7 // indirect
github.com/pion/transport/v2 v2.2.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/quic-go/qtls-go1-20 v0.4.1 // indirect github.com/quic-go/qtls-go1-20 v0.4.1 // indirect
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect

View File

@ -2,6 +2,7 @@ package v5cfg
import ( import (
"context" "context"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
core "github.com/v2fly/v2ray-core/v5" core "github.com/v2fly/v2ray-core/v5"

View File

@ -111,7 +111,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
if err != nil { if err != nil {
return newError("failed to find an available destination").AtWarning().Base(err) return newError("failed to find an available destination").AtWarning().Base(err)
} }
newError("tunneling request to ", destination, " via ", network, ":", c.config.Address).WriteToLog(session.ExportIDToError(ctx)) newError("tunneling request to ", destination, " via ", network, ":", net.TCPDestination(c.config.Address.AsAddress(), net.Port(c.config.Port)).NetAddr()).WriteToLog(session.ExportIDToError(ctx))
defer conn.Close() defer conn.Close()
request := &TCPRequest{ request := &TCPRequest{

View File

@ -11,14 +11,17 @@ import (
"github.com/v2fly/v2ray-core/v5/common/buf" "github.com/v2fly/v2ray-core/v5/common/buf"
"github.com/v2fly/v2ray-core/v5/common/net" "github.com/v2fly/v2ray-core/v5/common/net"
"github.com/v2fly/v2ray-core/v5/transport/internet" "github.com/v2fly/v2ray-core/v5/transport/internet"
"github.com/pion/transport/v2/replaydetector"
) )
func NewClientUDPSession(ctx context.Context, conn io.ReadWriteCloser, packetProcessor UDPClientPacketProcessor) *ClientUDPSession { func NewClientUDPSession(ctx context.Context, conn io.ReadWriteCloser, packetProcessor UDPClientPacketProcessor) *ClientUDPSession {
session := &ClientUDPSession{ session := &ClientUDPSession{
locker: &sync.Mutex{}, locker: &sync.RWMutex{},
conn: conn, conn: conn,
packetProcessor: packetProcessor, packetProcessor: packetProcessor,
sessionMap: make(map[string]*ClientUDPSessionConn), sessionMap: make(map[string]*ClientUDPSessionConn),
sessionMapAlias: make(map[string]string),
} }
session.ctx, session.finish = context.WithCancel(ctx) session.ctx, session.finish = context.WithCancel(ctx)
@ -27,16 +30,87 @@ func NewClientUDPSession(ctx context.Context, conn io.ReadWriteCloser, packetPro
} }
type ClientUDPSession struct { type ClientUDPSession struct {
locker *sync.Mutex locker *sync.RWMutex
conn io.ReadWriteCloser conn io.ReadWriteCloser
packetProcessor UDPClientPacketProcessor packetProcessor UDPClientPacketProcessor
sessionMap map[string]*ClientUDPSessionConn sessionMap map[string]*ClientUDPSessionConn
sessionMapAlias map[string]string
ctx context.Context ctx context.Context
finish func() finish func()
} }
func (c *ClientUDPSession) GetCachedState(sessionID string) UDPClientPacketProcessorCachedState {
c.locker.RLock()
defer c.locker.RUnlock()
state, ok := c.sessionMap[sessionID]
if !ok {
return nil
}
return state.cachedProcessorState
}
func (c *ClientUDPSession) GetCachedServerState(serverSessionID string) UDPClientPacketProcessorCachedState {
c.locker.RLock()
defer c.locker.RUnlock()
clientSessionID := c.getCachedStateAlias(serverSessionID)
if clientSessionID == "" {
return nil
}
state, ok := c.sessionMap[clientSessionID]
if !ok {
return nil
}
if serverState, ok := state.trackedServerSessionID[serverSessionID]; !ok {
return nil
} else {
return serverState.cachedRecvProcessorState
}
}
func (c *ClientUDPSession) getCachedStateAlias(serverSessionID string) string {
state, ok := c.sessionMapAlias[serverSessionID]
if !ok {
return ""
}
return state
}
func (c *ClientUDPSession) PutCachedState(sessionID string, cache UDPClientPacketProcessorCachedState) {
c.locker.RLock()
defer c.locker.RUnlock()
state, ok := c.sessionMap[sessionID]
if !ok {
return
}
state.cachedProcessorState = cache
}
func (c *ClientUDPSession) PutCachedServerState(serverSessionID string, cache UDPClientPacketProcessorCachedState) {
c.locker.RLock()
defer c.locker.RUnlock()
clientSessionID := c.getCachedStateAlias(serverSessionID)
if clientSessionID == "" {
return
}
state, ok := c.sessionMap[clientSessionID]
if !ok {
return
}
if serverState, ok := state.trackedServerSessionID[serverSessionID]; ok {
serverState.cachedRecvProcessorState = cache
return
}
}
func (c *ClientUDPSession) Close() error { func (c *ClientUDPSession) Close() error {
c.finish() c.finish()
return c.conn.Close() return c.conn.Close()
@ -45,7 +119,7 @@ func (c *ClientUDPSession) Close() error {
func (c *ClientUDPSession) WriteUDPRequest(request *UDPRequest) error { func (c *ClientUDPSession) WriteUDPRequest(request *UDPRequest) error {
buffer := buf.New() buffer := buf.New()
defer buffer.Release() defer buffer.Release()
err := c.packetProcessor.EncodeUDPRequest(request, buffer) err := c.packetProcessor.EncodeUDPRequest(request, buffer, c)
if request.Payload != nil { if request.Payload != nil {
request.Payload.Release() request.Payload.Release()
} }
@ -69,7 +143,7 @@ func (c *ClientUDPSession) KeepReading() {
return return
} }
if n != 0 { if n != 0 {
err := c.packetProcessor.DecodeUDPResp(buffer[:n], udpResp) err := c.packetProcessor.DecodeUDPResp(buffer[:n], udpResp, c)
if err != nil { if err != nil {
newError("unable to decode udp response").Base(err).WriteToLog() newError("unable to decode udp response").Base(err).WriteToLog()
continue continue
@ -78,13 +152,14 @@ func (c *ClientUDPSession) KeepReading() {
{ {
timeDifference := int64(udpResp.TimeStamp) - time.Now().Unix() timeDifference := int64(udpResp.TimeStamp) - time.Now().Unix()
if timeDifference < -30 || timeDifference > 30 { if timeDifference < -30 || timeDifference > 30 {
newError("udp packet timestamp difference too large, packet discarded").WriteToLog() newError("udp packet timestamp difference too large, packet discarded, time diff = ", timeDifference).WriteToLog()
continue continue
} }
} }
c.locker.Lock() c.locker.RLock()
session, ok := c.sessionMap[string(udpResp.ClientSessionID[:])] session, ok := c.sessionMap[string(udpResp.ClientSessionID[:])]
c.locker.RUnlock()
if ok { if ok {
select { select {
case session.readChan <- udpResp: case session.readChan <- udpResp:
@ -93,7 +168,6 @@ func (c *ClientUDPSession) KeepReading() {
} else { } else {
newError("misbehaving server: unknown client session ID").Base(err).WriteToLog() newError("misbehaving server: unknown client session ID").Base(err).WriteToLog()
} }
c.locker.Unlock()
} }
} }
} }
@ -108,12 +182,13 @@ func (c *ClientUDPSession) NewSessionConn() (internet.AbstractPacketConn, error)
connctx, connfinish := context.WithCancel(c.ctx) connctx, connfinish := context.WithCancel(c.ctx)
sessionConn := &ClientUDPSessionConn{ sessionConn := &ClientUDPSessionConn{
sessionID: string(sessionID), sessionID: string(sessionID),
readChan: make(chan *UDPResponse, 16), readChan: make(chan *UDPResponse, 128),
parent: c, parent: c,
ctx: connctx, ctx: connctx,
finish: connfinish, finish: connfinish,
nextWritePacketID: 0, nextWritePacketID: 0,
trackedServerSessionID: make(map[string]*ClientUDPSessionServerTracker),
} }
c.locker.Lock() c.locker.Lock()
c.sessionMap[sessionConn.sessionID] = sessionConn c.sessionMap[sessionConn.sessionID] = sessionConn
@ -121,19 +196,33 @@ func (c *ClientUDPSession) NewSessionConn() (internet.AbstractPacketConn, error)
return sessionConn, nil return sessionConn, nil
} }
type ClientUDPSessionServerTracker struct {
cachedRecvProcessorState UDPClientPacketProcessorCachedState
rxReplayDetector replaydetector.ReplayDetector
lastSeen time.Time
}
type ClientUDPSessionConn struct { type ClientUDPSessionConn struct {
sessionID string sessionID string
readChan chan *UDPResponse readChan chan *UDPResponse
parent *ClientUDPSession parent *ClientUDPSession
nextWritePacketID uint64 nextWritePacketID uint64
trackedServerSessionID map[string]*ClientUDPSessionServerTracker
cachedProcessorState UDPClientPacketProcessorCachedState
ctx context.Context ctx context.Context
finish func() finish func()
} }
func (c *ClientUDPSessionConn) Close() error { func (c *ClientUDPSessionConn) Close() error {
c.parent.locker.Lock()
delete(c.parent.sessionMap, c.sessionID) delete(c.parent.sessionMap, c.sessionID)
for k := range c.trackedServerSessionID {
delete(c.parent.sessionMapAlias, k)
}
c.parent.locker.Unlock()
c.finish() c.finish()
return nil return nil
} }
@ -160,13 +249,44 @@ func (c *ClientUDPSessionConn) WriteTo(p []byte, addr gonet.Addr) (n int, err er
} }
func (c *ClientUDPSessionConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { func (c *ClientUDPSessionConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
select { for {
case <-c.ctx.Done(): select {
return 0, nil, io.EOF case <-c.ctx.Done():
case resp := <-c.readChan: return 0, nil, io.EOF
n = copy(p, resp.Payload.Bytes()) case resp := <-c.readChan:
resp.Payload.Release() n = copy(p, resp.Payload.Bytes())
addr = &net.UDPAddr{IP: resp.Address.IP(), Port: resp.Port} resp.Payload.Release()
var trackedState *ClientUDPSessionServerTracker
if trackedStateReceived, ok := c.trackedServerSessionID[string(resp.SessionID[:])]; !ok {
for key, value := range c.trackedServerSessionID {
if time.Since(value.lastSeen) > 65*time.Second {
delete(c.trackedServerSessionID, key)
}
}
state := &ClientUDPSessionServerTracker{
rxReplayDetector: replaydetector.New(1024, ^uint64(0)),
}
c.trackedServerSessionID[string(resp.SessionID[:])] = state
c.parent.locker.RLock()
c.parent.sessionMapAlias[string(resp.SessionID[:])] = string(resp.ClientSessionID[:])
c.parent.locker.RUnlock()
trackedState = state
} else {
trackedState = trackedStateReceived
}
if accept, ok := trackedState.rxReplayDetector.Check(resp.PacketID); ok {
accept()
} else {
newError("misbehaving server: replayed packet").Base(err).WriteToLog()
continue
}
trackedState.lastSeen = time.Now()
addr = &net.UDPAddr{IP: resp.Address.IP(), Port: resp.Port}
}
return n, addr, nil
} }
return
} }

View File

@ -4,10 +4,13 @@ import (
"bytes" "bytes"
"crypto/cipher" "crypto/cipher"
cryptoRand "crypto/rand" cryptoRand "crypto/rand"
"encoding/binary"
"io" "io"
"math/rand" "math/rand"
"time" "time"
"github.com/v2fly/v2ray-core/v5/common"
"github.com/lunixbochs/struc" "github.com/lunixbochs/struc"
"github.com/v2fly/v2ray-core/v5/common/buf" "github.com/v2fly/v2ray-core/v5/common/buf"
@ -59,7 +62,7 @@ func (t *TCPRequest) EncodeTCPRequestHeader(effectivePsk []byte,
paddingLength := TCPMinPaddingLength paddingLength := TCPMinPaddingLength
if initialPayload == nil { if initialPayload == nil {
initialPayload = []byte{} initialPayload = []byte{}
paddingLength += rand.Intn(TCPMaxPaddingLength) // TODO INSECURE RANDOM USED paddingLength += 1 + rand.Intn(TCPMaxPaddingLength) // TODO INSECURE RANDOM USED
} }
variableLengthHeader := &TCPRequestHeader3VariableLength{ variableLengthHeader := &TCPRequestHeader3VariableLength{
@ -110,10 +113,14 @@ func (t *TCPRequest) EncodeTCPRequestHeader(effectivePsk []byte,
return newError("failed to pack fixed length header").Base(err) return newError("failed to pack fixed length header").Base(err)
} }
} }
eihGenerator := newAESEIHGeneratorContainer(len(eih), effectivePsk, eih) eihHeader := ExtensibleIdentityHeaders(newAESEIH(0))
eihHeader, err := eihGenerator.GenerateEIH(t.keyDerivation, t.method, requestSalt.Bytes()) if len(eih) != 0 {
if err != nil { eihGenerator := newAESEIHGeneratorContainer(len(eih), effectivePsk, eih)
return newError("failed to construct EIH").Base(err) eihHeaderGenerated, err := eihGenerator.GenerateEIH(t.keyDerivation, t.method, requestSalt.Bytes())
if err != nil {
return newError("failed to construct EIH").Base(err)
}
eihHeader = eihHeaderGenerated
} }
preSessionKeyHeader := &TCPRequestHeader1PreSessionKey{ preSessionKeyHeader := &TCPRequestHeader1PreSessionKey{
Salt: requestSalt, Salt: requestSalt,
@ -206,7 +213,7 @@ func (t *TCPRequest) DecodeTCPResponseHeader(effectivePsk []byte, in io.Reader)
} }
timeDifference := int64(fixedLengthHeader.Timestamp) - time.Now().Unix() timeDifference := int64(fixedLengthHeader.Timestamp) - time.Now().Unix()
if timeDifference < -30 || timeDifference > 30 { if timeDifference < -30 || timeDifference > 30 {
return newError("timestamp is too far away") return newError("timestamp is too far away, timeDifference = ", timeDifference)
} }
t.s2cSaltAssert = fixedLengthHeader.RequestSalt t.s2cSaltAssert = fixedLengthHeader.RequestSalt
@ -239,7 +246,7 @@ func (t *TCPRequest) CreateClientS2CReader(in io.Reader, initialPayload *buf.Buf
if err != nil { if err != nil {
return nil, newError("failed to decrypt initial payload").Base(err) return nil, newError("failed to decrypt initial payload").Base(err)
} }
return crypto.NewAuthenticationReader(AEADAuthenticator, &crypto.AEADChunkSizeParser{ return crypto.NewAuthenticationReader(AEADAuthenticator, &AEADChunkSizeParser{
Auth: AEADAuthenticator, Auth: AEADAuthenticator,
}, in, protocol.TransferTypeStream, nil), nil }, in, protocol.TransferTypeStream, nil), nil
} }
@ -255,3 +262,30 @@ func (t *TCPRequest) CreateClientC2SWriter(writer io.Writer) buf.Writer {
} }
return crypto.NewAuthenticationWriter(AEADAuthenticator, sizeParser, writer, protocol.TransferTypeStream, nil) return crypto.NewAuthenticationWriter(AEADAuthenticator, sizeParser, writer, protocol.TransferTypeStream, nil)
} }
type AEADChunkSizeParser struct {
Auth *crypto.AEADAuthenticator
}
func (p *AEADChunkSizeParser) HasConstantOffset() uint16 {
return uint16(p.Auth.Overhead())
}
func (p *AEADChunkSizeParser) SizeBytes() int32 {
return 2 + int32(p.Auth.Overhead())
}
func (p *AEADChunkSizeParser) Encode(size uint16, b []byte) []byte {
binary.BigEndian.PutUint16(b, size-uint16(p.Auth.Overhead()))
b, err := p.Auth.Seal(b[:0], b[:2])
common.Must(err)
return b
}
func (p *AEADChunkSizeParser) Decode(b []byte) (uint16, error) {
b, err := p.Auth.Open(b[:0], b)
if err != nil {
return 0, err
}
return binary.BigEndian.Uint16(b), nil
}

View File

@ -112,9 +112,18 @@ const (
UDPHeaderTypeServerToClientStream = byte(0x01) UDPHeaderTypeServerToClientStream = byte(0x01)
) )
type UDPClientPacketProcessorCachedStateContainer interface {
GetCachedState(sessionID string) UDPClientPacketProcessorCachedState
PutCachedState(sessionID string, cache UDPClientPacketProcessorCachedState)
GetCachedServerState(serverSessionID string) UDPClientPacketProcessorCachedState
PutCachedServerState(serverSessionID string, cache UDPClientPacketProcessorCachedState)
}
type UDPClientPacketProcessorCachedState interface{}
// UDPClientPacketProcessor // UDPClientPacketProcessor
// Caller retain and receive all ownership of the buffer // Caller retain and receive all ownership of the buffer
type UDPClientPacketProcessor interface { type UDPClientPacketProcessor interface {
EncodeUDPRequest(request *UDPRequest, out *buf.Buffer) error EncodeUDPRequest(request *UDPRequest, out *buf.Buffer, cache UDPClientPacketProcessorCachedStateContainer) error
DecodeUDPResp(input []byte, resp *UDPResponse) error DecodeUDPResp(input []byte, resp *UDPResponse, cache UDPClientPacketProcessorCachedStateContainer) error
} }

View File

@ -47,7 +47,14 @@ type respHeader struct {
Padding []byte Padding []byte
} }
func (p *AESUDPClientPacketProcessor) EncodeUDPRequest(request *UDPRequest, out *buf.Buffer) error { type cachedUDPState struct {
sessionAEAD cipher.AEAD
sessionRecvAEAD cipher.AEAD
}
func (p *AESUDPClientPacketProcessor) EncodeUDPRequest(request *UDPRequest, out *buf.Buffer,
cache UDPClientPacketProcessorCachedStateContainer,
) error {
separateHeaderStruct := separateHeader{PacketID: request.PacketID, SessionID: request.SessionID} separateHeaderStruct := separateHeader{PacketID: request.PacketID, SessionID: request.SessionID}
separateHeaderBuffer := buf.New() separateHeaderBuffer := buf.New()
defer separateHeaderBuffer.Release() defer separateHeaderBuffer.Release()
@ -102,14 +109,28 @@ func (p *AESUDPClientPacketProcessor) EncodeUDPRequest(request *UDPRequest, out
} }
} }
{ {
mainPacketAEADMaterialized := p.mainPacketAEAD(separateHeaderBufferBytes[0:8]) cacheKey := string(separateHeaderBufferBytes[0:8])
receivedCacheInterface := cache.GetCachedState(cacheKey)
cachedState := &cachedUDPState{}
if receivedCacheInterface != nil {
cachedState = receivedCacheInterface.(*cachedUDPState)
}
if cachedState.sessionAEAD == nil {
cachedState.sessionAEAD = p.mainPacketAEAD(separateHeaderBufferBytes[0:8])
cache.PutCachedState(cacheKey, cachedState)
}
mainPacketAEADMaterialized := cachedState.sessionAEAD
encryptedDest := out.Extend(int32(mainPacketAEADMaterialized.Overhead()) + requestBodyBuffer.Len()) encryptedDest := out.Extend(int32(mainPacketAEADMaterialized.Overhead()) + requestBodyBuffer.Len())
mainPacketAEADMaterialized.Seal(encryptedDest[:0], separateHeaderBuffer.Bytes()[4:16], requestBodyBuffer.Bytes(), nil) mainPacketAEADMaterialized.Seal(encryptedDest[:0], separateHeaderBuffer.Bytes()[4:16], requestBodyBuffer.Bytes(), nil)
} }
return nil return nil
} }
func (p *AESUDPClientPacketProcessor) DecodeUDPResp(input []byte, resp *UDPResponse) error { func (p *AESUDPClientPacketProcessor) DecodeUDPResp(input []byte, resp *UDPResponse,
cache UDPClientPacketProcessorCachedStateContainer,
) error {
separateHeaderBuffer := buf.New() separateHeaderBuffer := buf.New()
defer separateHeaderBuffer.Release() defer separateHeaderBuffer.Release()
{ {
@ -126,7 +147,19 @@ func (p *AESUDPClientPacketProcessor) DecodeUDPResp(input []byte, resp *UDPRespo
resp.PacketID = separateHeaderStruct.PacketID resp.PacketID = separateHeaderStruct.PacketID
resp.SessionID = separateHeaderStruct.SessionID resp.SessionID = separateHeaderStruct.SessionID
{ {
mainPacketAEADMaterialized := p.mainPacketAEAD(separateHeaderBuffer.Bytes()[0:8]) cacheKey := string(separateHeaderBuffer.Bytes()[0:8])
receivedCacheInterface := cache.GetCachedServerState(cacheKey)
cachedState := &cachedUDPState{}
if receivedCacheInterface != nil {
cachedState = receivedCacheInterface.(*cachedUDPState)
}
if cachedState.sessionRecvAEAD == nil {
cachedState.sessionRecvAEAD = p.mainPacketAEAD(separateHeaderBuffer.Bytes()[0:8])
cache.PutCachedServerState(cacheKey, cachedState)
}
mainPacketAEADMaterialized := cachedState.sessionRecvAEAD
decryptedDestBuffer := buf.New() decryptedDestBuffer := buf.New()
decryptedDest := decryptedDestBuffer.Extend(int32(len(input)) - 16 - int32(mainPacketAEADMaterialized.Overhead())) decryptedDest := decryptedDestBuffer.Extend(int32(len(input)) - 16 - int32(mainPacketAEADMaterialized.Overhead()))
_, err := mainPacketAEADMaterialized.Open(decryptedDest[:0], separateHeaderBuffer.Bytes()[4:16], input[16:], nil) _, err := mainPacketAEADMaterialized.Open(decryptedDest[:0], separateHeaderBuffer.Bytes()[4:16], input[16:], nil)