route/vendor/github.com/GoRethink/gorethink/connection.go

382 lines
9.2 KiB
Go

package gorethink
import (
"crypto/tls"
"encoding/binary"
"encoding/json"
"net"
"sync"
"sync/atomic"
"time"
p "gopkg.in/gorethink/gorethink.v2/ql2"
)
const (
respHeaderLen = 12
defaultKeepAlivePeriod = time.Second * 30
)
// Response represents the raw response from a query, most of the time you
// should instead use a Cursor when reading from the database.
type Response struct {
Token int64
Type p.Response_ResponseType `json:"t"`
ErrorType p.Response_ErrorType `json:"e"`
Notes []p.Response_ResponseNote `json:"n"`
Responses []json.RawMessage `json:"r"`
Backtrace []interface{} `json:"b"`
Profile interface{} `json:"p"`
}
// Connection is a connection to a rethinkdb database. Connection is not thread
// safe and should only be accessed be a single goroutine
type Connection struct {
net.Conn
address string
opts *ConnectOpts
_ [4]byte
mu sync.Mutex
token int64
cursors map[int64]*Cursor
bad bool
closed bool
}
// NewConnection creates a new connection to the database server
func NewConnection(address string, opts *ConnectOpts) (*Connection, error) {
var err error
c := &Connection{
address: address,
opts: opts,
cursors: make(map[int64]*Cursor),
}
keepAlivePeriod := defaultKeepAlivePeriod
if opts.KeepAlivePeriod > 0 {
keepAlivePeriod = opts.KeepAlivePeriod
}
// Connect to Server
nd := net.Dialer{Timeout: c.opts.Timeout, KeepAlive: keepAlivePeriod}
if c.opts.TLSConfig == nil {
c.Conn, err = nd.Dial("tcp", address)
} else {
c.Conn, err = tls.DialWithDialer(&nd, "tcp", address, c.opts.TLSConfig)
}
if err != nil {
return nil, RQLConnectionError{rqlError(err.Error())}
}
// Send handshake
handshake, err := c.handshake(opts.HandshakeVersion)
if err != nil {
return nil, err
}
if err = handshake.Send(); err != nil {
return nil, err
}
return c, nil
}
// Close closes the underlying net.Conn
func (c *Connection) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
var err error
if !c.closed {
err = c.Conn.Close()
c.closed = true
c.cursors = make(map[int64]*Cursor)
}
return err
}
// Query sends a Query to the database, returning both the raw Response and a
// Cursor which should be used to view the query's response.
//
// This function is used internally by Run which should be used for most queries.
func (c *Connection) Query(q Query) (*Response, *Cursor, error) {
if c == nil {
return nil, nil, ErrConnectionClosed
}
c.mu.Lock()
if c.Conn == nil {
c.bad = true
c.mu.Unlock()
return nil, nil, ErrConnectionClosed
}
// Add token if query is a START/NOREPLY_WAIT
if q.Type == p.Query_START || q.Type == p.Query_NOREPLY_WAIT || q.Type == p.Query_SERVER_INFO {
q.Token = c.nextToken()
}
if q.Type == p.Query_START || q.Type == p.Query_NOREPLY_WAIT {
if c.opts.Database != "" {
var err error
q.Opts["db"], err = DB(c.opts.Database).Build()
if err != nil {
c.mu.Unlock()
return nil, nil, RQLDriverError{rqlError(err.Error())}
}
}
}
c.mu.Unlock()
err := c.sendQuery(q)
if err != nil {
return nil, nil, err
}
if noreply, ok := q.Opts["noreply"]; ok && noreply.(bool) {
return nil, nil, nil
}
for {
response, err := c.readResponse()
if err != nil {
return nil, nil, err
}
if response.Token == q.Token {
// If this was the requested response process and return
return c.processResponse(q, response)
} else if _, ok := c.cursors[response.Token]; ok {
// If the token is in the cursor cache then process the response
c.processResponse(q, response)
} else {
putResponse(response)
}
}
}
type ServerResponse struct {
ID string `gorethink:"id"`
Name string `gorethink:"name"`
}
// Server returns the server name and server UUID being used by a connection.
func (c *Connection) Server() (ServerResponse, error) {
var response ServerResponse
_, cur, err := c.Query(Query{
Type: p.Query_SERVER_INFO,
})
if err != nil {
return response, err
}
if err = cur.One(&response); err != nil {
return response, err
}
if err = cur.Close(); err != nil {
return response, err
}
return response, nil
}
// sendQuery marshals the Query and sends the JSON to the server.
func (c *Connection) sendQuery(q Query) error {
// Build query
b, err := json.Marshal(q.Build())
if err != nil {
return RQLDriverError{rqlError("Error building query")}
}
// Set timeout
if c.opts.WriteTimeout == 0 {
c.Conn.SetWriteDeadline(time.Time{})
} else {
c.Conn.SetWriteDeadline(time.Now().Add(c.opts.WriteTimeout))
}
// Send the JSON encoding of the query itself.
if err = c.writeQuery(q.Token, b); err != nil {
c.bad = true
return RQLConnectionError{rqlError(err.Error())}
}
return nil
}
// getToken generates the next query token, used to number requests and match
// responses with requests.
func (c *Connection) nextToken() int64 {
// requires c.token to be 64-bit aligned on ARM
return atomic.AddInt64(&c.token, 1)
}
// readResponse attempts to read a Response from the server, if no response
// could be read then an error is returned.
func (c *Connection) readResponse() (*Response, error) {
// Set timeout
if c.opts.ReadTimeout == 0 {
c.Conn.SetReadDeadline(time.Time{})
} else {
c.Conn.SetReadDeadline(time.Now().Add(c.opts.ReadTimeout))
}
// Read response header (token+length)
headerBuf := [respHeaderLen]byte{}
if _, err := c.read(headerBuf[:], respHeaderLen); err != nil {
c.bad = true
return nil, RQLConnectionError{rqlError(err.Error())}
}
responseToken := int64(binary.LittleEndian.Uint64(headerBuf[:8]))
messageLength := binary.LittleEndian.Uint32(headerBuf[8:])
// Read the JSON encoding of the Response itself.
b := make([]byte, int(messageLength))
if _, err := c.read(b, int(messageLength)); err != nil {
c.bad = true
return nil, RQLConnectionError{rqlError(err.Error())}
}
// Decode the response
var response = newCachedResponse()
if err := json.Unmarshal(b, response); err != nil {
c.bad = true
return nil, RQLDriverError{rqlError(err.Error())}
}
response.Token = responseToken
return response, nil
}
func (c *Connection) processResponse(q Query, response *Response) (*Response, *Cursor, error) {
switch response.Type {
case p.Response_CLIENT_ERROR:
return c.processErrorResponse(q, response, RQLClientError{rqlServerError{response, q.Term}})
case p.Response_COMPILE_ERROR:
return c.processErrorResponse(q, response, RQLCompileError{rqlServerError{response, q.Term}})
case p.Response_RUNTIME_ERROR:
return c.processErrorResponse(q, response, createRuntimeError(response.ErrorType, response, q.Term))
case p.Response_SUCCESS_ATOM, p.Response_SERVER_INFO:
return c.processAtomResponse(q, response)
case p.Response_SUCCESS_PARTIAL:
return c.processPartialResponse(q, response)
case p.Response_SUCCESS_SEQUENCE:
return c.processSequenceResponse(q, response)
case p.Response_WAIT_COMPLETE:
return c.processWaitResponse(q, response)
default:
putResponse(response)
return nil, nil, RQLDriverError{rqlError("Unexpected response type")}
}
}
func (c *Connection) processErrorResponse(q Query, response *Response, err error) (*Response, *Cursor, error) {
c.mu.Lock()
cursor := c.cursors[response.Token]
delete(c.cursors, response.Token)
c.mu.Unlock()
return response, cursor, err
}
func (c *Connection) processAtomResponse(q Query, response *Response) (*Response, *Cursor, error) {
// Create cursor
cursor := newCursor(c, "Cursor", response.Token, q.Term, q.Opts)
cursor.profile = response.Profile
cursor.extend(response)
return response, cursor, nil
}
func (c *Connection) processPartialResponse(q Query, response *Response) (*Response, *Cursor, error) {
cursorType := "Cursor"
if len(response.Notes) > 0 {
switch response.Notes[0] {
case p.Response_SEQUENCE_FEED:
cursorType = "Feed"
case p.Response_ATOM_FEED:
cursorType = "AtomFeed"
case p.Response_ORDER_BY_LIMIT_FEED:
cursorType = "OrderByLimitFeed"
case p.Response_UNIONED_FEED:
cursorType = "UnionedFeed"
case p.Response_INCLUDES_STATES:
cursorType = "IncludesFeed"
}
}
c.mu.Lock()
cursor, ok := c.cursors[response.Token]
if !ok {
// Create a new cursor if needed
cursor = newCursor(c, cursorType, response.Token, q.Term, q.Opts)
cursor.profile = response.Profile
c.cursors[response.Token] = cursor
}
c.mu.Unlock()
cursor.extend(response)
return response, cursor, nil
}
func (c *Connection) processSequenceResponse(q Query, response *Response) (*Response, *Cursor, error) {
c.mu.Lock()
cursor, ok := c.cursors[response.Token]
if !ok {
// Create a new cursor if needed
cursor = newCursor(c, "Cursor", response.Token, q.Term, q.Opts)
cursor.profile = response.Profile
}
delete(c.cursors, response.Token)
c.mu.Unlock()
cursor.extend(response)
return response, cursor, nil
}
func (c *Connection) processWaitResponse(q Query, response *Response) (*Response, *Cursor, error) {
c.mu.Lock()
delete(c.cursors, response.Token)
c.mu.Unlock()
return response, nil, nil
}
func (c *Connection) isBad() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.bad
}
var responseCache = make(chan *Response, 16)
func newCachedResponse() *Response {
select {
case r := <-responseCache:
return r
default:
return new(Response)
}
}
func putResponse(r *Response) {
*r = Response{} // zero it
select {
case responseCache <- r:
default:
}
}