From 3b545abe02df106f7ccd29e65e0773c92cd1c730 Mon Sep 17 00:00:00 2001 From: v2ray Date: Sun, 15 May 2016 23:09:28 -0700 Subject: [PATCH] dns client implementation --- app/dns/dns.go | 24 +++-- app/dns/internal/config.go | 11 +-- app/dns/internal/config_json.go | 20 ++-- app/dns/internal/dns.go | 97 +++++++++++--------- app/dns/internal/dns_test.go | 54 ++++++++--- app/dns/internal/nameserver.go | 158 ++++++++++++++++++++++++++++++++ common/net/address.go | 4 + proxy/dokodemo/dokodemo.go | 20 ++-- 8 files changed, 292 insertions(+), 96 deletions(-) create mode 100644 app/dns/internal/nameserver.go diff --git a/app/dns/dns.go b/app/dns/dns.go index 96c34efc2..b80bf4d3b 100644 --- a/app/dns/dns.go +++ b/app/dns/dns.go @@ -11,33 +11,31 @@ const ( ) // A DnsCache is an internal cache of DNS resolutions. -type DnsCache interface { - Get(domain string) net.IP - Add(domain string, ip net.IP) +type Server interface { + Get(domain string) []net.IP } -type dnsCacheWithContext interface { - Get(context app.Context, domain string) net.IP - Add(contaxt app.Context, domain string, ip net.IP) +type dnsServerWithContext interface { + Get(context app.Context, domain string) []net.IP } -type contextedDnsCache struct { +type contextedDnsServer struct { context app.Context - dnsCache dnsCacheWithContext + dnsCache dnsServerWithContext } -func (this *contextedDnsCache) Get(domain string) net.IP { +func (this *contextedDnsServer) Get(domain string) []net.IP { return this.dnsCache.Get(this.context, domain) } -func (this *contextedDnsCache) Add(domain string, ip net.IP) { - this.dnsCache.Add(this.context, domain, ip) +func CreateDNSServer(rawConfig interface{}) (Server, error) { + return nil, nil } func init() { app.Register(APP_ID, func(context app.Context, obj interface{}) interface{} { - dcContext := obj.(dnsCacheWithContext) - return &contextedDnsCache{ + dcContext := obj.(dnsServerWithContext) + return &contextedDnsServer{ context: context, dnsCache: dcContext, } diff --git a/app/dns/internal/config.go b/app/dns/internal/config.go index 0b28aeb69..7d18323ec 100644 --- a/app/dns/internal/config.go +++ b/app/dns/internal/config.go @@ -1,14 +1,9 @@ package internal import ( - "github.com/v2ray/v2ray-core/common/serial" + v2net "github.com/v2ray/v2ray-core/common/net" ) -type CacheConfig struct { - TrustedTags map[serial.StringLiteral]bool -} - -func (this *CacheConfig) IsTrustedSource(tag serial.StringLiteral) bool { - _, found := this.TrustedTags[tag] - return found +type Config struct { + NameServers []v2net.Destination } diff --git a/app/dns/internal/config_json.go b/app/dns/internal/config_json.go index 6b4462c02..351cb0ecb 100644 --- a/app/dns/internal/config_json.go +++ b/app/dns/internal/config_json.go @@ -5,19 +5,21 @@ package internal import ( "encoding/json" - "github.com/v2ray/v2ray-core/common/serial" + v2net "github.com/v2ray/v2ray-core/common/net" ) -func (this *CacheConfig) UnmarshalJSON(data []byte) error { - var strlist serial.StringLiteralList - if err := json.Unmarshal(data, strlist); err != nil { +func (this *Config) UnmarshalJSON(data []byte) error { + type JsonConfig struct { + Servers []v2net.Address `json:"servers"` + } + jsonConfig := new(JsonConfig) + if err := json.Unmarshal(data, jsonConfig); err != nil { return err } - config := &CacheConfig{ - TrustedTags: make(map[serial.StringLiteral]bool, strlist.Len()), - } - for _, str := range strlist { - config.TrustedTags[str.TrimSpace()] = true + this.NameServers = make([]v2net.Destination, len(jsonConfig.Servers)) + for idx, server := range jsonConfig.Servers { + this.NameServers[idx] = v2net.UDPDestination(server, v2net.Port(53)) } + return nil } diff --git a/app/dns/internal/dns.go b/app/dns/internal/dns.go index 0445360fe..817ed11e1 100644 --- a/app/dns/internal/dns.go +++ b/app/dns/internal/dns.go @@ -2,61 +2,72 @@ package internal import ( "net" + "sync" "time" "github.com/v2ray/v2ray-core/app" - "github.com/v2ray/v2ray-core/common/collect" - "github.com/v2ray/v2ray-core/common/serial" + "github.com/v2ray/v2ray-core/app/dispatcher" + + "github.com/miekg/dns" ) -type entry struct { - domain string - ip net.IP - validUntil time.Time +const ( + QueryTimeout = time.Second * 2 +) + +type DomainRecord struct { + A *ARecord } -func newEntry(domain string, ip net.IP) *entry { - this := &entry{ - domain: domain, - ip: ip, +type Server struct { + sync.RWMutex + records map[string]*DomainRecord + servers []NameServer +} + +func NewServer(space app.Space, config *Config) *Server { + server := &Server{ + records: make(map[string]*DomainRecord), + servers: make([]NameServer, len(config.NameServers)), } - this.Extend() - return this -} - -func (this *entry) IsValid() bool { - return this.validUntil.After(time.Now()) -} - -func (this *entry) Extend() { - this.validUntil = time.Now().Add(time.Hour) -} - -type DnsCache struct { - cache *collect.ValidityMap - config *CacheConfig -} - -func NewCache(config *CacheConfig) *DnsCache { - cache := &DnsCache{ - cache: collect.NewValidityMap(3600), - config: config, + dispatcher := space.GetApp(dispatcher.APP_ID).(dispatcher.PacketDispatcher) + for idx, ns := range config.NameServers { + server.servers[idx] = NewUDPNameServer(ns, dispatcher) } - return cache + return server } -func (this *DnsCache) Add(context app.Context, domain string, ip net.IP) { - callerTag := context.CallerTag() - if !this.config.IsTrustedSource(serial.StringLiteral(callerTag)) { - return - } +//@Private +func (this *Server) GetCached(domain string) []net.IP { + this.RLock() + defer this.RUnlock() - this.cache.Set(serial.StringLiteral(domain), newEntry(domain, ip)) -} - -func (this *DnsCache) Get(context app.Context, domain string) net.IP { - if value := this.cache.Get(serial.StringLiteral(domain)); value != nil { - return value.(*entry).ip + if record, found := this.records[domain]; found && record.A.Expire.After(time.Now()) { + return record.A.IPs } return nil } + +func (this *Server) Get(context app.Context, domain string) []net.IP { + domain = dns.Fqdn(domain) + ips := this.GetCached(domain) + if ips != nil { + return ips + } + + for _, server := range this.servers { + response := server.QueryA(domain) + select { + case a := <-response: + this.Lock() + this.records[domain] = &DomainRecord{ + A: a, + } + this.Unlock() + return a.IPs + case <-time.Tick(QueryTimeout): + } + } + + return nil +} diff --git a/app/dns/internal/dns_test.go b/app/dns/internal/dns_test.go index 09d347c7d..111ec1a14 100644 --- a/app/dns/internal/dns_test.go +++ b/app/dns/internal/dns_test.go @@ -4,30 +4,56 @@ import ( "net" "testing" + "github.com/v2ray/v2ray-core/app" + "github.com/v2ray/v2ray-core/app/dispatcher" . "github.com/v2ray/v2ray-core/app/dns/internal" apptesting "github.com/v2ray/v2ray-core/app/testing" + v2net "github.com/v2ray/v2ray-core/common/net" netassert "github.com/v2ray/v2ray-core/common/net/testing/assert" - "github.com/v2ray/v2ray-core/common/serial" + "github.com/v2ray/v2ray-core/proxy/freedom" v2testing "github.com/v2ray/v2ray-core/testing" + "github.com/v2ray/v2ray-core/testing/assert" + "github.com/v2ray/v2ray-core/transport/ray" ) +type TestDispatcher struct { + freedom *freedom.FreedomConnection +} + +func (this *TestDispatcher) DispatchToOutbound(context app.Context, dest v2net.Destination) ray.InboundRay { + direct := ray.NewRay() + + go func() { + payload, err := direct.OutboundInput().Read() + if err != nil { + direct.OutboundInput().Release() + direct.OutboundOutput().Release() + return + } + this.freedom.Dispatch(dest, payload, direct) + }() + return direct +} + func TestDnsAdd(t *testing.T) { v2testing.Current(t) + d := &TestDispatcher{ + freedom: &freedom.FreedomConnection{}, + } + spaceController := app.NewController() + spaceController.Bind(dispatcher.APP_ID, d) + space := spaceController.ForContext("test") + domain := "v2ray.com" - cache := NewCache(&CacheConfig{ - TrustedTags: map[serial.StringLiteral]bool{ - serial.StringLiteral("testtag"): true, + server := NewServer(space, &Config{ + NameServers: []v2net.Destination{ + v2net.UDPDestination(v2net.IPAddress([]byte{8, 8, 8, 8}), v2net.Port(53)), }, }) - ip := cache.Get(&apptesting.Context{}, domain) - netassert.IP(ip).IsNil() - - cache.Add(&apptesting.Context{CallerTagValue: "notvalidtag"}, domain, []byte{1, 2, 3, 4}) - ip = cache.Get(&apptesting.Context{}, domain) - netassert.IP(ip).IsNil() - - cache.Add(&apptesting.Context{CallerTagValue: "testtag"}, domain, []byte{1, 2, 3, 4}) - ip = cache.Get(&apptesting.Context{}, domain) - netassert.IP(ip).Equals(net.IP([]byte{1, 2, 3, 4})) + ips := server.Get(&apptesting.Context{ + CallerTagValue: "a", + }, domain) + assert.Int(len(ips)).Equals(2) + netassert.IP(ips[0].To4()).Equals(net.IP([]byte{104, 27, 154, 107})) } diff --git a/app/dns/internal/nameserver.go b/app/dns/internal/nameserver.go new file mode 100644 index 000000000..a87826696 --- /dev/null +++ b/app/dns/internal/nameserver.go @@ -0,0 +1,158 @@ +package internal + +import ( + "math/rand" + "net" + "sync" + "time" + + "github.com/v2ray/v2ray-core/app/dispatcher" + "github.com/v2ray/v2ray-core/common/alloc" + "github.com/v2ray/v2ray-core/common/log" + v2net "github.com/v2ray/v2ray-core/common/net" + "github.com/v2ray/v2ray-core/transport/hub" + + "github.com/miekg/dns" +) + +const ( + DefaultTTL = uint32(3600) +) + +type ARecord struct { + IPs []net.IP + Expire time.Time +} + +type NameServer interface { + QueryA(domain string) <-chan *ARecord +} + +type PendingRequest struct { + expire time.Time + response chan<- *ARecord +} + +type UDPNameServer struct { + sync.Mutex + address v2net.Destination + requests map[uint16]*PendingRequest + udpServer *hub.UDPServer +} + +func NewUDPNameServer(address v2net.Destination, dispatcher dispatcher.PacketDispatcher) *UDPNameServer { + s := &UDPNameServer{ + address: address, + requests: make(map[uint16]*PendingRequest), + udpServer: hub.NewUDPServer(dispatcher), + } + go s.Cleanup() + return s +} + +// @Private +func (this *UDPNameServer) Cleanup() { + for { + time.Sleep(time.Second * 60) + expiredRequests := make([]uint16, 0, 16) + now := time.Now() + this.Lock() + for id, r := range this.requests { + if r.expire.Before(now) { + expiredRequests = append(expiredRequests, id) + close(r.response) + } + } + for _, id := range expiredRequests { + delete(this.requests, id) + } + this.Unlock() + expiredRequests = nil + } +} + +// @Private +func (this *UDPNameServer) AssignUnusedID(response chan<- *ARecord) uint16 { + var id uint16 + this.Lock() + for { + id = uint16(rand.Intn(65536)) + if _, found := this.requests[id]; found { + continue + } + log.Debug("DNS: Add pending request id ", id) + this.requests[id] = &PendingRequest{ + expire: time.Now().Add(time.Second * 16), + response: response, + } + break + } + this.Unlock() + return id +} + +// @Private +func (this *UDPNameServer) HandleResponse(dest v2net.Destination, payload *alloc.Buffer) { + msg := new(dns.Msg) + err := msg.Unpack(payload.Value) + if err != nil { + log.Warning("DNS: Failed to parse DNS response: ", err) + return + } + record := &ARecord{ + IPs: make([]net.IP, 0, 16), + } + id := msg.Id + ttl := DefaultTTL + + this.Lock() + request, found := this.requests[id] + if !found { + this.Unlock() + return + } + delete(this.requests, id) + this.Unlock() + + for _, rr := range msg.Answer { + if a, ok := rr.(*dns.A); ok { + record.IPs = append(record.IPs, a.A) + if a.Hdr.Ttl < ttl { + ttl = a.Hdr.Ttl + } + } + } + record.Expire = time.Now().Add(time.Second * time.Duration(ttl)) + + request.response <- record + close(request.response) +} + +func (this *UDPNameServer) QueryA(domain string) <-chan *ARecord { + response := make(chan *ARecord) + + buffer := alloc.NewBuffer() + msg := new(dns.Msg) + msg.Id = this.AssignUnusedID(response) + msg.RecursionDesired = true + msg.Question = []dns.Question{ + dns.Question{ + Name: dns.Fqdn(domain), + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + dns.Question{ + Name: dns.Fqdn(domain), + Qtype: dns.TypeAAAA, + Qclass: dns.ClassINET, + }, + } + + writtenBuffer, _ := msg.PackBuffer(buffer.Value) + buffer.Slice(0, len(writtenBuffer)) + + fakeDestination := v2net.UDPDestination(v2net.LocalHostIP, v2net.Port(53)) + this.udpServer.Dispatch(fakeDestination, this.address, buffer, this.HandleResponse) + + return response +} diff --git a/common/net/address.go b/common/net/address.go index c886df90d..bd3f132fb 100644 --- a/common/net/address.go +++ b/common/net/address.go @@ -7,6 +7,10 @@ import ( "github.com/v2ray/v2ray-core/common/serial" ) +var ( + LocalHostIP = IPAddress([]byte{127, 0, 0, 1}) +) + // Address represents a network address to be communicated with. It may be an IP address or domain // address, not both. This interface doesn't resolve IP address for a given domain. type Address interface { diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index c1b2a0e5d..7898fe997 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -95,15 +95,17 @@ func (this *DokodemoDoor) ListenUDP(port v2net.Port) error { } func (this *DokodemoDoor) handleUDPPackets(payload *alloc.Buffer, dest v2net.Destination) { - this.udpServer.Dispatch(dest, v2net.UDPDestination(this.address, this.port), payload, func(destination v2net.Destination, payload *alloc.Buffer) { - defer payload.Release() - this.udpMutex.RLock() - defer this.udpMutex.RUnlock() - if !this.accepting { - return - } - this.udpHub.WriteTo(payload.Value, destination) - }) + this.udpServer.Dispatch(dest, v2net.UDPDestination(this.address, this.port), payload, this.handleUDPResponse) +} + +func (this *DokodemoDoor) handleUDPResponse(dest v2net.Destination, payload *alloc.Buffer) { + defer payload.Release() + this.udpMutex.RLock() + defer this.udpMutex.RUnlock() + if !this.accepting { + return + } + this.udpHub.WriteTo(payload.Value, dest) } func (this *DokodemoDoor) ListenTCP(port v2net.Port) error {