mirror of
https://github.com/v2fly/v2ray-core.git
synced 2025-11-24 22:16:28 +00:00
360 lines
8.6 KiB
Go
360 lines
8.6 KiB
Go
|
|
package wireguard
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"encoding/base64"
|
||
|
|
"encoding/hex"
|
||
|
|
core "github.com/v2fly/v2ray-core/v4"
|
||
|
|
"github.com/v2fly/v2ray-core/v4/common"
|
||
|
|
"github.com/v2fly/v2ray-core/v4/common/buf"
|
||
|
|
"github.com/v2fly/v2ray-core/v4/common/net"
|
||
|
|
"github.com/v2fly/v2ray-core/v4/common/protocol"
|
||
|
|
"github.com/v2fly/v2ray-core/v4/common/session"
|
||
|
|
"github.com/v2fly/v2ray-core/v4/common/signal"
|
||
|
|
"github.com/v2fly/v2ray-core/v4/common/signal/done"
|
||
|
|
"github.com/v2fly/v2ray-core/v4/common/task"
|
||
|
|
"github.com/v2fly/v2ray-core/v4/features/dns"
|
||
|
|
"github.com/v2fly/v2ray-core/v4/features/policy"
|
||
|
|
"github.com/v2fly/v2ray-core/v4/features/routing"
|
||
|
|
"github.com/v2fly/v2ray-core/v4/proxy"
|
||
|
|
"github.com/v2fly/v2ray-core/v4/transport"
|
||
|
|
"github.com/v2fly/v2ray-core/v4/transport/internet"
|
||
|
|
"golang.zx2c4.com/wireguard/conn"
|
||
|
|
"golang.zx2c4.com/wireguard/device"
|
||
|
|
"golang.zx2c4.com/wireguard/tun"
|
||
|
|
"strings"
|
||
|
|
"sync"
|
||
|
|
)
|
||
|
|
|
||
|
|
func init() {
|
||
|
|
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
|
||
|
|
o := new(Outbound)
|
||
|
|
err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher, policyManager policy.Manager, dnsClient dns.Client) error {
|
||
|
|
o.ctx = ctx
|
||
|
|
o.dispatcher = dispatcher
|
||
|
|
o.dnsClient = dnsClient
|
||
|
|
o.init = done.New()
|
||
|
|
return o.Init(config.(*Config), policyManager)
|
||
|
|
})
|
||
|
|
return o, err
|
||
|
|
}))
|
||
|
|
common.Must(common.RegisterConfig((*SimplifiedConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
|
||
|
|
o := new(Outbound)
|
||
|
|
err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher, policyManager policy.Manager, dnsClient dns.Client) error {
|
||
|
|
sf := config.(*SimplifiedConfig)
|
||
|
|
cf := &Config{
|
||
|
|
Server: &protocol.ServerEndpoint{
|
||
|
|
Address: sf.Address,
|
||
|
|
Port: sf.Port,
|
||
|
|
},
|
||
|
|
Network: sf.Network,
|
||
|
|
PrivateKey: sf.PrivateKey,
|
||
|
|
PeerPublicKey: sf.PeerPublicKey,
|
||
|
|
PreSharedKey: sf.PreSharedKey,
|
||
|
|
Mtu: sf.Mtu,
|
||
|
|
UserLevel: sf.UserLevel,
|
||
|
|
}
|
||
|
|
o.ctx = ctx
|
||
|
|
o.dispatcher = dispatcher
|
||
|
|
o.dnsClient = dnsClient
|
||
|
|
o.init = done.New()
|
||
|
|
return o.Init(cf, policyManager)
|
||
|
|
})
|
||
|
|
return o, err
|
||
|
|
}))
|
||
|
|
}
|
||
|
|
|
||
|
|
var _ proxy.Outbound = (*Outbound)(nil)
|
||
|
|
var _ conn.Bind = (*Outbound)(nil)
|
||
|
|
|
||
|
|
type Outbound struct {
|
||
|
|
sync.Mutex
|
||
|
|
|
||
|
|
ctx context.Context
|
||
|
|
dispatcher routing.Dispatcher
|
||
|
|
sessionPolicy policy.Session
|
||
|
|
dnsClient dns.Client
|
||
|
|
|
||
|
|
tun tun.Device
|
||
|
|
dev *device.Device
|
||
|
|
wire *Net
|
||
|
|
dialer internet.Dialer
|
||
|
|
|
||
|
|
init *done.Instance
|
||
|
|
destination net.Destination
|
||
|
|
endpoint *conn.StdNetEndpoint
|
||
|
|
connection *remoteConnection
|
||
|
|
}
|
||
|
|
|
||
|
|
func (o *Outbound) Init(config *Config, policyManager policy.Manager) error {
|
||
|
|
o.sessionPolicy = policyManager.ForLevel(config.UserLevel)
|
||
|
|
spec, err := protocol.NewServerSpecFromPB(config.Server)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
o.destination = spec.Destination()
|
||
|
|
o.endpoint = &conn.StdNetEndpoint{
|
||
|
|
Port: int(o.destination.Port),
|
||
|
|
}
|
||
|
|
|
||
|
|
if o.destination.Address.Family().IsDomain() {
|
||
|
|
o.endpoint.IP = []byte{172, 19, 0, 3}
|
||
|
|
} else {
|
||
|
|
o.endpoint.IP = o.destination.Address.IP()
|
||
|
|
}
|
||
|
|
|
||
|
|
localAddress := make([]net.IP, len(config.LocalAddress))
|
||
|
|
if len(localAddress) == 0 {
|
||
|
|
return newError("empty local address")
|
||
|
|
}
|
||
|
|
for index, address := range config.LocalAddress {
|
||
|
|
localAddress[index] = net.ParseIP(address)
|
||
|
|
}
|
||
|
|
|
||
|
|
var privateKey, peerPublicKey, preSharedKey string
|
||
|
|
{
|
||
|
|
decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(config.PrivateKey))
|
||
|
|
bytes, err := buf.ReadAllToBytes(decoder)
|
||
|
|
if err != nil {
|
||
|
|
return newError("failed to decode private key from base64: ", config.PrivateKey).Base(err)
|
||
|
|
}
|
||
|
|
privateKey = hex.EncodeToString(bytes)
|
||
|
|
}
|
||
|
|
{
|
||
|
|
decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(config.PeerPublicKey))
|
||
|
|
bytes, err := buf.ReadAllToBytes(decoder)
|
||
|
|
if err != nil {
|
||
|
|
return newError("failed to decode peer public key from base64: ", config.PeerPublicKey).Base(err)
|
||
|
|
}
|
||
|
|
peerPublicKey = hex.EncodeToString(bytes)
|
||
|
|
}
|
||
|
|
if config.PreSharedKey != "" {
|
||
|
|
decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(config.PreSharedKey))
|
||
|
|
bytes, err := buf.ReadAllToBytes(decoder)
|
||
|
|
if err != nil {
|
||
|
|
return newError("failed to decode pre share key from base64: ", config.PreSharedKey).Base(err)
|
||
|
|
}
|
||
|
|
preSharedKey = hex.EncodeToString(bytes)
|
||
|
|
}
|
||
|
|
ipcConf := "private_key=" + privateKey
|
||
|
|
ipcConf += "\npublic_key=" + peerPublicKey
|
||
|
|
ipcConf += "\nendpoint=" + o.endpoint.DstToString()
|
||
|
|
|
||
|
|
if preSharedKey != "" {
|
||
|
|
ipcConf += "\npreshared_key=" + preSharedKey
|
||
|
|
}
|
||
|
|
|
||
|
|
var has4, has6 bool
|
||
|
|
|
||
|
|
for _, address := range localAddress {
|
||
|
|
if address.To4() != nil {
|
||
|
|
has4 = true
|
||
|
|
} else {
|
||
|
|
has6 = true
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if has4 {
|
||
|
|
ipcConf += "\nallowed_ip=0.0.0.0/0"
|
||
|
|
}
|
||
|
|
|
||
|
|
if has6 {
|
||
|
|
ipcConf += "\nallowed_ip=::/0"
|
||
|
|
}
|
||
|
|
|
||
|
|
mtu := int(config.Mtu)
|
||
|
|
if mtu == 0 {
|
||
|
|
mtu = 1450
|
||
|
|
}
|
||
|
|
tun, wire, err := CreateNetTUN(localAddress, mtu)
|
||
|
|
|
||
|
|
if err != nil {
|
||
|
|
return newError("failed to create wireguard device").Base(err)
|
||
|
|
}
|
||
|
|
|
||
|
|
dev := device.NewDevice(tun, o, device.NewLogger(device.LogLevelVerbose, ""))
|
||
|
|
err = dev.IpcSet(ipcConf)
|
||
|
|
if err != nil {
|
||
|
|
return newError("failed to set wireguard ipc conf").Base(err)
|
||
|
|
}
|
||
|
|
|
||
|
|
o.tun = tun
|
||
|
|
o.dev = dev
|
||
|
|
o.wire = wire
|
||
|
|
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (o *Outbound) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
|
||
|
|
if o.dialer == nil {
|
||
|
|
o.dialer = dialer
|
||
|
|
}
|
||
|
|
o.init.Close()
|
||
|
|
|
||
|
|
outbound := session.OutboundFromContext(ctx)
|
||
|
|
if outbound == nil || !outbound.Target.IsValid() {
|
||
|
|
return newError("target not specified")
|
||
|
|
}
|
||
|
|
destination := outbound.Target
|
||
|
|
|
||
|
|
if destination.Address.Family().IsDomain() {
|
||
|
|
if c, ok := o.dnsClient.(dns.ClientWithIPOption); ok {
|
||
|
|
c.SetFakeDNSOption(false)
|
||
|
|
}
|
||
|
|
ips, err := o.dnsClient.LookupIP(destination.Address.Domain())
|
||
|
|
if err != nil {
|
||
|
|
return newError("failed to lookup ip addresses for domain ", destination.Address.Domain()).Base(err)
|
||
|
|
}
|
||
|
|
destination.Address = net.IPAddress(ips[0])
|
||
|
|
}
|
||
|
|
|
||
|
|
var conn internet.Connection
|
||
|
|
{
|
||
|
|
var err error
|
||
|
|
|
||
|
|
switch destination.Network {
|
||
|
|
case net.Network_TCP:
|
||
|
|
conn, err = o.wire.DialContextTCP(ctx, &net.TCPAddr{
|
||
|
|
IP: destination.Address.IP(),
|
||
|
|
Port: int(destination.Port),
|
||
|
|
})
|
||
|
|
case net.Network_UDP:
|
||
|
|
conn, err = o.wire.DialUDP(nil, &net.UDPAddr{
|
||
|
|
IP: destination.Address.IP(),
|
||
|
|
Port: int(destination.Port),
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
defer conn.Close()
|
||
|
|
|
||
|
|
ctx, cancel := context.WithCancel(ctx)
|
||
|
|
timer := signal.CancelAfterInactivity(ctx, cancel, o.sessionPolicy.Timeouts.ConnectionIdle)
|
||
|
|
ctx = policy.ContextWithBufferPolicy(ctx, o.sessionPolicy.Buffer)
|
||
|
|
|
||
|
|
uplink := func() error {
|
||
|
|
defer timer.SetTimeout(o.sessionPolicy.Timeouts.UplinkOnly)
|
||
|
|
|
||
|
|
if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil {
|
||
|
|
return newError("failed to transport all TCP response").Base(err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
downlink := func() error {
|
||
|
|
defer timer.SetTimeout(o.sessionPolicy.Timeouts.DownlinkOnly)
|
||
|
|
|
||
|
|
if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil {
|
||
|
|
return newError("failed to transport all TCP request").Base(err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := task.Run(ctx, uplink, downlink); err != nil {
|
||
|
|
common.Interrupt(link.Reader)
|
||
|
|
common.Interrupt(link.Writer)
|
||
|
|
return newError("connection ends").Base(err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return nil
|
||
|
|
|
||
|
|
}
|
||
|
|
|
||
|
|
type remoteConnection struct {
|
||
|
|
internet.Connection
|
||
|
|
done *done.Instance
|
||
|
|
}
|
||
|
|
|
||
|
|
func (r remoteConnection) Close() error {
|
||
|
|
if !r.done.Done() {
|
||
|
|
r.done.Close()
|
||
|
|
}
|
||
|
|
return r.Connection.Close()
|
||
|
|
}
|
||
|
|
|
||
|
|
func (o *Outbound) connect() (*remoteConnection, error) {
|
||
|
|
if o.dialer == nil {
|
||
|
|
<-o.init.Wait()
|
||
|
|
}
|
||
|
|
|
||
|
|
if c := o.connection; c != nil && !c.done.Done() {
|
||
|
|
return c, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
o.Lock()
|
||
|
|
defer o.Unlock()
|
||
|
|
|
||
|
|
if c := o.connection; c != nil && !c.done.Done() {
|
||
|
|
return c, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
conn, err := o.dialer.Dial(context.Background(), o.destination)
|
||
|
|
if err == nil {
|
||
|
|
o.connection = &remoteConnection{
|
||
|
|
conn,
|
||
|
|
done.New(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return o.connection, err
|
||
|
|
}
|
||
|
|
|
||
|
|
func (o *Outbound) Open(uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
|
||
|
|
return []conn.ReceiveFunc{o.Receive}, 0, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (o *Outbound) Receive(b []byte) (n int, ep conn.Endpoint, err error) {
|
||
|
|
var c *remoteConnection
|
||
|
|
c, err = o.connect()
|
||
|
|
if err != nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
n, err = c.Read(b)
|
||
|
|
if err != nil {
|
||
|
|
common.Close(c)
|
||
|
|
} else {
|
||
|
|
ep = o.endpoint
|
||
|
|
}
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
func (o *Outbound) Close() error {
|
||
|
|
o.Lock()
|
||
|
|
defer o.Unlock()
|
||
|
|
|
||
|
|
c := o.connection
|
||
|
|
if c != nil {
|
||
|
|
common.Close(c)
|
||
|
|
}
|
||
|
|
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (o *Outbound) SetMark(uint32) error {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (o *Outbound) Send(b []byte, _ conn.Endpoint) (err error) {
|
||
|
|
var c *remoteConnection
|
||
|
|
c, err = o.connect()
|
||
|
|
if err != nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
_, err = c.Write(b)
|
||
|
|
if err != nil {
|
||
|
|
common.Close(c)
|
||
|
|
}
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
func (o *Outbound) ParseEndpoint(string) (conn.Endpoint, error) {
|
||
|
|
return o.endpoint, nil
|
||
|
|
}
|