From cee85bdf26fe181d256da7a0120aa91f455143b6 Mon Sep 17 00:00:00 2001 From: V2Ray Date: Wed, 2 Dec 2015 12:47:54 +0100 Subject: [PATCH] Add Port as a type --- app/router/rules/config/json/fieldrule.go | 2 +- common/net/address.go | 100 ++++++++++----------- common/net/address_test.go | 22 ++--- common/net/json/portrange.go | 21 ++--- common/net/json/portrange_test.go | 12 +-- common/net/port.go | 23 +++++ common/net/portrange.go | 4 +- common/net/testing/portrange.go | 12 ++- common/retry/retry.go | 2 +- proxy/socks/protocol/udp.go | 2 +- proxy/socks/socks.go | 2 +- proxy/vmess/protocol/vmess.go | 2 +- shell/point/config/json/json_test.go | 4 +- shell/point/config/testing/mocks/config.go | 8 +- shell/point/inbound_detour.go | 5 +- testing/scenarios/socks5_helper.go | 2 +- 16 files changed, 125 insertions(+), 98 deletions(-) create mode 100644 common/net/port.go diff --git a/app/router/rules/config/json/fieldrule.go b/app/router/rules/config/json/fieldrule.go index 710c83236..726c6167e 100644 --- a/app/router/rules/config/json/fieldrule.go +++ b/app/router/rules/config/json/fieldrule.go @@ -85,7 +85,7 @@ func (this *FieldRule) Apply(dest v2net.Destination) bool { if this.Port != nil { port := address.Port() - if port < this.Port.From() || port > this.Port.To() { + if port.Value() < this.Port.From().Value() || port.Value() > this.Port.To().Value() { return false } } diff --git a/common/net/address.go b/common/net/address.go index 5a70f18a9..e210d477f 100644 --- a/common/net/address.go +++ b/common/net/address.go @@ -2,7 +2,6 @@ package net import ( "net" - "strconv" "github.com/v2ray/v2ray-core/common/log" ) @@ -10,10 +9,9 @@ import ( // Address represents a network address to be communicated with. It may be an IP address or domain // address, not both. This interface doesn't resolve IP address for a given domain. type Address interface { - IP() net.IP // IP of this Address - Domain() string // Domain of this Address - Port() uint16 // Port of this Address - PortBytes() []byte // Port in bytes, network byte order + IP() net.IP // IP of this Address + Domain() string // Domain of this Address + Port() Port // Port of this Address IsIPv4() bool // True if this Address is an IPv4 address IsIPv6() bool // True if this Address is an IPv6 address @@ -35,16 +33,16 @@ func allZeros(data []byte) bool { func IPAddress(ip []byte, port uint16) Address { switch len(ip) { case net.IPv4len: - return IPv4Address{ - PortAddress: PortAddress{port: port}, - ip: [4]byte{ip[0], ip[1], ip[2], ip[3]}, + return &IPv4Address{ + port: Port(port), + ip: [4]byte{ip[0], ip[1], ip[2], ip[3]}, } case net.IPv6len: if allZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff { return IPAddress(ip[12:16], port) } - return IPv6Address{ - PortAddress: PortAddress{port: port}, + return &IPv6Address{ + port: Port(port), ip: [16]byte{ ip[0], ip[1], ip[2], ip[3], ip[4], ip[5], ip[6], ip[7], @@ -60,107 +58,107 @@ func IPAddress(ip []byte, port uint16) Address { // DomainAddress creates an Address with given domain and port. func DomainAddress(domain string, port uint16) Address { - return DomainAddressImpl{ - domain: domain, - PortAddress: PortAddress{port: port}, + return &DomainAddressImpl{ + domain: domain, + port: Port(port), } } -type PortAddress struct { - port uint16 -} - -func (addr PortAddress) Port() uint16 { - return addr.port -} - -func (addr PortAddress) PortBytes() []byte { - return []byte{byte(addr.port >> 8), byte(addr.port)} -} - type IPv4Address struct { - PortAddress - ip [4]byte + port Port + ip [4]byte } -func (addr IPv4Address) IP() net.IP { +func (addr *IPv4Address) IP() net.IP { return net.IP(addr.ip[:]) } -func (addr IPv4Address) Domain() string { +func (this *IPv4Address) Port() Port { + return this.port +} + +func (addr *IPv4Address) Domain() string { panic("Calling Domain() on an IPv4Address.") } -func (addr IPv4Address) IsIPv4() bool { +func (addr *IPv4Address) IsIPv4() bool { return true } -func (addr IPv4Address) IsIPv6() bool { +func (addr *IPv4Address) IsIPv6() bool { return false } -func (addr IPv4Address) IsDomain() bool { +func (addr *IPv4Address) IsDomain() bool { return false } -func (addr IPv4Address) String() string { - return addr.IP().String() + ":" + strconv.Itoa(int(addr.PortAddress.port)) +func (this *IPv4Address) String() string { + return this.IP().String() + ":" + this.port.String() } type IPv6Address struct { - PortAddress - ip [16]byte + port Port + ip [16]byte } -func (addr IPv6Address) IP() net.IP { +func (addr *IPv6Address) IP() net.IP { return net.IP(addr.ip[:]) } -func (addr IPv6Address) Domain() string { +func (this *IPv6Address) Port() Port { + return this.port +} + +func (addr *IPv6Address) Domain() string { panic("Calling Domain() on an IPv6Address.") } -func (addr IPv6Address) IsIPv4() bool { +func (addr *IPv6Address) IsIPv4() bool { return false } -func (addr IPv6Address) IsIPv6() bool { +func (addr *IPv6Address) IsIPv6() bool { return true } -func (addr IPv6Address) IsDomain() bool { +func (addr *IPv6Address) IsDomain() bool { return false } -func (addr IPv6Address) String() string { - return "[" + addr.IP().String() + "]:" + strconv.Itoa(int(addr.PortAddress.port)) +func (this *IPv6Address) String() string { + return "[" + this.IP().String() + "]:" + this.port.String() } type DomainAddressImpl struct { - PortAddress + port Port domain string } -func (addr DomainAddressImpl) IP() net.IP { +func (addr *DomainAddressImpl) IP() net.IP { panic("Calling IP() on a DomainAddress.") } -func (addr DomainAddressImpl) Domain() string { +func (this *DomainAddressImpl) Port() Port { + return this.port +} + +func (addr *DomainAddressImpl) Domain() string { return addr.domain } -func (addr DomainAddressImpl) IsIPv4() bool { +func (addr *DomainAddressImpl) IsIPv4() bool { return false } -func (addr DomainAddressImpl) IsIPv6() bool { +func (addr *DomainAddressImpl) IsIPv6() bool { return false } -func (addr DomainAddressImpl) IsDomain() bool { +func (addr *DomainAddressImpl) IsDomain() bool { return true } -func (addr DomainAddressImpl) String() string { - return addr.domain + ":" + strconv.Itoa(int(addr.PortAddress.port)) +func (this *DomainAddressImpl) String() string { + return this.domain + ":" + this.port.String() } diff --git a/common/net/address_test.go b/common/net/address_test.go index bb26a1506..9d5157d7d 100644 --- a/common/net/address_test.go +++ b/common/net/address_test.go @@ -11,14 +11,14 @@ func TestIPv4Address(t *testing.T) { assert := unit.Assert(t) ip := []byte{byte(1), byte(2), byte(3), byte(4)} - port := uint16(80) - addr := IPAddress(ip, port) + port := NewPort(80) + addr := IPAddress(ip, port.Value()) assert.Bool(addr.IsIPv4()).IsTrue() assert.Bool(addr.IsIPv6()).IsFalse() assert.Bool(addr.IsDomain()).IsFalse() assert.Bytes(addr.IP()).Equals(ip) - assert.Uint16(addr.Port()).Equals(port) + assert.Uint16(addr.Port().Value()).Equals(port.Value()) assert.String(addr.String()).Equals("1.2.3.4:80") } @@ -31,14 +31,14 @@ func TestIPv6Address(t *testing.T) { byte(1), byte(2), byte(3), byte(4), byte(1), byte(2), byte(3), byte(4), } - port := uint16(443) - addr := IPAddress(ip, port) + port := NewPort(443) + addr := IPAddress(ip, port.Value()) assert.Bool(addr.IsIPv6()).IsTrue() assert.Bool(addr.IsIPv4()).IsFalse() assert.Bool(addr.IsDomain()).IsFalse() assert.Bytes(addr.IP()).Equals(ip) - assert.Uint16(addr.Port()).Equals(port) + assert.Uint16(addr.Port().Value()).Equals(port.Value()) assert.String(addr.String()).Equals("[102:304:102:304:102:304:102:304]:443") } @@ -46,14 +46,14 @@ func TestDomainAddress(t *testing.T) { assert := unit.Assert(t) domain := "v2ray.com" - port := uint16(443) - addr := DomainAddress(domain, port) + port := NewPort(443) + addr := DomainAddress(domain, port.Value()) assert.Bool(addr.IsDomain()).IsTrue() assert.Bool(addr.IsIPv4()).IsFalse() assert.Bool(addr.IsIPv6()).IsFalse() assert.String(addr.Domain()).Equals(domain) - assert.Uint16(addr.Port()).Equals(port) + assert.Uint16(addr.Port().Value()).Equals(port.Value()) assert.String(addr.String()).Equals("v2ray.com:443") } @@ -61,8 +61,8 @@ func TestNetIPv4Address(t *testing.T) { assert := unit.Assert(t) ip := net.IPv4(1, 2, 3, 4) - port := uint16(80) - addr := IPAddress(ip, port) + port := NewPort(80) + addr := IPAddress(ip, port.Value()) assert.Bool(addr.IsIPv4()).IsTrue() assert.String(addr.String()).Equals("1.2.3.4:80") } diff --git a/common/net/json/portrange.go b/common/net/json/portrange.go index 516f8605d..5768206c8 100644 --- a/common/net/json/portrange.go +++ b/common/net/json/portrange.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/v2ray/v2ray-core/common/log" + v2net "github.com/v2ray/v2ray-core/common/net" ) var ( @@ -14,15 +15,15 @@ var ( ) type PortRange struct { - from uint16 - to uint16 + from v2net.Port + to v2net.Port } -func (this *PortRange) From() uint16 { +func (this *PortRange) From() v2net.Port { return this.from } -func (this *PortRange) To() uint16 { +func (this *PortRange) To() v2net.Port { return this.to } @@ -34,8 +35,8 @@ func (this *PortRange) UnmarshalJSON(data []byte) error { log.Error("Invalid port [%s]", string(data)) return InvalidPortRange } - this.from = uint16(maybeint) - this.to = uint16(maybeint) + this.from = v2net.Port(uint16(maybeint)) + this.to = v2net.Port(uint16(maybeint)) return nil } @@ -49,8 +50,8 @@ func (this *PortRange) UnmarshalJSON(data []byte) error { log.Error("Invalid from port %s", pair[0]) return InvalidPortRange } - this.from = uint16(value) - this.to = uint16(value) + this.from = v2net.Port(uint16(value)) + this.to = v2net.Port(uint16(value)) return nil } else if len(pair) == 2 { from, err := strconv.Atoi(pair[0]) @@ -58,14 +59,14 @@ func (this *PortRange) UnmarshalJSON(data []byte) error { log.Error("Invalid from port %s", pair[0]) return InvalidPortRange } - this.from = uint16(from) + this.from = v2net.Port(uint16(from)) to, err := strconv.Atoi(pair[1]) if err != nil || to <= 0 || to >= 65535 { log.Error("Invalid to port %s", pair[1]) return InvalidPortRange } - this.to = uint16(to) + this.to = v2net.Port(uint16(to)) if this.from > this.to { log.Error("Invalid port range %d -> %d", this.from, this.to) diff --git a/common/net/json/portrange_test.go b/common/net/json/portrange_test.go index 24f0e7225..125ce2acc 100644 --- a/common/net/json/portrange_test.go +++ b/common/net/json/portrange_test.go @@ -14,8 +14,8 @@ func TestIntPort(t *testing.T) { err := json.Unmarshal([]byte("1234"), &portRange) assert.Error(err).IsNil() - assert.Uint16(portRange.from).Equals(uint16(1234)) - assert.Uint16(portRange.to).Equals(uint16(1234)) + assert.Uint16(portRange.from.Value()).Equals(uint16(1234)) + assert.Uint16(portRange.to.Value()).Equals(uint16(1234)) } func TestOverRangeIntPort(t *testing.T) { @@ -36,8 +36,8 @@ func TestSingleStringPort(t *testing.T) { err := json.Unmarshal([]byte("\"1234\""), &portRange) assert.Error(err).IsNil() - assert.Uint16(portRange.from).Equals(uint16(1234)) - assert.Uint16(portRange.to).Equals(uint16(1234)) + assert.Uint16(portRange.from.Value()).Equals(uint16(1234)) + assert.Uint16(portRange.to.Value()).Equals(uint16(1234)) } func TestStringPairPort(t *testing.T) { @@ -47,8 +47,8 @@ func TestStringPairPort(t *testing.T) { err := json.Unmarshal([]byte("\"1234-5678\""), &portRange) assert.Error(err).IsNil() - assert.Uint16(portRange.from).Equals(uint16(1234)) - assert.Uint16(portRange.to).Equals(uint16(5678)) + assert.Uint16(portRange.from.Value()).Equals(uint16(1234)) + assert.Uint16(portRange.to.Value()).Equals(uint16(5678)) } func TestOverRangeStringPort(t *testing.T) { diff --git a/common/net/port.go b/common/net/port.go new file mode 100644 index 000000000..a2edf2719 --- /dev/null +++ b/common/net/port.go @@ -0,0 +1,23 @@ +package net + +import ( + "strconv" +) + +type Port uint16 + +func NewPort(port int) Port { + return Port(uint16(port)) +} + +func (this Port) Value() uint16 { + return uint16(this) +} + +func (this Port) Bytes() []byte { + return []byte{byte(this >> 8), byte(this)} +} + +func (this Port) String() string { + return strconv.Itoa(int(this)) +} diff --git a/common/net/portrange.go b/common/net/portrange.go index 309ce1db6..d42c6d1af 100644 --- a/common/net/portrange.go +++ b/common/net/portrange.go @@ -1,6 +1,6 @@ package net type PortRange interface { - From() uint16 - To() uint16 + From() Port + To() Port } diff --git a/common/net/testing/portrange.go b/common/net/testing/portrange.go index 5f78a8f86..6b1e9f738 100644 --- a/common/net/testing/portrange.go +++ b/common/net/testing/portrange.go @@ -1,14 +1,18 @@ package testing +import ( + v2net "github.com/v2ray/v2ray-core/common/net" +) + type PortRange struct { - FromValue uint16 - ToValue uint16 + FromValue v2net.Port + ToValue v2net.Port } -func (this *PortRange) From() uint16 { +func (this *PortRange) From() v2net.Port { return this.FromValue } -func (this *PortRange) To() uint16 { +func (this *PortRange) To() v2net.Port { return this.ToValue } diff --git a/common/retry/retry.go b/common/retry/retry.go index 4233b2158..a9ba75d5a 100644 --- a/common/retry/retry.go +++ b/common/retry/retry.go @@ -11,7 +11,7 @@ var ( // Strategy is a way to retry on a specific function. type Strategy interface { - // On performs a retry on a specific function, until it doesn't return any error. + // On performs a retry on a specific function, until it doesn't return any error. On(func() error) error } diff --git a/proxy/socks/protocol/udp.go b/proxy/socks/protocol/udp.go index 7204035bb..6f331cfd1 100644 --- a/proxy/socks/protocol/udp.go +++ b/proxy/socks/protocol/udp.go @@ -34,7 +34,7 @@ func (request *Socks5UDPRequest) Write(buffer *alloc.Buffer) { case request.Address.IsDomain(): buffer.AppendBytes(AddrTypeDomain, byte(len(request.Address.Domain()))).Append([]byte(request.Address.Domain())) } - buffer.Append(request.Address.PortBytes()) + buffer.Append(request.Address.Port().Bytes()) buffer.Append(request.Data.Value) } diff --git a/proxy/socks/socks.go b/proxy/socks/socks.go index 29d54f333..202b1fe78 100644 --- a/proxy/socks/socks.go +++ b/proxy/socks/socks.go @@ -193,7 +193,7 @@ func (this *SocksServer) handleUDP(reader *v2net.TimeOutReader, writer io.Writer udpAddr := this.getUDPAddr() - response.Port = udpAddr.Port() + response.Port = udpAddr.Port().Value() switch { case udpAddr.IsIPv4(): response.SetIPv4(udpAddr.IP()) diff --git a/proxy/vmess/protocol/vmess.go b/proxy/vmess/protocol/vmess.go index 7997b3c82..7e30ca657 100644 --- a/proxy/vmess/protocol/vmess.go +++ b/proxy/vmess/protocol/vmess.go @@ -172,7 +172,7 @@ func (this *VMessRequest) ToBytes(idHash user.CounterHash, randomRangeInt64 user buffer.Append(this.RequestKey) buffer.Append(this.ResponseHeader) buffer.AppendBytes(this.Command) - buffer.Append(this.Address.PortBytes()) + buffer.Append(this.Address.Port().Bytes()) switch { case this.Address.IsIPv4(): diff --git a/shell/point/config/json/json_test.go b/shell/point/config/json/json_test.go index e1d830583..297c2387a 100644 --- a/shell/point/config/json/json_test.go +++ b/shell/point/config/json/json_test.go @@ -67,7 +67,7 @@ func TestDetourConfig(t *testing.T) { detour := detours[0] assert.String(detour.Protocol()).Equals("dokodemo-door") - assert.Uint16(detour.PortRange().From()).Equals(uint16(28394)) - assert.Uint16(detour.PortRange().To()).Equals(uint16(28394)) + assert.Uint16(detour.PortRange().From().Value()).Equals(uint16(28394)) + assert.Uint16(detour.PortRange().To().Value()).Equals(uint16(28394)) assert.Pointer(detour.Settings()).IsNotNil() } diff --git a/shell/point/config/testing/mocks/config.go b/shell/point/config/testing/mocks/config.go index ddb64d5a4..7246b7bfb 100644 --- a/shell/point/config/testing/mocks/config.go +++ b/shell/point/config/testing/mocks/config.go @@ -24,15 +24,15 @@ type LogConfig struct { } type PortRange struct { - FromValue uint16 - ToValue uint16 + FromValue v2net.Port + ToValue v2net.Port } -func (this *PortRange) From() uint16 { +func (this *PortRange) From() v2net.Port { return this.FromValue } -func (this *PortRange) To() uint16 { +func (this *PortRange) To() v2net.Port { return this.ToValue } diff --git a/shell/point/inbound_detour.go b/shell/point/inbound_detour.go index 5639618a2..3aa7d41b6 100644 --- a/shell/point/inbound_detour.go +++ b/shell/point/inbound_detour.go @@ -2,13 +2,14 @@ package point import ( "github.com/v2ray/v2ray-core/common/log" + v2net "github.com/v2ray/v2ray-core/common/net" "github.com/v2ray/v2ray-core/common/retry" "github.com/v2ray/v2ray-core/proxy/common/connhandler" "github.com/v2ray/v2ray-core/shell/point/config" ) type InboundConnectionHandlerWithPort struct { - port uint16 + port v2net.Port handler connhandler.InboundConnectionHandler } @@ -45,7 +46,7 @@ func (this *InboundDetourHandler) Initialize() error { func (this *InboundDetourHandler) Start() error { for _, ich := range this.ich { return retry.Timed(100 /* times */, 100 /* ms */).On(func() error { - err := ich.handler.Listen(ich.port) + err := ich.handler.Listen(ich.port.Value()) if err != nil { return err } diff --git a/testing/scenarios/socks5_helper.go b/testing/scenarios/socks5_helper.go index 4e99117f9..7b09ece6b 100644 --- a/testing/scenarios/socks5_helper.go +++ b/testing/scenarios/socks5_helper.go @@ -40,7 +40,7 @@ func appendAddress(request []byte, address v2net.Address) []byte { request = append(request, []byte(address.Domain())...) } - request = append(request, address.PortBytes()...) + request = append(request, address.Port().Bytes()...) return request }