Shadowsocks2022 Client Implementation Improvements (#2770)

* Shadowsocks2022 Client Implementation Improvements

1. Added UDP Replay Detection
2. Added UDP Processor State Cache
3. Added More Detailed Output for Time Difference Error
4. Replaced Mutex with RWMutex for reduced lock contention
5. Added per server session tracking of decryption cache and anti-replay window
6. Adjust server session track time
7. Increase per session buffer to 128
8. Fix client crash when EIH is not enabled
9. Fix client log contains non-human-friendly content
10.Remove mark and remove for trackedSessions
11. Fixed packet size uint16 overflow issue
This commit is contained in:
Xiaokang Wang (Shelikhoo) 2023-11-24 00:40:07 +00:00 committed by GitHub
parent 62428d8011
commit b6da3e86a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 254 additions and 44 deletions

View File

@ -91,6 +91,7 @@ type AuthenticationReader struct {
transferType protocol.TransferType
padding PaddingLengthGenerator
size uint16
sizeOffset uint16
paddingLen uint16
hasSize bool
done bool
@ -104,6 +105,9 @@ func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, re
padding: paddingLen,
sizeBytes: make([]byte, sizeParser.SizeBytes()),
}
if chunkSizeDecoderWithOffset, ok := sizeParser.(ChunkSizeDecoderWithOffset); ok {
r.sizeOffset = chunkSizeDecoderWithOffset.HasConstantOffset()
}
if breader, ok := reader.(*buf.BufferedReader); ok {
r.reader = breader
} else {
@ -160,12 +164,14 @@ func (r *AuthenticationReader) readInternal(soft bool, mb *buf.MultiBuffer) erro
return err
}
if size == uint16(r.auth.Overhead())+padding {
if size+r.sizeOffset == uint16(r.auth.Overhead())+padding {
r.done = true
return io.EOF
}
if soft && int32(size) > r.reader.BufferedBytes() {
effectiveSize := int32(size) + int32(r.sizeOffset)
if soft && effectiveSize > r.reader.BufferedBytes() {
r.size = size
r.paddingLen = padding
r.hasSize = true
@ -173,7 +179,7 @@ func (r *AuthenticationReader) readInternal(soft bool, mb *buf.MultiBuffer) erro
}
if size <= buf.Size {
b, err := r.readBuffer(int32(size), int32(padding))
b, err := r.readBuffer(effectiveSize, int32(padding))
if err != nil {
return nil
}
@ -181,16 +187,16 @@ func (r *AuthenticationReader) readInternal(soft bool, mb *buf.MultiBuffer) erro
return nil
}
payload := bytespool.Alloc(int32(size))
payload := bytespool.Alloc(effectiveSize)
defer bytespool.Free(payload)
if _, err := io.ReadFull(r.reader, payload[:size]); err != nil {
if _, err := io.ReadFull(r.reader, payload[:effectiveSize]); err != nil {
return err
}
size -= padding
effectiveSize -= int32(padding)
rb, err := r.auth.Open(payload[:0], payload[:size])
rb, err := r.auth.Open(payload[:0], payload[:effectiveSize])
if err != nil {
return err
}

View File

@ -15,6 +15,13 @@ type ChunkSizeDecoder interface {
Decode([]byte) (uint16, error)
}
type ChunkSizeDecoderWithOffset interface {
ChunkSizeDecoder
// HasConstantOffset set the constant offset of Decode
// The effective size should be HasConstantOffset() + Decode(_).[0](uint64)
HasConstantOffset() uint16
}
// ChunkSizeEncoder is a utility class to encode size value into bytes.
type ChunkSizeEncoder interface {
SizeBytes() int32

2
go.mod
View File

@ -17,6 +17,7 @@ require (
github.com/miekg/dns v1.1.57
github.com/mustafaturan/bus v1.0.2
github.com/pelletier/go-toml v1.9.5
github.com/pion/transport/v2 v2.2.1
github.com/pires/go-proxyproto v0.7.0
github.com/quic-go/quic-go v0.40.0
github.com/refraction-networking/utls v1.5.4
@ -68,7 +69,6 @@ require (
github.com/pion/logging v0.2.2 // indirect
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/sctp v1.8.7 // indirect
github.com/pion/transport/v2 v2.2.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/quic-go/qtls-go1-20 v0.4.1 // indirect
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect

View File

@ -2,6 +2,7 @@ package v5cfg
import (
"context"
"github.com/golang/protobuf/proto"
core "github.com/v2fly/v2ray-core/v5"

View File

@ -111,7 +111,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
if err != nil {
return newError("failed to find an available destination").AtWarning().Base(err)
}
newError("tunneling request to ", destination, " via ", network, ":", c.config.Address).WriteToLog(session.ExportIDToError(ctx))
newError("tunneling request to ", destination, " via ", network, ":", net.TCPDestination(c.config.Address.AsAddress(), net.Port(c.config.Port)).NetAddr()).WriteToLog(session.ExportIDToError(ctx))
defer conn.Close()
request := &TCPRequest{

View File

@ -11,14 +11,17 @@ import (
"github.com/v2fly/v2ray-core/v5/common/buf"
"github.com/v2fly/v2ray-core/v5/common/net"
"github.com/v2fly/v2ray-core/v5/transport/internet"
"github.com/pion/transport/v2/replaydetector"
)
func NewClientUDPSession(ctx context.Context, conn io.ReadWriteCloser, packetProcessor UDPClientPacketProcessor) *ClientUDPSession {
session := &ClientUDPSession{
locker: &sync.Mutex{},
locker: &sync.RWMutex{},
conn: conn,
packetProcessor: packetProcessor,
sessionMap: make(map[string]*ClientUDPSessionConn),
sessionMapAlias: make(map[string]string),
}
session.ctx, session.finish = context.WithCancel(ctx)
@ -27,16 +30,87 @@ func NewClientUDPSession(ctx context.Context, conn io.ReadWriteCloser, packetPro
}
type ClientUDPSession struct {
locker *sync.Mutex
locker *sync.RWMutex
conn io.ReadWriteCloser
packetProcessor UDPClientPacketProcessor
sessionMap map[string]*ClientUDPSessionConn
sessionMapAlias map[string]string
ctx context.Context
finish func()
}
func (c *ClientUDPSession) GetCachedState(sessionID string) UDPClientPacketProcessorCachedState {
c.locker.RLock()
defer c.locker.RUnlock()
state, ok := c.sessionMap[sessionID]
if !ok {
return nil
}
return state.cachedProcessorState
}
func (c *ClientUDPSession) GetCachedServerState(serverSessionID string) UDPClientPacketProcessorCachedState {
c.locker.RLock()
defer c.locker.RUnlock()
clientSessionID := c.getCachedStateAlias(serverSessionID)
if clientSessionID == "" {
return nil
}
state, ok := c.sessionMap[clientSessionID]
if !ok {
return nil
}
if serverState, ok := state.trackedServerSessionID[serverSessionID]; !ok {
return nil
} else {
return serverState.cachedRecvProcessorState
}
}
func (c *ClientUDPSession) getCachedStateAlias(serverSessionID string) string {
state, ok := c.sessionMapAlias[serverSessionID]
if !ok {
return ""
}
return state
}
func (c *ClientUDPSession) PutCachedState(sessionID string, cache UDPClientPacketProcessorCachedState) {
c.locker.RLock()
defer c.locker.RUnlock()
state, ok := c.sessionMap[sessionID]
if !ok {
return
}
state.cachedProcessorState = cache
}
func (c *ClientUDPSession) PutCachedServerState(serverSessionID string, cache UDPClientPacketProcessorCachedState) {
c.locker.RLock()
defer c.locker.RUnlock()
clientSessionID := c.getCachedStateAlias(serverSessionID)
if clientSessionID == "" {
return
}
state, ok := c.sessionMap[clientSessionID]
if !ok {
return
}
if serverState, ok := state.trackedServerSessionID[serverSessionID]; ok {
serverState.cachedRecvProcessorState = cache
return
}
}
func (c *ClientUDPSession) Close() error {
c.finish()
return c.conn.Close()
@ -45,7 +119,7 @@ func (c *ClientUDPSession) Close() error {
func (c *ClientUDPSession) WriteUDPRequest(request *UDPRequest) error {
buffer := buf.New()
defer buffer.Release()
err := c.packetProcessor.EncodeUDPRequest(request, buffer)
err := c.packetProcessor.EncodeUDPRequest(request, buffer, c)
if request.Payload != nil {
request.Payload.Release()
}
@ -69,7 +143,7 @@ func (c *ClientUDPSession) KeepReading() {
return
}
if n != 0 {
err := c.packetProcessor.DecodeUDPResp(buffer[:n], udpResp)
err := c.packetProcessor.DecodeUDPResp(buffer[:n], udpResp, c)
if err != nil {
newError("unable to decode udp response").Base(err).WriteToLog()
continue
@ -78,13 +152,14 @@ func (c *ClientUDPSession) KeepReading() {
{
timeDifference := int64(udpResp.TimeStamp) - time.Now().Unix()
if timeDifference < -30 || timeDifference > 30 {
newError("udp packet timestamp difference too large, packet discarded").WriteToLog()
newError("udp packet timestamp difference too large, packet discarded, time diff = ", timeDifference).WriteToLog()
continue
}
}
c.locker.Lock()
c.locker.RLock()
session, ok := c.sessionMap[string(udpResp.ClientSessionID[:])]
c.locker.RUnlock()
if ok {
select {
case session.readChan <- udpResp:
@ -93,7 +168,6 @@ func (c *ClientUDPSession) KeepReading() {
} else {
newError("misbehaving server: unknown client session ID").Base(err).WriteToLog()
}
c.locker.Unlock()
}
}
}
@ -108,12 +182,13 @@ func (c *ClientUDPSession) NewSessionConn() (internet.AbstractPacketConn, error)
connctx, connfinish := context.WithCancel(c.ctx)
sessionConn := &ClientUDPSessionConn{
sessionID: string(sessionID),
readChan: make(chan *UDPResponse, 16),
parent: c,
ctx: connctx,
finish: connfinish,
nextWritePacketID: 0,
sessionID: string(sessionID),
readChan: make(chan *UDPResponse, 128),
parent: c,
ctx: connctx,
finish: connfinish,
nextWritePacketID: 0,
trackedServerSessionID: make(map[string]*ClientUDPSessionServerTracker),
}
c.locker.Lock()
c.sessionMap[sessionConn.sessionID] = sessionConn
@ -121,19 +196,33 @@ func (c *ClientUDPSession) NewSessionConn() (internet.AbstractPacketConn, error)
return sessionConn, nil
}
type ClientUDPSessionServerTracker struct {
cachedRecvProcessorState UDPClientPacketProcessorCachedState
rxReplayDetector replaydetector.ReplayDetector
lastSeen time.Time
}
type ClientUDPSessionConn struct {
sessionID string
readChan chan *UDPResponse
parent *ClientUDPSession
nextWritePacketID uint64
nextWritePacketID uint64
trackedServerSessionID map[string]*ClientUDPSessionServerTracker
cachedProcessorState UDPClientPacketProcessorCachedState
ctx context.Context
finish func()
}
func (c *ClientUDPSessionConn) Close() error {
c.parent.locker.Lock()
delete(c.parent.sessionMap, c.sessionID)
for k := range c.trackedServerSessionID {
delete(c.parent.sessionMapAlias, k)
}
c.parent.locker.Unlock()
c.finish()
return nil
}
@ -160,13 +249,44 @@ func (c *ClientUDPSessionConn) WriteTo(p []byte, addr gonet.Addr) (n int, err er
}
func (c *ClientUDPSessionConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
select {
case <-c.ctx.Done():
return 0, nil, io.EOF
case resp := <-c.readChan:
n = copy(p, resp.Payload.Bytes())
resp.Payload.Release()
addr = &net.UDPAddr{IP: resp.Address.IP(), Port: resp.Port}
for {
select {
case <-c.ctx.Done():
return 0, nil, io.EOF
case resp := <-c.readChan:
n = copy(p, resp.Payload.Bytes())
resp.Payload.Release()
var trackedState *ClientUDPSessionServerTracker
if trackedStateReceived, ok := c.trackedServerSessionID[string(resp.SessionID[:])]; !ok {
for key, value := range c.trackedServerSessionID {
if time.Since(value.lastSeen) > 65*time.Second {
delete(c.trackedServerSessionID, key)
}
}
state := &ClientUDPSessionServerTracker{
rxReplayDetector: replaydetector.New(1024, ^uint64(0)),
}
c.trackedServerSessionID[string(resp.SessionID[:])] = state
c.parent.locker.RLock()
c.parent.sessionMapAlias[string(resp.SessionID[:])] = string(resp.ClientSessionID[:])
c.parent.locker.RUnlock()
trackedState = state
} else {
trackedState = trackedStateReceived
}
if accept, ok := trackedState.rxReplayDetector.Check(resp.PacketID); ok {
accept()
} else {
newError("misbehaving server: replayed packet").Base(err).WriteToLog()
continue
}
trackedState.lastSeen = time.Now()
addr = &net.UDPAddr{IP: resp.Address.IP(), Port: resp.Port}
}
return n, addr, nil
}
return
}

View File

@ -4,10 +4,13 @@ import (
"bytes"
"crypto/cipher"
cryptoRand "crypto/rand"
"encoding/binary"
"io"
"math/rand"
"time"
"github.com/v2fly/v2ray-core/v5/common"
"github.com/lunixbochs/struc"
"github.com/v2fly/v2ray-core/v5/common/buf"
@ -59,7 +62,7 @@ func (t *TCPRequest) EncodeTCPRequestHeader(effectivePsk []byte,
paddingLength := TCPMinPaddingLength
if initialPayload == nil {
initialPayload = []byte{}
paddingLength += rand.Intn(TCPMaxPaddingLength) // TODO INSECURE RANDOM USED
paddingLength += 1 + rand.Intn(TCPMaxPaddingLength) // TODO INSECURE RANDOM USED
}
variableLengthHeader := &TCPRequestHeader3VariableLength{
@ -110,10 +113,14 @@ func (t *TCPRequest) EncodeTCPRequestHeader(effectivePsk []byte,
return newError("failed to pack fixed length header").Base(err)
}
}
eihGenerator := newAESEIHGeneratorContainer(len(eih), effectivePsk, eih)
eihHeader, err := eihGenerator.GenerateEIH(t.keyDerivation, t.method, requestSalt.Bytes())
if err != nil {
return newError("failed to construct EIH").Base(err)
eihHeader := ExtensibleIdentityHeaders(newAESEIH(0))
if len(eih) != 0 {
eihGenerator := newAESEIHGeneratorContainer(len(eih), effectivePsk, eih)
eihHeaderGenerated, err := eihGenerator.GenerateEIH(t.keyDerivation, t.method, requestSalt.Bytes())
if err != nil {
return newError("failed to construct EIH").Base(err)
}
eihHeader = eihHeaderGenerated
}
preSessionKeyHeader := &TCPRequestHeader1PreSessionKey{
Salt: requestSalt,
@ -206,7 +213,7 @@ func (t *TCPRequest) DecodeTCPResponseHeader(effectivePsk []byte, in io.Reader)
}
timeDifference := int64(fixedLengthHeader.Timestamp) - time.Now().Unix()
if timeDifference < -30 || timeDifference > 30 {
return newError("timestamp is too far away")
return newError("timestamp is too far away, timeDifference = ", timeDifference)
}
t.s2cSaltAssert = fixedLengthHeader.RequestSalt
@ -239,7 +246,7 @@ func (t *TCPRequest) CreateClientS2CReader(in io.Reader, initialPayload *buf.Buf
if err != nil {
return nil, newError("failed to decrypt initial payload").Base(err)
}
return crypto.NewAuthenticationReader(AEADAuthenticator, &crypto.AEADChunkSizeParser{
return crypto.NewAuthenticationReader(AEADAuthenticator, &AEADChunkSizeParser{
Auth: AEADAuthenticator,
}, in, protocol.TransferTypeStream, nil), nil
}
@ -255,3 +262,30 @@ func (t *TCPRequest) CreateClientC2SWriter(writer io.Writer) buf.Writer {
}
return crypto.NewAuthenticationWriter(AEADAuthenticator, sizeParser, writer, protocol.TransferTypeStream, nil)
}
type AEADChunkSizeParser struct {
Auth *crypto.AEADAuthenticator
}
func (p *AEADChunkSizeParser) HasConstantOffset() uint16 {
return uint16(p.Auth.Overhead())
}
func (p *AEADChunkSizeParser) SizeBytes() int32 {
return 2 + int32(p.Auth.Overhead())
}
func (p *AEADChunkSizeParser) Encode(size uint16, b []byte) []byte {
binary.BigEndian.PutUint16(b, size-uint16(p.Auth.Overhead()))
b, err := p.Auth.Seal(b[:0], b[:2])
common.Must(err)
return b
}
func (p *AEADChunkSizeParser) Decode(b []byte) (uint16, error) {
b, err := p.Auth.Open(b[:0], b)
if err != nil {
return 0, err
}
return binary.BigEndian.Uint16(b), nil
}

View File

@ -112,9 +112,18 @@ const (
UDPHeaderTypeServerToClientStream = byte(0x01)
)
type UDPClientPacketProcessorCachedStateContainer interface {
GetCachedState(sessionID string) UDPClientPacketProcessorCachedState
PutCachedState(sessionID string, cache UDPClientPacketProcessorCachedState)
GetCachedServerState(serverSessionID string) UDPClientPacketProcessorCachedState
PutCachedServerState(serverSessionID string, cache UDPClientPacketProcessorCachedState)
}
type UDPClientPacketProcessorCachedState interface{}
// UDPClientPacketProcessor
// Caller retain and receive all ownership of the buffer
type UDPClientPacketProcessor interface {
EncodeUDPRequest(request *UDPRequest, out *buf.Buffer) error
DecodeUDPResp(input []byte, resp *UDPResponse) error
EncodeUDPRequest(request *UDPRequest, out *buf.Buffer, cache UDPClientPacketProcessorCachedStateContainer) error
DecodeUDPResp(input []byte, resp *UDPResponse, cache UDPClientPacketProcessorCachedStateContainer) error
}

View File

@ -47,7 +47,14 @@ type respHeader struct {
Padding []byte
}
func (p *AESUDPClientPacketProcessor) EncodeUDPRequest(request *UDPRequest, out *buf.Buffer) error {
type cachedUDPState struct {
sessionAEAD cipher.AEAD
sessionRecvAEAD cipher.AEAD
}
func (p *AESUDPClientPacketProcessor) EncodeUDPRequest(request *UDPRequest, out *buf.Buffer,
cache UDPClientPacketProcessorCachedStateContainer,
) error {
separateHeaderStruct := separateHeader{PacketID: request.PacketID, SessionID: request.SessionID}
separateHeaderBuffer := buf.New()
defer separateHeaderBuffer.Release()
@ -102,14 +109,28 @@ func (p *AESUDPClientPacketProcessor) EncodeUDPRequest(request *UDPRequest, out
}
}
{
mainPacketAEADMaterialized := p.mainPacketAEAD(separateHeaderBufferBytes[0:8])
cacheKey := string(separateHeaderBufferBytes[0:8])
receivedCacheInterface := cache.GetCachedState(cacheKey)
cachedState := &cachedUDPState{}
if receivedCacheInterface != nil {
cachedState = receivedCacheInterface.(*cachedUDPState)
}
if cachedState.sessionAEAD == nil {
cachedState.sessionAEAD = p.mainPacketAEAD(separateHeaderBufferBytes[0:8])
cache.PutCachedState(cacheKey, cachedState)
}
mainPacketAEADMaterialized := cachedState.sessionAEAD
encryptedDest := out.Extend(int32(mainPacketAEADMaterialized.Overhead()) + requestBodyBuffer.Len())
mainPacketAEADMaterialized.Seal(encryptedDest[:0], separateHeaderBuffer.Bytes()[4:16], requestBodyBuffer.Bytes(), nil)
}
return nil
}
func (p *AESUDPClientPacketProcessor) DecodeUDPResp(input []byte, resp *UDPResponse) error {
func (p *AESUDPClientPacketProcessor) DecodeUDPResp(input []byte, resp *UDPResponse,
cache UDPClientPacketProcessorCachedStateContainer,
) error {
separateHeaderBuffer := buf.New()
defer separateHeaderBuffer.Release()
{
@ -126,7 +147,19 @@ func (p *AESUDPClientPacketProcessor) DecodeUDPResp(input []byte, resp *UDPRespo
resp.PacketID = separateHeaderStruct.PacketID
resp.SessionID = separateHeaderStruct.SessionID
{
mainPacketAEADMaterialized := p.mainPacketAEAD(separateHeaderBuffer.Bytes()[0:8])
cacheKey := string(separateHeaderBuffer.Bytes()[0:8])
receivedCacheInterface := cache.GetCachedServerState(cacheKey)
cachedState := &cachedUDPState{}
if receivedCacheInterface != nil {
cachedState = receivedCacheInterface.(*cachedUDPState)
}
if cachedState.sessionRecvAEAD == nil {
cachedState.sessionRecvAEAD = p.mainPacketAEAD(separateHeaderBuffer.Bytes()[0:8])
cache.PutCachedServerState(cacheKey, cachedState)
}
mainPacketAEADMaterialized := cachedState.sessionRecvAEAD
decryptedDestBuffer := buf.New()
decryptedDest := decryptedDestBuffer.Extend(int32(len(input)) - 16 - int32(mainPacketAEADMaterialized.Overhead()))
_, err := mainPacketAEADMaterialized.Open(decryptedDest[:0], separateHeaderBuffer.Bytes()[4:16], input[16:], nil)