diff --git a/transport/internet/kcp/dialer.go b/transport/internet/kcp/dialer.go index 35fcc94f9..bb58787c0 100644 --- a/transport/internet/kcp/dialer.go +++ b/transport/internet/kcp/dialer.go @@ -77,16 +77,9 @@ func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, er var iConn internet.Connection = session - if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil { - switch securitySettings := securitySettings.(type) { - case *v2tls.Config: - if dest.Address.Family().IsDomain() { - securitySettings.OverrideServerNameIfEmpty(dest.Address.Domain()) - } - config := securitySettings.GetTLSConfig() - tlsConn := tls.Client(iConn, config) - iConn = tlsConn - } + if config := v2tls.ConfigFromContext(ctx, v2tls.WithDestination(dest)); config != nil { + tlsConn := tls.Client(iConn, config.GetTLSConfig()) + iConn = tlsConn } return iConn, nil diff --git a/transport/internet/kcp/listener.go b/transport/internet/kcp/listener.go index aabe680d1..99ec2f127 100644 --- a/transport/internet/kcp/listener.go +++ b/transport/internet/kcp/listener.go @@ -59,13 +59,11 @@ func NewListener(ctx context.Context, address net.Address, port net.Port, addCon config: kcpSettings, addConn: addConn, } - securitySettings := internet.SecuritySettingsFromContext(ctx) - if securitySettings != nil { - switch securitySettings := securitySettings.(type) { - case *v2tls.Config: - l.tlsConfig = securitySettings.GetTLSConfig() - } + + if config := v2tls.ConfigFromContext(ctx); config != nil { + l.tlsConfig = config.GetTLSConfig() } + hub, err := udp.ListenUDP(address, port, udp.ListenOption{Callback: l.OnReceive, Concurrency: 2}) if err != nil { return nil, err diff --git a/transport/internet/tcp/dialer.go b/transport/internet/tcp/dialer.go index 37921d9d3..bbad0184c 100644 --- a/transport/internet/tcp/dialer.go +++ b/transport/internet/tcp/dialer.go @@ -19,22 +19,16 @@ func getTCPSettingsFromContext(ctx context.Context) *Config { } func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error) { - log.Trace(newError("dailing TCP to ", dest)) + log.Trace(newError("dialing TCP to ", dest)) src := internet.DialerSourceFromContext(ctx) conn, err := internet.DialSystem(ctx, src, dest) if err != nil { return nil, err } - if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil { - tlsConfig, ok := securitySettings.(*tls.Config) - if ok { - if dest.Address.Family().IsDomain() { - tlsConfig.OverrideServerNameIfEmpty(dest.Address.Domain()) - } - config := tlsConfig.GetTLSConfig() - conn = tls.Client(conn, config) - } + + if config := tls.ConfigFromContext(ctx, tls.WithDestination(dest)); config != nil { + conn = tls.Client(conn, config.GetTLSConfig()) } tcpSettings := getTCPSettingsFromContext(ctx) diff --git a/transport/internet/tcp/hub.go b/transport/internet/tcp/hub.go index 22a6cfb20..8015025d1 100644 --- a/transport/internet/tcp/hub.go +++ b/transport/internet/tcp/hub.go @@ -37,12 +37,11 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, addConn config: tcpSettings, addConn: addConn, } - if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil { - tlsConfig, ok := securitySettings.(*tls.Config) - if ok { - l.tlsConfig = tlsConfig.GetTLSConfig() - } + + if config := tls.ConfigFromContext(ctx); config != nil { + l.tlsConfig = config.GetTLSConfig() } + if tcpSettings.HeaderSettings != nil { headerConfig, err := tcpSettings.HeaderSettings.GetInstance() if err != nil { diff --git a/transport/internet/tls/config.go b/transport/internet/tls/config.go index 919a91a37..8d9831c07 100644 --- a/transport/internet/tls/config.go +++ b/transport/internet/tls/config.go @@ -1,9 +1,12 @@ package tls import ( + "context" "crypto/tls" "v2ray.com/core/app/log" + "v2ray.com/core/common/net" + "v2ray.com/core/transport/internet" ) var ( @@ -42,8 +45,26 @@ func (c *Config) GetTLSConfig() *tls.Config { return config } -func (c *Config) OverrideServerNameIfEmpty(serverName string) { - if len(c.ServerName) == 0 { - c.ServerName = serverName +type Option func(*Config) + +func WithDestination(dest net.Destination) Option { + return func(config *Config) { + if dest.Address.Family().IsDomain() && len(config.ServerName) == 0 { + config.ServerName = dest.Address.Domain() + } } } + +func ConfigFromContext(ctx context.Context, opts ...Option) *Config { + securitySettings := internet.SecuritySettingsFromContext(ctx) + if securitySettings == nil { + return nil + } + if config, ok := securitySettings.(*Config); ok { + for _, opt := range opts { + opt(config) + } + return config + } + return nil +} diff --git a/transport/internet/websocket/dialer.go b/transport/internet/websocket/dialer.go index 403600638..868e52857 100644 --- a/transport/internet/websocket/dialer.go +++ b/transport/internet/websocket/dialer.go @@ -42,15 +42,9 @@ func dialWebsocket(ctx context.Context, dest net.Destination) (net.Conn, error) protocol := "ws" - if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil { - tlsConfig, ok := securitySettings.(*tls.Config) - if ok { - protocol = "wss" - if dest.Address.Family().IsDomain() { - tlsConfig.OverrideServerNameIfEmpty(dest.Address.Domain()) - } - dialer.TLSClientConfig = tlsConfig.GetTLSConfig() - } + if config := tls.ConfigFromContext(ctx, tls.WithDestination(dest)); config != nil { + protocol = "wss" + dialer.TLSClientConfig = config.GetTLSConfig() } host := dest.NetAddr() diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go index d88ac6583..0226be785 100644 --- a/transport/internet/websocket/hub.go +++ b/transport/internet/websocket/hub.go @@ -59,11 +59,8 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, addConn i config: wsSettings, addConn: addConn, } - if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil { - tlsConfig, ok := securitySettings.(*v2tls.Config) - if ok { - l.tlsConfig = tlsConfig.GetTLSConfig() - } + if config := v2tls.ConfigFromContext(ctx); config != nil { + l.tlsConfig = config.GetTLSConfig() } err := l.listenws(address, port)