382 lines
9.2 KiB
Go
382 lines
9.2 KiB
Go
|
package gorethink
|
||
|
|
||
|
import (
|
||
|
"crypto/tls"
|
||
|
"encoding/binary"
|
||
|
"encoding/json"
|
||
|
"net"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
"time"
|
||
|
|
||
|
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
respHeaderLen = 12
|
||
|
defaultKeepAlivePeriod = time.Second * 30
|
||
|
)
|
||
|
|
||
|
// Response represents the raw response from a query, most of the time you
|
||
|
// should instead use a Cursor when reading from the database.
|
||
|
type Response struct {
|
||
|
Token int64
|
||
|
Type p.Response_ResponseType `json:"t"`
|
||
|
ErrorType p.Response_ErrorType `json:"e"`
|
||
|
Notes []p.Response_ResponseNote `json:"n"`
|
||
|
Responses []json.RawMessage `json:"r"`
|
||
|
Backtrace []interface{} `json:"b"`
|
||
|
Profile interface{} `json:"p"`
|
||
|
}
|
||
|
|
||
|
// Connection is a connection to a rethinkdb database. Connection is not thread
|
||
|
// safe and should only be accessed be a single goroutine
|
||
|
type Connection struct {
|
||
|
net.Conn
|
||
|
|
||
|
address string
|
||
|
opts *ConnectOpts
|
||
|
|
||
|
_ [4]byte
|
||
|
mu sync.Mutex
|
||
|
token int64
|
||
|
cursors map[int64]*Cursor
|
||
|
bad bool
|
||
|
closed bool
|
||
|
}
|
||
|
|
||
|
// NewConnection creates a new connection to the database server
|
||
|
func NewConnection(address string, opts *ConnectOpts) (*Connection, error) {
|
||
|
var err error
|
||
|
c := &Connection{
|
||
|
address: address,
|
||
|
opts: opts,
|
||
|
cursors: make(map[int64]*Cursor),
|
||
|
}
|
||
|
|
||
|
keepAlivePeriod := defaultKeepAlivePeriod
|
||
|
if opts.KeepAlivePeriod > 0 {
|
||
|
keepAlivePeriod = opts.KeepAlivePeriod
|
||
|
}
|
||
|
|
||
|
// Connect to Server
|
||
|
nd := net.Dialer{Timeout: c.opts.Timeout, KeepAlive: keepAlivePeriod}
|
||
|
if c.opts.TLSConfig == nil {
|
||
|
c.Conn, err = nd.Dial("tcp", address)
|
||
|
} else {
|
||
|
c.Conn, err = tls.DialWithDialer(&nd, "tcp", address, c.opts.TLSConfig)
|
||
|
}
|
||
|
if err != nil {
|
||
|
return nil, RQLConnectionError{rqlError(err.Error())}
|
||
|
}
|
||
|
|
||
|
// Send handshake
|
||
|
handshake, err := c.handshake(opts.HandshakeVersion)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
if err = handshake.Send(); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return c, nil
|
||
|
}
|
||
|
|
||
|
// Close closes the underlying net.Conn
|
||
|
func (c *Connection) Close() error {
|
||
|
c.mu.Lock()
|
||
|
defer c.mu.Unlock()
|
||
|
|
||
|
var err error
|
||
|
|
||
|
if !c.closed {
|
||
|
err = c.Conn.Close()
|
||
|
c.closed = true
|
||
|
c.cursors = make(map[int64]*Cursor)
|
||
|
}
|
||
|
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Query sends a Query to the database, returning both the raw Response and a
|
||
|
// Cursor which should be used to view the query's response.
|
||
|
//
|
||
|
// This function is used internally by Run which should be used for most queries.
|
||
|
func (c *Connection) Query(q Query) (*Response, *Cursor, error) {
|
||
|
if c == nil {
|
||
|
return nil, nil, ErrConnectionClosed
|
||
|
}
|
||
|
c.mu.Lock()
|
||
|
if c.Conn == nil {
|
||
|
c.bad = true
|
||
|
c.mu.Unlock()
|
||
|
return nil, nil, ErrConnectionClosed
|
||
|
}
|
||
|
|
||
|
// Add token if query is a START/NOREPLY_WAIT
|
||
|
if q.Type == p.Query_START || q.Type == p.Query_NOREPLY_WAIT || q.Type == p.Query_SERVER_INFO {
|
||
|
q.Token = c.nextToken()
|
||
|
}
|
||
|
if q.Type == p.Query_START || q.Type == p.Query_NOREPLY_WAIT {
|
||
|
if c.opts.Database != "" {
|
||
|
var err error
|
||
|
q.Opts["db"], err = DB(c.opts.Database).Build()
|
||
|
if err != nil {
|
||
|
c.mu.Unlock()
|
||
|
return nil, nil, RQLDriverError{rqlError(err.Error())}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
c.mu.Unlock()
|
||
|
|
||
|
err := c.sendQuery(q)
|
||
|
if err != nil {
|
||
|
return nil, nil, err
|
||
|
}
|
||
|
|
||
|
if noreply, ok := q.Opts["noreply"]; ok && noreply.(bool) {
|
||
|
return nil, nil, nil
|
||
|
}
|
||
|
|
||
|
for {
|
||
|
response, err := c.readResponse()
|
||
|
if err != nil {
|
||
|
return nil, nil, err
|
||
|
}
|
||
|
|
||
|
if response.Token == q.Token {
|
||
|
// If this was the requested response process and return
|
||
|
return c.processResponse(q, response)
|
||
|
} else if _, ok := c.cursors[response.Token]; ok {
|
||
|
// If the token is in the cursor cache then process the response
|
||
|
c.processResponse(q, response)
|
||
|
} else {
|
||
|
putResponse(response)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type ServerResponse struct {
|
||
|
ID string `gorethink:"id"`
|
||
|
Name string `gorethink:"name"`
|
||
|
}
|
||
|
|
||
|
// Server returns the server name and server UUID being used by a connection.
|
||
|
func (c *Connection) Server() (ServerResponse, error) {
|
||
|
var response ServerResponse
|
||
|
|
||
|
_, cur, err := c.Query(Query{
|
||
|
Type: p.Query_SERVER_INFO,
|
||
|
})
|
||
|
if err != nil {
|
||
|
return response, err
|
||
|
}
|
||
|
|
||
|
if err = cur.One(&response); err != nil {
|
||
|
return response, err
|
||
|
}
|
||
|
|
||
|
if err = cur.Close(); err != nil {
|
||
|
return response, err
|
||
|
}
|
||
|
|
||
|
return response, nil
|
||
|
}
|
||
|
|
||
|
// sendQuery marshals the Query and sends the JSON to the server.
|
||
|
func (c *Connection) sendQuery(q Query) error {
|
||
|
// Build query
|
||
|
b, err := json.Marshal(q.Build())
|
||
|
if err != nil {
|
||
|
return RQLDriverError{rqlError("Error building query")}
|
||
|
}
|
||
|
|
||
|
// Set timeout
|
||
|
if c.opts.WriteTimeout == 0 {
|
||
|
c.Conn.SetWriteDeadline(time.Time{})
|
||
|
} else {
|
||
|
c.Conn.SetWriteDeadline(time.Now().Add(c.opts.WriteTimeout))
|
||
|
}
|
||
|
|
||
|
// Send the JSON encoding of the query itself.
|
||
|
if err = c.writeQuery(q.Token, b); err != nil {
|
||
|
c.bad = true
|
||
|
return RQLConnectionError{rqlError(err.Error())}
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// getToken generates the next query token, used to number requests and match
|
||
|
// responses with requests.
|
||
|
func (c *Connection) nextToken() int64 {
|
||
|
// requires c.token to be 64-bit aligned on ARM
|
||
|
return atomic.AddInt64(&c.token, 1)
|
||
|
}
|
||
|
|
||
|
// readResponse attempts to read a Response from the server, if no response
|
||
|
// could be read then an error is returned.
|
||
|
func (c *Connection) readResponse() (*Response, error) {
|
||
|
// Set timeout
|
||
|
if c.opts.ReadTimeout == 0 {
|
||
|
c.Conn.SetReadDeadline(time.Time{})
|
||
|
} else {
|
||
|
c.Conn.SetReadDeadline(time.Now().Add(c.opts.ReadTimeout))
|
||
|
}
|
||
|
|
||
|
// Read response header (token+length)
|
||
|
headerBuf := [respHeaderLen]byte{}
|
||
|
if _, err := c.read(headerBuf[:], respHeaderLen); err != nil {
|
||
|
c.bad = true
|
||
|
return nil, RQLConnectionError{rqlError(err.Error())}
|
||
|
}
|
||
|
|
||
|
responseToken := int64(binary.LittleEndian.Uint64(headerBuf[:8]))
|
||
|
messageLength := binary.LittleEndian.Uint32(headerBuf[8:])
|
||
|
|
||
|
// Read the JSON encoding of the Response itself.
|
||
|
b := make([]byte, int(messageLength))
|
||
|
|
||
|
if _, err := c.read(b, int(messageLength)); err != nil {
|
||
|
c.bad = true
|
||
|
return nil, RQLConnectionError{rqlError(err.Error())}
|
||
|
}
|
||
|
|
||
|
// Decode the response
|
||
|
var response = newCachedResponse()
|
||
|
if err := json.Unmarshal(b, response); err != nil {
|
||
|
c.bad = true
|
||
|
return nil, RQLDriverError{rqlError(err.Error())}
|
||
|
}
|
||
|
response.Token = responseToken
|
||
|
|
||
|
return response, nil
|
||
|
}
|
||
|
|
||
|
func (c *Connection) processResponse(q Query, response *Response) (*Response, *Cursor, error) {
|
||
|
switch response.Type {
|
||
|
case p.Response_CLIENT_ERROR:
|
||
|
return c.processErrorResponse(q, response, RQLClientError{rqlServerError{response, q.Term}})
|
||
|
case p.Response_COMPILE_ERROR:
|
||
|
return c.processErrorResponse(q, response, RQLCompileError{rqlServerError{response, q.Term}})
|
||
|
case p.Response_RUNTIME_ERROR:
|
||
|
return c.processErrorResponse(q, response, createRuntimeError(response.ErrorType, response, q.Term))
|
||
|
case p.Response_SUCCESS_ATOM, p.Response_SERVER_INFO:
|
||
|
return c.processAtomResponse(q, response)
|
||
|
case p.Response_SUCCESS_PARTIAL:
|
||
|
return c.processPartialResponse(q, response)
|
||
|
case p.Response_SUCCESS_SEQUENCE:
|
||
|
return c.processSequenceResponse(q, response)
|
||
|
case p.Response_WAIT_COMPLETE:
|
||
|
return c.processWaitResponse(q, response)
|
||
|
default:
|
||
|
putResponse(response)
|
||
|
return nil, nil, RQLDriverError{rqlError("Unexpected response type")}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (c *Connection) processErrorResponse(q Query, response *Response, err error) (*Response, *Cursor, error) {
|
||
|
c.mu.Lock()
|
||
|
cursor := c.cursors[response.Token]
|
||
|
|
||
|
delete(c.cursors, response.Token)
|
||
|
c.mu.Unlock()
|
||
|
|
||
|
return response, cursor, err
|
||
|
}
|
||
|
|
||
|
func (c *Connection) processAtomResponse(q Query, response *Response) (*Response, *Cursor, error) {
|
||
|
// Create cursor
|
||
|
cursor := newCursor(c, "Cursor", response.Token, q.Term, q.Opts)
|
||
|
cursor.profile = response.Profile
|
||
|
|
||
|
cursor.extend(response)
|
||
|
|
||
|
return response, cursor, nil
|
||
|
}
|
||
|
|
||
|
func (c *Connection) processPartialResponse(q Query, response *Response) (*Response, *Cursor, error) {
|
||
|
cursorType := "Cursor"
|
||
|
if len(response.Notes) > 0 {
|
||
|
switch response.Notes[0] {
|
||
|
case p.Response_SEQUENCE_FEED:
|
||
|
cursorType = "Feed"
|
||
|
case p.Response_ATOM_FEED:
|
||
|
cursorType = "AtomFeed"
|
||
|
case p.Response_ORDER_BY_LIMIT_FEED:
|
||
|
cursorType = "OrderByLimitFeed"
|
||
|
case p.Response_UNIONED_FEED:
|
||
|
cursorType = "UnionedFeed"
|
||
|
case p.Response_INCLUDES_STATES:
|
||
|
cursorType = "IncludesFeed"
|
||
|
}
|
||
|
}
|
||
|
|
||
|
c.mu.Lock()
|
||
|
cursor, ok := c.cursors[response.Token]
|
||
|
if !ok {
|
||
|
// Create a new cursor if needed
|
||
|
cursor = newCursor(c, cursorType, response.Token, q.Term, q.Opts)
|
||
|
cursor.profile = response.Profile
|
||
|
|
||
|
c.cursors[response.Token] = cursor
|
||
|
}
|
||
|
c.mu.Unlock()
|
||
|
|
||
|
cursor.extend(response)
|
||
|
|
||
|
return response, cursor, nil
|
||
|
}
|
||
|
|
||
|
func (c *Connection) processSequenceResponse(q Query, response *Response) (*Response, *Cursor, error) {
|
||
|
c.mu.Lock()
|
||
|
cursor, ok := c.cursors[response.Token]
|
||
|
if !ok {
|
||
|
// Create a new cursor if needed
|
||
|
cursor = newCursor(c, "Cursor", response.Token, q.Term, q.Opts)
|
||
|
cursor.profile = response.Profile
|
||
|
}
|
||
|
|
||
|
delete(c.cursors, response.Token)
|
||
|
c.mu.Unlock()
|
||
|
|
||
|
cursor.extend(response)
|
||
|
|
||
|
return response, cursor, nil
|
||
|
}
|
||
|
|
||
|
func (c *Connection) processWaitResponse(q Query, response *Response) (*Response, *Cursor, error) {
|
||
|
c.mu.Lock()
|
||
|
delete(c.cursors, response.Token)
|
||
|
c.mu.Unlock()
|
||
|
|
||
|
return response, nil, nil
|
||
|
}
|
||
|
|
||
|
func (c *Connection) isBad() bool {
|
||
|
c.mu.Lock()
|
||
|
defer c.mu.Unlock()
|
||
|
|
||
|
return c.bad
|
||
|
}
|
||
|
|
||
|
var responseCache = make(chan *Response, 16)
|
||
|
|
||
|
func newCachedResponse() *Response {
|
||
|
select {
|
||
|
case r := <-responseCache:
|
||
|
return r
|
||
|
default:
|
||
|
return new(Response)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func putResponse(r *Response) {
|
||
|
*r = Response{} // zero it
|
||
|
select {
|
||
|
case responseCache <- r:
|
||
|
default:
|
||
|
}
|
||
|
}
|