vendor: gorqlite

This commit is contained in:
Cadey Ratio 2017-08-29 10:23:26 -07:00
parent 464d947714
commit 85d874b020
No known key found for this signature in database
GPG Key ID: D607EE27C2E7F89A
8 changed files with 1543 additions and 0 deletions

View File

@ -39,3 +39,4 @@ f5079bd7f6f74e23c4d65efa0f4ce14cbd6a3c0f golang.org/x/net/context
f51c12702a4d776e4c1fa9b0fabab841babae631 golang.org/x/time/rate
ae77be60afb1dcacde03767a8c37337fad28ac14 github.com/kardianos/osext
6a18b51d929caddbe795e1609195dee1d1cc729e github.com/justinian/dice
40a5e952d22c3ef520c6ab7bdb9b1a010ec9a524 git.xeserv.us/xena/gorqlite

View File

@ -0,0 +1,203 @@
package gorqlite
/*
this file has low level stuff:
rqliteApiGet()
rqliteApiPost()
There is some code duplication between those and they should
probably be combined into one function.
nothing public here.
*/
import "bytes"
import "encoding/json"
import "errors"
import "fmt"
import "io/ioutil"
import "net/http"
import "time"
/* *****************************************************************
method: rqliteApiGet() - for api_STATUS
- lowest level interface - does not do any JSON unmarshaling
- handles retries
- handles timeouts
* *****************************************************************/
func (conn *Connection) rqliteApiGet(apiOp apiOperation) ([]byte, error) {
var responseBody []byte
trace("%s: rqliteApiGet() called", conn.ID)
// only api_STATUS now - maybe someday BACKUP
if apiOp != api_STATUS {
return responseBody, errors.New("rqliteApiGet() called for invalid api operation")
}
// just to be safe, check this
peersToTry := conn.cluster.makePeerList()
if len(peersToTry) < 1 {
return responseBody, errors.New("I don't have any cluster info")
}
trace("%s: I have a peer list %d peers long", conn.ID, len(peersToTry))
// failure log is used so that if all peers fail, we can say something
// about why each failed
failureLog := make([]string, 0)
PeerLoop:
for peerNum, peerToTry := range peersToTry {
trace("%s: attemping to contact peer %d", conn.ID, peerNum)
// docs say default GET policy is up to 10 follows automatically
url := conn.assembleURL(api_STATUS, peerToTry)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
trace("%s: got error '%s' doing http.NewRequest", conn.ID, err.Error())
failureLog = append(failureLog, fmt.Sprintf("%s failed due to %s", url, err.Error()))
continue PeerLoop
}
trace("%s: http.NewRequest() OK")
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
client.Timeout = time.Duration(conn.timeout) * time.Second
response, err := client.Do(req)
if err != nil {
trace("%s: got error '%s' doing client.Do", conn.ID, err.Error())
failureLog = append(failureLog, fmt.Sprintf("%s failed due to %s", url, err.Error()))
continue PeerLoop
}
defer response.Body.Close()
trace("%s: client.Do() OK")
responseBody, err := ioutil.ReadAll(response.Body)
if err != nil {
trace("%s: got error '%s' doing ioutil.ReadAll", conn.ID, err.Error())
failureLog = append(failureLog, fmt.Sprintf("%s failed due to %s", url, err.Error()))
continue PeerLoop
}
trace("%s: ioutil.ReadAll() OK")
if response.Status != "200 OK" {
trace("%s: got code %s", conn.ID, response.Status)
failureLog = append(failureLog, fmt.Sprintf("%s failed, got: %s", url, response.Status))
continue PeerLoop
}
// if we got here, we succeeded
trace("%s: api call OK, returning", conn.ID)
return responseBody, nil
}
// if we got here, all peers failed. Let's build a verbose error message
var stringBuffer bytes.Buffer
stringBuffer.WriteString("tried all peers unsuccessfully. here are the results:\n")
for n, v := range failureLog {
stringBuffer.WriteString(fmt.Sprintf(" peer #%d: %s\n", n, v))
}
return responseBody, errors.New(stringBuffer.String())
}
/* *****************************************************************
method: rqliteApiPost() - for api_QUERY and api_WRITE
- lowest level interface - does not do any JSON unmarshaling
- handles 301s, etc.
- handles retries
- handles timeouts
it is called with an apiOperation type because the URL it will use varies
depending on the API operation type (api_QUERY vs. api_WRITE)
* *****************************************************************/
func (conn *Connection) rqliteApiPost(apiOp apiOperation, sqlStatements []string) ([]byte, error) {
var responseBody []byte
switch apiOp {
case api_QUERY:
trace("%s: rqliteApiGet() post called for a QUERY of %d statements", conn.ID, len(sqlStatements))
case api_WRITE:
trace("%s: rqliteApiGet() post called for a QUERY of %d statements", conn.ID, len(sqlStatements))
default:
return responseBody, errors.New("weird! called for an invalid apiOperation in rqliteApiPost()")
}
// jsonify the statements. not really needed in the
// case of api_STATUS but doesn't hurt
jStatements, err := json.Marshal(sqlStatements)
if err != nil {
return nil, err
}
// just to be safe, check this
peersToTry := conn.cluster.makePeerList()
if len(peersToTry) < 1 {
return responseBody, errors.New("I don't have any cluster info")
}
// failure log is used so that if all peers fail, we can say something
// about why each failed
failureLog := make([]string, 0)
PeerLoop:
for peerNum, peer := range peersToTry {
trace("%s: trying peer #%d", conn.ID, peerNum)
// we're doing a post, and the RFCs say that if you get a 301, it's not
// automatically followed, so we have to do that ourselves
responseStatus := "Haven't Tried Yet"
var url string
for responseStatus == "Haven't Tried Yet" || responseStatus == "301 Moved Permanently" {
url = conn.assembleURL(apiOp, peer)
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jStatements))
if err != nil {
trace("%s: got error '%s' doing http.NewRequest", conn.ID, err.Error())
failureLog = append(failureLog, fmt.Sprintf("%s failed due to %s", url, err.Error()))
continue PeerLoop
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
response, err := client.Do(req)
if err != nil {
trace("%s: got error '%s' doing client.Do", conn.ID, err.Error())
failureLog = append(failureLog, fmt.Sprintf("%s failed due to %s", url, err.Error()))
continue PeerLoop
}
defer response.Body.Close()
responseBody, err = ioutil.ReadAll(response.Body)
if err != nil {
trace("%s: got error '%s' doing ioutil.ReadAll", conn.ID, err.Error())
failureLog = append(failureLog, fmt.Sprintf("%s failed due to %s", url, err.Error()))
continue PeerLoop
}
responseStatus = response.Status
if responseStatus == "301 Moved Permanently" {
v := response.Header["Location"]
failureLog = append(failureLog, fmt.Sprintf("%s redirected me to %s", url, v[0]))
url = v[0]
continue PeerLoop
} else if responseStatus == "200 OK" {
trace("%s: api call OK, returning", conn.ID)
return responseBody, nil
} else {
trace("%s: got error in responseStatus: %s", conn.ID, responseStatus)
failureLog = append(failureLog, fmt.Sprintf("%s failed, got: %s", url, response.Status))
continue PeerLoop
}
}
}
// if we got here, all peers failed. Let's build a verbose error message
var stringBuffer bytes.Buffer
stringBuffer.WriteString("tried all peers unsuccessfully. here are the results:\n")
for n, v := range failureLog {
stringBuffer.WriteString(fmt.Sprintf(" peer #%d: %s\n", n, v))
}
return responseBody, errors.New(stringBuffer.String())
}

View File

@ -0,0 +1,223 @@
package gorqlite
/*
this file holds most of the cluster-related stuff:
types:
peer
rqliteCluster
Connection methods:
assembleURL (from a peer)
updateClusterInfo (does the full cluster discovery via status)
*/
/* *****************************************************************
imports
* *****************************************************************/
import "bytes"
import "encoding/json"
import "errors"
import "fmt"
import "strings"
//import "os"
//import "reflect"
/* *****************************************************************
type: peer
this is an internal type to abstact peer info.
note that hostname is sometimes used for "has this struct been
inialized" checks.
* *****************************************************************/
type peer struct {
hostname string // hostname or "localhost"
port string // "4001" or port, only ever used as a string
}
func (p *peer) String() string {
return fmt.Sprintf("%s:%s", p.hostname, p.port)
}
/* *****************************************************************
type: rqliteCluster
internal type that abstracts the full cluster state (leader, peers)
* *****************************************************************/
type rqliteCluster struct {
leader peer
otherPeers []peer
conn *Connection
}
/* *****************************************************************
method: rqliteCluster.makePeerList()
in the api calls, we'll want to try the leader first, then the other
peers. to make looping easy, this function returns a list of peers
in the order the try them: leader, other peer, other peer, etc.
* *****************************************************************/
func (rc *rqliteCluster) makePeerList() []peer {
trace("%s: makePeerList() called", rc.conn.ID)
var peerList []peer
peerList = append(peerList, rc.leader)
for _, p := range rc.otherPeers {
peerList = append(peerList, p)
}
trace("%s: makePeerList() returning this list:", rc.conn.ID)
for n, v := range peerList {
trace("%s: makePeerList() peer %d -> %s", rc.conn.ID, n, v.hostname+":"+v.port)
}
return peerList
}
/* *****************************************************************
method: Connection.assembleURL()
tell it what peer to talk to and what kind of API operation you're
making, and it will return the full URL, from start to finish.
e.g.:
https://mary:secret2@server1.example.com:1234/db/query?transaction&level=strong
note: this func needs to live at the Connection level because the
Connection holds the username, password, consistencyLevel, etc.
* *****************************************************************/
func (conn *Connection) assembleURL(apiOp apiOperation, p peer) string {
var stringBuffer bytes.Buffer
if conn.wantsHTTPS == true {
stringBuffer.WriteString("https")
} else {
stringBuffer.WriteString("http")
}
stringBuffer.WriteString("://")
if conn.username != "" && conn.password != "" {
stringBuffer.WriteString(conn.username)
stringBuffer.WriteString(":")
stringBuffer.WriteString(conn.password)
stringBuffer.WriteString("@")
}
stringBuffer.WriteString(p.hostname)
stringBuffer.WriteString(":")
stringBuffer.WriteString(p.port)
switch apiOp {
case api_STATUS:
stringBuffer.WriteString("/status")
case api_QUERY:
stringBuffer.WriteString("/db/query")
case api_WRITE:
stringBuffer.WriteString("/db/execute")
}
if apiOp == api_QUERY || apiOp == api_WRITE {
stringBuffer.WriteString("?timings&transaction&level=")
stringBuffer.WriteString(consistencyLevelNames[conn.consistencyLevel])
}
switch apiOp {
case api_QUERY:
trace("%s: assembled URL for an api_QUERY: %s", conn.ID, stringBuffer.String())
case api_STATUS:
trace("%s: assembled URL for an api_STATUS: %s", conn.ID, stringBuffer.String())
case api_WRITE:
trace("%s: assembled URL for an api_WRITE: %s", conn.ID, stringBuffer.String())
}
return stringBuffer.String()
}
/* *****************************************************************
method: Connection.updateClusterInfo()
upon invocation, updateClusterInfo() completely erases and refreshes
the Connection's cluster info, replacing its rqliteCluster object
with current info.
the web heavy lifting (retrying, etc.) is done in rqliteApiGet()
* *****************************************************************/
func (conn *Connection) updateClusterInfo() error {
trace("%s: updateClusterInfo() called", conn.ID)
// start with a fresh new cluster
var rc rqliteCluster
rc.conn = conn
responseBody, err := conn.rqliteApiGet(api_STATUS)
if err != nil {
return err
}
trace("%s: updateClusterInfo() back from api call OK", conn.ID)
sections := make(map[string]interface{})
err = json.Unmarshal(responseBody, &sections)
if err != nil {
return err
}
sMap := sections["store"].(map[string]interface{})
leaderRaftAddr := sMap["leader"].(string)
trace("%s: leader from store section is %s", conn.ID, leaderRaftAddr)
// leader in this case is the RAFT address
// we want the HTTP address, so we'll use this as
// a key as we sift through APIPeers
meta := sMap["meta"].(map[string]interface{})
apiPeers := meta["APIPeers"].(map[string]interface{})
for raftAddr, httpAddr := range apiPeers {
trace("%s: examining httpAddr %s", conn.ID, httpAddr)
/* httpAddr are usually hostname:port */
var p peer
parts := strings.Split(httpAddr.(string), ":")
p.hostname = parts[0]
p.port = parts[1]
// so is this the leader?
if leaderRaftAddr == raftAddr {
trace("%s: found leader at %s", conn.ID, httpAddr)
rc.leader = p
} else {
rc.otherPeers = append(rc.otherPeers, p)
}
}
if rc.leader.hostname == "" {
return errors.New("could not determine leader from API status call")
}
// dump to trace
trace("%s: here is my cluster config:", conn.ID)
trace("%s: leader : %s", conn.ID, rc.leader.String())
for n, v := range rc.otherPeers {
trace("%s: otherPeer #%d: %s", conn.ID, n, v.String())
}
// now make it official
conn.cluster = rc
return nil
}

View File

@ -0,0 +1,300 @@
package gorqlite
/*
this file contains some high-level Connection-oriented stuff
*/
/* *****************************************************************
imports
* *****************************************************************/
import "errors"
import "fmt"
import "io"
import "net"
import nurl "net/url"
import "strings"
var errClosed = errors.New("gorqlite: connection is closed")
var traceOut io.Writer
// defaults to false. This is used in trace() to quickly
// return if tracing is off, so that we don't do a perhaps
// expensive Sprintf() call only to send it to Discard
var wantsTrace bool
/* *****************************************************************
type: Connection
* *****************************************************************/
/*
The connection abstraction. Note that since rqlite is stateless,
there really is no "connection". However, this type holds
information such as the current leader, peers, connection
string to build URLs, etc.
Connections are assigned a "connection ID" which is a pseudo-UUID
for connection identification in trace output only. This helps
sort out what's going on if you have multiple connections going
at once. It's generated using a non-standards-or-anything-else-compliant
function that uses crypto/rand to generate 16 random bytes.
Note that the Connection objection holds info on all peers, gathered
at time of Open() from the node specified.
*/
type Connection struct {
cluster rqliteCluster
/*
name type default
*/
username string // username or ""
password string // username or ""
consistencyLevel consistencyLevel // WEAK
wantsHTTPS bool // false unless connection URL is https
// variables below this line need to be initialized in Open()
timeout int // 10
hasBeenClosed bool // false
ID string // generated in init()
}
/* *****************************************************************
method: Connection.Close()
* *****************************************************************/
func (conn *Connection) Close() {
conn.hasBeenClosed = true
trace("%s: %s", conn.ID, "closing connection")
}
/* *****************************************************************
method: Connection.ConsistencyLevel()
* *****************************************************************/
func (conn *Connection) ConsistencyLevel() (string, error) {
if conn.hasBeenClosed {
return "", errClosed
}
return consistencyLevelNames[conn.consistencyLevel], nil
}
/* *****************************************************************
method: Connection.Leader()
* *****************************************************************/
func (conn *Connection) Leader() (string, error) {
if conn.hasBeenClosed {
return "", errClosed
}
trace("%s: Leader(), calling updateClusterInfo()", conn.ID)
err := conn.updateClusterInfo()
if err != nil {
trace("%s: Leader() got error from updateClusterInfo(): %s", conn.ID, err.Error())
return "", err
} else {
trace("%s: Leader(), updateClusterInfo() OK", conn.ID)
}
return conn.cluster.leader.String(), nil
}
/* *****************************************************************
method: Connection.Peers()
* *****************************************************************/
func (conn *Connection) Peers() ([]string, error) {
if conn.hasBeenClosed {
var ans []string
return ans, errClosed
}
plist := make([]string, 0)
trace("%s: Peers(), calling updateClusterInfo()", conn.ID)
err := conn.updateClusterInfo()
if err != nil {
trace("%s: Peers() got error from updateClusterInfo(): %s", conn.ID, err.Error())
return plist, err
} else {
trace("%s: Peers(), updateClusterInfo() OK", conn.ID)
}
plist = append(plist, conn.cluster.leader.String())
for _, p := range conn.cluster.otherPeers {
plist = append(plist, p.String())
}
return plist, nil
}
/* *****************************************************************
method: Connection.SetConsistencyLevel()
* *****************************************************************/
func (conn *Connection) SetConsistencyLevel(levelDesired string) error {
if conn.hasBeenClosed {
return errClosed
}
_, ok := consistencyLevels[levelDesired]
if ok {
conn.consistencyLevel = consistencyLevels[levelDesired]
return nil
}
return errors.New(fmt.Sprintf("unknown consistency level: %s", levelDesired))
}
/* *****************************************************************
method: Connection.initConnection()
* *****************************************************************/
/*
initConnection takes the initial connection URL specified by
the user, and parses it into a peer. This peer is assumed to
be the leader. The next thing Open() does is updateClusterInfo()
so the truth will be revealed soon enough.
initConnection() does not talk to rqlite. It only parses the
connection URL and prepares the new connection for work.
URL format:
http[s]://${USER}:${PASSWORD}@${HOSTNAME}:${PORT}/db?[OPTIONS]
Examples:
https://mary:secret2@localhost:4001/db
https://mary:secret2@server1.example.com:4001/db?level=none
https://mary:secret2@server2.example.com:4001/db?level=weak
https://mary:secret2@localhost:2265/db?level=strong
to use default connection to localhost:4001 with no auth:
http://
https://
guaranteed map fields - will be set to "" if not specified
field name default if not specified
username ""
password ""
hostname "localhost"
port "4001"
consistencyLevel "weak"
*/
func (conn *Connection) initConnection(url string) error {
// do some sanity checks. You know users.
if len(url) < 7 {
return errors.New("url specified is impossibly short")
}
if strings.HasPrefix(url, "http") == false {
return errors.New("url does not start with 'http'")
}
u, err := nurl.Parse(url)
if err != nil {
return err
}
trace("%s: net.url.Parse() OK", conn.ID)
if u.Scheme == "https" {
conn.wantsHTTPS = true
}
// specs say Username() is always populated even if empty
if u.User == nil {
conn.username = ""
conn.password = ""
} else {
// guaranteed, but could be empty which is ok
conn.username = u.User.Username()
// not guaranteed, so test if set
pass, isset := u.User.Password()
if isset {
conn.password = pass
} else {
conn.password = ""
}
}
if u.Host == "" {
conn.cluster.leader.hostname = "localhost"
} else {
conn.cluster.leader.hostname = u.Host
}
if u.Host == "" {
conn.cluster.leader.hostname = "localhost"
conn.cluster.leader.port = "4001"
} else {
// SplitHostPort() should only return an error if there is no host port.
// I think.
h, p, err := net.SplitHostPort(u.Host)
if err != nil {
conn.cluster.leader.hostname = u.Host
} else {
conn.cluster.leader.hostname = h
conn.cluster.leader.port = p
}
}
/*
at the moment, the only allowed query is "level=" with
the desired consistency level
*/
// default
conn.consistencyLevel = cl_WEAK
if u.RawQuery != "" {
if u.RawQuery == "level=weak" {
// that's ok but nothing to do
} else if u.RawQuery == "level=strong" {
conn.consistencyLevel = cl_STRONG
} else if u.RawQuery == "level=none" { // the fools!
conn.consistencyLevel = cl_NONE
} else {
return errors.New("don't know what to do with this query: " + u.RawQuery)
}
}
trace("%s: parseDefaultPeer() is done:", conn.ID)
if conn.wantsHTTPS == true {
trace("%s: %s -> %s", conn.ID, "wants https?", "yes")
} else {
trace("%s: %s -> %s", conn.ID, "wants https?", "no")
}
trace("%s: %s -> %s", conn.ID, "username", conn.username)
trace("%s: %s -> %s", conn.ID, "password", conn.password)
trace("%s: %s -> %s", conn.ID, "hostname", conn.cluster.leader.hostname)
trace("%s: %s -> %s", conn.ID, "port", conn.cluster.leader.port)
trace("%s: %s -> %s", conn.ID, "consistencyLevel", consistencyLevelNames[conn.consistencyLevel])
conn.cluster.conn = conn
return nil
}

View File

@ -0,0 +1,189 @@
/*
gorqlite
A golang database/sql driver for rqlite, the distributed consistent sqlite.
Copyright (c)2016 andrew fabbro (andrew@fabbro.org)
See LICENSE.md for license. tl;dr: MIT. Conveniently, the same licese as rqlite.
Project home page: https://github.com/raindo308/gorqlite
Learn more about rqlite at: https://github.com/rqlite/rqlite
*/
package gorqlite
/*
this file contains package-level stuff:
consts
init()
Open, TraceOn(), TraceOff()
*/
import "crypto/rand"
import "fmt"
import "io"
import "io/ioutil"
import "strings"
/* *****************************************************************
const
* *****************************************************************/
type consistencyLevel int
const (
cl_NONE consistencyLevel = iota
cl_WEAK
cl_STRONG
)
// used in several places, actually
var consistencyLevelNames map[consistencyLevel]string
var consistencyLevels map[string]consistencyLevel
type apiOperation int
const (
api_QUERY apiOperation = iota
api_STATUS
api_WRITE
)
/* *****************************************************************
init()
* *****************************************************************/
func init() {
traceOut = ioutil.Discard
consistencyLevelNames = make(map[consistencyLevel]string)
consistencyLevelNames[cl_NONE] = "none"
consistencyLevelNames[cl_WEAK] = "weak"
consistencyLevelNames[cl_STRONG] = "strong"
consistencyLevels = make(map[string]consistencyLevel)
consistencyLevels["none"] = cl_NONE
consistencyLevels["weak"] = cl_WEAK
consistencyLevels["strong"] = cl_STRONG
}
/* *****************************************************************
Open() creates and returns a "connection" to rqlite.
Since rqlite is stateless, there is no actual connection. Open() creates and initializes a gorqlite Connection type, which represents various config information.
The URL should be in a form like this:
http://localhost:4001
http:// default, no auth, localhost:4001
https:// default, no auth, localhost:4001, using https
http://localhost:1234
http://mary:secret2@localhost:1234
https://mary:secret2@somewhere.example.com:1234
https://mary:secret2@somewhere.example.com // will use 4001
* *****************************************************************/
func Open(connURL string) (Connection, error) {
var conn Connection
// generate our uuid for trace
b := make([]byte, 16)
_, err := rand.Read(b)
if err != nil {
return conn, err
}
conn.ID = fmt.Sprintf("%X-%X-%X-%X-%X", b[0:4], b[4:6], b[6:8], b[8:10], b[10:])
trace("%s: Open() called for url: %s", conn.ID, connURL)
// set defaults
conn.timeout = 10
conn.hasBeenClosed = false
// parse the URL given
err = conn.initConnection(connURL)
if err != nil {
return conn, err
}
// call updateClusterInfo() to populate the cluster
// also tests the user's default
err = conn.updateClusterInfo()
// and the err from updateClusterInfo() will be our err as well
return conn, err
}
/* *****************************************************************
func: trace()
adds a message to the trace output
not a public function. we (inside) can add - outside they can
only see.
Call trace as: Sprintf pattern , args...
This is done so that the more expensive Sprintf() stuff is
done only if truly needed. When tracing is off, calls to
trace() just hit a bool check and return. If tracing is on,
then the Sprintfing is done at a leisurely pace because, well,
we're tracing.
Premature optimization is the root of all evil, so this is
probably sinful behavior.
Don't put a \n in your Sprintf pattern becuase trace() adds one
* *****************************************************************/
func trace(pattern string, args ...interface{}) {
// don't do the probably expensive Sprintf() if not needed
if wantsTrace == false {
return
}
// this could all be made into one long statement but we have
// compilers to do such things for us. let's sip a mint julep
// and spell this out in glorious exposition.
// make sure there is one and only one newline
nlPattern := strings.TrimSpace(pattern) + "\n"
msg := fmt.Sprintf(nlPattern, args...)
traceOut.Write([]byte(msg))
}
/*
TraceOn()
Turns on tracing output to the io.Writer of your choice.
Trace output is very detailed and verbose, as you might expect.
Normally, you should run with tracing off, as it makes absolutely
no concession to performance and is intended for debugging/dev use.
*/
func TraceOn(w io.Writer) {
traceOut = w
wantsTrace = true
}
/*
TraceOff()
Turns off tracing output. Once you call TraceOff(), no further
info is sent to the io.Writer, unless it is TraceOn'd again.
*/
func TraceOff() {
wantsTrace = false
traceOut = ioutil.Discard
}

View File

@ -0,0 +1,54 @@
package gorqlite
import (
"fmt"
"strings"
)
// EscapeString sql-escapes a string.
func EscapeString(value string) string {
replace := [][2]string{
{`\`, `\\`},
{`\0`, `\\0`},
{`\n`, `\\n`},
{`\r`, `\\r`},
{`"`, `\"`},
{`'`, `\'`},
}
for _, val := range replace {
value = strings.Replace(value, val[0], val[1], -1)
}
return value
}
// PreparedStatement is a simple wrapper around fmt.Sprintf for prepared SQL
// statements.
type PreparedStatement struct {
body string
}
// NewPreparedStatement takes a sprintf syntax SQL query for later binding of
// parameters.
func NewPreparedStatement(body string) PreparedStatement {
return PreparedStatement{body: body}
}
// Bind takes arguments and SQL-escapes them, then calling fmt.Sprintf.
func (p PreparedStatement) Bind(args ...interface{}) string {
var spargs []interface{}
for _, arg := range args {
switch arg.(type) {
case string:
spargs = append(spargs, `'`+EscapeString(arg.(string))+`'`)
case fmt.Stringer:
spargs = append(spargs, `'`+EscapeString(arg.(fmt.Stringer).String())+`'`)
default:
spargs = append(spargs, arg)
}
}
return fmt.Sprintf(p.body, spargs...)
}

View File

@ -0,0 +1,395 @@
package gorqlite
import "errors"
import "fmt"
import "encoding/json"
/* *****************************************************************
method: Connection.Query()
This is the JSON we get back:
{
"results": [
{
"columns": [
"id",
"name"
],
"types": [
"integer",
"text"
],
"values": [
[
1,
"fiona"
],
[
2,
"sinead"
]
],
"time": 0.0150043
}
],
"time": 0.0220043
}
or
{
"results": [
{
"columns": [
"id",
"name"
],
"types": [
"number",
"text"
],
"values": [
[
null,
"Hulk"
]
],
"time": 4.8958e-05
},
{
"columns": [
"id",
"name"
],
"types": [
"number",
"text"
],
"time": 1.8460000000000003e-05
}
],
"time": 0.000134776
}
or
{
"results": [
{
"error": "near \"nonsense\": syntax error"
}
],
"time": 2.478862
}
* *****************************************************************/
/*
QueryOne() is a convenience method that wraps Query() into a single-statement
method.
*/
func (conn *Connection) QueryOne(sqlStatement string) (qr QueryResult, err error) {
if conn.hasBeenClosed {
qr.Err = errClosed
return qr, errClosed
}
sqlStatements := make([]string, 0)
sqlStatements = append(sqlStatements, sqlStatement)
qra, err := conn.Query(sqlStatements)
return qra[0], err
}
/*
Query() is used to perform SELECT operations in the database.
It takes an array of SQL statements and executes them in a single transaction, returning an array of QueryResult vars.
*/
func (conn *Connection) Query(sqlStatements []string) (results []QueryResult, err error) {
results = make([]QueryResult, 0)
if conn.hasBeenClosed {
var errResult QueryResult
errResult.Err = errClosed
results = append(results, errResult)
return results, errClosed
}
trace("%s: Query() for %d statements", conn.ID, len(sqlStatements))
// if we get an error POSTing, that's a showstopper
response, err := conn.rqliteApiPost(api_QUERY, sqlStatements)
if err != nil {
trace("%s: rqliteApiCall() ERROR: %s", conn.ID, err.Error())
var errResult QueryResult
errResult.Err = err
results = append(results, errResult)
return results, err
}
trace("%s: rqliteApiCall() OK", conn.ID)
// if we get an error Unmarshaling, that's a showstopper
var sections map[string]interface{}
err = json.Unmarshal(response, &sections)
if err != nil {
trace("%s: json.Unmarshal() ERROR: %s", conn.ID, err.Error())
var errResult QueryResult
errResult.Err = err
results = append(results, errResult)
return results, err
}
/*
at this point, we have a "results" section and
a "time" section. we can igore the latter.
*/
resultsArray := sections["results"].([]interface{})
trace("%s: I have %d result(s) to parse", conn.ID, len(resultsArray))
numStatementErrors := 0
for n, r := range resultsArray {
trace("%s: parsing result %d", conn.ID, n)
var thisQR QueryResult
thisQR.conn = conn
// r is a hash with columns, types, values, and time
thisResult := r.(map[string]interface{})
// did we get an error?
_, ok := thisResult["error"]
if ok {
trace("%s: have an error on this result: %s", conn.ID, thisResult["error"].(string))
thisQR.Err = errors.New(thisResult["error"].(string))
results = append(results, thisQR)
numStatementErrors++
continue
}
// time is a float64
thisQR.Timing = thisResult["time"].(float64)
// column & type are an array of strings
c := thisResult["columns"].([]interface{})
t := thisResult["types"].([]interface{})
for i := 0; i < len(c); i++ {
thisQR.columns = append(thisQR.columns, c[i].(string))
thisQR.types = append(thisQR.types, t[i].(string))
}
// and values are an array of arrays
if thisResult["values"] != nil {
thisQR.values = thisResult["values"].([]interface{})
} else {
trace("%s: fyi, no values this query", conn.ID)
}
thisQR.rowNumber = -1
trace("%s: this result (#col,time) %d %f", conn.ID, len(thisQR.columns), thisQR.Timing)
results = append(results, thisQR)
}
trace("%s: finished parsing, returning %d results", conn.ID, len(results))
if numStatementErrors > 0 {
return results, errors.New(fmt.Sprintf("there were %d statement errors", numStatementErrors))
} else {
return results, nil
}
}
/* *****************************************************************
type: QueryResult
* *****************************************************************/
/*
A QueryResult type holds the results of a call to Query(). You could think of it as a rowset.
So if you were to query:
SELECT id, name FROM some_table;
then a QueryResult would hold any errors from that query, a list of columns and types, and the actual row values.
Query() returns an array of QueryResult vars, while QueryOne() returns a single variable.
*/
type QueryResult struct {
conn *Connection
Err error
columns []string
types []string
Timing float64
values []interface{}
rowNumber int64
}
// these are done as getters rather than as public
// variables to prevent monkey business by the user
// that would put us in an inconsistent state
/* *****************************************************************
method: QueryResult.Columns()
* *****************************************************************/
/*
Columns returns a list of the column names for this QueryResult.
*/
func (qr *QueryResult) Columns() []string {
return qr.columns
}
/* *****************************************************************
method: QueryResult.Map()
* *****************************************************************/
/*
Map() returns the current row (as advanced by Next()) as a map[string]interface{}
The key is a string corresponding to a column name.
The value is the corresponding column.
Note that only json values are supported, so you will need to type the interface{} accordingly.
*/
func (qr *QueryResult) Map() (map[string]interface{}, error) {
trace("%s: Map() called for row %d", qr.conn.ID, qr.rowNumber)
ans := make(map[string]interface{})
if qr.rowNumber == -1 {
return ans, errors.New("you need to Next() before you Map(), sorry, it's complicated")
}
thisRowValues := qr.values[qr.rowNumber].([]interface{})
for i := 0; i < len(qr.columns); i++ {
ans[qr.columns[i]] = thisRowValues[i]
}
return ans, nil
}
/* *****************************************************************
method: QueryResult.Next()
* *****************************************************************/
/*
Next() positions the QueryResult result pointer so that Scan() or Map() is ready.
You should call Next() first, but gorqlite will fix it if you call Map() or Scan() before
the initial Next().
A common idiom:
rows := conn.Write(something)
for rows.Next() {
// your Scan/Map and processing here.
}
*/
func (qr *QueryResult) Next() bool {
if qr.rowNumber >= int64(len(qr.values)-1) {
return false
}
qr.rowNumber += 1
return true
}
/* *****************************************************************
method: QueryResult.NumRows()
* *****************************************************************/
/*
NumRows() returns the number of rows returned by the query.
*/
func (qr *QueryResult) NumRows() int64 {
return int64(len(qr.values))
}
/* *****************************************************************
method: QueryResult.RowNumber()
* *****************************************************************/
/*
RowNumber() returns the current row number as Next() iterates through the result's rows.
*/
func (qr *QueryResult) RowNumber() int64 {
return qr.rowNumber
}
/* *****************************************************************
method: QueryResult.Scan()
* *****************************************************************/
/*
Scan() takes a list of pointers and then updates them to reflect he current row's data.
Note that only the following data types are used, and they
are a subset of the types JSON uses:
string, for JSON strings
float64, for JSON numbers
int64, as a convenient extension
nil for JSON null
booleans, JSON arrays, and JSON objects are not supported,
since sqlite does not support them.
*/
func (qr *QueryResult) Scan(dest ...interface{}) error {
trace("%s: Scan() called for %d vars", qr.conn.ID, len(dest))
if qr.rowNumber == -1 {
return errors.New("you need to Next() before you Scan(), sorry, it's complicated")
}
if len(dest) != len(qr.columns) {
return errors.New(fmt.Sprintf("expected %d columns but got %d vars\n", len(qr.columns), len(dest)))
}
thisRowValues := qr.values[qr.rowNumber].([]interface{})
for n, d := range dest {
switch d.(type) {
case *int64:
f := int64(thisRowValues[n].(float64))
*d.(*int64) = f
case *float64:
f := float64(thisRowValues[n].(float64))
*d.(*float64) = f
case *string:
s := string(thisRowValues[n].(string))
*d.(*string) = s
default:
return errors.New(fmt.Sprintf("unknown destination type to scan into in variable #%d", n))
}
}
return nil
}
/* *****************************************************************
method: QueryResult.Types()
* *****************************************************************/
/*
Types() returns an array of the column's types.
Note that sqlite will repeat the type you tell it, but in many cases, it's ignored. So you can initialize a column as CHAR(3) but it's really TEXT. See https://www.sqlite.org/datatype3.html
This info may additionally conflict with the reality that your data is being JSON encoded/decoded.
*/
func (qr *QueryResult) Types() []string {
return qr.types
}

View File

@ -0,0 +1,178 @@
package gorqlite
/*
this file has
Write()
WriteResult and its methods
*/
import "errors"
import "encoding/json"
import "fmt"
/* *****************************************************************
method: Connection.Write()
This is the JSON we get back:
{
"results": [
{
"last_insert_id": 1,
"rows_affected": 1,
"time": 0.00759015
},
{
"last_insert_id": 2,
"rows_affected": 1,
"time": 0.00669015
}
],
"time": 0.869015
}
or
{
"results": [
{
"error": "table foo already exists"
}
],
"time": 0.18472685400000002
}
We don't care about the overall time. We just want the results,
so we'll take those and put each into a WriteResult
Because the results themselves are smaller than the JSON
(which repeats strings like "last_insert_id" frequently),
we'll just parse everything at once.
* *****************************************************************/
/*
WriteOne() is a convenience method that wraps Write() into a single-statement
method.
*/
func (conn *Connection) WriteOne(sqlStatement string) (wr WriteResult, err error) {
if conn.hasBeenClosed {
wr.Err = errClosed
return wr, errClosed
}
sqlStatements := make([]string, 0)
sqlStatements = append(sqlStatements, sqlStatement)
wra, err := conn.Write(sqlStatements)
return wra[0], err
}
/*
Write() is used to perform DDL/DML in the database. ALTER, CREATE, DELETE, DROP, INSERT, UPDATE, etc. all go through Write().
Write() takes an array of SQL statements, and returns an equal-sized array of WriteResults, each corresponding to the SQL statement that produced it.
All statements are executed as a single transaction.
Write() returns an error if one is encountered during its operation. If it's something like a call to the rqlite API, then it'll return that error. If one statement out of several has an error, it will return a generic "there were %d statement errors" and you'll have to look at the individual statement's Err for more info.
*/
func (conn *Connection) Write(sqlStatements []string) (results []WriteResult, err error) {
results = make([]WriteResult, 0)
if conn.hasBeenClosed {
var errResult WriteResult
errResult.Err = errClosed
results = append(results, errResult)
return results, errClosed
}
trace("%s: Write() for %d statements", conn.ID, len(sqlStatements))
response, err := conn.rqliteApiPost(api_WRITE, sqlStatements)
if err != nil {
trace("%s: rqliteApiCall() ERROR: %s", conn.ID, err.Error())
var errResult WriteResult
errResult.Err = err
results = append(results, errResult)
return results, err
}
trace("%s: rqliteApiCall() OK", conn.ID)
var sections map[string]interface{}
err = json.Unmarshal(response, &sections)
if err != nil {
trace("%s: json.Unmarshal() ERROR: %s", conn.ID, err.Error())
var errResult WriteResult
errResult.Err = err
results = append(results, errResult)
return results, err
}
/*
at this point, we have a "results" section and
a "time" section. we can igore the latter.
*/
resultsArray := sections["results"].([]interface{})
trace("%s: I have %d result(s) to parse", conn.ID, len(resultsArray))
numStatementErrors := 0
for n, k := range resultsArray {
trace("%s: starting on result %d", conn.ID, n)
thisResult := k.(map[string]interface{})
var thisWR WriteResult
thisWR.conn = conn
// did we get an error?
_, ok := thisResult["error"]
if ok {
trace("%s: have an error on this result: %s", conn.ID, thisResult["error"].(string))
thisWR.Err = errors.New(thisResult["error"].(string))
results = append(results, thisWR)
numStatementErrors += 1
continue
}
_, ok = thisResult["last_insert_id"]
if ok {
thisWR.LastInsertID = int64(thisResult["last_insert_id"].(float64))
}
_, ok = thisResult["rows_affected"] // could be zero for a CREATE
if ok {
thisWR.RowsAffected = int64(thisResult["rows_affected"].(float64))
}
thisWR.Timing = thisResult["time"].(float64)
trace("%s: this result (LII,RA,T): %d %d %f", conn.ID, thisWR.LastInsertID, thisWR.RowsAffected, thisWR.Timing)
results = append(results, thisWR)
}
trace("%s: finished parsing, returning %d results", conn.ID, len(results))
if numStatementErrors > 0 {
return results, errors.New(fmt.Sprintf("there were %d statement errors", numStatementErrors))
} else {
return results, nil
}
}
/* *****************************************************************
type: WriteResult
* *****************************************************************/
/*
A WriteResult holds the result of a single statement sent to Write().
Write() returns an array of WriteResult vars, while WriteOne() returns a single WriteResult.
*/
type WriteResult struct {
Err error // don't trust the rest if this isn't nil
Timing float64
RowsAffected int64 // affected by the change
LastInsertID int64 // if relevant, otherwise zero value
conn *Connection
}