diff --git a/common/net/packetaddr/connection_adaptor.go b/common/net/packetaddr/connection_adaptor.go index 429fc9c2f..a95b1b12f 100644 --- a/common/net/packetaddr/connection_adaptor.go +++ b/common/net/packetaddr/connection_adaptor.go @@ -69,16 +69,17 @@ func (c *packetConnectionAdaptor) ReadFrom(p []byte) (n int, addr gonet.Addr, er } } c.readerBuffer, n = buf.SplitFirstBytes(c.readerBuffer, p) - var w []byte - w, addr = ExtractAddressFromPacket(p[:n]) - n = copy(p, w) + var w *buf.Buffer + w, addr, err = ExtractAddressFromPacket(buf.FromBytes(p[:n])) + n = copy(p, w.Bytes()) + w.Release() return } func (c *packetConnectionAdaptor) WriteTo(p []byte, addr gonet.Addr) (n int, err error) { payloadLen := len(p) - p = AttachAddressToPacket(p, addr) - buffer := buf.FromBytes(p) + var buffer *buf.Buffer + buffer, err = AttachAddressToPacket(buf.FromBytes(p), addr) mb := buf.MultiBuffer{buffer} err = c.link.Writer.WriteMultiBuffer(mb) if err != nil { @@ -134,18 +135,26 @@ func (pc *packetConnWrapper) Read(p []byte) (n int, err error) { if err != nil { return 0, err } - result := AttachAddressToPacket(recbuf.Bytes()[0:n], addr) - n = copy(p, result) - recbuf.Release() + recbuf.Resize(0, int32(n)) + result, err := AttachAddressToPacket(&recbuf, addr) + if err != nil { + return 0, err + } + n = copy(p, result.Bytes()) + result.Release() return n, nil } func (pc *packetConnWrapper) Write(p []byte) (n int, err error) { - data, addr := ExtractAddressFromPacket(p) - _, err = pc.PacketConn.WriteTo(data, addr) + data, addr, err := ExtractAddressFromPacket(buf.FromBytes(p)) if err != nil { return 0, err } + _, err = pc.PacketConn.WriteTo(data.Bytes(), addr) + if err != nil { + return 0, err + } + data.Release() return len(p), nil } diff --git a/common/net/packetaddr/packetaddr.go b/common/net/packetaddr/packetaddr.go index eabf179ed..b1add37ff 100644 --- a/common/net/packetaddr/packetaddr.go +++ b/common/net/packetaddr/packetaddr.go @@ -13,35 +13,43 @@ var addrParser = protocol.NewAddressParser( protocol.AddressFamilyByte(0x02, net.AddressFamilyIPv6), ) -func AttachAddressToPacket(data []byte, address gonet.Addr) []byte { - packetBuf := buf.StackNew() +// AttachAddressToPacket +// relinquish ownership of data +// gain ownership of the returning value +func AttachAddressToPacket(data *buf.Buffer, address gonet.Addr) (*buf.Buffer, error) { + packetBuf := buf.New() udpaddr := address.(*gonet.UDPAddr) port, err := net.PortFromInt(uint32(udpaddr.Port)) if err != nil { - panic(err) + return nil, err } - err = addrParser.WriteAddressPort(&packetBuf, net.IPAddress(udpaddr.IP), port) + err = addrParser.WriteAddressPort(packetBuf, net.IPAddress(udpaddr.IP), port) if err != nil { - panic(err) + return nil, err } - //Incorrect buffer reuse - data = append(packetBuf.Bytes(), data...) - //packetBuf.Release() - return data + _, err = packetBuf.Write(data.Bytes()) + if err != nil { + return nil, err + } + data.Release() + return packetBuf, nil } -func ExtractAddressFromPacket(data []byte) ([]byte, gonet.Addr) { +// ExtractAddressFromPacket +// relinquish ownership of data +// gain ownership of the returning value +func ExtractAddressFromPacket(data *buf.Buffer) (*buf.Buffer, gonet.Addr, error) { packetBuf := buf.StackNew() - address, port, err := addrParser.ReadAddressPort(&packetBuf, bytes.NewReader(data)) + address, port, err := addrParser.ReadAddressPort(&packetBuf, bytes.NewReader(data.Bytes())) if err != nil { - panic(err) + return nil, nil, err } var addr = &gonet.UDPAddr{ IP: address.IP(), Port: int(port.Value()), Zone: "", } - payload := data[int(packetBuf.Len()):] + data.Advance(packetBuf.Len()) packetBuf.Release() - return payload, addr + return data, addr, nil } diff --git a/common/net/packetaddr/packetaddr_test.go b/common/net/packetaddr/packetaddr_test.go index 9962286fb..e324ce2fe 100644 --- a/common/net/packetaddr/packetaddr_test.go +++ b/common/net/packetaddr/packetaddr_test.go @@ -2,6 +2,7 @@ package packetaddr import ( "github.com/stretchr/testify/assert" + "github.com/v2fly/v2ray-core/v4/common/buf" sysnet "net" "testing" ) @@ -12,11 +13,13 @@ func TestPacketEncodingIPv4(t *testing.T) { Port: 1234, } var packetData [256]byte - wrapped := AttachAddressToPacket(packetData[:], packetAddress) + wrapped, err := AttachAddressToPacket(buf.FromBytes(packetData[:]), packetAddress) + assert.NoError(t, err) - packetPayload, decodedAddress := ExtractAddressFromPacket(wrapped) + packetPayload, decodedAddress, err := ExtractAddressFromPacket(wrapped) + assert.NoError(t, err) - assert.Equal(t, packetPayload, packetData[:]) + assert.Equal(t, packetPayload.Bytes(), packetData[:]) assert.Equal(t, packetAddress, decodedAddress) } @@ -26,10 +29,12 @@ func TestPacketEncodingIPv6(t *testing.T) { Port: 1234, } var packetData [256]byte - wrapped := AttachAddressToPacket(packetData[:], packetAddress) + wrapped, err := AttachAddressToPacket(buf.FromBytes(packetData[:]), packetAddress) + assert.NoError(t, err) - packetPayload, decodedAddress := ExtractAddressFromPacket(wrapped) + packetPayload, decodedAddress, err := ExtractAddressFromPacket(wrapped) + assert.NoError(t, err) - assert.Equal(t, packetPayload, packetData[:]) + assert.Equal(t, packetPayload.Bytes(), packetData[:]) assert.Equal(t, packetAddress, decodedAddress) }