route/vendor/github.com/GoRethink/gorethink/connection_handshake.go

451 lines
11 KiB
Go

package gorethink
import (
"bufio"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"hash"
"io"
"strconv"
"strings"
"golang.org/x/crypto/pbkdf2"
p "gopkg.in/gorethink/gorethink.v2/ql2"
)
type HandshakeVersion int
const (
HandshakeV1_0 HandshakeVersion = iota
HandshakeV0_4
)
type connectionHandshake interface {
Send() error
}
func (c *Connection) handshake(version HandshakeVersion) (connectionHandshake, error) {
switch version {
case HandshakeV0_4:
return &connectionHandshakeV0_4{conn: c}, nil
case HandshakeV1_0:
return &connectionHandshakeV1_0{conn: c}, nil
default:
return nil, fmt.Errorf("Unrecognised handshake version")
}
}
type connectionHandshakeV0_4 struct {
conn *Connection
}
func (c *connectionHandshakeV0_4) Send() error {
// Send handshake request
if err := c.writeHandshakeReq(); err != nil {
c.conn.Close()
return RQLConnectionError{rqlError(err.Error())}
}
// Read handshake response
if err := c.readHandshakeSuccess(); err != nil {
c.conn.Close()
return RQLConnectionError{rqlError(err.Error())}
}
return nil
}
func (c *connectionHandshakeV0_4) writeHandshakeReq() error {
pos := 0
dataLen := 4 + 4 + len(c.conn.opts.AuthKey) + 4
data := make([]byte, dataLen)
// Send the protocol version to the server as a 4-byte little-endian-encoded integer
binary.LittleEndian.PutUint32(data[pos:], uint32(p.VersionDummy_V0_4))
pos += 4
// Send the length of the auth key to the server as a 4-byte little-endian-encoded integer
binary.LittleEndian.PutUint32(data[pos:], uint32(len(c.conn.opts.AuthKey)))
pos += 4
// Send the auth key as an ASCII string
if len(c.conn.opts.AuthKey) > 0 {
pos += copy(data[pos:], c.conn.opts.AuthKey)
}
// Send the protocol type as a 4-byte little-endian-encoded integer
binary.LittleEndian.PutUint32(data[pos:], uint32(p.VersionDummy_JSON))
pos += 4
return c.conn.writeData(data)
}
func (c *connectionHandshakeV0_4) readHandshakeSuccess() error {
reader := bufio.NewReader(c.conn.Conn)
line, err := reader.ReadBytes('\x00')
if err != nil {
if err == io.EOF {
return fmt.Errorf("Unexpected EOF: %s", string(line))
}
return err
}
// convert to string and remove trailing NUL byte
response := string(line[:len(line)-1])
if response != "SUCCESS" {
response = strings.TrimSpace(response)
// we failed authorization or something else terrible happened
return RQLDriverError{rqlError(fmt.Sprintf("Server dropped connection with message: \"%s\"", response))}
}
return nil
}
const (
handshakeV1_0_protocolVersionNumber = 0
handshakeV1_0_authenticationMethod = "SCRAM-SHA-256"
)
type connectionHandshakeV1_0 struct {
conn *Connection
reader *bufio.Reader
authMsg string
}
func (c *connectionHandshakeV1_0) Send() error {
c.reader = bufio.NewReader(c.conn.Conn)
// Generate client nonce
clientNonce, err := c.generateNonce()
if err != nil {
c.conn.Close()
return RQLDriverError{rqlError(fmt.Sprintf("Failed to generate client nonce: %s", err))}
}
// Send client first message
if err := c.writeFirstMessage(clientNonce); err != nil {
c.conn.Close()
return err
}
// Read status
if err := c.checkServerVersions(); err != nil {
c.conn.Close()
return err
}
// Read server first message
i, salt, serverNonce, err := c.readFirstMessage()
if err != nil {
c.conn.Close()
return err
}
// Check server nonce
if !strings.HasPrefix(serverNonce, clientNonce) {
return RQLAuthError{RQLDriverError{rqlError("Invalid nonce from server")}}
}
// Generate proof
saltedPass := c.saltPassword(i, salt)
clientProof := c.calculateProof(saltedPass, clientNonce, serverNonce)
serverSignature := c.serverSignature(saltedPass)
// Send client final message
if err := c.writeFinalMessage(serverNonce, clientProof); err != nil {
c.conn.Close()
return err
}
// Read server final message
if err := c.readFinalMessage(serverSignature); err != nil {
c.conn.Close()
return err
}
return nil
}
func (c *connectionHandshakeV1_0) writeFirstMessage(clientNonce string) error {
// Default username to admin if not set
username := "admin"
if c.conn.opts.Username != "" {
username = c.conn.opts.Username
}
c.authMsg = fmt.Sprintf("n=%s,r=%s", username, clientNonce)
msg := fmt.Sprintf(
`{"protocol_version": %d,"authentication": "n,,%s","authentication_method": "%s"}`,
handshakeV1_0_protocolVersionNumber, c.authMsg, handshakeV1_0_authenticationMethod,
)
pos := 0
dataLen := 4 + len(msg) + 1
data := make([]byte, dataLen)
// Send the protocol version to the server as a 4-byte little-endian-encoded integer
binary.LittleEndian.PutUint32(data[pos:], uint32(p.VersionDummy_V1_0))
pos += 4
// Send the auth message as an ASCII string
pos += copy(data[pos:], msg)
// Add null terminating byte
data[pos] = '\x00'
return c.writeData(data)
}
func (c *connectionHandshakeV1_0) checkServerVersions() error {
b, err := c.readResponse()
if err != nil {
return err
}
// Read status
type versionsResponse struct {
Success bool `json:"success"`
MinProtocolVersion int `json:"min_protocol_version"`
MaxProtocolVersion int `json:"max_protocol_version"`
ServerVersion string `json:"server_version"`
ErrorCode int `json:"error_code"`
Error string `json:"error"`
}
var rsp *versionsResponse
statusStr := string(b)
if err := json.Unmarshal(b, &rsp); err != nil {
if strings.HasPrefix(statusStr, "ERROR: ") {
statusStr = strings.TrimPrefix(statusStr, "ERROR: ")
return RQLConnectionError{rqlError(statusStr)}
}
return RQLDriverError{rqlError(fmt.Sprintf("Error reading versions: %s", err))}
}
if !rsp.Success {
return c.handshakeError(rsp.ErrorCode, rsp.Error)
}
if rsp.MinProtocolVersion > handshakeV1_0_protocolVersionNumber ||
rsp.MaxProtocolVersion < handshakeV1_0_protocolVersionNumber {
return RQLDriverError{rqlError(
fmt.Sprintf(
"Unsupported protocol version %d, expected between %d and %d.",
handshakeV1_0_protocolVersionNumber,
rsp.MinProtocolVersion,
rsp.MaxProtocolVersion,
),
)}
}
return nil
}
func (c *connectionHandshakeV1_0) readFirstMessage() (i int64, salt []byte, serverNonce string, err error) {
b, err2 := c.readResponse()
if err2 != nil {
err = err2
return
}
// Read server message
type firstMessageResponse struct {
Success bool `json:"success"`
Authentication string `json:"authentication"`
ErrorCode int `json:"error_code"`
Error string `json:"error"`
}
var rsp *firstMessageResponse
if err2 := json.Unmarshal(b, &rsp); err2 != nil {
err = RQLDriverError{rqlError(fmt.Sprintf("Error parsing auth response: %s", err2))}
return
}
if !rsp.Success {
err = c.handshakeError(rsp.ErrorCode, rsp.Error)
return
}
c.authMsg += ","
c.authMsg += rsp.Authentication
// Parse authentication field
auth := map[string]string{}
parts := strings.Split(rsp.Authentication, ",")
for _, part := range parts {
i := strings.Index(part, "=")
if i != -1 {
auth[part[:i]] = part[i+1:]
}
}
// Extract return values
if v, ok := auth["i"]; ok {
i, err = strconv.ParseInt(v, 10, 64)
if err != nil {
return
}
}
if v, ok := auth["s"]; ok {
salt, err = base64.StdEncoding.DecodeString(v)
if err != nil {
return
}
}
if v, ok := auth["r"]; ok {
serverNonce = v
}
return
}
func (c *connectionHandshakeV1_0) writeFinalMessage(serverNonce, clientProof string) error {
authMsg := "c=biws,r="
authMsg += serverNonce
authMsg += ",p="
authMsg += clientProof
msg := fmt.Sprintf(`{"authentication": "%s"}`, authMsg)
pos := 0
dataLen := len(msg) + 1
data := make([]byte, dataLen)
// Send the auth message as an ASCII string
pos += copy(data[pos:], msg)
// Add null terminating byte
data[pos] = '\x00'
return c.writeData(data)
}
func (c *connectionHandshakeV1_0) readFinalMessage(serverSignature string) error {
b, err := c.readResponse()
if err != nil {
return err
}
// Read server message
type finalMessageResponse struct {
Success bool `json:"success"`
Authentication string `json:"authentication"`
ErrorCode int `json:"error_code"`
Error string `json:"error"`
}
var rsp *finalMessageResponse
if err := json.Unmarshal(b, &rsp); err != nil {
return RQLDriverError{rqlError(fmt.Sprintf("Error parsing auth response: %s", err))}
}
if !rsp.Success {
return c.handshakeError(rsp.ErrorCode, rsp.Error)
}
// Parse authentication field
auth := map[string]string{}
parts := strings.Split(rsp.Authentication, ",")
for _, part := range parts {
i := strings.Index(part, "=")
if i != -1 {
auth[part[:i]] = part[i+1:]
}
}
// Validate server response
if serverSignature != auth["v"] {
return RQLAuthError{RQLDriverError{rqlError("Invalid server signature")}}
}
return nil
}
func (c *connectionHandshakeV1_0) writeData(data []byte) error {
if err := c.conn.writeData(data); err != nil {
return RQLConnectionError{rqlError(err.Error())}
}
return nil
}
func (c *connectionHandshakeV1_0) readResponse() ([]byte, error) {
line, err := c.reader.ReadBytes('\x00')
if err != nil {
if err == io.EOF {
return nil, RQLConnectionError{rqlError(fmt.Sprintf("Unexpected EOF: %s", string(line)))}
}
return nil, RQLConnectionError{rqlError(err.Error())}
}
// Strip null byte and return
return line[:len(line)-1], nil
}
func (c *connectionHandshakeV1_0) generateNonce() (string, error) {
const nonceSize = 24
b := make([]byte, nonceSize)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(b), nil
}
func (c *connectionHandshakeV1_0) saltPassword(iter int64, salt []byte) []byte {
pass := []byte(c.conn.opts.Password)
return pbkdf2.Key(pass, salt, int(iter), sha256.Size, sha256.New)
}
func (c *connectionHandshakeV1_0) calculateProof(saltedPass []byte, clientNonce, serverNonce string) string {
// Generate proof
c.authMsg += ",c=biws,r=" + serverNonce
mac := hmac.New(c.hashFunc(), saltedPass)
mac.Write([]byte("Client Key"))
clientKey := mac.Sum(nil)
hash := c.hashFunc()()
hash.Write(clientKey)
storedKey := hash.Sum(nil)
mac = hmac.New(c.hashFunc(), storedKey)
mac.Write([]byte(c.authMsg))
clientSignature := mac.Sum(nil)
clientProof := make([]byte, len(clientKey))
for i, _ := range clientKey {
clientProof[i] = clientKey[i] ^ clientSignature[i]
}
return base64.StdEncoding.EncodeToString(clientProof)
}
func (c *connectionHandshakeV1_0) serverSignature(saltedPass []byte) string {
mac := hmac.New(c.hashFunc(), saltedPass)
mac.Write([]byte("Server Key"))
serverKey := mac.Sum(nil)
mac = hmac.New(c.hashFunc(), serverKey)
mac.Write([]byte(c.authMsg))
serverSignature := mac.Sum(nil)
return base64.StdEncoding.EncodeToString(serverSignature)
}
func (c *connectionHandshakeV1_0) handshakeError(code int, message string) error {
if code >= 10 || code <= 20 {
return RQLAuthError{RQLDriverError{rqlError(message)}}
}
return RQLDriverError{rqlError(message)}
}
func (c *connectionHandshakeV1_0) hashFunc() func() hash.Hash {
return sha256.New
}