| 
									
										
										
										
											2016-02-27 00:05:53 +01:00
										 |  |  | package raw | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							|  |  |  | 	"errors" | 
					
						
							|  |  |  | 	"io" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-02-27 16:41:21 +01:00
										 |  |  | 	"github.com/v2ray/v2ray-core/common/alloc" | 
					
						
							| 
									
										
										
										
											2016-02-27 00:05:53 +01:00
										 |  |  | 	v2net "github.com/v2ray/v2ray-core/common/net" | 
					
						
							|  |  |  | 	"github.com/v2ray/v2ray-core/common/protocol" | 
					
						
							|  |  |  | 	"github.com/v2ray/v2ray-core/common/serial" | 
					
						
							|  |  |  | 	"github.com/v2ray/v2ray-core/common/uuid" | 
					
						
							|  |  |  | 	"github.com/v2ray/v2ray-core/transport" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | var ( | 
					
						
							|  |  |  | 	ErrorCommandTypeMismatch = errors.New("Command type mismatch.") | 
					
						
							|  |  |  | 	ErrorUnknownCommand      = errors.New("Unknown command.") | 
					
						
							| 
									
										
										
										
											2016-02-27 16:41:21 +01:00
										 |  |  | 	ErrorCommandTooLarge     = errors.New("Command too large.") | 
					
						
							| 
									
										
										
										
											2016-02-27 00:05:53 +01:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func MarshalCommand(command interface{}, writer io.Writer) error { | 
					
						
							| 
									
										
										
										
											2016-02-27 17:28:21 +01:00
										 |  |  | 	if command == nil { | 
					
						
							|  |  |  | 		return ErrorUnknownCommand | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-02-27 16:41:21 +01:00
										 |  |  | 	var cmdId byte | 
					
						
							| 
									
										
										
										
											2016-02-27 00:05:53 +01:00
										 |  |  | 	var factory CommandFactory | 
					
						
							|  |  |  | 	switch command.(type) { | 
					
						
							|  |  |  | 	case *protocol.CommandSwitchAccount: | 
					
						
							|  |  |  | 		factory = new(CommandSwitchAccountFactory) | 
					
						
							| 
									
										
										
										
											2016-02-27 16:41:21 +01:00
										 |  |  | 		cmdId = 1 | 
					
						
							| 
									
										
										
										
											2016-02-27 00:05:53 +01:00
										 |  |  | 	default: | 
					
						
							|  |  |  | 		return ErrorUnknownCommand | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2016-02-27 16:41:21 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-02-27 17:28:21 +01:00
										 |  |  | 	buffer := alloc.NewSmallBuffer().Clear() | 
					
						
							| 
									
										
										
										
											2016-02-27 16:41:21 +01:00
										 |  |  | 	err := factory.Marshal(command, buffer) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	auth := Authenticate(buffer.Value) | 
					
						
							|  |  |  | 	len := buffer.Len() + 4 | 
					
						
							|  |  |  | 	if len > 255 { | 
					
						
							|  |  |  | 		return ErrorCommandTooLarge | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	writer.Write([]byte{cmdId, byte(len), byte(auth >> 24), byte(auth >> 16), byte(auth >> 8), byte(auth)}) | 
					
						
							|  |  |  | 	writer.Write(buffer.Value) | 
					
						
							|  |  |  | 	return nil | 
					
						
							| 
									
										
										
										
											2016-02-27 00:05:53 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func UnmarshalCommand(cmdId byte, data []byte) (protocol.ResponseCommand, error) { | 
					
						
							| 
									
										
										
										
											2016-02-27 16:41:21 +01:00
										 |  |  | 	if len(data) <= 4 { | 
					
						
							|  |  |  | 		return nil, transport.ErrorCorruptedPacket | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	expectedAuth := Authenticate(data[4:]) | 
					
						
							|  |  |  | 	actualAuth := serial.BytesLiteral(data[:4]).Uint32Value() | 
					
						
							|  |  |  | 	if expectedAuth != actualAuth { | 
					
						
							|  |  |  | 		return nil, transport.ErrorCorruptedPacket | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-02-27 00:05:53 +01:00
										 |  |  | 	var factory CommandFactory | 
					
						
							|  |  |  | 	switch cmdId { | 
					
						
							|  |  |  | 	case 1: | 
					
						
							|  |  |  | 		factory = new(CommandSwitchAccountFactory) | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 		return nil, ErrorUnknownCommand | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2016-02-27 16:41:21 +01:00
										 |  |  | 	return factory.Unmarshal(data[4:]) | 
					
						
							| 
									
										
										
										
											2016-02-27 00:05:53 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type CommandFactory interface { | 
					
						
							|  |  |  | 	Marshal(command interface{}, writer io.Writer) error | 
					
						
							|  |  |  | 	Unmarshal(data []byte) (interface{}, error) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type CommandSwitchAccountFactory struct { | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (this *CommandSwitchAccountFactory) Marshal(command interface{}, writer io.Writer) error { | 
					
						
							|  |  |  | 	cmd, ok := command.(*protocol.CommandSwitchAccount) | 
					
						
							|  |  |  | 	if !ok { | 
					
						
							|  |  |  | 		return ErrorCommandTypeMismatch | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	hostStr := "" | 
					
						
							|  |  |  | 	if cmd.Host != nil { | 
					
						
							|  |  |  | 		hostStr = cmd.Host.String() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	writer.Write([]byte{byte(len(hostStr))}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if len(hostStr) > 0 { | 
					
						
							|  |  |  | 		writer.Write([]byte(hostStr)) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	writer.Write(cmd.Port.Bytes()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	idBytes := cmd.ID.Bytes() | 
					
						
							|  |  |  | 	writer.Write(idBytes) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	writer.Write(cmd.AlterIds.Bytes()) | 
					
						
							|  |  |  | 	writer.Write([]byte{byte(cmd.Level)}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	writer.Write([]byte{cmd.ValidMin}) | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (this *CommandSwitchAccountFactory) Unmarshal(data []byte) (interface{}, error) { | 
					
						
							|  |  |  | 	cmd := new(protocol.CommandSwitchAccount) | 
					
						
							|  |  |  | 	if len(data) == 0 { | 
					
						
							|  |  |  | 		return nil, transport.ErrorCorruptedPacket | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	lenHost := int(data[0]) | 
					
						
							|  |  |  | 	if len(data) < lenHost+1 { | 
					
						
							|  |  |  | 		return nil, transport.ErrorCorruptedPacket | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if lenHost > 0 { | 
					
						
							|  |  |  | 		cmd.Host = v2net.ParseAddress(string(data[1 : 1+lenHost])) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	portStart := 1 + lenHost | 
					
						
							|  |  |  | 	if len(data) < portStart+2 { | 
					
						
							|  |  |  | 		return nil, transport.ErrorCorruptedPacket | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	cmd.Port = v2net.PortFromBytes(data[portStart : portStart+2]) | 
					
						
							|  |  |  | 	idStart := portStart + 2 | 
					
						
							|  |  |  | 	if len(data) < idStart+16 { | 
					
						
							|  |  |  | 		return nil, transport.ErrorCorruptedPacket | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	cmd.ID, _ = uuid.ParseBytes(data[idStart : idStart+16]) | 
					
						
							|  |  |  | 	alterIdStart := idStart + 16 | 
					
						
							|  |  |  | 	if len(data) < alterIdStart+2 { | 
					
						
							|  |  |  | 		return nil, transport.ErrorCorruptedPacket | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	cmd.AlterIds = serial.BytesLiteral(data[alterIdStart : alterIdStart+2]).Uint16() | 
					
						
							|  |  |  | 	levelStart := alterIdStart + 2 | 
					
						
							|  |  |  | 	if len(data) < levelStart+1 { | 
					
						
							|  |  |  | 		return nil, transport.ErrorCorruptedPacket | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	cmd.Level = protocol.UserLevel(data[levelStart]) | 
					
						
							|  |  |  | 	timeStart := levelStart + 1 | 
					
						
							|  |  |  | 	if len(data) < timeStart { | 
					
						
							|  |  |  | 		return nil, transport.ErrorCorruptedPacket | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	cmd.ValidMin = data[timeStart] | 
					
						
							|  |  |  | 	return cmd, nil | 
					
						
							|  |  |  | } |