vendor dependencies
This commit is contained in:
parent
86be40fea0
commit
2fb2ab8fa2
|
@ -11,10 +11,9 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/koding/logging"
|
||||
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
"github.com/koding/logging"
|
||||
)
|
||||
|
||||
//go:generate stringer -type ClientState
|
||||
|
|
|
@ -16,7 +16,6 @@ import (
|
|||
|
||||
"git.xeserv.us/xena/route/lib/tunnel"
|
||||
"git.xeserv.us/xena/route/lib/tunnel/tunneltest"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
|
|
|
@ -17,10 +17,9 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/koding/logging"
|
||||
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
"github.com/koding/logging"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
|
@ -4,8 +4,8 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/koding/logging"
|
||||
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
||||
"github.com/koding/logging"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
|
||||
"git.xeserv.us/xena/route/lib/tunnel"
|
||||
"git.xeserv.us/xena/route/lib/tunnel/tunneltest"
|
||||
|
||||
"github.com/cenkalti/backoff"
|
||||
)
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"time"
|
||||
|
||||
"git.xeserv.us/xena/route/lib/tunnel/proto"
|
||||
|
||||
"github.com/cenkalti/backoff"
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
417badecf1ab14d0d6e38ad82397da2a59e2f6ca github.com/GoRethink/gorethink
|
||||
9b48ece7fc373043054858f8c0d362665e866004 github.com/Sirupsen/logrus
|
||||
62b230097e9c9534ca2074782b25d738c4b68964 (dirty) github.com/Xe/uuid
|
||||
38b46760280b5500edd530aa39a8075bf22f9630 github.com/Yawning/bulb
|
||||
b02f2bbce11d7ea6b97f282ef1771b0fe2f65ef3 github.com/cenk/backoff
|
||||
b02f2bbce11d7ea6b97f282ef1771b0fe2f65ef3 github.com/cenkalti/backoff
|
||||
fcd59fca7456889be7f2ad4515b7612fd6acef31 github.com/facebookgo/flagenv
|
||||
8ee79997227bf9b34611aee7946ae64735e6fd93 github.com/golang/protobuf/proto
|
||||
e80d13ce29ede4452c43dea11e79b9bc8a15b478 github.com/hailocab/go-hostpool
|
||||
d1caa6c97c9fc1cc9e83bbe34d0603f9ff0ce8bd github.com/hashicorp/yamux
|
||||
4ed13390c0acd2ff4e371e64d8b97c8954138243 github.com/joho/godotenv
|
||||
4ed13390c0acd2ff4e371e64d8b97c8954138243 github.com/joho/godotenv/autoload
|
||||
8b5a689ed69b1c7cd1e3595276fc2a352d7818e0 github.com/koding/logging
|
||||
1627eaec269965440f742a25a627910195ad1c7a github.com/sycamoreone/orc/tor
|
||||
38b46760280b5500edd530aa39a8075bf22f9630 github.com/yawning/bulb/utils
|
||||
38b46760280b5500edd530aa39a8075bf22f9630 github.com/yawning/bulb/utils/pkcs1
|
||||
b8a2a83acfe6e6770b75de42d5ff4c67596675c0 golang.org/x/crypto/pbkdf2
|
||||
f2499483f923065a842d38eb4c7f1927e6fc6e6d golang.org/x/net/proxy
|
||||
6e328e67893eb46323ad06f0e92cb9536babbabc gopkg.in/fatih/pool.v2
|
||||
016a1d3b4d15951ab2e39bd3596718ba94d298ba gopkg.in/gorethink/gorethink.v2/encoding
|
||||
016a1d3b4d15951ab2e39bd3596718ba94d298ba gopkg.in/gorethink/gorethink.v2/ql2
|
||||
016a1d3b4d15951ab2e39bd3596718ba94d298ba gopkg.in/gorethink/gorethink.v2/types
|
|
@ -0,0 +1,522 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Sirupsen/logrus"
|
||||
"github.com/cenk/backoff"
|
||||
"github.com/hailocab/go-hostpool"
|
||||
)
|
||||
|
||||
// A Cluster represents a connection to a RethinkDB cluster, a cluster is created
|
||||
// by the Session and should rarely be created manually.
|
||||
//
|
||||
// The cluster keeps track of all nodes in the cluster and if requested can listen
|
||||
// for cluster changes and start tracking a new node if one appears. Currently
|
||||
// nodes are removed from the pool if they become unhealthy (100 failed queries).
|
||||
// This should hopefully soon be replaced by a backoff system.
|
||||
type Cluster struct {
|
||||
opts *ConnectOpts
|
||||
|
||||
mu sync.RWMutex
|
||||
seeds []Host // Initial host nodes specified by user.
|
||||
hp hostpool.HostPool
|
||||
nodes map[string]*Node // Active nodes in cluster.
|
||||
closed bool
|
||||
|
||||
nodeIndex int64
|
||||
}
|
||||
|
||||
// NewCluster creates a new cluster by connecting to the given hosts.
|
||||
func NewCluster(hosts []Host, opts *ConnectOpts) (*Cluster, error) {
|
||||
c := &Cluster{
|
||||
hp: hostpool.NewEpsilonGreedy([]string{}, opts.HostDecayDuration, &hostpool.LinearEpsilonValueCalculator{}),
|
||||
seeds: hosts,
|
||||
opts: opts,
|
||||
}
|
||||
|
||||
// Attempt to connect to each host and discover any additional hosts if host
|
||||
// discovery is enabled
|
||||
if err := c.connectNodes(c.getSeeds()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !c.IsConnected() {
|
||||
return nil, ErrNoConnectionsStarted
|
||||
}
|
||||
|
||||
if opts.DiscoverHosts {
|
||||
go c.discover()
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Query executes a ReQL query using the cluster to connect to the database
|
||||
func (c *Cluster) Query(q Query) (cursor *Cursor, err error) {
|
||||
for i := 0; i < c.numRetries(); i++ {
|
||||
var node *Node
|
||||
var hpr hostpool.HostPoolResponse
|
||||
|
||||
node, hpr, err = c.GetNextNode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cursor, err = node.Query(q)
|
||||
hpr.Mark(err)
|
||||
|
||||
if !shouldRetryQuery(q, err) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return cursor, err
|
||||
}
|
||||
|
||||
// Exec executes a ReQL query using the cluster to connect to the database
|
||||
func (c *Cluster) Exec(q Query) (err error) {
|
||||
for i := 0; i < c.numRetries(); i++ {
|
||||
var node *Node
|
||||
var hpr hostpool.HostPoolResponse
|
||||
|
||||
node, hpr, err = c.GetNextNode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = node.Exec(q)
|
||||
hpr.Mark(err)
|
||||
|
||||
if !shouldRetryQuery(q, err) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Server returns the server name and server UUID being used by a connection.
|
||||
func (c *Cluster) Server() (response ServerResponse, err error) {
|
||||
for i := 0; i < c.numRetries(); i++ {
|
||||
var node *Node
|
||||
var hpr hostpool.HostPoolResponse
|
||||
|
||||
node, hpr, err = c.GetNextNode()
|
||||
if err != nil {
|
||||
return ServerResponse{}, err
|
||||
}
|
||||
|
||||
response, err = node.Server()
|
||||
hpr.Mark(err)
|
||||
|
||||
// This query should not fail so retry if any error is detected
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return response, err
|
||||
}
|
||||
|
||||
// SetInitialPoolCap sets the initial capacity of the connection pool.
|
||||
func (c *Cluster) SetInitialPoolCap(n int) {
|
||||
for _, node := range c.GetNodes() {
|
||||
node.SetInitialPoolCap(n)
|
||||
}
|
||||
}
|
||||
|
||||
// SetMaxIdleConns sets the maximum number of connections in the idle
|
||||
// connection pool.
|
||||
func (c *Cluster) SetMaxIdleConns(n int) {
|
||||
for _, node := range c.GetNodes() {
|
||||
node.SetMaxIdleConns(n)
|
||||
}
|
||||
}
|
||||
|
||||
// SetMaxOpenConns sets the maximum number of open connections to the database.
|
||||
func (c *Cluster) SetMaxOpenConns(n int) {
|
||||
for _, node := range c.GetNodes() {
|
||||
node.SetMaxOpenConns(n)
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the cluster
|
||||
func (c *Cluster) Close(optArgs ...CloseOpts) error {
|
||||
if c.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, node := range c.GetNodes() {
|
||||
err := node.Close(optArgs...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
c.hp.Close()
|
||||
c.closed = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// discover attempts to find new nodes in the cluster using the current nodes
|
||||
func (c *Cluster) discover() {
|
||||
// Keep retrying with exponential backoff.
|
||||
b := backoff.NewExponentialBackOff()
|
||||
// Never finish retrying (max interval is still 60s)
|
||||
b.MaxElapsedTime = 0
|
||||
|
||||
// Keep trying to discover new nodes
|
||||
for {
|
||||
backoff.RetryNotify(func() error {
|
||||
// If no hosts try seeding nodes
|
||||
if len(c.GetNodes()) == 0 {
|
||||
c.connectNodes(c.getSeeds())
|
||||
}
|
||||
|
||||
return c.listenForNodeChanges()
|
||||
}, b, func(err error, wait time.Duration) {
|
||||
Log.Debugf("Error discovering hosts %s, waiting: %s", err, wait)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// listenForNodeChanges listens for changes to node status using change feeds.
|
||||
// This function will block until the query fails
|
||||
func (c *Cluster) listenForNodeChanges() error {
|
||||
// Start listening to changes from a random active node
|
||||
node, hpr, err := c.GetNextNode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q, err := newQuery(
|
||||
DB("rethinkdb").Table("server_status").Changes(),
|
||||
map[string]interface{}{},
|
||||
c.opts,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error building query: %s", err)
|
||||
}
|
||||
|
||||
cursor, err := node.Query(q)
|
||||
if err != nil {
|
||||
hpr.Mark(err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Keep reading node status updates from changefeed
|
||||
var result struct {
|
||||
NewVal nodeStatus `gorethink:"new_val"`
|
||||
OldVal nodeStatus `gorethink:"old_val"`
|
||||
}
|
||||
for cursor.Next(&result) {
|
||||
addr := fmt.Sprintf("%s:%d", result.NewVal.Network.Hostname, result.NewVal.Network.ReqlPort)
|
||||
addr = strings.ToLower(addr)
|
||||
|
||||
switch result.NewVal.Status {
|
||||
case "connected":
|
||||
// Connect to node using exponential backoff (give up after waiting 5s)
|
||||
// to give the node time to start-up.
|
||||
b := backoff.NewExponentialBackOff()
|
||||
b.MaxElapsedTime = time.Second * 5
|
||||
|
||||
backoff.Retry(func() error {
|
||||
node, err := c.connectNodeWithStatus(result.NewVal)
|
||||
if err == nil {
|
||||
if !c.nodeExists(node) {
|
||||
c.addNode(node)
|
||||
|
||||
Log.WithFields(logrus.Fields{
|
||||
"id": node.ID,
|
||||
"host": node.Host.String(),
|
||||
}).Debug("Connected to node")
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}, b)
|
||||
}
|
||||
}
|
||||
|
||||
err = cursor.Err()
|
||||
hpr.Mark(err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Cluster) connectNodes(hosts []Host) error {
|
||||
// Add existing nodes to map
|
||||
nodeSet := map[string]*Node{}
|
||||
for _, node := range c.GetNodes() {
|
||||
nodeSet[node.ID] = node
|
||||
}
|
||||
|
||||
var attemptErr error
|
||||
|
||||
// Attempt to connect to each seed host
|
||||
for _, host := range hosts {
|
||||
conn, err := NewConnection(host.String(), c.opts)
|
||||
if err != nil {
|
||||
attemptErr = err
|
||||
Log.Warnf("Error creating connection: %s", err.Error())
|
||||
continue
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if c.opts.DiscoverHosts {
|
||||
q, err := newQuery(
|
||||
DB("rethinkdb").Table("server_status"),
|
||||
map[string]interface{}{},
|
||||
c.opts,
|
||||
)
|
||||
if err != nil {
|
||||
Log.Warnf("Error building query: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
_, cursor, err := conn.Query(q)
|
||||
if err != nil {
|
||||
attemptErr = err
|
||||
Log.Warnf("Error fetching cluster status: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
var results []nodeStatus
|
||||
err = cursor.All(&results)
|
||||
if err != nil {
|
||||
attemptErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
for _, result := range results {
|
||||
node, err := c.connectNodeWithStatus(result)
|
||||
if err == nil {
|
||||
if _, ok := nodeSet[node.ID]; !ok {
|
||||
Log.WithFields(logrus.Fields{
|
||||
"id": node.ID,
|
||||
"host": node.Host.String(),
|
||||
}).Debug("Connected to node")
|
||||
nodeSet[node.ID] = node
|
||||
}
|
||||
} else {
|
||||
attemptErr = err
|
||||
Log.Warnf("Error connecting to node: %s", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
svrRsp, err := conn.Server()
|
||||
if err != nil {
|
||||
attemptErr = err
|
||||
Log.Warnf("Error fetching server ID: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
node, err := c.connectNode(svrRsp.ID, []Host{host})
|
||||
if err == nil {
|
||||
if _, ok := nodeSet[node.ID]; !ok {
|
||||
Log.WithFields(logrus.Fields{
|
||||
"id": node.ID,
|
||||
"host": node.Host.String(),
|
||||
}).Debug("Connected to node")
|
||||
|
||||
nodeSet[node.ID] = node
|
||||
}
|
||||
} else {
|
||||
attemptErr = err
|
||||
Log.Warnf("Error connecting to node: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no nodes were contactable then return the last error, this does not
|
||||
// include driver errors such as if there was an issue building the
|
||||
// query
|
||||
if len(nodeSet) == 0 {
|
||||
return attemptErr
|
||||
}
|
||||
|
||||
nodes := []*Node{}
|
||||
for _, node := range nodeSet {
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
c.setNodes(nodes)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cluster) connectNodeWithStatus(s nodeStatus) (*Node, error) {
|
||||
aliases := make([]Host, len(s.Network.CanonicalAddresses))
|
||||
for i, aliasAddress := range s.Network.CanonicalAddresses {
|
||||
aliases[i] = NewHost(aliasAddress.Host, int(s.Network.ReqlPort))
|
||||
}
|
||||
|
||||
return c.connectNode(s.ID, aliases)
|
||||
}
|
||||
|
||||
func (c *Cluster) connectNode(id string, aliases []Host) (*Node, error) {
|
||||
var pool *Pool
|
||||
var err error
|
||||
|
||||
for len(aliases) > 0 {
|
||||
pool, err = NewPool(aliases[0], c.opts)
|
||||
if err != nil {
|
||||
aliases = aliases[1:]
|
||||
continue
|
||||
}
|
||||
|
||||
err = pool.Ping()
|
||||
if err != nil {
|
||||
aliases = aliases[1:]
|
||||
continue
|
||||
}
|
||||
|
||||
// Ping successful so break out of loop
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(aliases) == 0 {
|
||||
return nil, ErrInvalidNode
|
||||
}
|
||||
|
||||
return newNode(id, aliases, c, pool), nil
|
||||
}
|
||||
|
||||
// IsConnected returns true if cluster has nodes and is not already closed.
|
||||
func (c *Cluster) IsConnected() bool {
|
||||
c.mu.RLock()
|
||||
closed := c.closed
|
||||
c.mu.RUnlock()
|
||||
|
||||
return (len(c.GetNodes()) > 0) && !closed
|
||||
}
|
||||
|
||||
// AddSeeds adds new seed hosts to the cluster.
|
||||
func (c *Cluster) AddSeeds(hosts []Host) {
|
||||
c.mu.Lock()
|
||||
c.seeds = append(c.seeds, hosts...)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *Cluster) getSeeds() []Host {
|
||||
c.mu.RLock()
|
||||
seeds := c.seeds
|
||||
c.mu.RUnlock()
|
||||
|
||||
return seeds
|
||||
}
|
||||
|
||||
// GetNextNode returns a random node on the cluster
|
||||
func (c *Cluster) GetNextNode() (*Node, hostpool.HostPoolResponse, error) {
|
||||
if !c.IsConnected() {
|
||||
return nil, nil, ErrNoConnections
|
||||
}
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
nodes := c.nodes
|
||||
hpr := c.hp.Get()
|
||||
if n, ok := nodes[hpr.Host()]; ok {
|
||||
if !n.Closed() {
|
||||
return n, hpr, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, ErrNoConnections
|
||||
}
|
||||
|
||||
// GetNodes returns a list of all nodes in the cluster
|
||||
func (c *Cluster) GetNodes() []*Node {
|
||||
c.mu.RLock()
|
||||
nodes := make([]*Node, 0, len(c.nodes))
|
||||
for _, n := range c.nodes {
|
||||
nodes = append(nodes, n)
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
return nodes
|
||||
}
|
||||
|
||||
func (c *Cluster) nodeExists(search *Node) bool {
|
||||
for _, node := range c.GetNodes() {
|
||||
if node.ID == search.ID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Cluster) addNode(node *Node) {
|
||||
c.mu.RLock()
|
||||
nodes := append(c.GetNodes(), node)
|
||||
c.mu.RUnlock()
|
||||
|
||||
c.setNodes(nodes)
|
||||
}
|
||||
|
||||
func (c *Cluster) addNodes(nodesToAdd []*Node) {
|
||||
c.mu.RLock()
|
||||
nodes := append(c.GetNodes(), nodesToAdd...)
|
||||
c.mu.RUnlock()
|
||||
|
||||
c.setNodes(nodes)
|
||||
}
|
||||
|
||||
func (c *Cluster) setNodes(nodes []*Node) {
|
||||
nodesMap := make(map[string]*Node, len(nodes))
|
||||
hosts := make([]string, len(nodes))
|
||||
for i, node := range nodes {
|
||||
host := node.Host.String()
|
||||
|
||||
nodesMap[host] = node
|
||||
hosts[i] = host
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.nodes = nodesMap
|
||||
c.hp.SetHosts(hosts)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *Cluster) removeNode(nodeID string) {
|
||||
nodes := c.GetNodes()
|
||||
nodeArray := make([]*Node, len(nodes)-1)
|
||||
count := 0
|
||||
|
||||
// Add nodes that are not in remove list.
|
||||
for _, n := range nodes {
|
||||
if n.ID != nodeID {
|
||||
nodeArray[count] = n
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
// Do sanity check to make sure assumptions are correct.
|
||||
if count < len(nodeArray) {
|
||||
// Resize array.
|
||||
nodeArray2 := make([]*Node, count)
|
||||
copy(nodeArray2, nodeArray)
|
||||
nodeArray = nodeArray2
|
||||
}
|
||||
|
||||
c.setNodes(nodeArray)
|
||||
}
|
||||
|
||||
func (c *Cluster) nextNodeIndex() int64 {
|
||||
return atomic.AddInt64(&c.nodeIndex, 1)
|
||||
}
|
||||
|
||||
func (c *Cluster) numRetries() int {
|
||||
if n := c.opts.NumRetries; n > 0 {
|
||||
return n
|
||||
}
|
||||
|
||||
return 3
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
// +build cluster
|
||||
// +build integration
|
||||
|
||||
package gorethink
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
test "gopkg.in/check.v1"
|
||||
)
|
||||
|
||||
func (s *RethinkSuite) TestClusterDetectNewNode(c *test.C) {
|
||||
session, err := Connect(ConnectOpts{
|
||||
Addresses: []string{url, url2},
|
||||
DiscoverHosts: true,
|
||||
NodeRefreshInterval: time.Second,
|
||||
})
|
||||
c.Assert(err, test.IsNil)
|
||||
|
||||
t := time.NewTimer(time.Second * 30)
|
||||
for {
|
||||
select {
|
||||
// Fail if deadline has passed
|
||||
case <-t.C:
|
||||
c.Fatal("No node was added to the cluster")
|
||||
default:
|
||||
// Pass if another node was added
|
||||
if len(session.cluster.GetNodes()) >= 3 {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RethinkSuite) TestClusterRecoverAfterNoNodes(c *test.C) {
|
||||
session, err := Connect(ConnectOpts{
|
||||
Addresses: []string{url, url2},
|
||||
DiscoverHosts: true,
|
||||
NodeRefreshInterval: time.Second,
|
||||
})
|
||||
c.Assert(err, test.IsNil)
|
||||
|
||||
t := time.NewTimer(time.Second * 30)
|
||||
hasHadZeroNodes := false
|
||||
for {
|
||||
select {
|
||||
// Fail if deadline has passed
|
||||
case <-t.C:
|
||||
c.Fatal("No node was added to the cluster")
|
||||
default:
|
||||
// Check if there are no nodes
|
||||
if len(session.cluster.GetNodes()) == 0 {
|
||||
hasHadZeroNodes = true
|
||||
}
|
||||
|
||||
// Pass if another node was added
|
||||
if len(session.cluster.GetNodes()) >= 1 && hasHadZeroNodes {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RethinkSuite) TestClusterNodeHealth(c *test.C) {
|
||||
session, err := Connect(ConnectOpts{
|
||||
Addresses: []string{url1, url2, url3},
|
||||
DiscoverHosts: true,
|
||||
NodeRefreshInterval: time.Second,
|
||||
InitialCap: 50,
|
||||
MaxOpen: 200,
|
||||
})
|
||||
c.Assert(err, test.IsNil)
|
||||
|
||||
attempts := 0
|
||||
failed := 0
|
||||
seconds := 0
|
||||
|
||||
t := time.NewTimer(time.Second * 30)
|
||||
tick := time.NewTicker(time.Second)
|
||||
for {
|
||||
select {
|
||||
// Fail if deadline has passed
|
||||
case <-tick.C:
|
||||
seconds++
|
||||
c.Logf("%ds elapsed", seconds)
|
||||
case <-t.C:
|
||||
// Execute queries for 10s and check that at most 5% of the queries fail
|
||||
c.Logf("%d of the %d(%d%%) queries failed", failed, attempts, (failed / attempts))
|
||||
c.Assert(failed <= 100, test.Equals, true)
|
||||
return
|
||||
default:
|
||||
attempts++
|
||||
if err := Expr(1).Exec(session); err != nil {
|
||||
c.Logf("Query failed, %s", err)
|
||||
failed++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
// +build cluster
|
||||
|
||||
package gorethink
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
test "gopkg.in/check.v1"
|
||||
)
|
||||
|
||||
func (s *RethinkSuite) TestClusterConnect(c *test.C) {
|
||||
session, err := Connect(ConnectOpts{
|
||||
Addresses: []string{url1, url2, url3},
|
||||
})
|
||||
c.Assert(err, test.IsNil)
|
||||
|
||||
row, err := Expr("Hello World").Run(session)
|
||||
c.Assert(err, test.IsNil)
|
||||
|
||||
var response string
|
||||
err = row.One(&response)
|
||||
c.Assert(err, test.IsNil)
|
||||
c.Assert(response, test.Equals, "Hello World")
|
||||
}
|
||||
|
||||
func (s *RethinkSuite) TestClusterMultipleQueries(c *test.C) {
|
||||
session, err := Connect(ConnectOpts{
|
||||
Addresses: []string{url1, url2, url3},
|
||||
})
|
||||
c.Assert(err, test.IsNil)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
row, err := Expr(fmt.Sprintf("Hello World", i)).Run(session)
|
||||
c.Assert(err, test.IsNil)
|
||||
|
||||
var response string
|
||||
err = row.One(&response)
|
||||
c.Assert(err, test.IsNil)
|
||||
c.Assert(response, test.Equals, fmt.Sprintf("Hello World", i))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RethinkSuite) TestClusterConnectError(c *test.C) {
|
||||
var err error
|
||||
_, err = Connect(ConnectOpts{
|
||||
Addresses: []string{"nonexistanturl"},
|
||||
Timeout: time.Second,
|
||||
})
|
||||
c.Assert(err, test.NotNil)
|
||||
}
|
||||
|
||||
func (s *RethinkSuite) TestClusterConnectDatabase(c *test.C) {
|
||||
session, err := Connect(ConnectOpts{
|
||||
Addresses: []string{url1, url2, url3},
|
||||
Database: "test2",
|
||||
})
|
||||
c.Assert(err, test.IsNil)
|
||||
|
||||
_, err = Table("test2").Run(session)
|
||||
c.Assert(err, test.NotNil)
|
||||
c.Assert(err.Error(), test.Equals, "gorethink: Database `test2` does not exist. in:\nr.Table(\"test2\")")
|
||||
}
|
|
@ -0,0 +1,381 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
const (
|
||||
respHeaderLen = 12
|
||||
defaultKeepAlivePeriod = time.Second * 30
|
||||
)
|
||||
|
||||
// Response represents the raw response from a query, most of the time you
|
||||
// should instead use a Cursor when reading from the database.
|
||||
type Response struct {
|
||||
Token int64
|
||||
Type p.Response_ResponseType `json:"t"`
|
||||
ErrorType p.Response_ErrorType `json:"e"`
|
||||
Notes []p.Response_ResponseNote `json:"n"`
|
||||
Responses []json.RawMessage `json:"r"`
|
||||
Backtrace []interface{} `json:"b"`
|
||||
Profile interface{} `json:"p"`
|
||||
}
|
||||
|
||||
// Connection is a connection to a rethinkdb database. Connection is not thread
|
||||
// safe and should only be accessed be a single goroutine
|
||||
type Connection struct {
|
||||
net.Conn
|
||||
|
||||
address string
|
||||
opts *ConnectOpts
|
||||
|
||||
_ [4]byte
|
||||
mu sync.Mutex
|
||||
token int64
|
||||
cursors map[int64]*Cursor
|
||||
bad bool
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewConnection creates a new connection to the database server
|
||||
func NewConnection(address string, opts *ConnectOpts) (*Connection, error) {
|
||||
var err error
|
||||
c := &Connection{
|
||||
address: address,
|
||||
opts: opts,
|
||||
cursors: make(map[int64]*Cursor),
|
||||
}
|
||||
|
||||
keepAlivePeriod := defaultKeepAlivePeriod
|
||||
if opts.KeepAlivePeriod > 0 {
|
||||
keepAlivePeriod = opts.KeepAlivePeriod
|
||||
}
|
||||
|
||||
// Connect to Server
|
||||
nd := net.Dialer{Timeout: c.opts.Timeout, KeepAlive: keepAlivePeriod}
|
||||
if c.opts.TLSConfig == nil {
|
||||
c.Conn, err = nd.Dial("tcp", address)
|
||||
} else {
|
||||
c.Conn, err = tls.DialWithDialer(&nd, "tcp", address, c.opts.TLSConfig)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, RQLConnectionError{rqlError(err.Error())}
|
||||
}
|
||||
|
||||
// Send handshake
|
||||
handshake, err := c.handshake(opts.HandshakeVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = handshake.Send(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying net.Conn
|
||||
func (c *Connection) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
var err error
|
||||
|
||||
if !c.closed {
|
||||
err = c.Conn.Close()
|
||||
c.closed = true
|
||||
c.cursors = make(map[int64]*Cursor)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Query sends a Query to the database, returning both the raw Response and a
|
||||
// Cursor which should be used to view the query's response.
|
||||
//
|
||||
// This function is used internally by Run which should be used for most queries.
|
||||
func (c *Connection) Query(q Query) (*Response, *Cursor, error) {
|
||||
if c == nil {
|
||||
return nil, nil, ErrConnectionClosed
|
||||
}
|
||||
c.mu.Lock()
|
||||
if c.Conn == nil {
|
||||
c.bad = true
|
||||
c.mu.Unlock()
|
||||
return nil, nil, ErrConnectionClosed
|
||||
}
|
||||
|
||||
// Add token if query is a START/NOREPLY_WAIT
|
||||
if q.Type == p.Query_START || q.Type == p.Query_NOREPLY_WAIT || q.Type == p.Query_SERVER_INFO {
|
||||
q.Token = c.nextToken()
|
||||
}
|
||||
if q.Type == p.Query_START || q.Type == p.Query_NOREPLY_WAIT {
|
||||
if c.opts.Database != "" {
|
||||
var err error
|
||||
q.Opts["db"], err = DB(c.opts.Database).Build()
|
||||
if err != nil {
|
||||
c.mu.Unlock()
|
||||
return nil, nil, RQLDriverError{rqlError(err.Error())}
|
||||
}
|
||||
}
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
err := c.sendQuery(q)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if noreply, ok := q.Opts["noreply"]; ok && noreply.(bool) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
for {
|
||||
response, err := c.readResponse()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if response.Token == q.Token {
|
||||
// If this was the requested response process and return
|
||||
return c.processResponse(q, response)
|
||||
} else if _, ok := c.cursors[response.Token]; ok {
|
||||
// If the token is in the cursor cache then process the response
|
||||
c.processResponse(q, response)
|
||||
} else {
|
||||
putResponse(response)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ServerResponse struct {
|
||||
ID string `gorethink:"id"`
|
||||
Name string `gorethink:"name"`
|
||||
}
|
||||
|
||||
// Server returns the server name and server UUID being used by a connection.
|
||||
func (c *Connection) Server() (ServerResponse, error) {
|
||||
var response ServerResponse
|
||||
|
||||
_, cur, err := c.Query(Query{
|
||||
Type: p.Query_SERVER_INFO,
|
||||
})
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
|
||||
if err = cur.One(&response); err != nil {
|
||||
return response, err
|
||||
}
|
||||
|
||||
if err = cur.Close(); err != nil {
|
||||
return response, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// sendQuery marshals the Query and sends the JSON to the server.
|
||||
func (c *Connection) sendQuery(q Query) error {
|
||||
// Build query
|
||||
b, err := json.Marshal(q.Build())
|
||||
if err != nil {
|
||||
return RQLDriverError{rqlError("Error building query")}
|
||||
}
|
||||
|
||||
// Set timeout
|
||||
if c.opts.WriteTimeout == 0 {
|
||||
c.Conn.SetWriteDeadline(time.Time{})
|
||||
} else {
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(c.opts.WriteTimeout))
|
||||
}
|
||||
|
||||
// Send the JSON encoding of the query itself.
|
||||
if err = c.writeQuery(q.Token, b); err != nil {
|
||||
c.bad = true
|
||||
return RQLConnectionError{rqlError(err.Error())}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getToken generates the next query token, used to number requests and match
|
||||
// responses with requests.
|
||||
func (c *Connection) nextToken() int64 {
|
||||
// requires c.token to be 64-bit aligned on ARM
|
||||
return atomic.AddInt64(&c.token, 1)
|
||||
}
|
||||
|
||||
// readResponse attempts to read a Response from the server, if no response
|
||||
// could be read then an error is returned.
|
||||
func (c *Connection) readResponse() (*Response, error) {
|
||||
// Set timeout
|
||||
if c.opts.ReadTimeout == 0 {
|
||||
c.Conn.SetReadDeadline(time.Time{})
|
||||
} else {
|
||||
c.Conn.SetReadDeadline(time.Now().Add(c.opts.ReadTimeout))
|
||||
}
|
||||
|
||||
// Read response header (token+length)
|
||||
headerBuf := [respHeaderLen]byte{}
|
||||
if _, err := c.read(headerBuf[:], respHeaderLen); err != nil {
|
||||
c.bad = true
|
||||
return nil, RQLConnectionError{rqlError(err.Error())}
|
||||
}
|
||||
|
||||
responseToken := int64(binary.LittleEndian.Uint64(headerBuf[:8]))
|
||||
messageLength := binary.LittleEndian.Uint32(headerBuf[8:])
|
||||
|
||||
// Read the JSON encoding of the Response itself.
|
||||
b := make([]byte, int(messageLength))
|
||||
|
||||
if _, err := c.read(b, int(messageLength)); err != nil {
|
||||
c.bad = true
|
||||
return nil, RQLConnectionError{rqlError(err.Error())}
|
||||
}
|
||||
|
||||
// Decode the response
|
||||
var response = newCachedResponse()
|
||||
if err := json.Unmarshal(b, response); err != nil {
|
||||
c.bad = true
|
||||
return nil, RQLDriverError{rqlError(err.Error())}
|
||||
}
|
||||
response.Token = responseToken
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (c *Connection) processResponse(q Query, response *Response) (*Response, *Cursor, error) {
|
||||
switch response.Type {
|
||||
case p.Response_CLIENT_ERROR:
|
||||
return c.processErrorResponse(q, response, RQLClientError{rqlServerError{response, q.Term}})
|
||||
case p.Response_COMPILE_ERROR:
|
||||
return c.processErrorResponse(q, response, RQLCompileError{rqlServerError{response, q.Term}})
|
||||
case p.Response_RUNTIME_ERROR:
|
||||
return c.processErrorResponse(q, response, createRuntimeError(response.ErrorType, response, q.Term))
|
||||
case p.Response_SUCCESS_ATOM, p.Response_SERVER_INFO:
|
||||
return c.processAtomResponse(q, response)
|
||||
case p.Response_SUCCESS_PARTIAL:
|
||||
return c.processPartialResponse(q, response)
|
||||
case p.Response_SUCCESS_SEQUENCE:
|
||||
return c.processSequenceResponse(q, response)
|
||||
case p.Response_WAIT_COMPLETE:
|
||||
return c.processWaitResponse(q, response)
|
||||
default:
|
||||
putResponse(response)
|
||||
return nil, nil, RQLDriverError{rqlError("Unexpected response type")}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) processErrorResponse(q Query, response *Response, err error) (*Response, *Cursor, error) {
|
||||
c.mu.Lock()
|
||||
cursor := c.cursors[response.Token]
|
||||
|
||||
delete(c.cursors, response.Token)
|
||||
c.mu.Unlock()
|
||||
|
||||
return response, cursor, err
|
||||
}
|
||||
|
||||
func (c *Connection) processAtomResponse(q Query, response *Response) (*Response, *Cursor, error) {
|
||||
// Create cursor
|
||||
cursor := newCursor(c, "Cursor", response.Token, q.Term, q.Opts)
|
||||
cursor.profile = response.Profile
|
||||
|
||||
cursor.extend(response)
|
||||
|
||||
return response, cursor, nil
|
||||
}
|
||||
|
||||
func (c *Connection) processPartialResponse(q Query, response *Response) (*Response, *Cursor, error) {
|
||||
cursorType := "Cursor"
|
||||
if len(response.Notes) > 0 {
|
||||
switch response.Notes[0] {
|
||||
case p.Response_SEQUENCE_FEED:
|
||||
cursorType = "Feed"
|
||||
case p.Response_ATOM_FEED:
|
||||
cursorType = "AtomFeed"
|
||||
case p.Response_ORDER_BY_LIMIT_FEED:
|
||||
cursorType = "OrderByLimitFeed"
|
||||
case p.Response_UNIONED_FEED:
|
||||
cursorType = "UnionedFeed"
|
||||
case p.Response_INCLUDES_STATES:
|
||||
cursorType = "IncludesFeed"
|
||||
}
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
cursor, ok := c.cursors[response.Token]
|
||||
if !ok {
|
||||
// Create a new cursor if needed
|
||||
cursor = newCursor(c, cursorType, response.Token, q.Term, q.Opts)
|
||||
cursor.profile = response.Profile
|
||||
|
||||
c.cursors[response.Token] = cursor
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
cursor.extend(response)
|
||||
|
||||
return response, cursor, nil
|
||||
}
|
||||
|
||||
func (c *Connection) processSequenceResponse(q Query, response *Response) (*Response, *Cursor, error) {
|
||||
c.mu.Lock()
|
||||
cursor, ok := c.cursors[response.Token]
|
||||
if !ok {
|
||||
// Create a new cursor if needed
|
||||
cursor = newCursor(c, "Cursor", response.Token, q.Term, q.Opts)
|
||||
cursor.profile = response.Profile
|
||||
}
|
||||
|
||||
delete(c.cursors, response.Token)
|
||||
c.mu.Unlock()
|
||||
|
||||
cursor.extend(response)
|
||||
|
||||
return response, cursor, nil
|
||||
}
|
||||
|
||||
func (c *Connection) processWaitResponse(q Query, response *Response) (*Response, *Cursor, error) {
|
||||
c.mu.Lock()
|
||||
delete(c.cursors, response.Token)
|
||||
c.mu.Unlock()
|
||||
|
||||
return response, nil, nil
|
||||
}
|
||||
|
||||
func (c *Connection) isBad() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
return c.bad
|
||||
}
|
||||
|
||||
var responseCache = make(chan *Response, 16)
|
||||
|
||||
func newCachedResponse() *Response {
|
||||
select {
|
||||
case r := <-responseCache:
|
||||
return r
|
||||
default:
|
||||
return new(Response)
|
||||
}
|
||||
}
|
||||
|
||||
func putResponse(r *Response) {
|
||||
*r = Response{} // zero it
|
||||
select {
|
||||
case responseCache <- r:
|
||||
default:
|
||||
}
|
||||
}
|
|
@ -0,0 +1,450 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
type HandshakeVersion int
|
||||
|
||||
const (
|
||||
HandshakeV1_0 HandshakeVersion = iota
|
||||
HandshakeV0_4
|
||||
)
|
||||
|
||||
type connectionHandshake interface {
|
||||
Send() error
|
||||
}
|
||||
|
||||
func (c *Connection) handshake(version HandshakeVersion) (connectionHandshake, error) {
|
||||
switch version {
|
||||
case HandshakeV0_4:
|
||||
return &connectionHandshakeV0_4{conn: c}, nil
|
||||
case HandshakeV1_0:
|
||||
return &connectionHandshakeV1_0{conn: c}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("Unrecognised handshake version")
|
||||
}
|
||||
}
|
||||
|
||||
type connectionHandshakeV0_4 struct {
|
||||
conn *Connection
|
||||
}
|
||||
|
||||
func (c *connectionHandshakeV0_4) Send() error {
|
||||
// Send handshake request
|
||||
if err := c.writeHandshakeReq(); err != nil {
|
||||
c.conn.Close()
|
||||
return RQLConnectionError{rqlError(err.Error())}
|
||||
}
|
||||
// Read handshake response
|
||||
if err := c.readHandshakeSuccess(); err != nil {
|
||||
c.conn.Close()
|
||||
return RQLConnectionError{rqlError(err.Error())}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *connectionHandshakeV0_4) writeHandshakeReq() error {
|
||||
pos := 0
|
||||
dataLen := 4 + 4 + len(c.conn.opts.AuthKey) + 4
|
||||
data := make([]byte, dataLen)
|
||||
|
||||
// Send the protocol version to the server as a 4-byte little-endian-encoded integer
|
||||
binary.LittleEndian.PutUint32(data[pos:], uint32(p.VersionDummy_V0_4))
|
||||
pos += 4
|
||||
|
||||
// Send the length of the auth key to the server as a 4-byte little-endian-encoded integer
|
||||
binary.LittleEndian.PutUint32(data[pos:], uint32(len(c.conn.opts.AuthKey)))
|
||||
pos += 4
|
||||
|
||||
// Send the auth key as an ASCII string
|
||||
if len(c.conn.opts.AuthKey) > 0 {
|
||||
pos += copy(data[pos:], c.conn.opts.AuthKey)
|
||||
}
|
||||
|
||||
// Send the protocol type as a 4-byte little-endian-encoded integer
|
||||
binary.LittleEndian.PutUint32(data[pos:], uint32(p.VersionDummy_JSON))
|
||||
pos += 4
|
||||
|
||||
return c.conn.writeData(data)
|
||||
}
|
||||
|
||||
func (c *connectionHandshakeV0_4) readHandshakeSuccess() error {
|
||||
reader := bufio.NewReader(c.conn.Conn)
|
||||
line, err := reader.ReadBytes('\x00')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return fmt.Errorf("Unexpected EOF: %s", string(line))
|
||||
}
|
||||
return err
|
||||
}
|
||||
// convert to string and remove trailing NUL byte
|
||||
response := string(line[:len(line)-1])
|
||||
if response != "SUCCESS" {
|
||||
response = strings.TrimSpace(response)
|
||||
// we failed authorization or something else terrible happened
|
||||
return RQLDriverError{rqlError(fmt.Sprintf("Server dropped connection with message: \"%s\"", response))}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
handshakeV1_0_protocolVersionNumber = 0
|
||||
handshakeV1_0_authenticationMethod = "SCRAM-SHA-256"
|
||||
)
|
||||
|
||||
type connectionHandshakeV1_0 struct {
|
||||
conn *Connection
|
||||
reader *bufio.Reader
|
||||
|
||||
authMsg string
|
||||
}
|
||||
|
||||
func (c *connectionHandshakeV1_0) Send() error {
|
||||
c.reader = bufio.NewReader(c.conn.Conn)
|
||||
|
||||
// Generate client nonce
|
||||
clientNonce, err := c.generateNonce()
|
||||
if err != nil {
|
||||
c.conn.Close()
|
||||
return RQLDriverError{rqlError(fmt.Sprintf("Failed to generate client nonce: %s", err))}
|
||||
}
|
||||
// Send client first message
|
||||
if err := c.writeFirstMessage(clientNonce); err != nil {
|
||||
c.conn.Close()
|
||||
return err
|
||||
}
|
||||
// Read status
|
||||
if err := c.checkServerVersions(); err != nil {
|
||||
c.conn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// Read server first message
|
||||
i, salt, serverNonce, err := c.readFirstMessage()
|
||||
if err != nil {
|
||||
c.conn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// Check server nonce
|
||||
if !strings.HasPrefix(serverNonce, clientNonce) {
|
||||
return RQLAuthError{RQLDriverError{rqlError("Invalid nonce from server")}}
|
||||
}
|
||||
|
||||
// Generate proof
|
||||
saltedPass := c.saltPassword(i, salt)
|
||||
clientProof := c.calculateProof(saltedPass, clientNonce, serverNonce)
|
||||
serverSignature := c.serverSignature(saltedPass)
|
||||
|
||||
// Send client final message
|
||||
if err := c.writeFinalMessage(serverNonce, clientProof); err != nil {
|
||||
c.conn.Close()
|
||||
return err
|
||||
}
|
||||
// Read server final message
|
||||
if err := c.readFinalMessage(serverSignature); err != nil {
|
||||
c.conn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *connectionHandshakeV1_0) writeFirstMessage(clientNonce string) error {
|
||||
// Default username to admin if not set
|
||||
username := "admin"
|
||||
if c.conn.opts.Username != "" {
|
||||
username = c.conn.opts.Username
|
||||
}
|
||||
|
||||
c.authMsg = fmt.Sprintf("n=%s,r=%s", username, clientNonce)
|
||||
msg := fmt.Sprintf(
|
||||
`{"protocol_version": %d,"authentication": "n,,%s","authentication_method": "%s"}`,
|
||||
handshakeV1_0_protocolVersionNumber, c.authMsg, handshakeV1_0_authenticationMethod,
|
||||
)
|
||||
|
||||
pos := 0
|
||||
dataLen := 4 + len(msg) + 1
|
||||
data := make([]byte, dataLen)
|
||||
|
||||
// Send the protocol version to the server as a 4-byte little-endian-encoded integer
|
||||
binary.LittleEndian.PutUint32(data[pos:], uint32(p.VersionDummy_V1_0))
|
||||
pos += 4
|
||||
|
||||
// Send the auth message as an ASCII string
|
||||
pos += copy(data[pos:], msg)
|
||||
|
||||
// Add null terminating byte
|
||||
data[pos] = '\x00'
|
||||
|
||||
return c.writeData(data)
|
||||
}
|
||||
|
||||
func (c *connectionHandshakeV1_0) checkServerVersions() error {
|
||||
b, err := c.readResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read status
|
||||
type versionsResponse struct {
|
||||
Success bool `json:"success"`
|
||||
MinProtocolVersion int `json:"min_protocol_version"`
|
||||
MaxProtocolVersion int `json:"max_protocol_version"`
|
||||
ServerVersion string `json:"server_version"`
|
||||
ErrorCode int `json:"error_code"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
var rsp *versionsResponse
|
||||
statusStr := string(b)
|
||||
|
||||
if err := json.Unmarshal(b, &rsp); err != nil {
|
||||
if strings.HasPrefix(statusStr, "ERROR: ") {
|
||||
statusStr = strings.TrimPrefix(statusStr, "ERROR: ")
|
||||
return RQLConnectionError{rqlError(statusStr)}
|
||||
}
|
||||
|
||||
return RQLDriverError{rqlError(fmt.Sprintf("Error reading versions: %s", err))}
|
||||
}
|
||||
|
||||
if !rsp.Success {
|
||||
return c.handshakeError(rsp.ErrorCode, rsp.Error)
|
||||
}
|
||||
if rsp.MinProtocolVersion > handshakeV1_0_protocolVersionNumber ||
|
||||
rsp.MaxProtocolVersion < handshakeV1_0_protocolVersionNumber {
|
||||
return RQLDriverError{rqlError(
|
||||
fmt.Sprintf(
|
||||
"Unsupported protocol version %d, expected between %d and %d.",
|
||||
handshakeV1_0_protocolVersionNumber,
|
||||
rsp.MinProtocolVersion,
|
||||
rsp.MaxProtocolVersion,
|
||||
),
|
||||
)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *connectionHandshakeV1_0) readFirstMessage() (i int64, salt []byte, serverNonce string, err error) {
|
||||
b, err2 := c.readResponse()
|
||||
if err2 != nil {
|
||||
err = err2
|
||||
return
|
||||
}
|
||||
|
||||
// Read server message
|
||||
type firstMessageResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Authentication string `json:"authentication"`
|
||||
ErrorCode int `json:"error_code"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
var rsp *firstMessageResponse
|
||||
|
||||
if err2 := json.Unmarshal(b, &rsp); err2 != nil {
|
||||
err = RQLDriverError{rqlError(fmt.Sprintf("Error parsing auth response: %s", err2))}
|
||||
return
|
||||
}
|
||||
if !rsp.Success {
|
||||
err = c.handshakeError(rsp.ErrorCode, rsp.Error)
|
||||
return
|
||||
}
|
||||
|
||||
c.authMsg += ","
|
||||
c.authMsg += rsp.Authentication
|
||||
|
||||
// Parse authentication field
|
||||
auth := map[string]string{}
|
||||
parts := strings.Split(rsp.Authentication, ",")
|
||||
for _, part := range parts {
|
||||
i := strings.Index(part, "=")
|
||||
if i != -1 {
|
||||
auth[part[:i]] = part[i+1:]
|
||||
}
|
||||
}
|
||||
|
||||
// Extract return values
|
||||
if v, ok := auth["i"]; ok {
|
||||
i, err = strconv.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if v, ok := auth["s"]; ok {
|
||||
salt, err = base64.StdEncoding.DecodeString(v)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if v, ok := auth["r"]; ok {
|
||||
serverNonce = v
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *connectionHandshakeV1_0) writeFinalMessage(serverNonce, clientProof string) error {
|
||||
authMsg := "c=biws,r="
|
||||
authMsg += serverNonce
|
||||
authMsg += ",p="
|
||||
authMsg += clientProof
|
||||
|
||||
msg := fmt.Sprintf(`{"authentication": "%s"}`, authMsg)
|
||||
|
||||
pos := 0
|
||||
dataLen := len(msg) + 1
|
||||
data := make([]byte, dataLen)
|
||||
|
||||
// Send the auth message as an ASCII string
|
||||
pos += copy(data[pos:], msg)
|
||||
|
||||
// Add null terminating byte
|
||||
data[pos] = '\x00'
|
||||
|
||||
return c.writeData(data)
|
||||
}
|
||||
|
||||
func (c *connectionHandshakeV1_0) readFinalMessage(serverSignature string) error {
|
||||
b, err := c.readResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read server message
|
||||
type finalMessageResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Authentication string `json:"authentication"`
|
||||
ErrorCode int `json:"error_code"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
var rsp *finalMessageResponse
|
||||
|
||||
if err := json.Unmarshal(b, &rsp); err != nil {
|
||||
return RQLDriverError{rqlError(fmt.Sprintf("Error parsing auth response: %s", err))}
|
||||
}
|
||||
if !rsp.Success {
|
||||
return c.handshakeError(rsp.ErrorCode, rsp.Error)
|
||||
}
|
||||
|
||||
// Parse authentication field
|
||||
auth := map[string]string{}
|
||||
parts := strings.Split(rsp.Authentication, ",")
|
||||
for _, part := range parts {
|
||||
i := strings.Index(part, "=")
|
||||
if i != -1 {
|
||||
auth[part[:i]] = part[i+1:]
|
||||
}
|
||||
}
|
||||
|
||||
// Validate server response
|
||||
if serverSignature != auth["v"] {
|
||||
return RQLAuthError{RQLDriverError{rqlError("Invalid server signature")}}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *connectionHandshakeV1_0) writeData(data []byte) error {
|
||||
|
||||
if err := c.conn.writeData(data); err != nil {
|
||||
return RQLConnectionError{rqlError(err.Error())}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *connectionHandshakeV1_0) readResponse() ([]byte, error) {
|
||||
line, err := c.reader.ReadBytes('\x00')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil, RQLConnectionError{rqlError(fmt.Sprintf("Unexpected EOF: %s", string(line)))}
|
||||
}
|
||||
return nil, RQLConnectionError{rqlError(err.Error())}
|
||||
}
|
||||
|
||||
// Strip null byte and return
|
||||
return line[:len(line)-1], nil
|
||||
}
|
||||
|
||||
func (c *connectionHandshakeV1_0) generateNonce() (string, error) {
|
||||
const nonceSize = 24
|
||||
|
||||
b := make([]byte, nonceSize)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func (c *connectionHandshakeV1_0) saltPassword(iter int64, salt []byte) []byte {
|
||||
pass := []byte(c.conn.opts.Password)
|
||||
|
||||
return pbkdf2.Key(pass, salt, int(iter), sha256.Size, sha256.New)
|
||||
}
|
||||
|
||||
func (c *connectionHandshakeV1_0) calculateProof(saltedPass []byte, clientNonce, serverNonce string) string {
|
||||
// Generate proof
|
||||
c.authMsg += ",c=biws,r=" + serverNonce
|
||||
|
||||
mac := hmac.New(c.hashFunc(), saltedPass)
|
||||
mac.Write([]byte("Client Key"))
|
||||
clientKey := mac.Sum(nil)
|
||||
|
||||
hash := c.hashFunc()()
|
||||
hash.Write(clientKey)
|
||||
storedKey := hash.Sum(nil)
|
||||
|
||||
mac = hmac.New(c.hashFunc(), storedKey)
|
||||
mac.Write([]byte(c.authMsg))
|
||||
clientSignature := mac.Sum(nil)
|
||||
clientProof := make([]byte, len(clientKey))
|
||||
for i, _ := range clientKey {
|
||||
clientProof[i] = clientKey[i] ^ clientSignature[i]
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(clientProof)
|
||||
}
|
||||
|
||||
func (c *connectionHandshakeV1_0) serverSignature(saltedPass []byte) string {
|
||||
mac := hmac.New(c.hashFunc(), saltedPass)
|
||||
mac.Write([]byte("Server Key"))
|
||||
serverKey := mac.Sum(nil)
|
||||
|
||||
mac = hmac.New(c.hashFunc(), serverKey)
|
||||
mac.Write([]byte(c.authMsg))
|
||||
serverSignature := mac.Sum(nil)
|
||||
|
||||
return base64.StdEncoding.EncodeToString(serverSignature)
|
||||
}
|
||||
|
||||
func (c *connectionHandshakeV1_0) handshakeError(code int, message string) error {
|
||||
if code >= 10 || code <= 20 {
|
||||
return RQLAuthError{RQLDriverError{rqlError(message)}}
|
||||
}
|
||||
|
||||
return RQLDriverError{rqlError(message)}
|
||||
}
|
||||
|
||||
func (c *connectionHandshakeV1_0) hashFunc() func() hash.Hash {
|
||||
return sha256.New
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
package gorethink
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
// Write 'data' to conn
|
||||
func (c *Connection) writeData(data []byte) error {
|
||||
_, err := c.Conn.Write(data[:])
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Connection) read(buf []byte, length int) (total int, err error) {
|
||||
var n int
|
||||
for total < length {
|
||||
if n, err = c.Conn.Read(buf[total:length]); err != nil {
|
||||
break
|
||||
}
|
||||
total += n
|
||||
}
|
||||
|
||||
return total, err
|
||||
}
|
||||
|
||||
func (c *Connection) writeQuery(token int64, q []byte) error {
|
||||
pos := 0
|
||||
dataLen := 8 + 4 + len(q)
|
||||
data := make([]byte, dataLen)
|
||||
|
||||
// Send the protocol version to the server as a 4-byte little-endian-encoded integer
|
||||
binary.LittleEndian.PutUint64(data[pos:], uint64(token))
|
||||
pos += 8
|
||||
|
||||
// Send the length of the auth key to the server as a 4-byte little-endian-encoded integer
|
||||
binary.LittleEndian.PutUint32(data[pos:], uint32(len(q)))
|
||||
pos += 4
|
||||
|
||||
// Send the auth key as an ASCII string
|
||||
pos += copy(data[pos:], q)
|
||||
|
||||
return c.writeData(data)
|
||||
}
|
|
@ -0,0 +1,710 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"gopkg.in/gorethink/gorethink.v2/encoding"
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
var (
|
||||
errNilCursor = errors.New("cursor is nil")
|
||||
errCursorClosed = errors.New("connection closed, cannot read cursor")
|
||||
)
|
||||
|
||||
func newCursor(conn *Connection, cursorType string, token int64, term *Term, opts map[string]interface{}) *Cursor {
|
||||
if cursorType == "" {
|
||||
cursorType = "Cursor"
|
||||
}
|
||||
|
||||
connOpts := &ConnectOpts{}
|
||||
if conn != nil {
|
||||
connOpts = conn.opts
|
||||
}
|
||||
|
||||
cursor := &Cursor{
|
||||
conn: conn,
|
||||
connOpts: connOpts,
|
||||
token: token,
|
||||
cursorType: cursorType,
|
||||
term: term,
|
||||
opts: opts,
|
||||
buffer: make([]interface{}, 0),
|
||||
responses: make([]json.RawMessage, 0),
|
||||
}
|
||||
|
||||
return cursor
|
||||
}
|
||||
|
||||
// Cursor is the result of a query. Its cursor starts before the first row
|
||||
// of the result set. A Cursor is not thread safe and should only be accessed
|
||||
// by a single goroutine at any given time. Use Next to advance through the
|
||||
// rows:
|
||||
//
|
||||
// cursor, err := query.Run(session)
|
||||
// ...
|
||||
// defer cursor.Close()
|
||||
//
|
||||
// var response interface{}
|
||||
// for cursor.Next(&response) {
|
||||
// ...
|
||||
// }
|
||||
// err = cursor.Err() // get any error encountered during iteration
|
||||
// ...
|
||||
type Cursor struct {
|
||||
releaseConn func() error
|
||||
|
||||
conn *Connection
|
||||
connOpts *ConnectOpts
|
||||
token int64
|
||||
cursorType string
|
||||
term *Term
|
||||
opts map[string]interface{}
|
||||
|
||||
mu sync.RWMutex
|
||||
lastErr error
|
||||
fetching bool
|
||||
closed bool
|
||||
finished bool
|
||||
isAtom bool
|
||||
isSingleValue bool
|
||||
pendingSkips int
|
||||
buffer []interface{}
|
||||
responses []json.RawMessage
|
||||
profile interface{}
|
||||
}
|
||||
|
||||
// Profile returns the information returned from the query profiler.
|
||||
func (c *Cursor) Profile() interface{} {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
return c.profile
|
||||
}
|
||||
|
||||
// Type returns the cursor type (by default "Cursor")
|
||||
func (c *Cursor) Type() string {
|
||||
if c == nil {
|
||||
return "Cursor"
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
return c.cursorType
|
||||
}
|
||||
|
||||
// Err returns nil if no errors happened during iteration, or the actual
|
||||
// error otherwise.
|
||||
func (c *Cursor) Err() error {
|
||||
if c == nil {
|
||||
return errNilCursor
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
return c.lastErr
|
||||
}
|
||||
|
||||
// Close closes the cursor, preventing further enumeration. If the end is
|
||||
// encountered, the cursor is closed automatically. Close is idempotent.
|
||||
func (c *Cursor) Close() error {
|
||||
if c == nil {
|
||||
return errNilCursor
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
var err error
|
||||
|
||||
// If cursor is already closed return immediately
|
||||
closed := c.closed
|
||||
if closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get connection and check its valid, don't need to lock as this is only
|
||||
// set when the cursor is created
|
||||
conn := c.conn
|
||||
if conn == nil {
|
||||
return nil
|
||||
}
|
||||
if conn.Conn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop any unfinished queries
|
||||
if !c.finished {
|
||||
q := Query{
|
||||
Type: p.Query_STOP,
|
||||
Token: c.token,
|
||||
Opts: map[string]interface{}{
|
||||
"noreply": true,
|
||||
},
|
||||
}
|
||||
|
||||
_, _, err = conn.Query(q)
|
||||
}
|
||||
|
||||
if c.releaseConn != nil {
|
||||
if err := c.releaseConn(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
c.closed = true
|
||||
c.conn = nil
|
||||
c.buffer = nil
|
||||
c.responses = nil
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Next retrieves the next document from the result set, blocking if necessary.
|
||||
// This method will also automatically retrieve another batch of documents from
|
||||
// the server when the current one is exhausted, or before that in background
|
||||
// if possible.
|
||||
//
|
||||
// Next returns true if a document was successfully unmarshalled onto result,
|
||||
// and false at the end of the result set or if an error happened.
|
||||
// When Next returns false, the Err method should be called to verify if
|
||||
// there was an error during iteration.
|
||||
//
|
||||
// Also note that you are able to reuse the same variable multiple times as
|
||||
// `Next` zeroes the value before scanning in the result.
|
||||
func (c *Cursor) Next(dest interface{}) bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
hasMore, err := c.nextLocked(dest, true)
|
||||
if c.handleErrorLocked(err) != nil {
|
||||
c.mu.Unlock()
|
||||
c.Close()
|
||||
return false
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if !hasMore {
|
||||
c.Close()
|
||||
}
|
||||
|
||||
return hasMore
|
||||
}
|
||||
|
||||
func (c *Cursor) nextLocked(dest interface{}, progressCursor bool) (bool, error) {
|
||||
for {
|
||||
if err := c.seekCursor(true); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if c.closed {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if len(c.buffer) == 0 && c.finished {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if len(c.buffer) > 0 {
|
||||
data := c.buffer[0]
|
||||
if progressCursor {
|
||||
c.buffer = c.buffer[1:]
|
||||
}
|
||||
|
||||
err := encoding.Decode(dest, data)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Peek behaves similarly to Next, retreiving the next document from the result set
|
||||
// and blocking if necessary. Peek, however, does not progress the position of the cursor.
|
||||
// This can be useful for expressions which can return different types to attempt to
|
||||
// decode them into different interfaces.
|
||||
//
|
||||
// Like Next, it will also automatically retrieve another batch of documents from
|
||||
// the server when the current one is exhausted, or before that in background
|
||||
// if possible.
|
||||
//
|
||||
// Unlike Next, Peek does not progress the position of the cursor. Peek
|
||||
// will return errors from decoding, but they will not be persisted in the cursor
|
||||
// and therefore will not be available on cursor.Err(). This can be useful for
|
||||
// expressions that can return different types to attempt to decode them into
|
||||
// different interfaces.
|
||||
//
|
||||
// Peek returns true if a document was successfully unmarshalled onto result,
|
||||
// and false at the end of the result set or if an error happened. Peek also
|
||||
// returns the error (if any) that occured
|
||||
func (c *Cursor) Peek(dest interface{}) (bool, error) {
|
||||
if c == nil {
|
||||
return false, errNilCursor
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return false, nil
|
||||
}
|
||||
|
||||
hasMore, err := c.nextLocked(dest, false)
|
||||
if _, isDecodeErr := err.(*encoding.DecodeTypeError); isDecodeErr {
|
||||
c.mu.Unlock()
|
||||
return false, err
|
||||
}
|
||||
|
||||
if c.handleErrorLocked(err) != nil {
|
||||
c.mu.Unlock()
|
||||
c.Close()
|
||||
return false, err
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
return hasMore, nil
|
||||
}
|
||||
|
||||
// Skip progresses the cursor by one record. It is useful after a successful
|
||||
// Peek to avoid duplicate decoding work.
|
||||
func (c *Cursor) Skip() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.pendingSkips++
|
||||
}
|
||||
|
||||
// NextResponse retrieves the next raw response from the result set, blocking if necessary.
|
||||
// Unlike Next the returned response is the raw JSON document returned from the
|
||||
// database.
|
||||
//
|
||||
// NextResponse returns false (and a nil byte slice) at the end of the result
|
||||
// set or if an error happened.
|
||||
func (c *Cursor) NextResponse() ([]byte, bool) {
|
||||
if c == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
b, hasMore, err := c.nextResponseLocked()
|
||||
if c.handleErrorLocked(err) != nil {
|
||||
c.mu.Unlock()
|
||||
c.Close()
|
||||
return nil, false
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if !hasMore {
|
||||
c.Close()
|
||||
}
|
||||
|
||||
return b, hasMore
|
||||
}
|
||||
|
||||
func (c *Cursor) nextResponseLocked() ([]byte, bool, error) {
|
||||
for {
|
||||
if err := c.seekCursor(false); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if len(c.responses) == 0 && c.finished {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
if len(c.responses) > 0 {
|
||||
var response json.RawMessage
|
||||
response, c.responses = c.responses[0], c.responses[1:]
|
||||
|
||||
return []byte(response), true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// All retrieves all documents from the result set into the provided slice
|
||||
// and closes the cursor.
|
||||
//
|
||||
// The result argument must necessarily be the address for a slice. The slice
|
||||
// may be nil or previously allocated.
|
||||
//
|
||||
// Also note that you are able to reuse the same variable multiple times as
|
||||
// `All` zeroes the value before scanning in the result. It also attempts
|
||||
// to reuse the existing slice without allocating any more space by either
|
||||
// resizing or returning a selection of the slice if necessary.
|
||||
func (c *Cursor) All(result interface{}) error {
|
||||
if c == nil {
|
||||
return errNilCursor
|
||||
}
|
||||
|
||||
resultv := reflect.ValueOf(result)
|
||||
if resultv.Kind() != reflect.Ptr || resultv.Elem().Kind() != reflect.Slice {
|
||||
panic("result argument must be a slice address")
|
||||
}
|
||||
slicev := resultv.Elem()
|
||||
slicev = slicev.Slice(0, slicev.Cap())
|
||||
elemt := slicev.Type().Elem()
|
||||
i := 0
|
||||
for {
|
||||
if slicev.Len() == i {
|
||||
elemp := reflect.New(elemt)
|
||||
if !c.Next(elemp.Interface()) {
|
||||
break
|
||||
}
|
||||
slicev = reflect.Append(slicev, elemp.Elem())
|
||||
slicev = slicev.Slice(0, slicev.Cap())
|
||||
} else {
|
||||
if !c.Next(slicev.Index(i).Addr().Interface()) {
|
||||
break
|
||||
}
|
||||
}
|
||||
i++
|
||||
}
|
||||
resultv.Elem().Set(slicev.Slice(0, i))
|
||||
|
||||
if err := c.Err(); err != nil {
|
||||
c.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// One retrieves a single document from the result set into the provided
|
||||
// slice and closes the cursor.
|
||||
//
|
||||
// Also note that you are able to reuse the same variable multiple times as
|
||||
// `One` zeroes the value before scanning in the result.
|
||||
func (c *Cursor) One(result interface{}) error {
|
||||
if c == nil {
|
||||
return errNilCursor
|
||||
}
|
||||
|
||||
if c.IsNil() {
|
||||
c.Close()
|
||||
return ErrEmptyResult
|
||||
}
|
||||
|
||||
hasResult := c.Next(result)
|
||||
|
||||
if err := c.Err(); err != nil {
|
||||
c.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !hasResult {
|
||||
return ErrEmptyResult
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Interface retrieves all documents from the result set and returns the data
|
||||
// as an interface{} and closes the cursor.
|
||||
//
|
||||
// If the query returns multiple documents then a slice will be returned,
|
||||
// otherwise a single value will be returned.
|
||||
func (c *Cursor) Interface() (interface{}, error) {
|
||||
if c == nil {
|
||||
return nil, errNilCursor
|
||||
}
|
||||
|
||||
var results []interface{}
|
||||
var result interface{}
|
||||
for c.Next(&result) {
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
if err := c.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
isSingleValue := c.isSingleValue
|
||||
c.mu.RUnlock()
|
||||
|
||||
if isSingleValue {
|
||||
if len(results) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return results[0], nil
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Listen listens for rows from the database and sends the result onto the given
|
||||
// channel. The type that the row is scanned into is determined by the element
|
||||
// type of the channel.
|
||||
//
|
||||
// Also note that this function returns immediately.
|
||||
//
|
||||
// cursor, err := r.Expr([]int{1,2,3}).Run(session)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
//
|
||||
// ch := make(chan int)
|
||||
// cursor.Listen(ch)
|
||||
// <- ch // 1
|
||||
// <- ch // 2
|
||||
// <- ch // 3
|
||||
func (c *Cursor) Listen(channel interface{}) {
|
||||
go func() {
|
||||
channelv := reflect.ValueOf(channel)
|
||||
if channelv.Kind() != reflect.Chan {
|
||||
panic("input argument must be a channel")
|
||||
}
|
||||
elemt := channelv.Type().Elem()
|
||||
for {
|
||||
elemp := reflect.New(elemt)
|
||||
if !c.Next(elemp.Interface()) {
|
||||
break
|
||||
}
|
||||
|
||||
channelv.Send(elemp.Elem())
|
||||
}
|
||||
|
||||
c.Close()
|
||||
channelv.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
// IsNil tests if the current row is nil.
|
||||
func (c *Cursor) IsNil() bool {
|
||||
if c == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if len(c.buffer) > 0 {
|
||||
return c.buffer[0] == nil
|
||||
}
|
||||
|
||||
if len(c.responses) > 0 {
|
||||
response := c.responses[0]
|
||||
if response == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if string(response) == "null" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// fetchMore fetches more rows from the database.
|
||||
//
|
||||
// If wait is true then it will wait for the database to reply otherwise it
|
||||
// will return after sending the continue query.
|
||||
func (c *Cursor) fetchMore() error {
|
||||
var err error
|
||||
|
||||
if !c.fetching {
|
||||
c.fetching = true
|
||||
|
||||
if c.closed {
|
||||
return errCursorClosed
|
||||
}
|
||||
|
||||
q := Query{
|
||||
Type: p.Query_CONTINUE,
|
||||
Token: c.token,
|
||||
}
|
||||
|
||||
c.mu.Unlock()
|
||||
_, _, err = c.conn.Query(q)
|
||||
c.mu.Lock()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// handleError sets the value of lastErr to err if lastErr is not yet set.
|
||||
func (c *Cursor) handleError(err error) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
return c.handleErrorLocked(err)
|
||||
}
|
||||
|
||||
func (c *Cursor) handleErrorLocked(err error) error {
|
||||
if c.lastErr == nil {
|
||||
c.lastErr = err
|
||||
}
|
||||
|
||||
return c.lastErr
|
||||
}
|
||||
|
||||
// extend adds the result of a continue query to the cursor.
|
||||
func (c *Cursor) extend(response *Response) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.extendLocked(response)
|
||||
}
|
||||
|
||||
func (c *Cursor) extendLocked(response *Response) {
|
||||
c.responses = append(c.responses, response.Responses...)
|
||||
c.finished = response.Type != p.Response_SUCCESS_PARTIAL
|
||||
c.fetching = false
|
||||
c.isAtom = response.Type == p.Response_SUCCESS_ATOM
|
||||
|
||||
putResponse(response)
|
||||
}
|
||||
|
||||
// seekCursor takes care of loading more data if needed and applying pending skips
|
||||
//
|
||||
// bufferResponse determines whether the response will be parsed into the buffer
|
||||
func (c *Cursor) seekCursor(bufferResponse bool) error {
|
||||
if c.lastErr != nil {
|
||||
return c.lastErr
|
||||
}
|
||||
|
||||
if len(c.buffer) == 0 && len(c.responses) == 0 && c.closed {
|
||||
return errCursorClosed
|
||||
}
|
||||
|
||||
// Loop over loading data, applying skips as necessary and loading more data as needed
|
||||
// until either the cursor is closed or finished, or we have applied all outstanding
|
||||
// skips and data is available
|
||||
for {
|
||||
c.applyPendingSkips(bufferResponse) // if we are buffering the responses, skip can drain from the buffer
|
||||
|
||||
if bufferResponse && len(c.buffer) == 0 && len(c.responses) > 0 {
|
||||
if err := c.bufferNextResponse(); err != nil {
|
||||
return err
|
||||
}
|
||||
continue // go around the loop again to re-apply pending skips
|
||||
} else if len(c.buffer) == 0 && len(c.responses) == 0 && !c.finished {
|
||||
// We skipped all of our data, load some more
|
||||
if err := c.fetchMore(); err != nil {
|
||||
return err
|
||||
}
|
||||
if c.closed {
|
||||
return nil
|
||||
}
|
||||
continue // go around the loop again to re-apply pending skips
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// applyPendingSkips applies all pending skips to the buffer and
|
||||
// returns whether there are more pending skips to be applied
|
||||
//
|
||||
// if drainFromBuffer is true, we will drain from the buffer, otherwise
|
||||
// we drain from the responses
|
||||
func (c *Cursor) applyPendingSkips(drainFromBuffer bool) (stillPending bool) {
|
||||
if c.pendingSkips == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if drainFromBuffer {
|
||||
if len(c.buffer) > c.pendingSkips {
|
||||
c.buffer = c.buffer[c.pendingSkips:]
|
||||
c.pendingSkips = 0
|
||||
return false
|
||||
}
|
||||
|
||||
c.pendingSkips -= len(c.buffer)
|
||||
c.buffer = c.buffer[:0]
|
||||
return c.pendingSkips > 0
|
||||
}
|
||||
|
||||
if len(c.responses) > c.pendingSkips {
|
||||
c.responses = c.responses[c.pendingSkips:]
|
||||
c.pendingSkips = 0
|
||||
return false
|
||||
}
|
||||
|
||||
c.pendingSkips -= len(c.responses)
|
||||
c.responses = c.responses[:0]
|
||||
return c.pendingSkips > 0
|
||||
}
|
||||
|
||||
// bufferResponse reads a single response and stores the result into the buffer
|
||||
// if the response is from an atomic response, it will check if the
|
||||
// response contains multiple records and store them all into the buffer
|
||||
func (c *Cursor) bufferNextResponse() error {
|
||||
if c.closed {
|
||||
return errCursorClosed
|
||||
}
|
||||
// If there are no responses, nothing to do
|
||||
if len(c.responses) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
response := c.responses[0]
|
||||
c.responses = c.responses[1:]
|
||||
|
||||
var value interface{}
|
||||
decoder := json.NewDecoder(bytes.NewBuffer(response))
|
||||
if c.connOpts.UseJSONNumber {
|
||||
decoder.UseNumber()
|
||||
}
|
||||
err := decoder.Decode(&value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
value, err = recursivelyConvertPseudotype(value, c.opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If response is an ATOM then try and convert to an array
|
||||
if data, ok := value.([]interface{}); ok && c.isAtom {
|
||||
c.buffer = append(c.buffer, data...)
|
||||
} else if value == nil {
|
||||
c.buffer = append(c.buffer, nil)
|
||||
} else {
|
||||
c.buffer = append(c.buffer, value)
|
||||
|
||||
// If this is the only value in the response and the response was an
|
||||
// atom then set the single value flag
|
||||
if c.isAtom {
|
||||
c.isSingleValue = true
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
// Package gorethink implements a Go driver for RethinkDB
|
||||
//
|
||||
// Current version: v3.0.0 (RethinkDB v2.3)
|
||||
// For more in depth information on how to use RethinkDB check out the API docs
|
||||
// at http://rethinkdb.com/api
|
||||
package gorethink
|
|
@ -0,0 +1,182 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNoHosts is returned when no hosts to the Connect method.
|
||||
ErrNoHosts = errors.New("no hosts provided")
|
||||
// ErrNoConnectionsStarted is returned when the driver couldn't to any of
|
||||
// the provided hosts.
|
||||
ErrNoConnectionsStarted = errors.New("no connections were made when creating the session")
|
||||
// ErrInvalidNode is returned when attempting to connect to a node which
|
||||
// returns an invalid response.
|
||||
ErrInvalidNode = errors.New("invalid node")
|
||||
// ErrNoConnections is returned when there are no active connections in the
|
||||
// clusters connection pool.
|
||||
ErrNoConnections = errors.New("gorethink: no connections were available")
|
||||
// ErrConnectionClosed is returned when trying to send a query with a closed
|
||||
// connection.
|
||||
ErrConnectionClosed = errors.New("gorethink: the connection is closed")
|
||||
)
|
||||
|
||||
func printCarrots(t Term, frames []*p.Frame) string {
|
||||
var frame *p.Frame
|
||||
if len(frames) > 1 {
|
||||
frame, frames = frames[0], frames[1:]
|
||||
} else if len(frames) == 1 {
|
||||
frame, frames = frames[0], []*p.Frame{}
|
||||
}
|
||||
|
||||
for i, arg := range t.args {
|
||||
if frame.GetPos() == int64(i) {
|
||||
t.args[i] = Term{
|
||||
termType: p.Term_DATUM,
|
||||
data: printCarrots(arg, frames),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for k, arg := range t.optArgs {
|
||||
if frame.GetOpt() == k {
|
||||
t.optArgs[k] = Term{
|
||||
termType: p.Term_DATUM,
|
||||
data: printCarrots(arg, frames),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
for _, c := range t.String() {
|
||||
if c != '^' {
|
||||
b.WriteString(" ")
|
||||
} else {
|
||||
b.WriteString("^")
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Error constants
|
||||
var ErrEmptyResult = errors.New("The result does not contain any more rows")
|
||||
|
||||
// Connection/Response errors
|
||||
|
||||
// rqlResponseError is the base type for all errors, it formats both
|
||||
// for the response and query if set.
|
||||
type rqlServerError struct {
|
||||
response *Response
|
||||
term *Term
|
||||
}
|
||||
|
||||
func (e rqlServerError) Error() string {
|
||||
var err = "An error occurred"
|
||||
if e.response != nil {
|
||||
json.Unmarshal(e.response.Responses[0], &err)
|
||||
}
|
||||
|
||||
if e.term == nil {
|
||||
return fmt.Sprintf("gorethink: %s", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("gorethink: %s in:\n%s", err, e.term.String())
|
||||
|
||||
}
|
||||
|
||||
func (e rqlServerError) String() string {
|
||||
return e.Error()
|
||||
}
|
||||
|
||||
type rqlError string
|
||||
|
||||
func (e rqlError) Error() string {
|
||||
return fmt.Sprintf("gorethink: %s", string(e))
|
||||
}
|
||||
|
||||
func (e rqlError) String() string {
|
||||
return e.Error()
|
||||
}
|
||||
|
||||
// Exported Error "Implementations"
|
||||
|
||||
type RQLClientError struct{ rqlServerError }
|
||||
type RQLCompileError struct{ rqlServerError }
|
||||
type RQLDriverCompileError struct{ RQLCompileError }
|
||||
type RQLServerCompileError struct{ RQLCompileError }
|
||||
type RQLAuthError struct{ RQLDriverError }
|
||||
type RQLRuntimeError struct{ rqlServerError }
|
||||
|
||||
type RQLQueryLogicError struct{ RQLRuntimeError }
|
||||
type RQLNonExistenceError struct{ RQLQueryLogicError }
|
||||
type RQLResourceLimitError struct{ RQLRuntimeError }
|
||||
type RQLUserError struct{ RQLRuntimeError }
|
||||
type RQLInternalError struct{ RQLRuntimeError }
|
||||
type RQLTimeoutError struct{ rqlServerError }
|
||||
type RQLAvailabilityError struct{ RQLRuntimeError }
|
||||
type RQLOpFailedError struct{ RQLAvailabilityError }
|
||||
type RQLOpIndeterminateError struct{ RQLAvailabilityError }
|
||||
|
||||
// RQLDriverError represents an unexpected error with the driver, if this error
|
||||
// persists please create an issue.
|
||||
type RQLDriverError struct {
|
||||
rqlError
|
||||
}
|
||||
|
||||
// RQLConnectionError represents an error when communicating with the database
|
||||
// server.
|
||||
type RQLConnectionError struct {
|
||||
rqlError
|
||||
}
|
||||
|
||||
func createRuntimeError(errorType p.Response_ErrorType, response *Response, term *Term) error {
|
||||
serverErr := rqlServerError{response, term}
|
||||
|
||||
switch errorType {
|
||||
case p.Response_QUERY_LOGIC:
|
||||
return RQLQueryLogicError{RQLRuntimeError{serverErr}}
|
||||
case p.Response_NON_EXISTENCE:
|
||||
return RQLNonExistenceError{RQLQueryLogicError{RQLRuntimeError{serverErr}}}
|
||||
case p.Response_RESOURCE_LIMIT:
|
||||
return RQLResourceLimitError{RQLRuntimeError{serverErr}}
|
||||
case p.Response_USER:
|
||||
return RQLUserError{RQLRuntimeError{serverErr}}
|
||||
case p.Response_INTERNAL:
|
||||
return RQLInternalError{RQLRuntimeError{serverErr}}
|
||||
case p.Response_OP_FAILED:
|
||||
return RQLOpFailedError{RQLAvailabilityError{RQLRuntimeError{serverErr}}}
|
||||
case p.Response_OP_INDETERMINATE:
|
||||
return RQLOpIndeterminateError{RQLAvailabilityError{RQLRuntimeError{serverErr}}}
|
||||
default:
|
||||
return RQLRuntimeError{serverErr}
|
||||
}
|
||||
}
|
||||
|
||||
// Error type helpers
|
||||
|
||||
// IsConflictErr returns true if the error is non-nil and the query failed
|
||||
// due to a duplicate primary key.
|
||||
func IsConflictErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return strings.HasPrefix(err.Error(), "Duplicate primary key")
|
||||
}
|
||||
|
||||
// IsTypeErr returns true if the error is non-nil and the query failed due
|
||||
// to a type error.
|
||||
func IsTypeErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return strings.HasPrefix(err.Error(), "Expected type")
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"reflect"
|
||||
|
||||
"github.com/Sirupsen/logrus"
|
||||
|
||||
"gopkg.in/gorethink/gorethink.v2/encoding"
|
||||
)
|
||||
|
||||
var (
|
||||
Log *logrus.Logger
|
||||
)
|
||||
|
||||
const (
|
||||
SystemDatabase = "rethinkdb"
|
||||
|
||||
TableConfigSystemTable = "table_config"
|
||||
ServerConfigSystemTable = "server_config"
|
||||
DBConfigSystemTable = "db_config"
|
||||
ClusterConfigSystemTable = "cluster_config"
|
||||
TableStatusSystemTable = "table_status"
|
||||
ServerStatusSystemTable = "server_status"
|
||||
CurrentIssuesSystemTable = "current_issues"
|
||||
UsersSystemTable = "users"
|
||||
PermissionsSystemTable = "permissions"
|
||||
JobsSystemTable = "jobs"
|
||||
StatsSystemTable = "stats"
|
||||
LogsSystemTable = "logs"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Set encoding package
|
||||
encoding.IgnoreType(reflect.TypeOf(Term{}))
|
||||
|
||||
Log = logrus.New()
|
||||
Log.Out = ioutil.Discard // By default don't log anything
|
||||
}
|
||||
|
||||
// SetVerbose allows the driver logging level to be set. If true is passed then
|
||||
// the log level is set to Debug otherwise it defaults to Info.
|
||||
func SetVerbose(verbose bool) {
|
||||
if verbose {
|
||||
Log.Level = logrus.DebugLevel
|
||||
return
|
||||
}
|
||||
|
||||
Log.Level = logrus.InfoLevel
|
||||
}
|
||||
|
||||
// SetTags allows you to override the tags used when decoding or encoding
|
||||
// structs. The driver will check for the tags in the same order that they were
|
||||
// passed into this function. If no parameters are passed then the driver will
|
||||
// default to checking for the gorethink tag (the gorethink tag is always included)
|
||||
func SetTags(tags ...string) {
|
||||
encoding.Tags = append(tags, "gorethink")
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Host name and port of server
|
||||
type Host struct {
|
||||
Name string
|
||||
Port int
|
||||
}
|
||||
|
||||
// NewHost create a new Host
|
||||
func NewHost(name string, port int) Host {
|
||||
return Host{
|
||||
Name: name,
|
||||
Port: port,
|
||||
}
|
||||
}
|
||||
|
||||
// Returns host address (name:port)
|
||||
func (h Host) String() string {
|
||||
return fmt.Sprintf("%s:%d", h.Name, h.Port)
|
||||
}
|
|
@ -0,0 +1,394 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
// Mocking is based on the amazing package github.com/stretchr/testify
|
||||
|
||||
// testingT is an interface wrapper around *testing.T
|
||||
type testingT interface {
|
||||
Logf(format string, args ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
FailNow()
|
||||
}
|
||||
|
||||
// MockAnything can be used in place of any term, this is useful when you want
|
||||
// mock similar queries or queries that you don't quite know the exact structure
|
||||
// of.
|
||||
func MockAnything() Term {
|
||||
t := constructRootTerm("MockAnything", p.Term_DATUM, nil, nil)
|
||||
t.isMockAnything = true
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t Term) MockAnything() Term {
|
||||
t = constructMethodTerm(t, "MockAnything", p.Term_DATUM, nil, nil)
|
||||
t.isMockAnything = true
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// MockQuery represents a mocked query and is used for setting expectations,
|
||||
// as well as recording activity.
|
||||
type MockQuery struct {
|
||||
parent *Mock
|
||||
|
||||
// Holds the query and term
|
||||
Query Query
|
||||
|
||||
// Holds the JSON representation of query
|
||||
BuiltQuery []byte
|
||||
|
||||
// Holds the response that should be returned when this method is executed.
|
||||
Response interface{}
|
||||
|
||||
// Holds the error that should be returned when this method is executed.
|
||||
Error error
|
||||
|
||||
// The number of times to return the return arguments when setting
|
||||
// expectations. 0 means to always return the value.
|
||||
Repeatability int
|
||||
|
||||
// Holds a channel that will be used to block the Return until it either
|
||||
// recieves a message or is closed. nil means it returns immediately.
|
||||
WaitFor <-chan time.Time
|
||||
|
||||
// Amount of times this query has been executed
|
||||
executed int
|
||||
}
|
||||
|
||||
func newMockQuery(parent *Mock, q Query) *MockQuery {
|
||||
// Build and marshal term
|
||||
builtQuery, err := json.Marshal(q.Build())
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to build query: %s", err))
|
||||
}
|
||||
|
||||
return &MockQuery{
|
||||
parent: parent,
|
||||
Query: q,
|
||||
BuiltQuery: builtQuery,
|
||||
Response: make([]interface{}, 0),
|
||||
Repeatability: 0,
|
||||
WaitFor: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func newMockQueryFromTerm(parent *Mock, t Term, opts map[string]interface{}) *MockQuery {
|
||||
q, err := parent.newQuery(t, opts)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to build query: %s", err))
|
||||
}
|
||||
|
||||
return newMockQuery(parent, q)
|
||||
}
|
||||
|
||||
func (mq *MockQuery) lock() {
|
||||
mq.parent.mu.Lock()
|
||||
}
|
||||
|
||||
func (mq *MockQuery) unlock() {
|
||||
mq.parent.mu.Unlock()
|
||||
}
|
||||
|
||||
// Return specifies the return arguments for the expectation.
|
||||
//
|
||||
// mock.On(r.Table("test")).Return(nil, errors.New("failed"))
|
||||
func (mq *MockQuery) Return(response interface{}, err error) *MockQuery {
|
||||
mq.lock()
|
||||
defer mq.unlock()
|
||||
|
||||
mq.Response = response
|
||||
mq.Error = err
|
||||
|
||||
return mq
|
||||
}
|
||||
|
||||
// Once indicates that that the mock should only return the value once.
|
||||
//
|
||||
// mock.On(r.Table("test")).Return(result, nil).Once()
|
||||
func (mq *MockQuery) Once() *MockQuery {
|
||||
return mq.Times(1)
|
||||
}
|
||||
|
||||
// Twice indicates that that the mock should only return the value twice.
|
||||
//
|
||||
// mock.On(r.Table("test")).Return(result, nil).Twice()
|
||||
func (mq *MockQuery) Twice() *MockQuery {
|
||||
return mq.Times(2)
|
||||
}
|
||||
|
||||
// Times indicates that that the mock should only return the indicated number
|
||||
// of times.
|
||||
//
|
||||
// mock.On(r.Table("test")).Return(result, nil).Times(5)
|
||||
func (mq *MockQuery) Times(i int) *MockQuery {
|
||||
mq.lock()
|
||||
defer mq.unlock()
|
||||
mq.Repeatability = i
|
||||
return mq
|
||||
}
|
||||
|
||||
// WaitUntil sets the channel that will block the mock's return until its closed
|
||||
// or a message is received.
|
||||
//
|
||||
// mock.On(r.Table("test")).WaitUntil(time.After(time.Second))
|
||||
func (mq *MockQuery) WaitUntil(w <-chan time.Time) *MockQuery {
|
||||
mq.lock()
|
||||
defer mq.unlock()
|
||||
mq.WaitFor = w
|
||||
return mq
|
||||
}
|
||||
|
||||
// After sets how long to block until the query returns
|
||||
//
|
||||
// mock.On(r.Table("test")).After(time.Second)
|
||||
func (mq *MockQuery) After(d time.Duration) *MockQuery {
|
||||
return mq.WaitUntil(time.After(d))
|
||||
}
|
||||
|
||||
// On chains a new expectation description onto the mocked interface. This
|
||||
// allows syntax like.
|
||||
//
|
||||
// Mock.
|
||||
// On(r.Table("test")).Return(result, nil).
|
||||
// On(r.Table("test2")).Return(nil, errors.New("Some Error"))
|
||||
func (mq *MockQuery) On(t Term) *MockQuery {
|
||||
return mq.parent.On(t)
|
||||
}
|
||||
|
||||
// Mock is used to mock query execution and verify that the expected queries are
|
||||
// being executed. Mocks are used by creating an instance using NewMock and then
|
||||
// passing this when running your queries instead of a session. For example:
|
||||
//
|
||||
// mock := r.NewMock()
|
||||
// mock.On(r.Table("test")).Return([]interface{}{data}, nil)
|
||||
//
|
||||
// cursor, err := r.Table("test").Run(mock)
|
||||
//
|
||||
// mock.AssertExpectations(t)
|
||||
type Mock struct {
|
||||
mu sync.Mutex
|
||||
opts ConnectOpts
|
||||
|
||||
ExpectedQueries []*MockQuery
|
||||
Queries []MockQuery
|
||||
}
|
||||
|
||||
// NewMock creates an instance of Mock, you can optionally pass ConnectOpts to
|
||||
// the function, if passed any mocked query will be generated using those
|
||||
// options.
|
||||
func NewMock(opts ...ConnectOpts) *Mock {
|
||||
m := &Mock{
|
||||
ExpectedQueries: make([]*MockQuery, 0),
|
||||
Queries: make([]MockQuery, 0),
|
||||
}
|
||||
|
||||
if len(opts) > 0 {
|
||||
m.opts = opts[0]
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// On starts a description of an expectation of the specified query
|
||||
// being executed.
|
||||
//
|
||||
// mock.On(r.Table("test"))
|
||||
func (m *Mock) On(t Term, opts ...map[string]interface{}) *MockQuery {
|
||||
var qopts map[string]interface{}
|
||||
if len(opts) > 0 {
|
||||
qopts = opts[0]
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
mq := newMockQueryFromTerm(m, t, qopts)
|
||||
m.ExpectedQueries = append(m.ExpectedQueries, mq)
|
||||
return mq
|
||||
}
|
||||
|
||||
// AssertExpectations asserts that everything specified with On and Return was
|
||||
// in fact executed as expected. Queries may have been executed in any order.
|
||||
func (m *Mock) AssertExpectations(t testingT) bool {
|
||||
var somethingMissing bool
|
||||
var failedExpectations int
|
||||
|
||||
// iterate through each expectation
|
||||
expectedQueries := m.expectedQueries()
|
||||
for _, expectedQuery := range expectedQueries {
|
||||
if !m.queryWasExecuted(expectedQuery) && expectedQuery.executed == 0 {
|
||||
somethingMissing = true
|
||||
failedExpectations++
|
||||
t.Logf("❌\t%s", expectedQuery.Query.Term.String())
|
||||
} else {
|
||||
m.mu.Lock()
|
||||
if expectedQuery.Repeatability > 0 {
|
||||
somethingMissing = true
|
||||
failedExpectations++
|
||||
} else {
|
||||
t.Logf("✅\t%s", expectedQuery.Query.Term.String())
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
if somethingMissing {
|
||||
t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe query you are testing needs to be executed %d more times(s).", len(expectedQueries)-failedExpectations, len(expectedQueries), failedExpectations)
|
||||
}
|
||||
|
||||
return !somethingMissing
|
||||
}
|
||||
|
||||
// AssertNumberOfExecutions asserts that the query was executed expectedExecutions times.
|
||||
func (m *Mock) AssertNumberOfExecutions(t testingT, expectedQuery *MockQuery, expectedExecutions int) bool {
|
||||
var actualExecutions int
|
||||
for _, query := range m.queries() {
|
||||
if query.Query.Term.compare(*expectedQuery.Query.Term, map[int64]int64{}) && query.Repeatability > -1 {
|
||||
// if bytes.Equal(query.BuiltQuery, expectedQuery.BuiltQuery) {
|
||||
actualExecutions++
|
||||
}
|
||||
}
|
||||
|
||||
if expectedExecutions != actualExecutions {
|
||||
t.Errorf("Expected number of executions (%d) does not match the actual number of executions (%d).", expectedExecutions, actualExecutions)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AssertExecuted asserts that the method was executed.
|
||||
// It can produce a false result when an argument is a pointer type and the underlying value changed after executing the mocked method.
|
||||
func (m *Mock) AssertExecuted(t testingT, expectedQuery *MockQuery) bool {
|
||||
if !m.queryWasExecuted(expectedQuery) {
|
||||
t.Errorf("The query \"%s\" should have been executed, but was not.", expectedQuery.Query.Term.String())
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// AssertNotExecuted asserts that the method was not executed.
|
||||
// It can produce a false result when an argument is a pointer type and the underlying value changed after executing the mocked method.
|
||||
func (m *Mock) AssertNotExecuted(t testingT, expectedQuery *MockQuery) bool {
|
||||
if m.queryWasExecuted(expectedQuery) {
|
||||
t.Errorf("The query \"%s\" was executed, but should NOT have been.", expectedQuery.Query.Term.String())
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Mock) IsConnected() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Mock) Query(q Query) (*Cursor, error) {
|
||||
found, query := m.findExpectedQuery(q)
|
||||
|
||||
if found < 0 {
|
||||
panic(fmt.Sprintf("gorethink: mock: This query was unexpected:\n\t\t%s", q.Term.String()))
|
||||
} else {
|
||||
m.mu.Lock()
|
||||
switch {
|
||||
case query.Repeatability == 1:
|
||||
query.Repeatability = -1
|
||||
query.executed++
|
||||
|
||||
case query.Repeatability > 1:
|
||||
query.Repeatability--
|
||||
query.executed++
|
||||
|
||||
case query.Repeatability == 0:
|
||||
query.executed++
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// add the query
|
||||
m.mu.Lock()
|
||||
m.Queries = append(m.Queries, *newMockQuery(m, q))
|
||||
m.mu.Unlock()
|
||||
|
||||
// block if specified
|
||||
if query.WaitFor != nil {
|
||||
<-query.WaitFor
|
||||
}
|
||||
|
||||
// Return error without building cursor if non-nil
|
||||
if query.Error != nil {
|
||||
return nil, query.Error
|
||||
}
|
||||
|
||||
// Build cursor and return
|
||||
c := newCursor(nil, "", query.Query.Token, query.Query.Term, query.Query.Opts)
|
||||
c.finished = true
|
||||
c.fetching = false
|
||||
c.isAtom = true
|
||||
|
||||
responseVal := reflect.ValueOf(query.Response)
|
||||
if responseVal.Kind() == reflect.Slice || responseVal.Kind() == reflect.Array {
|
||||
for i := 0; i < responseVal.Len(); i++ {
|
||||
c.buffer = append(c.buffer, responseVal.Index(i).Interface())
|
||||
}
|
||||
} else {
|
||||
c.buffer = append(c.buffer, query.Response)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (m *Mock) Exec(q Query) error {
|
||||
_, err := m.Query(q)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *Mock) newQuery(t Term, opts map[string]interface{}) (Query, error) {
|
||||
return newQuery(t, opts, &m.opts)
|
||||
}
|
||||
|
||||
func (m *Mock) findExpectedQuery(q Query) (int, *MockQuery) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for i, query := range m.ExpectedQueries {
|
||||
// if bytes.Equal(query.BuiltQuery, builtQuery) && query.Repeatability > -1 {
|
||||
if query.Query.Term.compare(*q.Term, map[int64]int64{}) && query.Repeatability > -1 {
|
||||
return i, query
|
||||
}
|
||||
}
|
||||
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
func (m *Mock) queryWasExecuted(expectedQuery *MockQuery) bool {
|
||||
for _, query := range m.queries() {
|
||||
if query.Query.Term.compare(*expectedQuery.Query.Term, map[int64]int64{}) {
|
||||
// if bytes.Equal(query.BuiltQuery, expectedQuery.BuiltQuery) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// we didn't find the expected query
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Mock) expectedQueries() []*MockQuery {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return append([]*MockQuery{}, m.ExpectedQueries...)
|
||||
}
|
||||
|
||||
func (m *Mock) queries() []MockQuery {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return append([]MockQuery{}, m.Queries...)
|
||||
}
|
|
@ -0,0 +1,133 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
// Node represents a database server in the cluster
|
||||
type Node struct {
|
||||
ID string
|
||||
Host Host
|
||||
aliases []Host
|
||||
|
||||
cluster *Cluster
|
||||
pool *Pool
|
||||
|
||||
mu sync.RWMutex
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newNode(id string, aliases []Host, cluster *Cluster, pool *Pool) *Node {
|
||||
node := &Node{
|
||||
ID: id,
|
||||
Host: aliases[0],
|
||||
aliases: aliases,
|
||||
cluster: cluster,
|
||||
pool: pool,
|
||||
}
|
||||
|
||||
return node
|
||||
}
|
||||
|
||||
// Closed returns true if the node is closed
|
||||
func (n *Node) Closed() bool {
|
||||
n.mu.RLock()
|
||||
defer n.mu.RUnlock()
|
||||
|
||||
return n.closed
|
||||
}
|
||||
|
||||
// Close closes the session
|
||||
func (n *Node) Close(optArgs ...CloseOpts) error {
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
|
||||
if n.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(optArgs) >= 1 {
|
||||
if optArgs[0].NoReplyWait {
|
||||
n.NoReplyWait()
|
||||
}
|
||||
}
|
||||
|
||||
if n.pool != nil {
|
||||
n.pool.Close()
|
||||
}
|
||||
n.pool = nil
|
||||
n.closed = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetInitialPoolCap sets the initial capacity of the connection pool.
|
||||
func (n *Node) SetInitialPoolCap(idleConns int) {
|
||||
n.pool.SetInitialPoolCap(idleConns)
|
||||
}
|
||||
|
||||
// SetMaxIdleConns sets the maximum number of connections in the idle
|
||||
// connection pool.
|
||||
func (n *Node) SetMaxIdleConns(idleConns int) {
|
||||
n.pool.SetMaxIdleConns(idleConns)
|
||||
}
|
||||
|
||||
// SetMaxOpenConns sets the maximum number of open connections to the database.
|
||||
func (n *Node) SetMaxOpenConns(openConns int) {
|
||||
n.pool.SetMaxOpenConns(openConns)
|
||||
}
|
||||
|
||||
// NoReplyWait ensures that previous queries with the noreply flag have been
|
||||
// processed by the server. Note that this guarantee only applies to queries
|
||||
// run on the given connection
|
||||
func (n *Node) NoReplyWait() error {
|
||||
return n.pool.Exec(Query{
|
||||
Type: p.Query_NOREPLY_WAIT,
|
||||
})
|
||||
}
|
||||
|
||||
// Query executes a ReQL query using this nodes connection pool.
|
||||
func (n *Node) Query(q Query) (cursor *Cursor, err error) {
|
||||
if n.Closed() {
|
||||
return nil, ErrInvalidNode
|
||||
}
|
||||
|
||||
return n.pool.Query(q)
|
||||
}
|
||||
|
||||
// Exec executes a ReQL query using this nodes connection pool.
|
||||
func (n *Node) Exec(q Query) (err error) {
|
||||
if n.Closed() {
|
||||
return ErrInvalidNode
|
||||
}
|
||||
|
||||
return n.pool.Exec(q)
|
||||
}
|
||||
|
||||
// Server returns the server name and server UUID being used by a connection.
|
||||
func (n *Node) Server() (ServerResponse, error) {
|
||||
var response ServerResponse
|
||||
|
||||
if n.Closed() {
|
||||
return response, ErrInvalidNode
|
||||
}
|
||||
|
||||
return n.pool.Server()
|
||||
}
|
||||
|
||||
type nodeStatus struct {
|
||||
ID string `gorethink:"id"`
|
||||
Name string `gorethink:"name"`
|
||||
Status string `gorethink:"status"`
|
||||
Network struct {
|
||||
Hostname string `gorethink:"hostname"`
|
||||
ClusterPort int64 `gorethink:"cluster_port"`
|
||||
ReqlPort int64 `gorethink:"reql_port"`
|
||||
CanonicalAddresses []struct {
|
||||
Host string `gorethink:"host"`
|
||||
Port int64 `gorethink:"port"`
|
||||
} `gorethink:"canonical_addresses"`
|
||||
} `gorethink:"network"`
|
||||
}
|
|
@ -0,0 +1,200 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"gopkg.in/fatih/pool.v2"
|
||||
)
|
||||
|
||||
var (
|
||||
errPoolClosed = errors.New("gorethink: pool is closed")
|
||||
)
|
||||
|
||||
// A Pool is used to store a pool of connections to a single RethinkDB server
|
||||
type Pool struct {
|
||||
host Host
|
||||
opts *ConnectOpts
|
||||
|
||||
pool pool.Pool
|
||||
|
||||
mu sync.RWMutex // protects following fields
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewPool creates a new connection pool for the given host
|
||||
func NewPool(host Host, opts *ConnectOpts) (*Pool, error) {
|
||||
initialCap := opts.InitialCap
|
||||
if initialCap <= 0 {
|
||||
// Fallback to MaxIdle if InitialCap is zero, this should be removed
|
||||
// when MaxIdle is removed
|
||||
initialCap = opts.MaxIdle
|
||||
}
|
||||
|
||||
maxOpen := opts.MaxOpen
|
||||
if maxOpen <= 0 {
|
||||
maxOpen = 2
|
||||
}
|
||||
|
||||
p, err := pool.NewChannelPool(initialCap, maxOpen, func() (net.Conn, error) {
|
||||
conn, err := NewConnection(host.String(), opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Pool{
|
||||
pool: p,
|
||||
host: host,
|
||||
opts: opts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Ping verifies a connection to the database is still alive,
|
||||
// establishing a connection if necessary.
|
||||
func (p *Pool) Ping() error {
|
||||
_, pc, err := p.conn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return pc.Close()
|
||||
}
|
||||
|
||||
// Close closes the database, releasing any open resources.
|
||||
//
|
||||
// It is rare to Close a Pool, as the Pool handle is meant to be
|
||||
// long-lived and shared between many goroutines.
|
||||
func (p *Pool) Close() error {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
if p.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.pool.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Pool) conn() (*Connection, *pool.PoolConn, error) {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.closed {
|
||||
return nil, nil, errPoolClosed
|
||||
}
|
||||
|
||||
nc, err := p.pool.Get()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pc, ok := nc.(*pool.PoolConn)
|
||||
if !ok {
|
||||
// This should never happen!
|
||||
return nil, nil, fmt.Errorf("Invalid connection in pool")
|
||||
}
|
||||
|
||||
conn, ok := pc.Conn.(*Connection)
|
||||
if !ok {
|
||||
// This should never happen!
|
||||
return nil, nil, fmt.Errorf("Invalid connection in pool")
|
||||
}
|
||||
|
||||
return conn, pc, nil
|
||||
}
|
||||
|
||||
// SetInitialPoolCap sets the initial capacity of the connection pool.
|
||||
//
|
||||
// Deprecated: This value should only be set when connecting
|
||||
func (p *Pool) SetInitialPoolCap(n int) {
|
||||
return
|
||||
}
|
||||
|
||||
// SetMaxIdleConns sets the maximum number of connections in the idle
|
||||
// connection pool.
|
||||
//
|
||||
// Deprecated: This value should only be set when connecting
|
||||
func (p *Pool) SetMaxIdleConns(n int) {
|
||||
return
|
||||
}
|
||||
|
||||
// SetMaxOpenConns sets the maximum number of open connections to the database.
|
||||
//
|
||||
// Deprecated: This value should only be set when connecting
|
||||
func (p *Pool) SetMaxOpenConns(n int) {
|
||||
return
|
||||
}
|
||||
|
||||
// Query execution functions
|
||||
|
||||
// Exec executes a query without waiting for any response.
|
||||
func (p *Pool) Exec(q Query) error {
|
||||
c, pc, err := p.conn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer pc.Close()
|
||||
|
||||
_, _, err = c.Query(q)
|
||||
|
||||
if c.isBad() {
|
||||
pc.MarkUnusable()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Query executes a query and waits for the response
|
||||
func (p *Pool) Query(q Query) (*Cursor, error) {
|
||||
c, pc, err := p.conn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, cursor, err := c.Query(q)
|
||||
|
||||
if err == nil {
|
||||
cursor.releaseConn = releaseConn(c, pc)
|
||||
} else if c.isBad() {
|
||||
pc.MarkUnusable()
|
||||
}
|
||||
|
||||
return cursor, err
|
||||
}
|
||||
|
||||
// Server returns the server name and server UUID being used by a connection.
|
||||
func (p *Pool) Server() (ServerResponse, error) {
|
||||
var response ServerResponse
|
||||
|
||||
c, pc, err := p.conn()
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
defer pc.Close()
|
||||
|
||||
response, err = c.Server()
|
||||
|
||||
if c.isBad() {
|
||||
pc.MarkUnusable()
|
||||
}
|
||||
|
||||
return response, err
|
||||
}
|
||||
|
||||
func releaseConn(c *Connection, pc *pool.PoolConn) func() error {
|
||||
return func() error {
|
||||
if c.isBad() {
|
||||
pc.MarkUnusable()
|
||||
}
|
||||
|
||||
return pc.Close()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,235 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"math"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"gopkg.in/gorethink/gorethink.v2/types"
|
||||
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func convertPseudotype(obj map[string]interface{}, opts map[string]interface{}) (interface{}, error) {
|
||||
if reqlType, ok := obj["$reql_type$"]; ok {
|
||||
if reqlType == "TIME" {
|
||||
// load timeFormat, set to native if the option was not set
|
||||
timeFormat := "native"
|
||||
if opt, ok := opts["time_format"]; ok {
|
||||
if sopt, ok := opt.(string); ok {
|
||||
timeFormat = sopt
|
||||
} else {
|
||||
return nil, fmt.Errorf("Invalid time_format run option \"%s\".", opt)
|
||||
}
|
||||
}
|
||||
|
||||
if timeFormat == "native" {
|
||||
return reqlTimeToNativeTime(obj["epoch_time"].(float64), obj["timezone"].(string))
|
||||
} else if timeFormat == "raw" {
|
||||
return obj, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("Unknown time_format run option \"%s\".", reqlType)
|
||||
}
|
||||
} else if reqlType == "GROUPED_DATA" {
|
||||
// load groupFormat, set to native if the option was not set
|
||||
groupFormat := "native"
|
||||
if opt, ok := opts["group_format"]; ok {
|
||||
if sopt, ok := opt.(string); ok {
|
||||
groupFormat = sopt
|
||||
} else {
|
||||
return nil, fmt.Errorf("Invalid group_format run option \"%s\".", opt)
|
||||
}
|
||||
}
|
||||
|
||||
if groupFormat == "native" || groupFormat == "slice" {
|
||||
return reqlGroupedDataToSlice(obj)
|
||||
} else if groupFormat == "map" {
|
||||
return reqlGroupedDataToMap(obj)
|
||||
} else if groupFormat == "raw" {
|
||||
return obj, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("Unknown group_format run option \"%s\".", reqlType)
|
||||
}
|
||||
} else if reqlType == "BINARY" {
|
||||
binaryFormat := "native"
|
||||
if opt, ok := opts["binary_format"]; ok {
|
||||
if sopt, ok := opt.(string); ok {
|
||||
binaryFormat = sopt
|
||||
} else {
|
||||
return nil, fmt.Errorf("Invalid binary_format run option \"%s\".", opt)
|
||||
}
|
||||
}
|
||||
|
||||
if binaryFormat == "native" {
|
||||
return reqlBinaryToNativeBytes(obj)
|
||||
} else if binaryFormat == "raw" {
|
||||
return obj, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("Unknown binary_format run option \"%s\".", reqlType)
|
||||
}
|
||||
} else if reqlType == "GEOMETRY" {
|
||||
geometryFormat := "native"
|
||||
if opt, ok := opts["geometry_format"]; ok {
|
||||
if sopt, ok := opt.(string); ok {
|
||||
geometryFormat = sopt
|
||||
} else {
|
||||
return nil, fmt.Errorf("Invalid geometry_format run option \"%s\".", opt)
|
||||
}
|
||||
}
|
||||
|
||||
if geometryFormat == "native" {
|
||||
return reqlGeometryToNativeGeometry(obj)
|
||||
} else if geometryFormat == "raw" {
|
||||
return obj, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("Unknown geometry_format run option \"%s\".", reqlType)
|
||||
}
|
||||
} else {
|
||||
return obj, nil
|
||||
}
|
||||
}
|
||||
|
||||
return obj, nil
|
||||
}
|
||||
|
||||
func recursivelyConvertPseudotype(obj interface{}, opts map[string]interface{}) (interface{}, error) {
|
||||
var err error
|
||||
|
||||
switch obj := obj.(type) {
|
||||
case []interface{}:
|
||||
for key, val := range obj {
|
||||
obj[key], err = recursivelyConvertPseudotype(val, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
case map[string]interface{}:
|
||||
for key, val := range obj {
|
||||
obj[key], err = recursivelyConvertPseudotype(val, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
pobj, err := convertPseudotype(obj, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pobj, nil
|
||||
}
|
||||
|
||||
return obj, nil
|
||||
}
|
||||
|
||||
// Pseudo-type helper functions
|
||||
|
||||
func reqlTimeToNativeTime(timestamp float64, timezone string) (time.Time, error) {
|
||||
sec, ms := math.Modf(timestamp)
|
||||
|
||||
// Convert to native time rounding to milliseconds
|
||||
t := time.Unix(int64(sec), int64(math.Floor(ms*1000+0.5))*1000*1000)
|
||||
|
||||
// Caclulate the timezone
|
||||
if timezone != "" {
|
||||
hours, err := strconv.Atoi(timezone[1:3])
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
minutes, err := strconv.Atoi(timezone[4:6])
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
tzOffset := ((hours * 60) + minutes) * 60
|
||||
if timezone[:1] == "-" {
|
||||
tzOffset = 0 - tzOffset
|
||||
}
|
||||
|
||||
t = t.In(time.FixedZone(timezone, tzOffset))
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func reqlGroupedDataToSlice(obj map[string]interface{}) (interface{}, error) {
|
||||
if data, ok := obj["data"]; ok {
|
||||
ret := []interface{}{}
|
||||
for _, v := range data.([]interface{}) {
|
||||
v := v.([]interface{})
|
||||
ret = append(ret, map[string]interface{}{
|
||||
"group": v[0],
|
||||
"reduction": v[1],
|
||||
})
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
return nil, fmt.Errorf("pseudo-type GROUPED_DATA object %v does not have the expected field \"data\"", obj)
|
||||
}
|
||||
|
||||
func reqlGroupedDataToMap(obj map[string]interface{}) (interface{}, error) {
|
||||
if data, ok := obj["data"]; ok {
|
||||
ret := map[interface{}]interface{}{}
|
||||
for _, v := range data.([]interface{}) {
|
||||
v := v.([]interface{})
|
||||
ret[v[0]] = v[1]
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
return nil, fmt.Errorf("pseudo-type GROUPED_DATA object %v does not have the expected field \"data\"", obj)
|
||||
}
|
||||
|
||||
func reqlBinaryToNativeBytes(obj map[string]interface{}) (interface{}, error) {
|
||||
if data, ok := obj["data"]; ok {
|
||||
if data, ok := data.(string); ok {
|
||||
b, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decoding pseudo-type BINARY object %v", obj)
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
return nil, fmt.Errorf("pseudo-type BINARY object %v field \"data\" is not valid", obj)
|
||||
}
|
||||
return nil, fmt.Errorf("pseudo-type BINARY object %v does not have the expected field \"data\"", obj)
|
||||
}
|
||||
|
||||
func reqlGeometryToNativeGeometry(obj map[string]interface{}) (interface{}, error) {
|
||||
if typ, ok := obj["type"]; !ok {
|
||||
return nil, fmt.Errorf("pseudo-type GEOMETRY object %v does not have the expected field \"type\"", obj)
|
||||
} else if typ, ok := typ.(string); !ok {
|
||||
return nil, fmt.Errorf("pseudo-type GEOMETRY object %v field \"type\" is not valid", obj)
|
||||
} else if coords, ok := obj["coordinates"]; !ok {
|
||||
return nil, fmt.Errorf("pseudo-type GEOMETRY object %v does not have the expected field \"coordinates\"", obj)
|
||||
} else if typ == "Point" {
|
||||
point, err := types.UnmarshalPoint(coords)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return types.Geometry{
|
||||
Type: "Point",
|
||||
Point: point,
|
||||
}, nil
|
||||
} else if typ == "LineString" {
|
||||
line, err := types.UnmarshalLineString(coords)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return types.Geometry{
|
||||
Type: "LineString",
|
||||
Line: line,
|
||||
}, nil
|
||||
} else if typ == "Polygon" {
|
||||
lines, err := types.UnmarshalPolygon(coords)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return types.Geometry{
|
||||
Type: "Polygon",
|
||||
Lines: lines,
|
||||
}, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("pseudo-type GEOMETRY object %v field has unknown type %s", obj, typ)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,455 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
// A Query represents a query ready to be sent to the database, A Query differs
|
||||
// from a Term as it contains both a query type and token. These values are used
|
||||
// by the database to determine if the query is continuing a previous request
|
||||
// and also allows the driver to identify the response as they can come out of
|
||||
// order.
|
||||
type Query struct {
|
||||
Type p.Query_QueryType
|
||||
Token int64
|
||||
Term *Term
|
||||
Opts map[string]interface{}
|
||||
builtTerm interface{}
|
||||
}
|
||||
|
||||
func (q *Query) Build() []interface{} {
|
||||
res := []interface{}{int(q.Type)}
|
||||
if q.Term != nil {
|
||||
res = append(res, q.builtTerm)
|
||||
}
|
||||
|
||||
if len(q.Opts) > 0 {
|
||||
// Clone opts and remove custom gorethink options
|
||||
opts := map[string]interface{}{}
|
||||
for k, v := range q.Opts {
|
||||
switch k {
|
||||
case "geometry_format":
|
||||
default:
|
||||
opts[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
res = append(res, opts)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
type termsList []Term
|
||||
type termsObj map[string]Term
|
||||
|
||||
// A Term represents a query that is being built. Terms consist of a an array of
|
||||
// "sub-terms" and a term type. When a Term is a sub-term the first element of
|
||||
// the terms data is its parent Term.
|
||||
//
|
||||
// When built the term becomes a JSON array, for more information on the format
|
||||
// see http://rethinkdb.com/docs/writing-drivers/.
|
||||
type Term struct {
|
||||
name string
|
||||
rawQuery bool
|
||||
rootTerm bool
|
||||
termType p.Term_TermType
|
||||
data interface{}
|
||||
args []Term
|
||||
optArgs map[string]Term
|
||||
lastErr error
|
||||
isMockAnything bool
|
||||
}
|
||||
|
||||
func (t Term) compare(t2 Term, varMap map[int64]int64) bool {
|
||||
if t.isMockAnything || t2.isMockAnything {
|
||||
return true
|
||||
}
|
||||
|
||||
if t.name != t2.name ||
|
||||
t.rawQuery != t2.rawQuery ||
|
||||
t.rootTerm != t2.rootTerm ||
|
||||
t.termType != t2.termType ||
|
||||
!reflect.DeepEqual(t.data, t2.data) ||
|
||||
len(t.args) != len(t2.args) ||
|
||||
len(t.optArgs) != len(t2.optArgs) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i, v := range t.args {
|
||||
if t.termType == p.Term_FUNC && t2.termType == p.Term_FUNC && i == 0 {
|
||||
// Functions need to be compared differently as each variable
|
||||
// will have a different var ID so first try to create a mapping
|
||||
// between the two sets of IDs
|
||||
argsArr := t.args[0].args
|
||||
argsArr2 := t2.args[0].args
|
||||
|
||||
if len(argsArr) != len(argsArr2) {
|
||||
return false
|
||||
}
|
||||
|
||||
for j := 0; j < len(argsArr); j++ {
|
||||
varMap[argsArr[j].data.(int64)] = argsArr2[j].data.(int64)
|
||||
}
|
||||
} else if t.termType == p.Term_VAR && t2.termType == p.Term_VAR && i == 0 {
|
||||
// When comparing vars use our var map
|
||||
v1 := t.args[i].data.(int64)
|
||||
v2 := t2.args[i].data.(int64)
|
||||
|
||||
if varMap[v1] != v2 {
|
||||
return false
|
||||
}
|
||||
} else if !v.compare(t2.args[i], varMap) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range t.optArgs {
|
||||
if _, ok := t2.optArgs[k]; !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if !v.compare(t2.optArgs[k], varMap) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// build takes the query tree and prepares it to be sent as a JSON
|
||||
// expression
|
||||
func (t Term) Build() (interface{}, error) {
|
||||
var err error
|
||||
|
||||
if t.lastErr != nil {
|
||||
return nil, t.lastErr
|
||||
}
|
||||
|
||||
if t.rawQuery {
|
||||
return t.data, nil
|
||||
}
|
||||
|
||||
switch t.termType {
|
||||
case p.Term_DATUM:
|
||||
return t.data, nil
|
||||
case p.Term_MAKE_OBJ:
|
||||
res := map[string]interface{}{}
|
||||
for k, v := range t.optArgs {
|
||||
res[k], err = v.Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
case p.Term_BINARY:
|
||||
if len(t.args) == 0 {
|
||||
return map[string]interface{}{
|
||||
"$reql_type$": "BINARY",
|
||||
"data": t.data,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
args := make([]interface{}, len(t.args))
|
||||
optArgs := make(map[string]interface{}, len(t.optArgs))
|
||||
|
||||
for i, v := range t.args {
|
||||
arg, err := v.Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args[i] = arg
|
||||
}
|
||||
|
||||
for k, v := range t.optArgs {
|
||||
optArgs[k], err = v.Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
ret := []interface{}{int(t.termType)}
|
||||
|
||||
if len(args) > 0 {
|
||||
ret = append(ret, args)
|
||||
}
|
||||
if len(optArgs) > 0 {
|
||||
ret = append(ret, optArgs)
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// String returns a string representation of the query tree
|
||||
func (t Term) String() string {
|
||||
if t.isMockAnything {
|
||||
return "r.MockAnything()"
|
||||
}
|
||||
|
||||
switch t.termType {
|
||||
case p.Term_MAKE_ARRAY:
|
||||
return fmt.Sprintf("[%s]", strings.Join(argsToStringSlice(t.args), ", "))
|
||||
case p.Term_MAKE_OBJ:
|
||||
return fmt.Sprintf("{%s}", strings.Join(optArgsToStringSlice(t.optArgs), ", "))
|
||||
case p.Term_FUNC:
|
||||
// Get string representation of each argument
|
||||
args := []string{}
|
||||
for _, v := range t.args[0].args {
|
||||
args = append(args, fmt.Sprintf("var_%d", v.data))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("func(%s r.Term) r.Term { return %s }",
|
||||
strings.Join(args, ", "),
|
||||
t.args[1].String(),
|
||||
)
|
||||
case p.Term_VAR:
|
||||
return fmt.Sprintf("var_%s", t.args[0])
|
||||
case p.Term_IMPLICIT_VAR:
|
||||
return "r.Row"
|
||||
case p.Term_DATUM:
|
||||
switch v := t.data.(type) {
|
||||
case string:
|
||||
return strconv.Quote(v)
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
case p.Term_BINARY:
|
||||
if len(t.args) == 0 {
|
||||
return fmt.Sprintf("r.binary(<data>)")
|
||||
}
|
||||
}
|
||||
|
||||
if t.rootTerm {
|
||||
return fmt.Sprintf("r.%s(%s)", t.name, strings.Join(allArgsToStringSlice(t.args, t.optArgs), ", "))
|
||||
}
|
||||
|
||||
if t.args == nil {
|
||||
return "r"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s.%s(%s)", t.args[0].String(), t.name, strings.Join(allArgsToStringSlice(t.args[1:], t.optArgs), ", "))
|
||||
}
|
||||
|
||||
// OptArgs is an interface used to represent a terms optional arguments. All
|
||||
// optional argument types have a toMap function, the returned map can be encoded
|
||||
// and sent as part of the query.
|
||||
type OptArgs interface {
|
||||
toMap() map[string]interface{}
|
||||
}
|
||||
|
||||
func (t Term) OptArgs(args interface{}) Term {
|
||||
switch args := args.(type) {
|
||||
case OptArgs:
|
||||
t.optArgs = convertTermObj(args.toMap())
|
||||
case map[string]interface{}:
|
||||
t.optArgs = convertTermObj(args)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
type QueryExecutor interface {
|
||||
IsConnected() bool
|
||||
Query(Query) (*Cursor, error)
|
||||
Exec(Query) error
|
||||
|
||||
newQuery(t Term, opts map[string]interface{}) (Query, error)
|
||||
}
|
||||
|
||||
// WriteResponse is a helper type used when dealing with the response of a
|
||||
// write query. It is also returned by the RunWrite function.
|
||||
type WriteResponse struct {
|
||||
Errors int `gorethink:"errors"`
|
||||
Inserted int `gorethink:"inserted"`
|
||||
Updated int `gorethink:"updated"`
|
||||
Unchanged int `gorethink:"unchanged"`
|
||||
Replaced int `gorethink:"replaced"`
|
||||
Renamed int `gorethink:"renamed"`
|
||||
Skipped int `gorethink:"skipped"`
|
||||
Deleted int `gorethink:"deleted"`
|
||||
Created int `gorethink:"created"`
|
||||
DBsCreated int `gorethink:"dbs_created"`
|
||||
TablesCreated int `gorethink:"tables_created"`
|
||||
Dropped int `gorethink:"dropped"`
|
||||
DBsDropped int `gorethink:"dbs_dropped"`
|
||||
TablesDropped int `gorethink:"tables_dropped"`
|
||||
GeneratedKeys []string `gorethink:"generated_keys"`
|
||||
FirstError string `gorethink:"first_error"` // populated if Errors > 0
|
||||
ConfigChanges []ChangeResponse `gorethink:"config_changes"`
|
||||
Changes []ChangeResponse
|
||||
}
|
||||
|
||||
// ChangeResponse is a helper type used when dealing with changefeeds. The type
|
||||
// contains both the value before the query and the new value.
|
||||
type ChangeResponse struct {
|
||||
NewValue interface{} `gorethink:"new_val,omitempty"`
|
||||
OldValue interface{} `gorethink:"old_val,omitempty"`
|
||||
State string `gorethink:"state,omitempty"`
|
||||
Error string `gorethink:"error,omitempty"`
|
||||
}
|
||||
|
||||
// RunOpts contains the optional arguments for the Run function.
|
||||
type RunOpts struct {
|
||||
DB interface{} `gorethink:"db,omitempty"`
|
||||
Db interface{} `gorethink:"db,omitempty"` // Deprecated
|
||||
Profile interface{} `gorethink:"profile,omitempty"`
|
||||
Durability interface{} `gorethink:"durability,omitempty"`
|
||||
UseOutdated interface{} `gorethink:"use_outdated,omitempty"` // Deprecated
|
||||
ArrayLimit interface{} `gorethink:"array_limit,omitempty"`
|
||||
TimeFormat interface{} `gorethink:"time_format,omitempty"`
|
||||
GroupFormat interface{} `gorethink:"group_format,omitempty"`
|
||||
BinaryFormat interface{} `gorethink:"binary_format,omitempty"`
|
||||
GeometryFormat interface{} `gorethink:"geometry_format,omitempty"`
|
||||
ReadMode interface{} `gorethink:"read_mode,omitempty"`
|
||||
|
||||
MinBatchRows interface{} `gorethink:"min_batch_rows,omitempty"`
|
||||
MaxBatchRows interface{} `gorethink:"max_batch_rows,omitempty"`
|
||||
MaxBatchBytes interface{} `gorethink:"max_batch_bytes,omitempty"`
|
||||
MaxBatchSeconds interface{} `gorethink:"max_batch_seconds,omitempty"`
|
||||
FirstBatchScaledownFactor interface{} `gorethink:"first_batch_scaledown_factor,omitempty"`
|
||||
}
|
||||
|
||||
func (o RunOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Run runs a query using the given connection.
|
||||
//
|
||||
// rows, err := query.Run(sess)
|
||||
// if err != nil {
|
||||
// // error
|
||||
// }
|
||||
//
|
||||
// var doc MyDocumentType
|
||||
// for rows.Next(&doc) {
|
||||
// // Do something with document
|
||||
// }
|
||||
func (t Term) Run(s QueryExecutor, optArgs ...RunOpts) (*Cursor, error) {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
|
||||
if s == nil || !s.IsConnected() {
|
||||
return nil, ErrConnectionClosed
|
||||
}
|
||||
|
||||
q, err := s.newQuery(t, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.Query(q)
|
||||
}
|
||||
|
||||
// RunWrite runs a query using the given connection but unlike Run automatically
|
||||
// scans the result into a variable of type WriteResponse. This function should be used
|
||||
// if you are running a write query (such as Insert, Update, TableCreate, etc...).
|
||||
//
|
||||
// If an error occurs when running the write query the first error is returned.
|
||||
//
|
||||
// res, err := r.DB("database").Table("table").Insert(doc).RunWrite(sess)
|
||||
func (t Term) RunWrite(s QueryExecutor, optArgs ...RunOpts) (WriteResponse, error) {
|
||||
var response WriteResponse
|
||||
|
||||
res, err := t.Run(s, optArgs...)
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
defer res.Close()
|
||||
|
||||
if err = res.One(&response); err != nil {
|
||||
return response, err
|
||||
}
|
||||
|
||||
if response.Errors > 0 {
|
||||
return response, fmt.Errorf("%s", response.FirstError)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// ReadOne is a shortcut method that runs the query on the given connection
|
||||
// and reads one response from the cursor before closing it.
|
||||
//
|
||||
// It returns any errors encountered from running the query or reading the response
|
||||
func (t Term) ReadOne(dest interface{}, s QueryExecutor, optArgs ...RunOpts) error {
|
||||
res, err := t.Run(s, optArgs...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return res.One(dest)
|
||||
}
|
||||
|
||||
// ReadAll is a shortcut method that runs the query on the given connection
|
||||
// and reads all of the responses from the cursor before closing it.
|
||||
//
|
||||
// It returns any errors encountered from running the query or reading the responses
|
||||
func (t Term) ReadAll(dest interface{}, s QueryExecutor, optArgs ...RunOpts) error {
|
||||
res, err := t.Run(s, optArgs...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return res.All(dest)
|
||||
}
|
||||
|
||||
// ExecOpts contains the optional arguments for the Exec function and inherits
|
||||
// its options from RunOpts, the only difference is the addition of the NoReply
|
||||
// field.
|
||||
//
|
||||
// When NoReply is true it causes the driver not to wait to receive the result
|
||||
// and return immediately.
|
||||
type ExecOpts struct {
|
||||
DB interface{} `gorethink:"db,omitempty"`
|
||||
Db interface{} `gorethink:"db,omitempty"` // Deprecated
|
||||
Profile interface{} `gorethink:"profile,omitempty"`
|
||||
Durability interface{} `gorethink:"durability,omitempty"`
|
||||
UseOutdated interface{} `gorethink:"use_outdated,omitempty"` // Deprecated
|
||||
ArrayLimit interface{} `gorethink:"array_limit,omitempty"`
|
||||
TimeFormat interface{} `gorethink:"time_format,omitempty"`
|
||||
GroupFormat interface{} `gorethink:"group_format,omitempty"`
|
||||
BinaryFormat interface{} `gorethink:"binary_format,omitempty"`
|
||||
GeometryFormat interface{} `gorethink:"geometry_format,omitempty"`
|
||||
|
||||
MinBatchRows interface{} `gorethink:"min_batch_rows,omitempty"`
|
||||
MaxBatchRows interface{} `gorethink:"max_batch_rows,omitempty"`
|
||||
MaxBatchBytes interface{} `gorethink:"max_batch_bytes,omitempty"`
|
||||
MaxBatchSeconds interface{} `gorethink:"max_batch_seconds,omitempty"`
|
||||
FirstBatchScaledownFactor interface{} `gorethink:"first_batch_scaledown_factor,omitempty"`
|
||||
|
||||
NoReply interface{} `gorethink:"noreply,omitempty"`
|
||||
}
|
||||
|
||||
func (o ExecOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Exec runs the query but does not return the result. Exec will still wait for
|
||||
// the response to be received unless the NoReply field is true.
|
||||
//
|
||||
// err := r.DB("database").Table("table").Insert(doc).Exec(sess, r.ExecOpts{
|
||||
// NoReply: true,
|
||||
// })
|
||||
func (t Term) Exec(s QueryExecutor, optArgs ...ExecOpts) error {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
|
||||
if s == nil || !s.IsConnected() {
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
|
||||
q, err := s.newQuery(t, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.Exec(q)
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
// Config can be used to read and/or update the configurations for individual
|
||||
// tables or databases.
|
||||
func (t Term) Config() Term {
|
||||
return constructMethodTerm(t, "Config", p.Term_CONFIG, []interface{}{}, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Rebalance rebalances the shards of a table. When called on a database, all
|
||||
// the tables in that database will be rebalanced.
|
||||
func (t Term) Rebalance() Term {
|
||||
return constructMethodTerm(t, "Rebalance", p.Term_REBALANCE, []interface{}{}, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// ReconfigureOpts contains the optional arguments for the Reconfigure term.
|
||||
type ReconfigureOpts struct {
|
||||
Shards interface{} `gorethink:"shards,omitempty"`
|
||||
Replicas interface{} `gorethink:"replicas,omitempty"`
|
||||
DryRun interface{} `gorethink:"dry_run,omitempty"`
|
||||
EmergencyRepair interface{} `gorethink:"emergency_repair,omitempty"`
|
||||
NonVotingReplicaTags interface{} `gorethink:"nonvoting_replica_tags,omitempty"`
|
||||
PrimaryReplicaTag interface{} `gorethink:"primary_replica_tag,omitempty"`
|
||||
}
|
||||
|
||||
func (o ReconfigureOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Reconfigure a table's sharding and replication.
|
||||
func (t Term) Reconfigure(optArgs ...ReconfigureOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructMethodTerm(t, "Reconfigure", p.Term_RECONFIGURE, []interface{}{}, opts)
|
||||
}
|
||||
|
||||
// Status return the status of a table
|
||||
func (t Term) Status() Term {
|
||||
return constructMethodTerm(t, "Status", p.Term_STATUS, []interface{}{}, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// WaitOpts contains the optional arguments for the Wait term.
|
||||
type WaitOpts struct {
|
||||
WaitFor interface{} `gorethink:"wait_for,omitempty"`
|
||||
Timeout interface{} `gorethink:"timeout,omitempty"`
|
||||
}
|
||||
|
||||
func (o WaitOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Wait for a table or all the tables in a database to be ready. A table may be
|
||||
// temporarily unavailable after creation, rebalancing or reconfiguring. The
|
||||
// wait command blocks until the given table (or database) is fully up to date.
|
||||
//
|
||||
// Deprecated: This function is not supported by RethinkDB 2.3 and above.
|
||||
func Wait(optArgs ...WaitOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructRootTerm("Wait", p.Term_WAIT, []interface{}{}, opts)
|
||||
}
|
||||
|
||||
// Wait for a table or all the tables in a database to be ready. A table may be
|
||||
// temporarily unavailable after creation, rebalancing or reconfiguring. The
|
||||
// wait command blocks until the given table (or database) is fully up to date.
|
||||
func (t Term) Wait(optArgs ...WaitOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructMethodTerm(t, "Wait", p.Term_WAIT, []interface{}{}, opts)
|
||||
}
|
||||
|
||||
// Grant modifies access permissions for a user account, globally or on a
|
||||
// per-database or per-table basis.
|
||||
func (t Term) Grant(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Grant", p.Term_GRANT, args, map[string]interface{}{})
|
||||
}
|
|
@ -0,0 +1,362 @@
|
|||
package gorethink
|
||||
|
||||
import p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
|
||||
// Aggregation
|
||||
// These commands are used to compute smaller values from large sequences.
|
||||
|
||||
// Reduce produces a single value from a sequence through repeated application
|
||||
// of a reduction function
|
||||
//
|
||||
// It takes one argument of type `func (r.Term, r.Term) interface{}`, for
|
||||
// example this query sums all elements in an array:
|
||||
//
|
||||
// r.Expr([]int{1,3,6}).Reduce(func (left, right r.Term) interface{} {
|
||||
// return left.Add(right)
|
||||
// })
|
||||
func (t Term) Reduce(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Reduce", p.Term_REDUCE, funcWrapArgs(args), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// DistinctOpts contains the optional arguments for the Distinct term
|
||||
type DistinctOpts struct {
|
||||
Index interface{} `gorethink:"index,omitempty"`
|
||||
}
|
||||
|
||||
func (o DistinctOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Distinct removes duplicate elements from the sequence.
|
||||
func Distinct(arg interface{}, optArgs ...DistinctOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructRootTerm("Distinct", p.Term_DISTINCT, []interface{}{arg}, opts)
|
||||
}
|
||||
|
||||
// Distinct removes duplicate elements from the sequence.
|
||||
func (t Term) Distinct(optArgs ...DistinctOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructMethodTerm(t, "Distinct", p.Term_DISTINCT, []interface{}{}, opts)
|
||||
}
|
||||
|
||||
// GroupOpts contains the optional arguments for the Group term
|
||||
type GroupOpts struct {
|
||||
Index interface{} `gorethink:"index,omitempty"`
|
||||
Multi interface{} `gorethink:"multi,omitempty"`
|
||||
}
|
||||
|
||||
func (o GroupOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Group takes a stream and partitions it into multiple groups based on the
|
||||
// fields or functions provided. Commands chained after group will be
|
||||
// called on each of these grouped sub-streams, producing grouped data.
|
||||
func Group(fieldOrFunctions ...interface{}) Term {
|
||||
return constructRootTerm("Group", p.Term_GROUP, funcWrapArgs(fieldOrFunctions), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// MultiGroup takes a stream and partitions it into multiple groups based on the
|
||||
// fields or functions provided. Commands chained after group will be
|
||||
// called on each of these grouped sub-streams, producing grouped data.
|
||||
//
|
||||
// Unlike Group single documents can be assigned to multiple groups, similar
|
||||
// to the behavior of multi-indexes. When the grouping value is an array, documents
|
||||
// will be placed in each group that corresponds to the elements of the array. If
|
||||
// the array is empty the row will be ignored.
|
||||
func MultiGroup(fieldOrFunctions ...interface{}) Term {
|
||||
return constructRootTerm("Group", p.Term_GROUP, funcWrapArgs(fieldOrFunctions), map[string]interface{}{
|
||||
"multi": true,
|
||||
})
|
||||
}
|
||||
|
||||
// GroupByIndex takes a stream and partitions it into multiple groups based on the
|
||||
// fields or functions provided. Commands chained after group will be
|
||||
// called on each of these grouped sub-streams, producing grouped data.
|
||||
func GroupByIndex(index interface{}, fieldOrFunctions ...interface{}) Term {
|
||||
return constructRootTerm("Group", p.Term_GROUP, funcWrapArgs(fieldOrFunctions), map[string]interface{}{
|
||||
"index": index,
|
||||
})
|
||||
}
|
||||
|
||||
// MultiGroupByIndex takes a stream and partitions it into multiple groups based on the
|
||||
// fields or functions provided. Commands chained after group will be
|
||||
// called on each of these grouped sub-streams, producing grouped data.
|
||||
//
|
||||
// Unlike Group single documents can be assigned to multiple groups, similar
|
||||
// to the behavior of multi-indexes. When the grouping value is an array, documents
|
||||
// will be placed in each group that corresponds to the elements of the array. If
|
||||
// the array is empty the row will be ignored.
|
||||
func MultiGroupByIndex(index interface{}, fieldOrFunctions ...interface{}) Term {
|
||||
return constructRootTerm("Group", p.Term_GROUP, funcWrapArgs(fieldOrFunctions), map[string]interface{}{
|
||||
"index": index,
|
||||
"mutli": true,
|
||||
})
|
||||
}
|
||||
|
||||
// Group takes a stream and partitions it into multiple groups based on the
|
||||
// fields or functions provided. Commands chained after group will be
|
||||
// called on each of these grouped sub-streams, producing grouped data.
|
||||
func (t Term) Group(fieldOrFunctions ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Group", p.Term_GROUP, funcWrapArgs(fieldOrFunctions), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// MultiGroup takes a stream and partitions it into multiple groups based on the
|
||||
// fields or functions provided. Commands chained after group will be
|
||||
// called on each of these grouped sub-streams, producing grouped data.
|
||||
//
|
||||
// Unlike Group single documents can be assigned to multiple groups, similar
|
||||
// to the behavior of multi-indexes. When the grouping value is an array, documents
|
||||
// will be placed in each group that corresponds to the elements of the array. If
|
||||
// the array is empty the row will be ignored.
|
||||
func (t Term) MultiGroup(fieldOrFunctions ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Group", p.Term_GROUP, funcWrapArgs(fieldOrFunctions), map[string]interface{}{
|
||||
"multi": true,
|
||||
})
|
||||
}
|
||||
|
||||
// GroupByIndex takes a stream and partitions it into multiple groups based on the
|
||||
// fields or functions provided. Commands chained after group will be
|
||||
// called on each of these grouped sub-streams, producing grouped data.
|
||||
func (t Term) GroupByIndex(index interface{}, fieldOrFunctions ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Group", p.Term_GROUP, funcWrapArgs(fieldOrFunctions), map[string]interface{}{
|
||||
"index": index,
|
||||
})
|
||||
}
|
||||
|
||||
// MultiGroupByIndex takes a stream and partitions it into multiple groups based on the
|
||||
// fields or functions provided. Commands chained after group will be
|
||||
// called on each of these grouped sub-streams, producing grouped data.
|
||||
//
|
||||
// Unlike Group single documents can be assigned to multiple groups, similar
|
||||
// to the behavior of multi-indexes. When the grouping value is an array, documents
|
||||
// will be placed in each group that corresponds to the elements of the array. If
|
||||
// the array is empty the row will be ignored.
|
||||
func (t Term) MultiGroupByIndex(index interface{}, fieldOrFunctions ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Group", p.Term_GROUP, funcWrapArgs(fieldOrFunctions), map[string]interface{}{
|
||||
"index": index,
|
||||
"mutli": true,
|
||||
})
|
||||
}
|
||||
|
||||
// Ungroup takes a grouped stream or grouped data and turns it into an array of
|
||||
// objects representing the groups. Any commands chained after Ungroup will
|
||||
// operate on this array, rather than operating on each group individually.
|
||||
// This is useful if you want to e.g. order the groups by the value of their
|
||||
// reduction.
|
||||
func (t Term) Ungroup(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Ungroup", p.Term_UNGROUP, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Contains returns whether or not a sequence contains all the specified values,
|
||||
// or if functions are provided instead, returns whether or not a sequence
|
||||
// contains values matching all the specified functions.
|
||||
func Contains(args ...interface{}) Term {
|
||||
return constructRootTerm("Contains", p.Term_CONTAINS, funcWrapArgs(args), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Contains returns whether or not a sequence contains all the specified values,
|
||||
// or if functions are provided instead, returns whether or not a sequence
|
||||
// contains values matching all the specified functions.
|
||||
func (t Term) Contains(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Contains", p.Term_CONTAINS, funcWrapArgs(args), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Aggregators
|
||||
// These standard aggregator objects are to be used in conjunction with Group.
|
||||
|
||||
// Count the number of elements in the sequence. With a single argument,
|
||||
// count the number of elements equal to it. If the argument is a function,
|
||||
// it is equivalent to calling filter before count.
|
||||
func Count(args ...interface{}) Term {
|
||||
return constructRootTerm("Count", p.Term_COUNT, funcWrapArgs(args), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Count the number of elements in the sequence. With a single argument,
|
||||
// count the number of elements equal to it. If the argument is a function,
|
||||
// it is equivalent to calling filter before count.
|
||||
func (t Term) Count(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Count", p.Term_COUNT, funcWrapArgs(args), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Sum returns the sum of all the elements of a sequence. If called with a field
|
||||
// name, sums all the values of that field in the sequence, skipping elements of
|
||||
// the sequence that lack that field. If called with a function, calls that
|
||||
// function on every element of the sequence and sums the results, skipping
|
||||
// elements of the sequence where that function returns null or a non-existence
|
||||
// error.
|
||||
func Sum(args ...interface{}) Term {
|
||||
return constructRootTerm("Sum", p.Term_SUM, funcWrapArgs(args), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Sum returns the sum of all the elements of a sequence. If called with a field
|
||||
// name, sums all the values of that field in the sequence, skipping elements of
|
||||
// the sequence that lack that field. If called with a function, calls that
|
||||
// function on every element of the sequence and sums the results, skipping
|
||||
// elements of the sequence where that function returns null or a non-existence
|
||||
// error.
|
||||
func (t Term) Sum(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Sum", p.Term_SUM, funcWrapArgs(args), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Avg returns the average of all the elements of a sequence. If called with a field
|
||||
// name, averages all the values of that field in the sequence, skipping elements of
|
||||
// the sequence that lack that field. If called with a function, calls that function
|
||||
// on every element of the sequence and averages the results, skipping elements of the
|
||||
// sequence where that function returns null or a non-existence error.
|
||||
func Avg(args ...interface{}) Term {
|
||||
return constructRootTerm("Avg", p.Term_AVG, funcWrapArgs(args), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Avg returns the average of all the elements of a sequence. If called with a field
|
||||
// name, averages all the values of that field in the sequence, skipping elements of
|
||||
// the sequence that lack that field. If called with a function, calls that function
|
||||
// on every element of the sequence and averages the results, skipping elements of the
|
||||
// sequence where that function returns null or a non-existence error.
|
||||
func (t Term) Avg(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Avg", p.Term_AVG, funcWrapArgs(args), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// MinOpts contains the optional arguments for the Min term
|
||||
type MinOpts struct {
|
||||
Index interface{} `gorethink:"index,omitempty"`
|
||||
}
|
||||
|
||||
func (o MinOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Min finds the minimum of a sequence. If called with a field name, finds the element
|
||||
// of that sequence with the smallest value in that field. If called with a function,
|
||||
// calls that function on every element of the sequence and returns the element
|
||||
// which produced the smallest value, ignoring any elements where the function
|
||||
// returns null or produces a non-existence error.
|
||||
func Min(args ...interface{}) Term {
|
||||
return constructRootTerm("Min", p.Term_MIN, funcWrapArgs(args), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Min finds the minimum of a sequence. If called with a field name, finds the element
|
||||
// of that sequence with the smallest value in that field. If called with a function,
|
||||
// calls that function on every element of the sequence and returns the element
|
||||
// which produced the smallest value, ignoring any elements where the function
|
||||
// returns null or produces a non-existence error.
|
||||
func (t Term) Min(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Min", p.Term_MIN, funcWrapArgs(args), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// MinIndex finds the minimum of a sequence. If called with a field name, finds the element
|
||||
// of that sequence with the smallest value in that field. If called with a function,
|
||||
// calls that function on every element of the sequence and returns the element
|
||||
// which produced the smallest value, ignoring any elements where the function
|
||||
// returns null or produces a non-existence error.
|
||||
func MinIndex(index interface{}, args ...interface{}) Term {
|
||||
return constructRootTerm("Min", p.Term_MIN, funcWrapArgs(args), map[string]interface{}{
|
||||
"index": index,
|
||||
})
|
||||
}
|
||||
|
||||
// MinIndex finds the minimum of a sequence. If called with a field name, finds the element
|
||||
// of that sequence with the smallest value in that field. If called with a function,
|
||||
// calls that function on every element of the sequence and returns the element
|
||||
// which produced the smallest value, ignoring any elements where the function
|
||||
// returns null or produces a non-existence error.
|
||||
func (t Term) MinIndex(index interface{}, args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Min", p.Term_MIN, funcWrapArgs(args), map[string]interface{}{
|
||||
"index": index,
|
||||
})
|
||||
}
|
||||
|
||||
// MaxOpts contains the optional arguments for the Max term
|
||||
type MaxOpts struct {
|
||||
Index interface{} `gorethink:"index,omitempty"`
|
||||
}
|
||||
|
||||
func (o MaxOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Max finds the maximum of a sequence. If called with a field name, finds the element
|
||||
// of that sequence with the largest value in that field. If called with a function,
|
||||
// calls that function on every element of the sequence and returns the element
|
||||
// which produced the largest value, ignoring any elements where the function
|
||||
// returns null or produces a non-existence error.
|
||||
func Max(args ...interface{}) Term {
|
||||
return constructRootTerm("Max", p.Term_MAX, funcWrapArgs(args), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Max finds the maximum of a sequence. If called with a field name, finds the element
|
||||
// of that sequence with the largest value in that field. If called with a function,
|
||||
// calls that function on every element of the sequence and returns the element
|
||||
// which produced the largest value, ignoring any elements where the function
|
||||
// returns null or produces a non-existence error.
|
||||
func (t Term) Max(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Max", p.Term_MAX, funcWrapArgs(args), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// MaxIndex finds the maximum of a sequence. If called with a field name, finds the element
|
||||
// of that sequence with the largest value in that field. If called with a function,
|
||||
// calls that function on every element of the sequence and returns the element
|
||||
// which produced the largest value, ignoring any elements where the function
|
||||
// returns null or produces a non-existence error.
|
||||
func MaxIndex(index interface{}, args ...interface{}) Term {
|
||||
return constructRootTerm("Max", p.Term_MAX, funcWrapArgs(args), map[string]interface{}{
|
||||
"index": index,
|
||||
})
|
||||
}
|
||||
|
||||
// MaxIndex finds the maximum of a sequence. If called with a field name, finds the element
|
||||
// of that sequence with the largest value in that field. If called with a function,
|
||||
// calls that function on every element of the sequence and returns the element
|
||||
// which produced the largest value, ignoring any elements where the function
|
||||
// returns null or produces a non-existence error.
|
||||
func (t Term) MaxIndex(index interface{}, args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Max", p.Term_MAX, funcWrapArgs(args), map[string]interface{}{
|
||||
"index": index,
|
||||
})
|
||||
}
|
||||
|
||||
// FoldOpts contains the optional arguments for the Fold term
|
||||
type FoldOpts struct {
|
||||
Emit interface{} `gorethink:"emit,omitempty"`
|
||||
FinalEmit interface{} `gorethink:"final_emit,omitempty"`
|
||||
}
|
||||
|
||||
func (o FoldOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Fold applies a function to a sequence in order, maintaining state via an
|
||||
// accumulator. The Fold command returns either a single value or a new sequence.
|
||||
//
|
||||
// In its first form, Fold operates like Reduce, returning a value by applying a
|
||||
// combining function to each element in a sequence, passing the current element
|
||||
// and the previous reduction result to the function. However, Fold has the
|
||||
// following differences from Reduce:
|
||||
// - it is guaranteed to proceed through the sequence from first element to last.
|
||||
// - it passes an initial base value to the function with the first element in
|
||||
// place of the previous reduction result.
|
||||
//
|
||||
// In its second form, Fold operates like ConcatMap, returning a new sequence
|
||||
// rather than a single value. When an emit function is provided, Fold will:
|
||||
// - proceed through the sequence in order and take an initial base value, as above.
|
||||
// - for each element in the sequence, call both the combining function and a
|
||||
// separate emitting function with the current element and previous reduction result.
|
||||
// - optionally pass the result of the combining function to the emitting function.
|
||||
//
|
||||
// If provided, the emitting function must return a list.
|
||||
func (t Term) Fold(base, fn interface{}, optArgs ...FoldOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
|
||||
args := []interface{}{base, funcWrap(fn)}
|
||||
|
||||
return constructMethodTerm(t, "Fold", p.Term_FOLD, args, opts)
|
||||
}
|
|
@ -0,0 +1,395 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
|
||||
"reflect"
|
||||
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
// Expr converts any value to an expression and is also used by many other terms
|
||||
// such as Insert and Update. This function can convert the following basic Go
|
||||
// types (bool, int, uint, string, float) and even pointers, maps and structs.
|
||||
//
|
||||
// When evaluating structs they are encoded into a map before being sent to the
|
||||
// server. Each exported field is added to the map unless
|
||||
//
|
||||
// - the field's tag is "-", or
|
||||
// - the field is empty and its tag specifies the "omitempty" option.
|
||||
//
|
||||
// Each fields default name in the map is the field name but can be specified
|
||||
// in the struct field's tag value. The "gorethink" key in the struct field's
|
||||
// tag value is the key name, followed by an optional comma and options. Examples:
|
||||
//
|
||||
// // Field is ignored by this package.
|
||||
// Field int `gorethink:"-"`
|
||||
// // Field appears as key "myName".
|
||||
// Field int `gorethink:"myName"`
|
||||
// // Field appears as key "myName" and
|
||||
// // the field is omitted from the object if its value is empty,
|
||||
// // as defined above.
|
||||
// Field int `gorethink:"myName,omitempty"`
|
||||
// // Field appears as key "Field" (the default), but
|
||||
// // the field is skipped if empty.
|
||||
// // Note the leading comma.
|
||||
// Field int `gorethink:",omitempty"`
|
||||
func Expr(val interface{}) Term {
|
||||
if val == nil {
|
||||
return Term{
|
||||
termType: p.Term_DATUM,
|
||||
data: nil,
|
||||
}
|
||||
}
|
||||
|
||||
switch val := val.(type) {
|
||||
case Term:
|
||||
return val
|
||||
case []interface{}:
|
||||
vals := make([]Term, len(val))
|
||||
for i, v := range val {
|
||||
vals[i] = Expr(v)
|
||||
}
|
||||
|
||||
return makeArray(vals)
|
||||
case map[string]interface{}:
|
||||
vals := make(map[string]Term, len(val))
|
||||
for k, v := range val {
|
||||
vals[k] = Expr(v)
|
||||
}
|
||||
|
||||
return makeObject(vals)
|
||||
case
|
||||
bool,
|
||||
int,
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
uint,
|
||||
uint8,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
float32,
|
||||
float64,
|
||||
uintptr,
|
||||
string,
|
||||
*bool,
|
||||
*int,
|
||||
*int8,
|
||||
*int16,
|
||||
*int32,
|
||||
*int64,
|
||||
*uint,
|
||||
*uint8,
|
||||
*uint16,
|
||||
*uint32,
|
||||
*uint64,
|
||||
*float32,
|
||||
*float64,
|
||||
*uintptr,
|
||||
*string:
|
||||
return Term{
|
||||
termType: p.Term_DATUM,
|
||||
data: val,
|
||||
}
|
||||
default:
|
||||
// Use reflection to check for other types
|
||||
valType := reflect.TypeOf(val)
|
||||
valValue := reflect.ValueOf(val)
|
||||
|
||||
switch valType.Kind() {
|
||||
case reflect.Func:
|
||||
return makeFunc(val)
|
||||
case reflect.Struct, reflect.Map, reflect.Ptr:
|
||||
data, err := encode(val)
|
||||
|
||||
if err != nil || data == nil {
|
||||
return Term{
|
||||
termType: p.Term_DATUM,
|
||||
data: nil,
|
||||
lastErr: err,
|
||||
}
|
||||
}
|
||||
|
||||
return Expr(data)
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
// Check if slice is a byte slice
|
||||
if valType.Elem().Kind() == reflect.Uint8 {
|
||||
data, err := encode(val)
|
||||
|
||||
if err != nil || data == nil {
|
||||
return Term{
|
||||
termType: p.Term_DATUM,
|
||||
data: nil,
|
||||
lastErr: err,
|
||||
}
|
||||
}
|
||||
|
||||
return Expr(data)
|
||||
}
|
||||
|
||||
vals := make([]Term, valValue.Len())
|
||||
for i := 0; i < valValue.Len(); i++ {
|
||||
vals[i] = Expr(valValue.Index(i).Interface())
|
||||
}
|
||||
|
||||
return makeArray(vals)
|
||||
default:
|
||||
data, err := encode(val)
|
||||
|
||||
if err != nil || data == nil {
|
||||
return Term{
|
||||
termType: p.Term_DATUM,
|
||||
data: nil,
|
||||
lastErr: err,
|
||||
}
|
||||
}
|
||||
|
||||
return Term{
|
||||
termType: p.Term_DATUM,
|
||||
data: data,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// JSOpts contains the optional arguments for the JS term
|
||||
type JSOpts struct {
|
||||
Timeout interface{} `gorethink:"timeout,omitempty"`
|
||||
}
|
||||
|
||||
func (o JSOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// JS creates a JavaScript expression which is evaluated by the database when
|
||||
// running the query.
|
||||
func JS(jssrc interface{}, optArgs ...JSOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructRootTerm("Js", p.Term_JAVASCRIPT, []interface{}{jssrc}, opts)
|
||||
}
|
||||
|
||||
// HTTPOpts contains the optional arguments for the HTTP term
|
||||
type HTTPOpts struct {
|
||||
// General Options
|
||||
Timeout interface{} `gorethink:"timeout,omitempty"`
|
||||
Reattempts interface{} `gorethink:"reattempts,omitempty"`
|
||||
Redirects interface{} `gorethink:"redirect,omitempty"`
|
||||
Verify interface{} `gorethink:"verify,omitempty"`
|
||||
ResultFormat interface{} `gorethink:"resul_format,omitempty"`
|
||||
|
||||
// Request Options
|
||||
Method interface{} `gorethink:"method,omitempty"`
|
||||
Auth interface{} `gorethink:"auth,omitempty"`
|
||||
Params interface{} `gorethink:"params,omitempty"`
|
||||
Header interface{} `gorethink:"header,omitempty"`
|
||||
Data interface{} `gorethink:"data,omitempty"`
|
||||
|
||||
// Pagination
|
||||
Page interface{} `gorethink:"page,omitempty"`
|
||||
PageLimit interface{} `gorethink:"page_limit,omitempty"`
|
||||
}
|
||||
|
||||
func (o HTTPOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// HTTP retrieves data from the specified URL over HTTP. The return type depends
|
||||
// on the resultFormat option, which checks the Content-Type of the response by
|
||||
// default.
|
||||
func HTTP(url interface{}, optArgs ...HTTPOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructRootTerm("Http", p.Term_HTTP, []interface{}{url}, opts)
|
||||
}
|
||||
|
||||
// JSON parses a JSON string on the server.
|
||||
func JSON(args ...interface{}) Term {
|
||||
return constructRootTerm("Json", p.Term_JSON, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Error throws a runtime error. If called with no arguments inside the second argument
|
||||
// to `default`, re-throw the current error.
|
||||
func Error(args ...interface{}) Term {
|
||||
return constructRootTerm("Error", p.Term_ERROR, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Args is a special term usd to splice an array of arguments into another term.
|
||||
// This is useful when you want to call a varadic term such as GetAll with a set
|
||||
// of arguments provided at runtime.
|
||||
func Args(args ...interface{}) Term {
|
||||
return constructRootTerm("Args", p.Term_ARGS, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Binary encapsulates binary data within a query.
|
||||
//
|
||||
// The type of data binary accepts depends on the client language. In Go, it
|
||||
// expects either a byte array/slice or a bytes.Buffer.
|
||||
//
|
||||
// Only a limited subset of ReQL commands may be chained after binary:
|
||||
// - coerceTo can coerce binary objects to string types
|
||||
// - count will return the number of bytes in the object
|
||||
// - slice will treat bytes like array indexes (i.e., slice(10,20) will return bytes 10–19)
|
||||
// - typeOf returns PTYPE<BINARY>
|
||||
// - info will return information on a binary object.
|
||||
func Binary(data interface{}) Term {
|
||||
var b []byte
|
||||
|
||||
switch data := data.(type) {
|
||||
case Term:
|
||||
return constructRootTerm("Binary", p.Term_BINARY, []interface{}{data}, map[string]interface{}{})
|
||||
case []byte:
|
||||
b = data
|
||||
default:
|
||||
typ := reflect.TypeOf(data)
|
||||
if typ.Kind() == reflect.Slice && typ.Elem().Kind() == reflect.Uint8 {
|
||||
return Binary(reflect.ValueOf(data).Bytes())
|
||||
} else if typ.Kind() == reflect.Array && typ.Elem().Kind() == reflect.Uint8 {
|
||||
v := reflect.ValueOf(data)
|
||||
b = make([]byte, v.Len())
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
b[i] = v.Index(i).Interface().(byte)
|
||||
}
|
||||
return Binary(b)
|
||||
}
|
||||
panic("Unsupported binary type")
|
||||
}
|
||||
|
||||
return binaryTerm(base64.StdEncoding.EncodeToString(b))
|
||||
}
|
||||
|
||||
func binaryTerm(data string) Term {
|
||||
t := constructRootTerm("Binary", p.Term_BINARY, []interface{}{}, map[string]interface{}{})
|
||||
t.data = data
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// Do evaluates the expr in the context of one or more value bindings. The type of
|
||||
// the result is the type of the value returned from expr.
|
||||
func (t Term) Do(args ...interface{}) Term {
|
||||
newArgs := []interface{}{}
|
||||
newArgs = append(newArgs, funcWrap(args[len(args)-1]))
|
||||
newArgs = append(newArgs, t)
|
||||
newArgs = append(newArgs, args[:len(args)-1]...)
|
||||
|
||||
return constructRootTerm("Do", p.Term_FUNCALL, newArgs, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Do evaluates the expr in the context of one or more value bindings. The type of
|
||||
// the result is the type of the value returned from expr.
|
||||
func Do(args ...interface{}) Term {
|
||||
newArgs := []interface{}{}
|
||||
newArgs = append(newArgs, funcWrap(args[len(args)-1]))
|
||||
newArgs = append(newArgs, args[:len(args)-1]...)
|
||||
|
||||
return constructRootTerm("Do", p.Term_FUNCALL, newArgs, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Branch evaluates one of two control paths based on the value of an expression.
|
||||
// branch is effectively an if renamed due to language constraints.
|
||||
//
|
||||
// The type of the result is determined by the type of the branch that gets executed.
|
||||
func Branch(args ...interface{}) Term {
|
||||
return constructRootTerm("Branch", p.Term_BRANCH, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Branch evaluates one of two control paths based on the value of an expression.
|
||||
// branch is effectively an if renamed due to language constraints.
|
||||
//
|
||||
// The type of the result is determined by the type of the branch that gets executed.
|
||||
func (t Term) Branch(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Branch", p.Term_BRANCH, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// ForEach loops over a sequence, evaluating the given write query for each element.
|
||||
//
|
||||
// It takes one argument of type `func (r.Term) interface{}`, for
|
||||
// example clones a table:
|
||||
//
|
||||
// r.Table("table").ForEach(func (row r.Term) interface{} {
|
||||
// return r.Table("new_table").Insert(row)
|
||||
// })
|
||||
func (t Term) ForEach(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Foreach", p.Term_FOR_EACH, funcWrapArgs(args), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Range generates a stream of sequential integers in a specified range. It
|
||||
// accepts 0, 1, or 2 arguments, all of which should be numbers.
|
||||
func Range(args ...interface{}) Term {
|
||||
return constructRootTerm("Range", p.Term_RANGE, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Default handles non-existence errors. Tries to evaluate and return its first argument.
|
||||
// If an error related to the absence of a value is thrown in the process, or if
|
||||
// its first argument returns null, returns its second argument. (Alternatively,
|
||||
// the second argument may be a function which will be called with either the
|
||||
// text of the non-existence error or null.)
|
||||
func (t Term) Default(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Default", p.Term_DEFAULT, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// CoerceTo converts a value of one type into another.
|
||||
//
|
||||
// You can convert: a selection, sequence, or object into an ARRAY, an array of
|
||||
// pairs into an OBJECT, and any DATUM into a STRING.
|
||||
func (t Term) CoerceTo(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "CoerceTo", p.Term_COERCE_TO, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// TypeOf gets the type of a value.
|
||||
func TypeOf(args ...interface{}) Term {
|
||||
return constructRootTerm("TypeOf", p.Term_TYPE_OF, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// TypeOf gets the type of a value.
|
||||
func (t Term) TypeOf(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "TypeOf", p.Term_TYPE_OF, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// ToJSON converts a ReQL value or object to a JSON string.
|
||||
func (t Term) ToJSON() Term {
|
||||
return constructMethodTerm(t, "ToJSON", p.Term_TO_JSON_STRING, []interface{}{}, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Info gets information about a RQL value.
|
||||
func (t Term) Info(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Info", p.Term_INFO, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// UUID returns a UUID (universally unique identifier), a string that can be used
|
||||
// as a unique ID. If a string is passed to uuid as an argument, the UUID will be
|
||||
// deterministic, derived from the string’s SHA-1 hash.
|
||||
func UUID(args ...interface{}) Term {
|
||||
return constructRootTerm("UUID", p.Term_UUID, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// RawQuery creates a new query from a JSON string, this bypasses any encoding
|
||||
// done by GoRethink. The query should not contain the query type or any options
|
||||
// as this should be handled using the normal driver API.
|
||||
//
|
||||
// THis query will only work if this is the only term in the query.
|
||||
func RawQuery(q []byte) Term {
|
||||
data := json.RawMessage(q)
|
||||
return Term{
|
||||
name: "RawQuery",
|
||||
rootTerm: true,
|
||||
rawQuery: true,
|
||||
data: &data,
|
||||
args: []Term{
|
||||
Term{
|
||||
termType: p.Term_DATUM,
|
||||
data: string(q),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
// DBCreate creates a database. A RethinkDB database is a collection of tables,
|
||||
// similar to relational databases.
|
||||
//
|
||||
// Note: that you can only use alphanumeric characters and underscores for the
|
||||
// database name.
|
||||
func DBCreate(args ...interface{}) Term {
|
||||
return constructRootTerm("DBCreate", p.Term_DB_CREATE, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// DBDrop drops a database. The database, all its tables, and corresponding data
|
||||
// will be deleted.
|
||||
func DBDrop(args ...interface{}) Term {
|
||||
return constructRootTerm("DBDrop", p.Term_DB_DROP, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// DBList lists all database names in the system.
|
||||
func DBList(args ...interface{}) Term {
|
||||
return constructRootTerm("DBList", p.Term_DB_LIST, args, map[string]interface{}{})
|
||||
}
|
|
@ -0,0 +1,170 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
// CircleOpts contains the optional arguments for the Circle term.
|
||||
type CircleOpts struct {
|
||||
NumVertices interface{} `gorethink:"num_vertices,omitempty"`
|
||||
GeoSystem interface{} `gorethink:"geo_system,omitempty"`
|
||||
Unit interface{} `gorethink:"unit,omitempty"`
|
||||
Fill interface{} `gorethink:"fill,omitempty"`
|
||||
}
|
||||
|
||||
func (o CircleOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Circle constructs a circular line or polygon. A circle in RethinkDB is
|
||||
// a polygon or line approximating a circle of a given radius around a given
|
||||
// center, consisting of a specified number of vertices (default 32).
|
||||
func Circle(point, radius interface{}, optArgs ...CircleOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
|
||||
return constructRootTerm("Circle", p.Term_CIRCLE, []interface{}{point, radius}, opts)
|
||||
}
|
||||
|
||||
// DistanceOpts contains the optional arguments for the Distance term.
|
||||
type DistanceOpts struct {
|
||||
GeoSystem interface{} `gorethink:"geo_system,omitempty"`
|
||||
Unit interface{} `gorethink:"unit,omitempty"`
|
||||
}
|
||||
|
||||
func (o DistanceOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Distance calculates the Haversine distance between two points. At least one
|
||||
// of the geometry objects specified must be a point.
|
||||
func (t Term) Distance(point interface{}, optArgs ...DistanceOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
|
||||
return constructMethodTerm(t, "Distance", p.Term_DISTANCE, []interface{}{point}, opts)
|
||||
}
|
||||
|
||||
// Distance calculates the Haversine distance between two points. At least one
|
||||
// of the geometry objects specified must be a point.
|
||||
func Distance(point1, point2 interface{}, optArgs ...DistanceOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
|
||||
return constructRootTerm("Distance", p.Term_DISTANCE, []interface{}{point1, point2}, opts)
|
||||
}
|
||||
|
||||
// Fill converts a Line object into a Polygon object. If the last point does not
|
||||
// specify the same coordinates as the first point, polygon will close the
|
||||
// polygon by connecting them
|
||||
func (t Term) Fill() Term {
|
||||
return constructMethodTerm(t, "Fill", p.Term_FILL, []interface{}{}, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// GeoJSON converts a GeoJSON object to a ReQL geometry object.
|
||||
func GeoJSON(args ...interface{}) Term {
|
||||
return constructRootTerm("GeoJSON", p.Term_GEOJSON, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// ToGeoJSON converts a ReQL geometry object to a GeoJSON object.
|
||||
func (t Term) ToGeoJSON(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "ToGeoJSON", p.Term_TO_GEOJSON, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// GetIntersectingOpts contains the optional arguments for the GetIntersecting term.
|
||||
type GetIntersectingOpts struct {
|
||||
Index interface{} `gorethink:"index,omitempty"`
|
||||
}
|
||||
|
||||
func (o GetIntersectingOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// GetIntersecting gets all documents where the given geometry object intersects
|
||||
// the geometry object of the requested geospatial index.
|
||||
func (t Term) GetIntersecting(args interface{}, optArgs ...GetIntersectingOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
|
||||
return constructMethodTerm(t, "GetIntersecting", p.Term_GET_INTERSECTING, []interface{}{args}, opts)
|
||||
}
|
||||
|
||||
// GetNearestOpts contains the optional arguments for the GetNearest term.
|
||||
type GetNearestOpts struct {
|
||||
Index interface{} `gorethink:"index,omitempty"`
|
||||
MaxResults interface{} `gorethink:"max_results,omitempty"`
|
||||
MaxDist interface{} `gorethink:"max_dist,omitempty"`
|
||||
Unit interface{} `gorethink:"unit,omitempty"`
|
||||
GeoSystem interface{} `gorethink:"geo_system,omitempty"`
|
||||
}
|
||||
|
||||
func (o GetNearestOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// GetNearest gets all documents where the specified geospatial index is within a
|
||||
// certain distance of the specified point (default 100 kilometers).
|
||||
func (t Term) GetNearest(point interface{}, optArgs ...GetNearestOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
|
||||
return constructMethodTerm(t, "GetNearest", p.Term_GET_NEAREST, []interface{}{point}, opts)
|
||||
}
|
||||
|
||||
// Includes tests whether a geometry object is completely contained within another.
|
||||
// When applied to a sequence of geometry objects, includes acts as a filter,
|
||||
// returning a sequence of objects from the sequence that include the argument.
|
||||
func (t Term) Includes(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Includes", p.Term_INCLUDES, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Intersects tests whether two geometry objects intersect with one another.
|
||||
// When applied to a sequence of geometry objects, intersects acts as a filter,
|
||||
// returning a sequence of objects from the sequence that intersect with the
|
||||
// argument.
|
||||
func (t Term) Intersects(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Intersects", p.Term_INTERSECTS, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Line constructs a geometry object of type Line. The line can be specified in
|
||||
// one of two ways:
|
||||
// - Two or more two-item arrays, specifying longitude and latitude numbers of
|
||||
// the line's vertices;
|
||||
// - Two or more Point objects specifying the line's vertices.
|
||||
func Line(args ...interface{}) Term {
|
||||
return constructRootTerm("Line", p.Term_LINE, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Point constructs a geometry object of type Point. The point is specified by
|
||||
// two floating point numbers, the longitude (−180 to 180) and latitude
|
||||
// (−90 to 90) of the point on a perfect sphere.
|
||||
func Point(lon, lat interface{}) Term {
|
||||
return constructRootTerm("Point", p.Term_POINT, []interface{}{lon, lat}, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Polygon constructs a geometry object of type Polygon. The Polygon can be
|
||||
// specified in one of two ways:
|
||||
// - Three or more two-item arrays, specifying longitude and latitude numbers of the polygon's vertices;
|
||||
// - Three or more Point objects specifying the polygon's vertices.
|
||||
func Polygon(args ...interface{}) Term {
|
||||
return constructRootTerm("Polygon", p.Term_POLYGON, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// PolygonSub "punches a hole" out of the parent polygon using the polygon passed
|
||||
// to the function.
|
||||
// polygon1.PolygonSub(polygon2) -> polygon
|
||||
// In the example above polygon2 must be completely contained within polygon1
|
||||
// and must have no holes itself (it must not be the output of polygon_sub itself).
|
||||
func (t Term) PolygonSub(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "PolygonSub", p.Term_POLYGON_SUB, args, map[string]interface{}{})
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
// InnerJoin returns the inner product of two sequences (e.g. a table, a filter result)
|
||||
// filtered by the predicate. The query compares each row of the left sequence
|
||||
// with each row of the right sequence to find all pairs of rows which satisfy
|
||||
// the predicate. When the predicate is satisfied, each matched pair of rows
|
||||
// of both sequences are combined into a result row.
|
||||
func (t Term) InnerJoin(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "InnerJoin", p.Term_INNER_JOIN, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// OuterJoin computes a left outer join by retaining each row in the left table even
|
||||
// if no match was found in the right table.
|
||||
func (t Term) OuterJoin(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "OuterJoin", p.Term_OUTER_JOIN, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// EqJoinOpts contains the optional arguments for the EqJoin term.
|
||||
type EqJoinOpts struct {
|
||||
Index interface{} `gorethink:"index,omitempty"`
|
||||
Ordered interface{} `gorethink:"ordered,omitempty"`
|
||||
}
|
||||
|
||||
func (o EqJoinOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// EqJoin is an efficient join that looks up elements in the right table by primary key.
|
||||
//
|
||||
// Optional arguments: "index" (string - name of the index to use in right table instead of the primary key)
|
||||
func (t Term) EqJoin(left, right interface{}, optArgs ...EqJoinOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructMethodTerm(t, "EqJoin", p.Term_EQ_JOIN, []interface{}{funcWrap(left), right}, opts)
|
||||
}
|
||||
|
||||
// Zip is used to 'zip' up the result of a join by merging the 'right' fields into 'left'
|
||||
// fields of each member of the sequence.
|
||||
func (t Term) Zip(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Zip", p.Term_ZIP, args, map[string]interface{}{})
|
||||
}
|
|
@ -0,0 +1,121 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
// Row returns the currently visited document. Note that Row does not work within
|
||||
// subqueries to access nested documents; you should use anonymous functions to
|
||||
// access those documents instead. Also note that unlike in other drivers to
|
||||
// access a rows fields you should call Field. For example:
|
||||
// r.row("fieldname") should instead be r.Row.Field("fieldname")
|
||||
var Row = constructRootTerm("Doc", p.Term_IMPLICIT_VAR, []interface{}{}, map[string]interface{}{})
|
||||
|
||||
// Literal replaces an object in a field instead of merging it with an existing
|
||||
// object in a merge or update operation.
|
||||
func Literal(args ...interface{}) Term {
|
||||
return constructRootTerm("Literal", p.Term_LITERAL, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Field gets a single field from an object. If called on a sequence, gets that field
|
||||
// from every object in the sequence, skipping objects that lack it.
|
||||
func (t Term) Field(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Field", p.Term_GET_FIELD, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// HasFields tests if an object has all of the specified fields. An object has a field if
|
||||
// it has the specified key and that key maps to a non-null value. For instance,
|
||||
// the object `{'a':1,'b':2,'c':null}` has the fields `a` and `b`.
|
||||
func (t Term) HasFields(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "HasFields", p.Term_HAS_FIELDS, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Pluck plucks out one or more attributes from either an object or a sequence of
|
||||
// objects (projection).
|
||||
func (t Term) Pluck(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Pluck", p.Term_PLUCK, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Without is the opposite of pluck; takes an object or a sequence of objects, and returns
|
||||
// them with the specified paths removed.
|
||||
func (t Term) Without(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Without", p.Term_WITHOUT, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Merge merges two objects together to construct a new object with properties from both.
|
||||
// Gives preference to attributes from other when there is a conflict.
|
||||
func (t Term) Merge(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Merge", p.Term_MERGE, funcWrapArgs(args), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Append appends a value to an array.
|
||||
func (t Term) Append(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Append", p.Term_APPEND, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Prepend prepends a value to an array.
|
||||
func (t Term) Prepend(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Prepend", p.Term_PREPEND, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Difference removes the elements of one array from another array.
|
||||
func (t Term) Difference(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Difference", p.Term_DIFFERENCE, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// SetInsert adds a value to an array and return it as a set (an array with distinct values).
|
||||
func (t Term) SetInsert(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "SetInsert", p.Term_SET_INSERT, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// SetUnion adds several values to an array and return it as a set (an array with
|
||||
// distinct values).
|
||||
func (t Term) SetUnion(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "SetUnion", p.Term_SET_UNION, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// SetIntersection calculates the intersection of two arrays returning values that
|
||||
// occur in both of them as a set (an array with distinct values).
|
||||
func (t Term) SetIntersection(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "SetIntersection", p.Term_SET_INTERSECTION, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// SetDifference removes the elements of one array from another and return them as a set (an
|
||||
// array with distinct values).
|
||||
func (t Term) SetDifference(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "SetDifference", p.Term_SET_DIFFERENCE, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// InsertAt inserts a value in to an array at a given index. Returns the modified array.
|
||||
func (t Term) InsertAt(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "InsertAt", p.Term_INSERT_AT, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// SpliceAt inserts several values in to an array at a given index. Returns the modified array.
|
||||
func (t Term) SpliceAt(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "SpliceAt", p.Term_SPLICE_AT, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// DeleteAt removes an element from an array at a given index. Returns the modified array.
|
||||
func (t Term) DeleteAt(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "DeleteAt", p.Term_DELETE_AT, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// ChangeAt changes a value in an array at a given index. Returns the modified array.
|
||||
func (t Term) ChangeAt(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "ChangeAt", p.Term_CHANGE_AT, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Keys returns an array containing all of the object's keys.
|
||||
func (t Term) Keys(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Keys", p.Term_KEYS, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func (t Term) Values(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Values", p.Term_VALUES, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Object creates an object from a list of key-value pairs, where the keys must be strings.
|
||||
func Object(args ...interface{}) Term {
|
||||
return constructRootTerm("Object", p.Term_OBJECT, args, map[string]interface{}{})
|
||||
}
|
|
@ -0,0 +1,229 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
var (
|
||||
// MinVal represents the smallest possible value RethinkDB can store
|
||||
MinVal = constructRootTerm("MinVal", p.Term_MINVAL, []interface{}{}, map[string]interface{}{})
|
||||
// MaxVal represents the largest possible value RethinkDB can store
|
||||
MaxVal = constructRootTerm("MaxVal", p.Term_MAXVAL, []interface{}{}, map[string]interface{}{})
|
||||
)
|
||||
|
||||
// Add sums two numbers or concatenates two arrays.
|
||||
func (t Term) Add(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Add", p.Term_ADD, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Add sums two numbers or concatenates two arrays.
|
||||
func Add(args ...interface{}) Term {
|
||||
return constructRootTerm("Add", p.Term_ADD, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Sub subtracts two numbers.
|
||||
func (t Term) Sub(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Sub", p.Term_SUB, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Sub subtracts two numbers.
|
||||
func Sub(args ...interface{}) Term {
|
||||
return constructRootTerm("Sub", p.Term_SUB, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Mul multiplies two numbers.
|
||||
func (t Term) Mul(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Mul", p.Term_MUL, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Mul multiplies two numbers.
|
||||
func Mul(args ...interface{}) Term {
|
||||
return constructRootTerm("Mul", p.Term_MUL, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Div divides two numbers.
|
||||
func (t Term) Div(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Div", p.Term_DIV, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Div divides two numbers.
|
||||
func Div(args ...interface{}) Term {
|
||||
return constructRootTerm("Div", p.Term_DIV, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Mod divides two numbers and returns the remainder.
|
||||
func (t Term) Mod(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Mod", p.Term_MOD, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Mod divides two numbers and returns the remainder.
|
||||
func Mod(args ...interface{}) Term {
|
||||
return constructRootTerm("Mod", p.Term_MOD, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// And performs a logical and on two values.
|
||||
func (t Term) And(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "And", p.Term_AND, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// And performs a logical and on two values.
|
||||
func And(args ...interface{}) Term {
|
||||
return constructRootTerm("And", p.Term_AND, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Or performs a logical or on two values.
|
||||
func (t Term) Or(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Or", p.Term_OR, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Or performs a logical or on two values.
|
||||
func Or(args ...interface{}) Term {
|
||||
return constructRootTerm("Or", p.Term_OR, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Eq returns true if two values are equal.
|
||||
func (t Term) Eq(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Eq", p.Term_EQ, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Eq returns true if two values are equal.
|
||||
func Eq(args ...interface{}) Term {
|
||||
return constructRootTerm("Eq", p.Term_EQ, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Ne returns true if two values are not equal.
|
||||
func (t Term) Ne(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Ne", p.Term_NE, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Ne returns true if two values are not equal.
|
||||
func Ne(args ...interface{}) Term {
|
||||
return constructRootTerm("Ne", p.Term_NE, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Gt returns true if the first value is greater than the second.
|
||||
func (t Term) Gt(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Gt", p.Term_GT, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Gt returns true if the first value is greater than the second.
|
||||
func Gt(args ...interface{}) Term {
|
||||
return constructRootTerm("Gt", p.Term_GT, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Ge returns true if the first value is greater than or equal to the second.
|
||||
func (t Term) Ge(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Ge", p.Term_GE, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Ge returns true if the first value is greater than or equal to the second.
|
||||
func Ge(args ...interface{}) Term {
|
||||
return constructRootTerm("Ge", p.Term_GE, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Lt returns true if the first value is less than the second.
|
||||
func (t Term) Lt(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Lt", p.Term_LT, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Lt returns true if the first value is less than the second.
|
||||
func Lt(args ...interface{}) Term {
|
||||
return constructRootTerm("Lt", p.Term_LT, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Le returns true if the first value is less than or equal to the second.
|
||||
func (t Term) Le(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Le", p.Term_LE, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Le returns true if the first value is less than or equal to the second.
|
||||
func Le(args ...interface{}) Term {
|
||||
return constructRootTerm("Le", p.Term_LE, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Not performs a logical not on a value.
|
||||
func (t Term) Not(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Not", p.Term_NOT, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Not performs a logical not on a value.
|
||||
func Not(args ...interface{}) Term {
|
||||
return constructRootTerm("Not", p.Term_NOT, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// RandomOpts contains the optional arguments for the Random term.
|
||||
type RandomOpts struct {
|
||||
Float interface{} `gorethink:"float,omitempty"`
|
||||
}
|
||||
|
||||
func (o RandomOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Random generates a random number between given (or implied) bounds. Random
|
||||
// takes zero, one or two arguments.
|
||||
//
|
||||
// With zero arguments, the result will be a floating-point number in the range
|
||||
// [0,1).
|
||||
//
|
||||
// With one argument x, the result will be in the range [0,x), and will be an
|
||||
// integer unless the Float option is set to true. Specifying a floating point
|
||||
// number without the Float option will raise an error.
|
||||
//
|
||||
// With two arguments x and y, the result will be in the range [x,y), and will
|
||||
// be an integer unless the Float option is set to true. If x and y are equal an
|
||||
// error will occur, unless the floating-point option has been specified, in
|
||||
// which case x will be returned. Specifying a floating point number without the
|
||||
// float option will raise an error.
|
||||
//
|
||||
// Note: Any integer responses can be be coerced to floating-points, when
|
||||
// unmarshaling to a Go floating-point type. The last argument given will always
|
||||
// be the ‘open’ side of the range, but when generating a floating-point
|
||||
// number, the ‘open’ side may be less than the ‘closed’ side.
|
||||
func Random(args ...interface{}) Term {
|
||||
var opts = map[string]interface{}{}
|
||||
|
||||
// Look for options map
|
||||
if len(args) > 0 {
|
||||
if possibleOpts, ok := args[len(args)-1].(RandomOpts); ok {
|
||||
opts = possibleOpts.toMap()
|
||||
args = args[:len(args)-1]
|
||||
}
|
||||
}
|
||||
|
||||
return constructRootTerm("Random", p.Term_RANDOM, args, opts)
|
||||
}
|
||||
|
||||
// Round causes the input number to be rounded the given value to the nearest whole integer.
|
||||
func (t Term) Round(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Round", p.Term_ROUND, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Round causes the input number to be rounded the given value to the nearest whole integer.
|
||||
func Round(args ...interface{}) Term {
|
||||
return constructRootTerm("Round", p.Term_ROUND, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Ceil rounds the given value up, returning the smallest integer value greater
|
||||
// than or equal to the given value (the value’s ceiling).
|
||||
func (t Term) Ceil(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Ceil", p.Term_CEIL, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Ceil rounds the given value up, returning the smallest integer value greater
|
||||
// than or equal to the given value (the value’s ceiling).
|
||||
func Ceil(args ...interface{}) Term {
|
||||
return constructRootTerm("Ceil", p.Term_CEIL, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Floor rounds the given value down, returning the largest integer value less
|
||||
// than or equal to the given value (the value’s floor).
|
||||
func (t Term) Floor(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Floor", p.Term_FLOOR, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Floor rounds the given value down, returning the largest integer value less
|
||||
// than or equal to the given value (the value’s floor).
|
||||
func Floor(args ...interface{}) Term {
|
||||
return constructRootTerm("Floor", p.Term_FLOOR, args, map[string]interface{}{})
|
||||
}
|
|
@ -0,0 +1,141 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
// DB references a database.
|
||||
func DB(args ...interface{}) Term {
|
||||
return constructRootTerm("DB", p.Term_DB, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// TableOpts contains the optional arguments for the Table term
|
||||
type TableOpts struct {
|
||||
ReadMode interface{} `gorethink:"read_mode,omitempty"`
|
||||
UseOutdated interface{} `gorethink:"use_outdated,omitempty"` // Deprecated
|
||||
IdentifierFormat interface{} `gorethink:"identifier_format,omitempty"`
|
||||
}
|
||||
|
||||
func (o TableOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Table selects all documents in a table. This command can be chained with
|
||||
// other commands to do further processing on the data.
|
||||
//
|
||||
// There are two optional arguments.
|
||||
// - useOutdated: if true, this allows potentially out-of-date data to be
|
||||
// returned, with potentially faster reads. It also allows you to perform reads
|
||||
// from a secondary replica if a primary has failed. Default false.
|
||||
// - identifierFormat: possible values are name and uuid, with a default of name.
|
||||
// If set to uuid, then system tables will refer to servers, databases and tables
|
||||
// by UUID rather than name. (This only has an effect when used with system tables.)
|
||||
func Table(name interface{}, optArgs ...TableOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructRootTerm("Table", p.Term_TABLE, []interface{}{name}, opts)
|
||||
}
|
||||
|
||||
// Table selects all documents in a table. This command can be chained with
|
||||
// other commands to do further processing on the data.
|
||||
//
|
||||
// There are two optional arguments.
|
||||
// - useOutdated: if true, this allows potentially out-of-date data to be
|
||||
// returned, with potentially faster reads. It also allows you to perform reads
|
||||
// from a secondary replica if a primary has failed. Default false.
|
||||
// - identifierFormat: possible values are name and uuid, with a default of name.
|
||||
// If set to uuid, then system tables will refer to servers, databases and tables
|
||||
// by UUID rather than name. (This only has an effect when used with system tables.)
|
||||
func (t Term) Table(name interface{}, optArgs ...TableOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructMethodTerm(t, "Table", p.Term_TABLE, []interface{}{name}, opts)
|
||||
}
|
||||
|
||||
// Get gets a document by primary key. If nothing was found, RethinkDB will return a nil value.
|
||||
func (t Term) Get(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Get", p.Term_GET, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// GetAllOpts contains the optional arguments for the GetAll term
|
||||
type GetAllOpts struct {
|
||||
Index interface{} `gorethink:"index,omitempty"`
|
||||
}
|
||||
|
||||
func (o GetAllOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// GetAll gets all documents where the given value matches the value of the primary
|
||||
// index. Multiple values can be passed this function if you want to select multiple
|
||||
// documents. If the documents you are fetching have composite keys then each
|
||||
// argument should be a slice. For more information see the examples.
|
||||
func (t Term) GetAll(keys ...interface{}) Term {
|
||||
return constructMethodTerm(t, "GetAll", p.Term_GET_ALL, keys, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// GetAllByIndex gets all documents where the given value matches the value of
|
||||
// the requested index.
|
||||
func (t Term) GetAllByIndex(index interface{}, keys ...interface{}) Term {
|
||||
return constructMethodTerm(t, "GetAll", p.Term_GET_ALL, keys, map[string]interface{}{"index": index})
|
||||
}
|
||||
|
||||
// BetweenOpts contains the optional arguments for the Between term
|
||||
type BetweenOpts struct {
|
||||
Index interface{} `gorethink:"index,omitempty"`
|
||||
LeftBound interface{} `gorethink:"left_bound,omitempty"`
|
||||
RightBound interface{} `gorethink:"right_bound,omitempty"`
|
||||
}
|
||||
|
||||
func (o BetweenOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Between gets all documents between two keys. Accepts three optional arguments:
|
||||
// index, leftBound, and rightBound. If index is set to the name of a secondary
|
||||
// index, between will return all documents where that index’s value is in the
|
||||
// specified range (it uses the primary key by default). leftBound or rightBound
|
||||
// may be set to open or closed to indicate whether or not to include that endpoint
|
||||
// of the range (by default, leftBound is closed and rightBound is open).
|
||||
//
|
||||
// You may also use the special constants r.minval and r.maxval for boundaries,
|
||||
// which represent “less than any index key” and “more than any index key”
|
||||
// respectively. For instance, if you use r.minval as the lower key, then between
|
||||
// will return all documents whose primary keys (or indexes) are less than the
|
||||
// specified upper key.
|
||||
func (t Term) Between(lowerKey, upperKey interface{}, optArgs ...BetweenOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructMethodTerm(t, "Between", p.Term_BETWEEN, []interface{}{lowerKey, upperKey}, opts)
|
||||
}
|
||||
|
||||
// FilterOpts contains the optional arguments for the Filter term
|
||||
type FilterOpts struct {
|
||||
Default interface{} `gorethink:"default,omitempty"`
|
||||
}
|
||||
|
||||
func (o FilterOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Filter gets all the documents for which the given predicate is true.
|
||||
//
|
||||
// Filter can be called on a sequence, selection, or a field containing an array
|
||||
// of elements. The return type is the same as the type on which the function was
|
||||
// called on. The body of every filter is wrapped in an implicit `.default(false)`,
|
||||
// and the default value can be changed by passing the optional argument `default`.
|
||||
// Setting this optional argument to `r.error()` will cause any non-existence
|
||||
// errors to abort the filter.
|
||||
func (t Term) Filter(f interface{}, optArgs ...FilterOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructMethodTerm(t, "Filter", p.Term_FILTER, []interface{}{funcWrap(f)}, opts)
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
// Match matches against a regular expression. If no match is found, returns
|
||||
// null. If there is a match then an object with the following fields is
|
||||
// returned:
|
||||
// str: The matched string
|
||||
// start: The matched string’s start
|
||||
// end: The matched string’s end
|
||||
// groups: The capture groups defined with parentheses
|
||||
//
|
||||
// Accepts RE2 syntax (https://code.google.com/p/re2/wiki/Syntax). You can
|
||||
// enable case-insensitive matching by prefixing the regular expression with
|
||||
// (?i). See the linked RE2 documentation for more flags.
|
||||
//
|
||||
// The match command does not support backreferences.
|
||||
func (t Term) Match(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Match", p.Term_MATCH, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Split splits a string into substrings. Splits on whitespace when called with no arguments.
|
||||
// When called with a separator, splits on that separator. When called with a separator
|
||||
// and a maximum number of splits, splits on that separator at most max_splits times.
|
||||
// (Can be called with null as the separator if you want to split on whitespace while still
|
||||
// specifying max_splits.)
|
||||
//
|
||||
// Mimics the behavior of Python's string.split in edge cases, except for splitting on the
|
||||
// empty string, which instead produces an array of single-character strings.
|
||||
func (t Term) Split(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Split", p.Term_SPLIT, funcWrapArgs(args), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Upcase upper-cases a string.
|
||||
func (t Term) Upcase(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Upcase", p.Term_UPCASE, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Downcase lower-cases a string.
|
||||
func (t Term) Downcase(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Downcase", p.Term_DOWNCASE, args, map[string]interface{}{})
|
||||
}
|
|
@ -0,0 +1,173 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
// TableCreateOpts contains the optional arguments for the TableCreate term
|
||||
type TableCreateOpts struct {
|
||||
PrimaryKey interface{} `gorethink:"primary_key,omitempty"`
|
||||
Durability interface{} `gorethink:"durability,omitempty"`
|
||||
Shards interface{} `gorethink:"shards,omitempty"`
|
||||
Replicas interface{} `gorethink:"replicas,omitempty"`
|
||||
PrimaryReplicaTag interface{} `gorethink:"primary_replica_tag,omitempty"`
|
||||
NonVotingReplicaTags interface{} `gorethink:"nonvoting_replica_tags,omitempty"`
|
||||
}
|
||||
|
||||
func (o TableCreateOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// TableCreate creates a table. A RethinkDB table is a collection of JSON
|
||||
// documents.
|
||||
//
|
||||
// Note: Only alphanumeric characters and underscores are valid for the table name.
|
||||
func TableCreate(name interface{}, optArgs ...TableCreateOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructRootTerm("TableCreate", p.Term_TABLE_CREATE, []interface{}{name}, opts)
|
||||
}
|
||||
|
||||
// TableCreate creates a table. A RethinkDB table is a collection of JSON
|
||||
// documents.
|
||||
//
|
||||
// Note: Only alphanumeric characters and underscores are valid for the table name.
|
||||
func (t Term) TableCreate(name interface{}, optArgs ...TableCreateOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructMethodTerm(t, "TableCreate", p.Term_TABLE_CREATE, []interface{}{name}, opts)
|
||||
}
|
||||
|
||||
// TableDrop deletes a table. The table and all its data will be deleted.
|
||||
func TableDrop(args ...interface{}) Term {
|
||||
return constructRootTerm("TableDrop", p.Term_TABLE_DROP, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// TableDrop deletes a table. The table and all its data will be deleted.
|
||||
func (t Term) TableDrop(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "TableDrop", p.Term_TABLE_DROP, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// TableList lists all table names in a database.
|
||||
func TableList(args ...interface{}) Term {
|
||||
return constructRootTerm("TableList", p.Term_TABLE_LIST, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// TableList lists all table names in a database.
|
||||
func (t Term) TableList(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "TableList", p.Term_TABLE_LIST, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// IndexCreateOpts contains the optional arguments for the IndexCreate term
|
||||
type IndexCreateOpts struct {
|
||||
Multi interface{} `gorethink:"multi,omitempty"`
|
||||
Geo interface{} `gorethink:"geo,omitempty"`
|
||||
}
|
||||
|
||||
func (o IndexCreateOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// IndexCreate creates a new secondary index on a table. Secondary indexes
|
||||
// improve the speed of many read queries at the slight cost of increased
|
||||
// storage space and decreased write performance.
|
||||
//
|
||||
// IndexCreate supports the creation of the following types of indexes, to create
|
||||
// indexes using arbitrary expressions use IndexCreateFunc.
|
||||
// - Simple indexes based on the value of a single field.
|
||||
// - Geospatial indexes based on indexes of geometry objects, created when the
|
||||
// geo optional argument is true.
|
||||
func (t Term) IndexCreate(name interface{}, optArgs ...IndexCreateOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructMethodTerm(t, "IndexCreate", p.Term_INDEX_CREATE, []interface{}{name}, opts)
|
||||
}
|
||||
|
||||
// IndexCreateFunc creates a new secondary index on a table. Secondary indexes
|
||||
// improve the speed of many read queries at the slight cost of increased
|
||||
// storage space and decreased write performance. The function takes a index
|
||||
// name and RQL term as the index value , the term can be an anonymous function
|
||||
// or a binary representation obtained from the function field of indexStatus.
|
||||
//
|
||||
// It supports the creation of the following types of indexes.
|
||||
// - Simple indexes based on the value of a single field where the index has a
|
||||
// different name to the field.
|
||||
// - Compound indexes based on multiple fields.
|
||||
// - Multi indexes based on arrays of values, created when the multi optional argument is true.
|
||||
func (t Term) IndexCreateFunc(name, indexFunction interface{}, optArgs ...IndexCreateOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructMethodTerm(t, "IndexCreate", p.Term_INDEX_CREATE, []interface{}{name, funcWrap(indexFunction)}, opts)
|
||||
}
|
||||
|
||||
// IndexDrop deletes a previously created secondary index of a table.
|
||||
func (t Term) IndexDrop(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "IndexDrop", p.Term_INDEX_DROP, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// IndexList lists all the secondary indexes of a table.
|
||||
func (t Term) IndexList(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "IndexList", p.Term_INDEX_LIST, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// IndexRenameOpts contains the optional arguments for the IndexRename term
|
||||
type IndexRenameOpts struct {
|
||||
Overwrite interface{} `gorethink:"overwrite,omitempty"`
|
||||
}
|
||||
|
||||
func (o IndexRenameOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// IndexRename renames an existing secondary index on a table.
|
||||
func (t Term) IndexRename(oldName, newName interface{}, optArgs ...IndexRenameOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructMethodTerm(t, "IndexRename", p.Term_INDEX_RENAME, []interface{}{oldName, newName}, opts)
|
||||
}
|
||||
|
||||
// IndexStatus gets the status of the specified indexes on this table, or the
|
||||
// status of all indexes on this table if no indexes are specified.
|
||||
func (t Term) IndexStatus(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "IndexStatus", p.Term_INDEX_STATUS, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// IndexWait waits for the specified indexes on this table to be ready, or for
|
||||
// all indexes on this table to be ready if no indexes are specified.
|
||||
func (t Term) IndexWait(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "IndexWait", p.Term_INDEX_WAIT, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// ChangesOpts contains the optional arguments for the Changes term
|
||||
type ChangesOpts struct {
|
||||
Squash interface{} `gorethink:"squash,omitempty"`
|
||||
IncludeInitial interface{} `gorethink:"include_initial,omitempty"`
|
||||
IncludeStates interface{} `gorethink:"include_states,omitempty"`
|
||||
IncludeOffsets interface{} `gorethink:"include_offsets,omitempty"`
|
||||
IncludeTypes interface{} `gorethink:"include_types,omitempty"`
|
||||
ChangefeedQueueSize interface{} `gorethink:"changefeed_queue_size,omitempty"`
|
||||
}
|
||||
|
||||
// ChangesOpts contains the optional arguments for the Changes term
|
||||
func (o ChangesOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Changes returns an infinite stream of objects representing changes to a query.
|
||||
func (t Term) Changes(optArgs ...ChangesOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructMethodTerm(t, "Changes", p.Term_CHANGES, []interface{}{}, opts)
|
||||
}
|
|
@ -0,0 +1,187 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
// Now returns a time object representing the current time in UTC
|
||||
func Now(args ...interface{}) Term {
|
||||
return constructRootTerm("Now", p.Term_NOW, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Time creates a time object for a specific time
|
||||
func Time(args ...interface{}) Term {
|
||||
return constructRootTerm("Time", p.Term_TIME, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// EpochTime returns a time object based on seconds since epoch
|
||||
func EpochTime(args ...interface{}) Term {
|
||||
return constructRootTerm("EpochTime", p.Term_EPOCH_TIME, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// ISO8601Opts contains the optional arguments for the ISO8601 term
|
||||
type ISO8601Opts struct {
|
||||
DefaultTimezone interface{} `gorethink:"default_timezone,omitempty"`
|
||||
}
|
||||
|
||||
func (o ISO8601Opts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// ISO8601 returns a time object based on an ISO8601 formatted date-time string
|
||||
func ISO8601(date interface{}, optArgs ...ISO8601Opts) Term {
|
||||
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructRootTerm("ISO8601", p.Term_ISO8601, []interface{}{date}, opts)
|
||||
}
|
||||
|
||||
// InTimezone returns a new time object with a different time zone. While the
|
||||
// time stays the same, the results returned by methods such as hours() will
|
||||
// change since they take the timezone into account. The timezone argument
|
||||
// has to be of the ISO 8601 format.
|
||||
func (t Term) InTimezone(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "InTimezone", p.Term_IN_TIMEZONE, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Timezone returns the timezone of the time object
|
||||
func (t Term) Timezone(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Timezone", p.Term_TIMEZONE, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// DuringOpts contains the optional arguments for the During term
|
||||
type DuringOpts struct {
|
||||
LeftBound interface{} `gorethink:"left_bound,omitempty"`
|
||||
RightBound interface{} `gorethink:"right_bound,omitempty"`
|
||||
}
|
||||
|
||||
func (o DuringOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// During returns true if a time is between two other times
|
||||
// (by default, inclusive for the start, exclusive for the end).
|
||||
func (t Term) During(startTime, endTime interface{}, optArgs ...DuringOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructMethodTerm(t, "During", p.Term_DURING, []interface{}{startTime, endTime}, opts)
|
||||
}
|
||||
|
||||
// Date returns a new time object only based on the day, month and year
|
||||
// (ie. the same day at 00:00).
|
||||
func (t Term) Date(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Date", p.Term_DATE, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// TimeOfDay returns the number of seconds elapsed since the beginning of the
|
||||
// day stored in the time object.
|
||||
func (t Term) TimeOfDay(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "TimeOfDay", p.Term_TIME_OF_DAY, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Year returns the year of a time object.
|
||||
func (t Term) Year(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Year", p.Term_YEAR, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Month returns the month of a time object as a number between 1 and 12.
|
||||
// For your convenience, the terms r.January(), r.February() etc. are
|
||||
// defined and map to the appropriate integer.
|
||||
func (t Term) Month(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Month", p.Term_MONTH, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Day return the day of a time object as a number between 1 and 31.
|
||||
func (t Term) Day(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Day", p.Term_DAY, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// DayOfWeek returns the day of week of a time object as a number between
|
||||
// 1 and 7 (following ISO 8601 standard). For your convenience,
|
||||
// the terms r.Monday(), r.Tuesday() etc. are defined and map to
|
||||
// the appropriate integer.
|
||||
func (t Term) DayOfWeek(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "DayOfWeek", p.Term_DAY_OF_WEEK, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// DayOfYear returns the day of the year of a time object as a number between
|
||||
// 1 and 366 (following ISO 8601 standard).
|
||||
func (t Term) DayOfYear(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "DayOfYear", p.Term_DAY_OF_YEAR, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Hours returns the hour in a time object as a number between 0 and 23.
|
||||
func (t Term) Hours(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Hours", p.Term_HOURS, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Minutes returns the minute in a time object as a number between 0 and 59.
|
||||
func (t Term) Minutes(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Minutes", p.Term_MINUTES, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Seconds returns the seconds in a time object as a number between 0 and
|
||||
// 59.999 (double precision).
|
||||
func (t Term) Seconds(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Seconds", p.Term_SECONDS, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// ToISO8601 converts a time object to its iso 8601 format.
|
||||
func (t Term) ToISO8601(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "ToISO8601", p.Term_TO_ISO8601, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// ToEpochTime converts a time object to its epoch time.
|
||||
func (t Term) ToEpochTime(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "ToEpochTime", p.Term_TO_EPOCH_TIME, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
var (
|
||||
// Days
|
||||
|
||||
// Monday is a constant representing the day of the week Monday
|
||||
Monday = constructRootTerm("Monday", p.Term_MONDAY, []interface{}{}, map[string]interface{}{})
|
||||
// Tuesday is a constant representing the day of the week Tuesday
|
||||
Tuesday = constructRootTerm("Tuesday", p.Term_TUESDAY, []interface{}{}, map[string]interface{}{})
|
||||
// Wednesday is a constant representing the day of the week Wednesday
|
||||
Wednesday = constructRootTerm("Wednesday", p.Term_WEDNESDAY, []interface{}{}, map[string]interface{}{})
|
||||
// Thursday is a constant representing the day of the week Thursday
|
||||
Thursday = constructRootTerm("Thursday", p.Term_THURSDAY, []interface{}{}, map[string]interface{}{})
|
||||
// Friday is a constant representing the day of the week Friday
|
||||
Friday = constructRootTerm("Friday", p.Term_FRIDAY, []interface{}{}, map[string]interface{}{})
|
||||
// Saturday is a constant representing the day of the week Saturday
|
||||
Saturday = constructRootTerm("Saturday", p.Term_SATURDAY, []interface{}{}, map[string]interface{}{})
|
||||
// Sunday is a constant representing the day of the week Sunday
|
||||
Sunday = constructRootTerm("Sunday", p.Term_SUNDAY, []interface{}{}, map[string]interface{}{})
|
||||
|
||||
// Months
|
||||
|
||||
// January is a constant representing the month January
|
||||
January = constructRootTerm("January", p.Term_JANUARY, []interface{}{}, map[string]interface{}{})
|
||||
// February is a constant representing the month February
|
||||
February = constructRootTerm("February", p.Term_FEBRUARY, []interface{}{}, map[string]interface{}{})
|
||||
// March is a constant representing the month March
|
||||
March = constructRootTerm("March", p.Term_MARCH, []interface{}{}, map[string]interface{}{})
|
||||
// April is a constant representing the month April
|
||||
April = constructRootTerm("April", p.Term_APRIL, []interface{}{}, map[string]interface{}{})
|
||||
// May is a constant representing the month May
|
||||
May = constructRootTerm("May", p.Term_MAY, []interface{}{}, map[string]interface{}{})
|
||||
// June is a constant representing the month June
|
||||
June = constructRootTerm("June", p.Term_JUNE, []interface{}{}, map[string]interface{}{})
|
||||
// July is a constant representing the month July
|
||||
July = constructRootTerm("July", p.Term_JULY, []interface{}{}, map[string]interface{}{})
|
||||
// August is a constant representing the month August
|
||||
August = constructRootTerm("August", p.Term_AUGUST, []interface{}{}, map[string]interface{}{})
|
||||
// September is a constant representing the month September
|
||||
September = constructRootTerm("September", p.Term_SEPTEMBER, []interface{}{}, map[string]interface{}{})
|
||||
// October is a constant representing the month October
|
||||
October = constructRootTerm("October", p.Term_OCTOBER, []interface{}{}, map[string]interface{}{})
|
||||
// November is a constant representing the month November
|
||||
November = constructRootTerm("November", p.Term_NOVEMBER, []interface{}{}, map[string]interface{}{})
|
||||
// December is a constant representing the month December
|
||||
December = constructRootTerm("December", p.Term_DECEMBER, []interface{}{}, map[string]interface{}{})
|
||||
)
|
|
@ -0,0 +1,193 @@
|
|||
package gorethink
|
||||
|
||||
import p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
|
||||
// Map transform each element of the sequence by applying the given mapping
|
||||
// function. It takes two arguments, a sequence and a function of type
|
||||
// `func (r.Term) interface{}`.
|
||||
//
|
||||
// For example this query doubles each element in an array:
|
||||
//
|
||||
// r.Map([]int{1,3,6}, func (row r.Term) interface{} {
|
||||
// return row.Mul(2)
|
||||
// })
|
||||
func Map(args ...interface{}) Term {
|
||||
if len(args) > 0 {
|
||||
args = append(args[:len(args)-1], funcWrap(args[len(args)-1]))
|
||||
}
|
||||
|
||||
return constructRootTerm("Map", p.Term_MAP, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Map transforms each element of the sequence by applying the given mapping
|
||||
// function. It takes one argument of type `func (r.Term) interface{}`.
|
||||
//
|
||||
// For example this query doubles each element in an array:
|
||||
//
|
||||
// r.Expr([]int{1,3,6}).Map(func (row r.Term) interface{} {
|
||||
// return row.Mul(2)
|
||||
// })
|
||||
func (t Term) Map(args ...interface{}) Term {
|
||||
if len(args) > 0 {
|
||||
args = append(args[:len(args)-1], funcWrap(args[len(args)-1]))
|
||||
}
|
||||
|
||||
return constructMethodTerm(t, "Map", p.Term_MAP, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// WithFields takes a sequence of objects and a list of fields. If any objects in the
|
||||
// sequence don't have all of the specified fields, they're dropped from the
|
||||
// sequence. The remaining objects have the specified fields plucked out.
|
||||
// (This is identical to `HasFields` followed by `Pluck` on a sequence.)
|
||||
func (t Term) WithFields(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "WithFields", p.Term_WITH_FIELDS, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// ConcatMap concatenates one or more elements into a single sequence using a
|
||||
// mapping function. ConcatMap works in a similar fashion to Map, applying the
|
||||
// given function to each element in a sequence, but it will always return a
|
||||
// single sequence.
|
||||
func (t Term) ConcatMap(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "ConcatMap", p.Term_CONCAT_MAP, funcWrapArgs(args), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// OrderByOpts contains the optional arguments for the OrderBy term
|
||||
type OrderByOpts struct {
|
||||
Index interface{} `gorethink:"index,omitempty"`
|
||||
}
|
||||
|
||||
func (o OrderByOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// OrderBy sorts the sequence by document values of the given key(s). To specify
|
||||
// the ordering, wrap the attribute with either r.Asc or r.Desc (defaults to
|
||||
// ascending).
|
||||
//
|
||||
// Sorting without an index requires the server to hold the sequence in memory,
|
||||
// and is limited to 100,000 documents (or the setting of the ArrayLimit option
|
||||
// for run). Sorting with an index can be done on arbitrarily large tables, or
|
||||
// after a between command using the same index.
|
||||
func (t Term) OrderBy(args ...interface{}) Term {
|
||||
var opts = map[string]interface{}{}
|
||||
|
||||
// Look for options map
|
||||
if len(args) > 0 {
|
||||
if possibleOpts, ok := args[len(args)-1].(OrderByOpts); ok {
|
||||
opts = possibleOpts.toMap()
|
||||
args = args[:len(args)-1]
|
||||
}
|
||||
}
|
||||
|
||||
for k, arg := range args {
|
||||
if t, ok := arg.(Term); !(ok && (t.termType == p.Term_DESC || t.termType == p.Term_ASC)) {
|
||||
args[k] = funcWrap(arg)
|
||||
}
|
||||
}
|
||||
|
||||
return constructMethodTerm(t, "OrderBy", p.Term_ORDER_BY, args, opts)
|
||||
}
|
||||
|
||||
// Desc is used by the OrderBy term to specify the ordering to be descending.
|
||||
func Desc(args ...interface{}) Term {
|
||||
return constructRootTerm("Desc", p.Term_DESC, funcWrapArgs(args), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Asc is used by the OrderBy term to specify that the ordering be ascending (the
|
||||
// default).
|
||||
func Asc(args ...interface{}) Term {
|
||||
return constructRootTerm("Asc", p.Term_ASC, funcWrapArgs(args), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Skip skips a number of elements from the head of the sequence.
|
||||
func (t Term) Skip(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Skip", p.Term_SKIP, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Limit ends the sequence after the given number of elements.
|
||||
func (t Term) Limit(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Limit", p.Term_LIMIT, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// SliceOpts contains the optional arguments for the Slice term
|
||||
type SliceOpts struct {
|
||||
LeftBound interface{} `gorethink:"left_bound,omitempty"`
|
||||
RightBound interface{} `gorethink:"right_bound,omitempty"`
|
||||
}
|
||||
|
||||
func (o SliceOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Slice trims the sequence to within the bounds provided.
|
||||
func (t Term) Slice(args ...interface{}) Term {
|
||||
var opts = map[string]interface{}{}
|
||||
|
||||
// Look for options map
|
||||
if len(args) > 0 {
|
||||
if possibleOpts, ok := args[len(args)-1].(SliceOpts); ok {
|
||||
opts = possibleOpts.toMap()
|
||||
args = args[:len(args)-1]
|
||||
}
|
||||
}
|
||||
|
||||
return constructMethodTerm(t, "Slice", p.Term_SLICE, args, opts)
|
||||
}
|
||||
|
||||
// AtIndex gets a single field from an object or the nth element from a sequence.
|
||||
func (t Term) AtIndex(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "AtIndex", p.Term_BRACKET, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Nth gets the nth element from a sequence.
|
||||
func (t Term) Nth(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Nth", p.Term_NTH, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// OffsetsOf gets the indexes of an element in a sequence. If the argument is a
|
||||
// predicate, get the indexes of all elements matching it.
|
||||
func (t Term) OffsetsOf(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "OffsetsOf", p.Term_OFFSETS_OF, funcWrapArgs(args), map[string]interface{}{})
|
||||
}
|
||||
|
||||
// IsEmpty tests if a sequence is empty.
|
||||
func (t Term) IsEmpty(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "IsEmpty", p.Term_IS_EMPTY, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// UnionOpts contains the optional arguments for the Slice term
|
||||
type UnionOpts struct {
|
||||
Interleave interface{} `gorethink:"interleave,omitempty"`
|
||||
}
|
||||
|
||||
func (o UnionOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Union concatenates two sequences.
|
||||
func Union(args ...interface{}) Term {
|
||||
return constructRootTerm("Union", p.Term_UNION, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Union concatenates two sequences.
|
||||
func (t Term) Union(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Union", p.Term_UNION, args, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// UnionWithOpts like Union concatenates two sequences however allows for optional
|
||||
// arguments to be passed.
|
||||
func UnionWithOpts(optArgs UnionOpts, args ...interface{}) Term {
|
||||
return constructRootTerm("Union", p.Term_UNION, args, optArgs.toMap())
|
||||
}
|
||||
|
||||
// UnionWithOpts like Union concatenates two sequences however allows for optional
|
||||
// arguments to be passed.
|
||||
func (t Term) UnionWithOpts(optArgs UnionOpts, args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Union", p.Term_UNION, args, optArgs.toMap())
|
||||
}
|
||||
|
||||
// Sample selects a given number of elements from a sequence with uniform random
|
||||
// distribution. Selection is done without replacement.
|
||||
func (t Term) Sample(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Sample", p.Term_SAMPLE, args, map[string]interface{}{})
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
// InsertOpts contains the optional arguments for the Insert term
|
||||
type InsertOpts struct {
|
||||
Durability interface{} `gorethink:"durability,omitempty"`
|
||||
ReturnChanges interface{} `gorethink:"return_changes,omitempty"`
|
||||
Conflict interface{} `gorethink:"conflict,omitempty"`
|
||||
}
|
||||
|
||||
func (o InsertOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Insert documents into a table. Accepts a single document or an array
|
||||
// of documents.
|
||||
func (t Term) Insert(arg interface{}, optArgs ...InsertOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructMethodTerm(t, "Insert", p.Term_INSERT, []interface{}{Expr(arg)}, opts)
|
||||
}
|
||||
|
||||
// UpdateOpts contains the optional arguments for the Update term
|
||||
type UpdateOpts struct {
|
||||
Durability interface{} `gorethink:"durability,omitempty"`
|
||||
ReturnChanges interface{} `gorethink:"return_changes,omitempty"`
|
||||
NonAtomic interface{} `gorethink:"non_atomic,omitempty"`
|
||||
Conflict interface{} `gorethink:"conflict,omitempty"`
|
||||
}
|
||||
|
||||
func (o UpdateOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Update JSON documents in a table. Accepts a JSON document, a ReQL expression,
|
||||
// or a combination of the two. You can pass options like returnChanges that will
|
||||
// return the old and new values of the row you have modified.
|
||||
func (t Term) Update(arg interface{}, optArgs ...UpdateOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructMethodTerm(t, "Update", p.Term_UPDATE, []interface{}{funcWrap(arg)}, opts)
|
||||
}
|
||||
|
||||
// ReplaceOpts contains the optional arguments for the Replace term
|
||||
type ReplaceOpts struct {
|
||||
Durability interface{} `gorethink:"durability,omitempty"`
|
||||
ReturnChanges interface{} `gorethink:"return_changes,omitempty"`
|
||||
NonAtomic interface{} `gorethink:"non_atomic,omitempty"`
|
||||
}
|
||||
|
||||
func (o ReplaceOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Replace documents in a table. Accepts a JSON document or a ReQL expression,
|
||||
// and replaces the original document with the new one. The new document must
|
||||
// have the same primary key as the original document.
|
||||
func (t Term) Replace(arg interface{}, optArgs ...ReplaceOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructMethodTerm(t, "Replace", p.Term_REPLACE, []interface{}{funcWrap(arg)}, opts)
|
||||
}
|
||||
|
||||
// DeleteOpts contains the optional arguments for the Delete term
|
||||
type DeleteOpts struct {
|
||||
Durability interface{} `gorethink:"durability,omitempty"`
|
||||
ReturnChanges interface{} `gorethink:"return_changes,omitempty"`
|
||||
}
|
||||
|
||||
func (o DeleteOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Delete one or more documents from a table.
|
||||
func (t Term) Delete(optArgs ...DeleteOpts) Term {
|
||||
opts := map[string]interface{}{}
|
||||
if len(optArgs) >= 1 {
|
||||
opts = optArgs[0].toMap()
|
||||
}
|
||||
return constructMethodTerm(t, "Delete", p.Term_DELETE, []interface{}{}, opts)
|
||||
}
|
||||
|
||||
// Sync ensures that writes on a given table are written to permanent storage.
|
||||
// Queries that specify soft durability do not give such guarantees, so Sync
|
||||
// can be used to ensure the state of these queries. A call to Sync does not
|
||||
// return until all previous writes to the table are persisted.
|
||||
func (t Term) Sync(args ...interface{}) Term {
|
||||
return constructMethodTerm(t, "Sync", p.Term_SYNC, args, map[string]interface{}{})
|
||||
}
|
|
@ -0,0 +1,328 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
// A Session represents a connection to a RethinkDB cluster and should be used
|
||||
// when executing queries.
|
||||
type Session struct {
|
||||
hosts []Host
|
||||
opts *ConnectOpts
|
||||
|
||||
mu sync.RWMutex
|
||||
cluster *Cluster
|
||||
closed bool
|
||||
}
|
||||
|
||||
// ConnectOpts is used to specify optional arguments when connecting to a cluster.
|
||||
type ConnectOpts struct {
|
||||
// Address holds the address of the server initially used when creating the
|
||||
// session. Only used if Addresses is empty
|
||||
Address string `gorethink:"address,omitempty"`
|
||||
// Addresses holds the addresses of the servers initially used when creating
|
||||
// the session.
|
||||
Addresses []string `gorethink:"addresses,omitempty"`
|
||||
// Database is the default database name used when executing queries, this
|
||||
// value is only used if the query does not contain any DB term
|
||||
Database string `gorethink:"database,omitempty"`
|
||||
// Username holds the username used for authentication, if blank (and the v1
|
||||
// handshake protocol is being used) then the admin user is used
|
||||
Username string `gorethink:"username,omitempty"`
|
||||
// Password holds the password used for authentication (only used when using
|
||||
// the v1 handshake protocol)
|
||||
Password string `gorethink:"password,omitempty"`
|
||||
// AuthKey is used for authentication when using the v0.4 handshake protocol
|
||||
// This field is no deprecated
|
||||
AuthKey string `gorethink:"authkey,omitempty"`
|
||||
// Timeout is the time the driver waits when creating new connections, to
|
||||
// configure the timeout used when executing queries use WriteTimeout and
|
||||
// ReadTimeout
|
||||
Timeout time.Duration `gorethink:"timeout,omitempty"`
|
||||
// WriteTimeout is the amount of time the driver will wait when sending the
|
||||
// query to the server
|
||||
WriteTimeout time.Duration `gorethink:"write_timeout,omitempty"`
|
||||
// ReadTimeout is the amount of time the driver will wait for a response from
|
||||
// the server when executing queries.
|
||||
ReadTimeout time.Duration `gorethink:"read_timeout,omitempty"`
|
||||
// KeepAlivePeriod is the keep alive period used by the connection, by default
|
||||
// this is 30s. It is not possible to disable keep alive messages
|
||||
KeepAlivePeriod time.Duration `gorethink:"keep_alive_timeout,omitempty"`
|
||||
// TLSConfig holds the TLS configuration and can be used when connecting
|
||||
// to a RethinkDB server protected by SSL
|
||||
TLSConfig *tls.Config `gorethink:"tlsconfig,omitempty"`
|
||||
// HandshakeVersion is used to specify which handshake version should be
|
||||
// used, this currently defaults to v1 which is used by RethinkDB 2.3 and
|
||||
// later. If you are using an older version then you can set the handshake
|
||||
// version to 0.4
|
||||
HandshakeVersion HandshakeVersion `gorethink:"handshake_version,omitempty"`
|
||||
// UseJSONNumber indicates whether the cursors running in this session should
|
||||
// use json.Number instead of float64 while unmarshaling documents with
|
||||
// interface{}. The default is `false`.
|
||||
UseJSONNumber bool
|
||||
// NumRetries is the number of times a query is retried if a connection
|
||||
// error is detected, queries are not retried if RethinkDB returns a
|
||||
// runtime error.
|
||||
NumRetries int
|
||||
|
||||
// InitialCap is used by the internal connection pool and is used to
|
||||
// configure how many connections are created for each host when the
|
||||
// session is created. If zero then no connections are created until
|
||||
// the first query is executed.
|
||||
InitialCap int `gorethink:"initial_cap,omitempty"`
|
||||
// MaxOpen is used by the internal connection pool and is used to configure
|
||||
// the maximum number of connections held in the pool. If all available
|
||||
// connections are being used then the driver will open new connections as
|
||||
// needed however they will not be returned to the pool. By default the
|
||||
// maximum number of connections is 2
|
||||
MaxOpen int `gorethink:"max_open,omitempty"`
|
||||
|
||||
// Below options are for cluster discovery, please note there is a high
|
||||
// probability of these changing as the API is still being worked on.
|
||||
|
||||
// DiscoverHosts is used to enable host discovery, when true the driver
|
||||
// will attempt to discover any new nodes added to the cluster and then
|
||||
// start sending queries to these new nodes.
|
||||
DiscoverHosts bool `gorethink:"discover_hosts,omitempty"`
|
||||
// HostDecayDuration is used by the go-hostpool package to calculate a weighted
|
||||
// score when selecting a host. By default a value of 5 minutes is used.
|
||||
HostDecayDuration time.Duration
|
||||
|
||||
// Deprecated: This function is no longer used due to changes in the
|
||||
// way hosts are selected.
|
||||
NodeRefreshInterval time.Duration `gorethink:"node_refresh_interval,omitempty"`
|
||||
// Deprecated: Use InitialCap instead
|
||||
MaxIdle int `gorethink:"max_idle,omitempty"`
|
||||
}
|
||||
|
||||
func (o ConnectOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// Connect creates a new database session. To view the available connection
|
||||
// options see ConnectOpts.
|
||||
//
|
||||
// By default maxIdle and maxOpen are set to 1: passing values greater
|
||||
// than the default (e.g. MaxIdle: "10", MaxOpen: "20") will provide a
|
||||
// pool of re-usable connections.
|
||||
//
|
||||
// Basic connection example:
|
||||
//
|
||||
// session, err := r.Connect(r.ConnectOpts{
|
||||
// Host: "localhost:28015",
|
||||
// Database: "test",
|
||||
// AuthKey: "14daak1cad13dj",
|
||||
// })
|
||||
//
|
||||
// Cluster connection example:
|
||||
//
|
||||
// session, err := r.Connect(r.ConnectOpts{
|
||||
// Hosts: []string{"localhost:28015", "localhost:28016"},
|
||||
// Database: "test",
|
||||
// AuthKey: "14daak1cad13dj",
|
||||
// })
|
||||
func Connect(opts ConnectOpts) (*Session, error) {
|
||||
var addresses = opts.Addresses
|
||||
if len(addresses) == 0 {
|
||||
addresses = []string{opts.Address}
|
||||
}
|
||||
|
||||
hosts := make([]Host, len(addresses))
|
||||
for i, address := range addresses {
|
||||
hostname, port := splitAddress(address)
|
||||
hosts[i] = NewHost(hostname, port)
|
||||
}
|
||||
if len(hosts) <= 0 {
|
||||
return nil, ErrNoHosts
|
||||
}
|
||||
|
||||
// Connect
|
||||
s := &Session{
|
||||
hosts: hosts,
|
||||
opts: &opts,
|
||||
}
|
||||
|
||||
err := s.Reconnect()
|
||||
if err != nil {
|
||||
// note: s.Reconnect() will initialize cluster information which
|
||||
// will cause the .IsConnected() method to be caught in a loop
|
||||
return &Session{
|
||||
hosts: hosts,
|
||||
opts: &opts,
|
||||
}, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// CloseOpts allows calls to the Close function to be configured.
|
||||
type CloseOpts struct {
|
||||
NoReplyWait bool `gorethink:"noreplyWait,omitempty"`
|
||||
}
|
||||
|
||||
func (o CloseOpts) toMap() map[string]interface{} {
|
||||
return optArgsToMap(o)
|
||||
}
|
||||
|
||||
// IsConnected returns true if session has a valid connection.
|
||||
func (s *Session) IsConnected() bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.cluster == nil || s.closed {
|
||||
return false
|
||||
}
|
||||
return s.cluster.IsConnected()
|
||||
}
|
||||
|
||||
// Reconnect closes and re-opens a session.
|
||||
func (s *Session) Reconnect(optArgs ...CloseOpts) error {
|
||||
var err error
|
||||
|
||||
if err = s.Close(optArgs...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.cluster, err = NewCluster(s.hosts, s.opts)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
s.closed = false
|
||||
s.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the session
|
||||
func (s *Session) Close(optArgs ...CloseOpts) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(optArgs) >= 1 {
|
||||
if optArgs[0].NoReplyWait {
|
||||
s.mu.Unlock()
|
||||
s.NoReplyWait()
|
||||
s.mu.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
if s.cluster != nil {
|
||||
s.cluster.Close()
|
||||
}
|
||||
s.cluster = nil
|
||||
s.closed = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetInitialPoolCap sets the initial capacity of the connection pool.
|
||||
func (s *Session) SetInitialPoolCap(n int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.opts.InitialCap = n
|
||||
s.cluster.SetInitialPoolCap(n)
|
||||
}
|
||||
|
||||
// SetMaxIdleConns sets the maximum number of connections in the idle
|
||||
// connection pool.
|
||||
func (s *Session) SetMaxIdleConns(n int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.opts.MaxIdle = n
|
||||
s.cluster.SetMaxIdleConns(n)
|
||||
}
|
||||
|
||||
// SetMaxOpenConns sets the maximum number of open connections to the database.
|
||||
func (s *Session) SetMaxOpenConns(n int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.opts.MaxOpen = n
|
||||
s.cluster.SetMaxOpenConns(n)
|
||||
}
|
||||
|
||||
// NoReplyWait ensures that previous queries with the noreply flag have been
|
||||
// processed by the server. Note that this guarantee only applies to queries
|
||||
// run on the given connection
|
||||
func (s *Session) NoReplyWait() error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.closed {
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
|
||||
return s.cluster.Exec(Query{
|
||||
Type: p.Query_NOREPLY_WAIT,
|
||||
})
|
||||
}
|
||||
|
||||
// Use changes the default database used
|
||||
func (s *Session) Use(database string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.opts.Database = database
|
||||
}
|
||||
|
||||
// Database returns the selected database set by Use
|
||||
func (s *Session) Database() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
return s.opts.Database
|
||||
}
|
||||
|
||||
// Query executes a ReQL query using the session to connect to the database
|
||||
func (s *Session) Query(q Query) (*Cursor, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.closed {
|
||||
return nil, ErrConnectionClosed
|
||||
}
|
||||
|
||||
return s.cluster.Query(q)
|
||||
}
|
||||
|
||||
// Exec executes a ReQL query using the session to connect to the database
|
||||
func (s *Session) Exec(q Query) error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.closed {
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
|
||||
return s.cluster.Exec(q)
|
||||
}
|
||||
|
||||
// Server returns the server name and server UUID being used by a connection.
|
||||
func (s *Session) Server() (ServerResponse, error) {
|
||||
return s.cluster.Server()
|
||||
}
|
||||
|
||||
// SetHosts resets the hosts used when connecting to the RethinkDB cluster
|
||||
func (s *Session) SetHosts(hosts []Host) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.hosts = hosts
|
||||
}
|
||||
|
||||
func (s *Session) newQuery(t Term, opts map[string]interface{}) (Query, error) {
|
||||
return newQuery(t, opts, s.opts)
|
||||
}
|
|
@ -0,0 +1,283 @@
|
|||
package gorethink
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"gopkg.in/gorethink/gorethink.v2/encoding"
|
||||
p "gopkg.in/gorethink/gorethink.v2/ql2"
|
||||
)
|
||||
|
||||
// Helper functions for constructing terms
|
||||
|
||||
// constructRootTerm is an alias for creating a new term.
|
||||
func constructRootTerm(name string, termType p.Term_TermType, args []interface{}, optArgs map[string]interface{}) Term {
|
||||
return Term{
|
||||
name: name,
|
||||
rootTerm: true,
|
||||
termType: termType,
|
||||
args: convertTermList(args),
|
||||
optArgs: convertTermObj(optArgs),
|
||||
}
|
||||
}
|
||||
|
||||
// constructMethodTerm is an alias for creating a new term. Unlike constructRootTerm
|
||||
// this function adds the previous expression in the tree to the argument list to
|
||||
// create a method term.
|
||||
func constructMethodTerm(prevVal Term, name string, termType p.Term_TermType, args []interface{}, optArgs map[string]interface{}) Term {
|
||||
args = append([]interface{}{prevVal}, args...)
|
||||
|
||||
return Term{
|
||||
name: name,
|
||||
rootTerm: false,
|
||||
termType: termType,
|
||||
args: convertTermList(args),
|
||||
optArgs: convertTermObj(optArgs),
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for creating internal RQL types
|
||||
|
||||
func newQuery(t Term, qopts map[string]interface{}, copts *ConnectOpts) (q Query, err error) {
|
||||
queryOpts := map[string]interface{}{}
|
||||
for k, v := range qopts {
|
||||
queryOpts[k], err = Expr(v).Build()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if copts.Database != "" {
|
||||
queryOpts["db"], err = DB(copts.Database).Build()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
builtTerm, err := t.Build()
|
||||
if err != nil {
|
||||
return q, err
|
||||
}
|
||||
|
||||
// Construct query
|
||||
return Query{
|
||||
Type: p.Query_START,
|
||||
Term: &t,
|
||||
Opts: queryOpts,
|
||||
builtTerm: builtTerm,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// makeArray takes a slice of terms and produces a single MAKE_ARRAY term
|
||||
func makeArray(args termsList) Term {
|
||||
return Term{
|
||||
name: "[...]",
|
||||
termType: p.Term_MAKE_ARRAY,
|
||||
args: args,
|
||||
}
|
||||
}
|
||||
|
||||
// makeObject takes a map of terms and produces a single MAKE_OBJECT term
|
||||
func makeObject(args termsObj) Term {
|
||||
return Term{
|
||||
name: "{...}",
|
||||
termType: p.Term_MAKE_OBJ,
|
||||
optArgs: args,
|
||||
}
|
||||
}
|
||||
|
||||
var nextVarID int64
|
||||
|
||||
func makeFunc(f interface{}) Term {
|
||||
value := reflect.ValueOf(f)
|
||||
valueType := value.Type()
|
||||
|
||||
var argNums = make([]interface{}, valueType.NumIn())
|
||||
var args = make([]reflect.Value, valueType.NumIn())
|
||||
for i := 0; i < valueType.NumIn(); i++ {
|
||||
// Get a slice of the VARs to use as the function arguments
|
||||
varID := atomic.AddInt64(&nextVarID, 1)
|
||||
args[i] = reflect.ValueOf(constructRootTerm("var", p.Term_VAR, []interface{}{varID}, map[string]interface{}{}))
|
||||
argNums[i] = varID
|
||||
|
||||
// make sure all input arguments are of type Term
|
||||
argValueTypeName := valueType.In(i).String()
|
||||
if argValueTypeName != "gorethink.Term" && argValueTypeName != "interface {}" {
|
||||
panic("Function argument is not of type Term or interface {}")
|
||||
}
|
||||
}
|
||||
|
||||
if valueType.NumOut() != 1 {
|
||||
panic("Function does not have a single return value")
|
||||
}
|
||||
|
||||
body := value.Call(args)[0].Interface()
|
||||
argsArr := makeArray(convertTermList(argNums))
|
||||
|
||||
return constructRootTerm("func", p.Term_FUNC, []interface{}{argsArr, body}, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func funcWrap(value interface{}) Term {
|
||||
val := Expr(value)
|
||||
|
||||
if implVarScan(val) && val.termType != p.Term_ARGS {
|
||||
return makeFunc(func(x Term) Term {
|
||||
return val
|
||||
})
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func funcWrapArgs(args []interface{}) []interface{} {
|
||||
for i, arg := range args {
|
||||
args[i] = funcWrap(arg)
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
// implVarScan recursivly checks a value to see if it contains an
|
||||
// IMPLICIT_VAR term. If it does it returns true
|
||||
func implVarScan(value Term) bool {
|
||||
if value.termType == p.Term_IMPLICIT_VAR {
|
||||
return true
|
||||
}
|
||||
for _, v := range value.args {
|
||||
if implVarScan(v) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range value.optArgs {
|
||||
if implVarScan(v) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Convert an opt args struct to a map.
|
||||
func optArgsToMap(optArgs OptArgs) map[string]interface{} {
|
||||
data, err := encode(optArgs)
|
||||
|
||||
if err == nil && data != nil {
|
||||
if m, ok := data.(map[string]interface{}); ok {
|
||||
return m
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
|
||||
// Convert a list into a slice of terms
|
||||
func convertTermList(l []interface{}) termsList {
|
||||
if len(l) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
terms := make(termsList, len(l))
|
||||
for i, v := range l {
|
||||
terms[i] = Expr(v)
|
||||
}
|
||||
|
||||
return terms
|
||||
}
|
||||
|
||||
// Convert a map into a map of terms
|
||||
func convertTermObj(o map[string]interface{}) termsObj {
|
||||
if len(o) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
terms := make(termsObj, len(o))
|
||||
for k, v := range o {
|
||||
terms[k] = Expr(v)
|
||||
}
|
||||
|
||||
return terms
|
||||
}
|
||||
|
||||
// Helper functions for debugging
|
||||
|
||||
func allArgsToStringSlice(args termsList, optArgs termsObj) []string {
|
||||
allArgs := make([]string, len(args)+len(optArgs))
|
||||
i := 0
|
||||
|
||||
for _, v := range args {
|
||||
allArgs[i] = v.String()
|
||||
i++
|
||||
}
|
||||
for k, v := range optArgs {
|
||||
allArgs[i] = k + "=" + v.String()
|
||||
i++
|
||||
}
|
||||
|
||||
return allArgs
|
||||
}
|
||||
|
||||
func argsToStringSlice(args termsList) []string {
|
||||
allArgs := make([]string, len(args))
|
||||
|
||||
for i, v := range args {
|
||||
allArgs[i] = v.String()
|
||||
}
|
||||
|
||||
return allArgs
|
||||
}
|
||||
|
||||
func optArgsToStringSlice(optArgs termsObj) []string {
|
||||
allArgs := make([]string, len(optArgs))
|
||||
i := 0
|
||||
|
||||
for k, v := range optArgs {
|
||||
allArgs[i] = k + "=" + v.String()
|
||||
i++
|
||||
}
|
||||
|
||||
return allArgs
|
||||
}
|
||||
|
||||
func splitAddress(address string) (hostname string, port int) {
|
||||
hostname = "localhost"
|
||||
port = 28015
|
||||
|
||||
addrParts := strings.Split(address, ":")
|
||||
|
||||
if len(addrParts) >= 1 {
|
||||
hostname = addrParts[0]
|
||||
}
|
||||
if len(addrParts) >= 2 {
|
||||
port, _ = strconv.Atoi(addrParts[1])
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func encode(data interface{}) (interface{}, error) {
|
||||
if _, ok := data.(Term); ok {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
v, err := encoding.Encode(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// shouldRetryQuery checks the result of a query and returns true if the query
|
||||
// should be retried
|
||||
func shouldRetryQuery(q Query, err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if _, ok := err.(RQLConnectionError); ok {
|
||||
return true
|
||||
}
|
||||
|
||||
return err == ErrConnectionClosed
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
package logrus
|
||||
|
||||
// The following code was sourced and modified from the
|
||||
// https://bitbucket.org/tebeka/atexit package governed by the following license:
|
||||
//
|
||||
// Copyright (c) 2012 Miki Tebeka <miki.tebeka@gmail.com>.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
// this software and associated documentation files (the "Software"), to deal in
|
||||
// the Software without restriction, including without limitation the rights to
|
||||
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
// the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
var handlers = []func(){}
|
||||
|
||||
func runHandler(handler func()) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error: Logrus exit handler error:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
handler()
|
||||
}
|
||||
|
||||
func runHandlers() {
|
||||
for _, handler := range handlers {
|
||||
runHandler(handler)
|
||||
}
|
||||
}
|
||||
|
||||
// Exit runs all the Logrus atexit handlers and then terminates the program using os.Exit(code)
|
||||
func Exit(code int) {
|
||||
runHandlers()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
// RegisterExitHandler adds a Logrus Exit handler, call logrus.Exit to invoke
|
||||
// all handlers. The handlers will also be invoked when any Fatal log entry is
|
||||
// made.
|
||||
//
|
||||
// This method is useful when a caller wishes to use logrus to log a fatal
|
||||
// message but also needs to gracefully shutdown. An example usecase could be
|
||||
// closing database connections, or sending a alert that the application is
|
||||
// closing.
|
||||
func RegisterExitHandler(handler func()) {
|
||||
handlers = append(handlers, handler)
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
/*
|
||||
Package logrus is a structured logger for Go, completely API compatible with the standard library logger.
|
||||
|
||||
|
||||
The simplest way to use Logrus is simply the package-level exported logger:
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
log "github.com/Sirupsen/logrus"
|
||||
)
|
||||
|
||||
func main() {
|
||||
log.WithFields(log.Fields{
|
||||
"animal": "walrus",
|
||||
"number": 1,
|
||||
"size": 10,
|
||||
}).Info("A walrus appears")
|
||||
}
|
||||
|
||||
Output:
|
||||
time="2015-09-07T08:48:33Z" level=info msg="A walrus appears" animal=walrus number=1 size=10
|
||||
|
||||
For a full guide visit https://github.com/Sirupsen/logrus
|
||||
*/
|
||||
package logrus
|
|
@ -0,0 +1,275 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var bufferPool *sync.Pool
|
||||
|
||||
func init() {
|
||||
bufferPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(bytes.Buffer)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Defines the key when adding errors using WithError.
|
||||
var ErrorKey = "error"
|
||||
|
||||
// An entry is the final or intermediate Logrus logging entry. It contains all
|
||||
// the fields passed with WithField{,s}. It's finally logged when Debug, Info,
|
||||
// Warn, Error, Fatal or Panic is called on it. These objects can be reused and
|
||||
// passed around as much as you wish to avoid field duplication.
|
||||
type Entry struct {
|
||||
Logger *Logger
|
||||
|
||||
// Contains all the fields set by the user.
|
||||
Data Fields
|
||||
|
||||
// Time at which the log entry was created
|
||||
Time time.Time
|
||||
|
||||
// Level the log entry was logged at: Debug, Info, Warn, Error, Fatal or Panic
|
||||
Level Level
|
||||
|
||||
// Message passed to Debug, Info, Warn, Error, Fatal or Panic
|
||||
Message string
|
||||
|
||||
// When formatter is called in entry.log(), an Buffer may be set to entry
|
||||
Buffer *bytes.Buffer
|
||||
}
|
||||
|
||||
func NewEntry(logger *Logger) *Entry {
|
||||
return &Entry{
|
||||
Logger: logger,
|
||||
// Default is three fields, give a little extra room
|
||||
Data: make(Fields, 5),
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the string representation from the reader and ultimately the
|
||||
// formatter.
|
||||
func (entry *Entry) String() (string, error) {
|
||||
serialized, err := entry.Logger.Formatter.Format(entry)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
str := string(serialized)
|
||||
return str, nil
|
||||
}
|
||||
|
||||
// Add an error as single field (using the key defined in ErrorKey) to the Entry.
|
||||
func (entry *Entry) WithError(err error) *Entry {
|
||||
return entry.WithField(ErrorKey, err)
|
||||
}
|
||||
|
||||
// Add a single field to the Entry.
|
||||
func (entry *Entry) WithField(key string, value interface{}) *Entry {
|
||||
return entry.WithFields(Fields{key: value})
|
||||
}
|
||||
|
||||
// Add a map of fields to the Entry.
|
||||
func (entry *Entry) WithFields(fields Fields) *Entry {
|
||||
data := make(Fields, len(entry.Data)+len(fields))
|
||||
for k, v := range entry.Data {
|
||||
data[k] = v
|
||||
}
|
||||
for k, v := range fields {
|
||||
data[k] = v
|
||||
}
|
||||
return &Entry{Logger: entry.Logger, Data: data}
|
||||
}
|
||||
|
||||
// This function is not declared with a pointer value because otherwise
|
||||
// race conditions will occur when using multiple goroutines
|
||||
func (entry Entry) log(level Level, msg string) {
|
||||
var buffer *bytes.Buffer
|
||||
entry.Time = time.Now()
|
||||
entry.Level = level
|
||||
entry.Message = msg
|
||||
|
||||
if err := entry.Logger.Hooks.Fire(level, &entry); err != nil {
|
||||
entry.Logger.mu.Lock()
|
||||
fmt.Fprintf(os.Stderr, "Failed to fire hook: %v\n", err)
|
||||
entry.Logger.mu.Unlock()
|
||||
}
|
||||
buffer = bufferPool.Get().(*bytes.Buffer)
|
||||
buffer.Reset()
|
||||
defer bufferPool.Put(buffer)
|
||||
entry.Buffer = buffer
|
||||
serialized, err := entry.Logger.Formatter.Format(&entry)
|
||||
entry.Buffer = nil
|
||||
if err != nil {
|
||||
entry.Logger.mu.Lock()
|
||||
fmt.Fprintf(os.Stderr, "Failed to obtain reader, %v\n", err)
|
||||
entry.Logger.mu.Unlock()
|
||||
} else {
|
||||
entry.Logger.mu.Lock()
|
||||
_, err = entry.Logger.Out.Write(serialized)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
|
||||
}
|
||||
entry.Logger.mu.Unlock()
|
||||
}
|
||||
|
||||
// To avoid Entry#log() returning a value that only would make sense for
|
||||
// panic() to use in Entry#Panic(), we avoid the allocation by checking
|
||||
// directly here.
|
||||
if level <= PanicLevel {
|
||||
panic(&entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Debug(args ...interface{}) {
|
||||
if entry.Logger.Level >= DebugLevel {
|
||||
entry.log(DebugLevel, fmt.Sprint(args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Print(args ...interface{}) {
|
||||
entry.Info(args...)
|
||||
}
|
||||
|
||||
func (entry *Entry) Info(args ...interface{}) {
|
||||
if entry.Logger.Level >= InfoLevel {
|
||||
entry.log(InfoLevel, fmt.Sprint(args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Warn(args ...interface{}) {
|
||||
if entry.Logger.Level >= WarnLevel {
|
||||
entry.log(WarnLevel, fmt.Sprint(args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Warning(args ...interface{}) {
|
||||
entry.Warn(args...)
|
||||
}
|
||||
|
||||
func (entry *Entry) Error(args ...interface{}) {
|
||||
if entry.Logger.Level >= ErrorLevel {
|
||||
entry.log(ErrorLevel, fmt.Sprint(args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Fatal(args ...interface{}) {
|
||||
if entry.Logger.Level >= FatalLevel {
|
||||
entry.log(FatalLevel, fmt.Sprint(args...))
|
||||
}
|
||||
Exit(1)
|
||||
}
|
||||
|
||||
func (entry *Entry) Panic(args ...interface{}) {
|
||||
if entry.Logger.Level >= PanicLevel {
|
||||
entry.log(PanicLevel, fmt.Sprint(args...))
|
||||
}
|
||||
panic(fmt.Sprint(args...))
|
||||
}
|
||||
|
||||
// Entry Printf family functions
|
||||
|
||||
func (entry *Entry) Debugf(format string, args ...interface{}) {
|
||||
if entry.Logger.Level >= DebugLevel {
|
||||
entry.Debug(fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Infof(format string, args ...interface{}) {
|
||||
if entry.Logger.Level >= InfoLevel {
|
||||
entry.Info(fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Printf(format string, args ...interface{}) {
|
||||
entry.Infof(format, args...)
|
||||
}
|
||||
|
||||
func (entry *Entry) Warnf(format string, args ...interface{}) {
|
||||
if entry.Logger.Level >= WarnLevel {
|
||||
entry.Warn(fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Warningf(format string, args ...interface{}) {
|
||||
entry.Warnf(format, args...)
|
||||
}
|
||||
|
||||
func (entry *Entry) Errorf(format string, args ...interface{}) {
|
||||
if entry.Logger.Level >= ErrorLevel {
|
||||
entry.Error(fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Fatalf(format string, args ...interface{}) {
|
||||
if entry.Logger.Level >= FatalLevel {
|
||||
entry.Fatal(fmt.Sprintf(format, args...))
|
||||
}
|
||||
Exit(1)
|
||||
}
|
||||
|
||||
func (entry *Entry) Panicf(format string, args ...interface{}) {
|
||||
if entry.Logger.Level >= PanicLevel {
|
||||
entry.Panic(fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
||||
|
||||
// Entry Println family functions
|
||||
|
||||
func (entry *Entry) Debugln(args ...interface{}) {
|
||||
if entry.Logger.Level >= DebugLevel {
|
||||
entry.Debug(entry.sprintlnn(args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Infoln(args ...interface{}) {
|
||||
if entry.Logger.Level >= InfoLevel {
|
||||
entry.Info(entry.sprintlnn(args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Println(args ...interface{}) {
|
||||
entry.Infoln(args...)
|
||||
}
|
||||
|
||||
func (entry *Entry) Warnln(args ...interface{}) {
|
||||
if entry.Logger.Level >= WarnLevel {
|
||||
entry.Warn(entry.sprintlnn(args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Warningln(args ...interface{}) {
|
||||
entry.Warnln(args...)
|
||||
}
|
||||
|
||||
func (entry *Entry) Errorln(args ...interface{}) {
|
||||
if entry.Logger.Level >= ErrorLevel {
|
||||
entry.Error(entry.sprintlnn(args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Fatalln(args ...interface{}) {
|
||||
if entry.Logger.Level >= FatalLevel {
|
||||
entry.Fatal(entry.sprintlnn(args...))
|
||||
}
|
||||
Exit(1)
|
||||
}
|
||||
|
||||
func (entry *Entry) Panicln(args ...interface{}) {
|
||||
if entry.Logger.Level >= PanicLevel {
|
||||
entry.Panic(entry.sprintlnn(args...))
|
||||
}
|
||||
}
|
||||
|
||||
// Sprintlnn => Sprint no newline. This is to get the behavior of how
|
||||
// fmt.Sprintln where spaces are always added between operands, regardless of
|
||||
// their type. Instead of vendoring the Sprintln implementation to spare a
|
||||
// string allocation, we do the simplest thing.
|
||||
func (entry *Entry) sprintlnn(args ...interface{}) string {
|
||||
msg := fmt.Sprintln(args...)
|
||||
return msg[:len(msg)-1]
|
||||
}
|
|
@ -0,0 +1,193 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
var (
|
||||
// std is the name of the standard logger in stdlib `log`
|
||||
std = New()
|
||||
)
|
||||
|
||||
func StandardLogger() *Logger {
|
||||
return std
|
||||
}
|
||||
|
||||
// SetOutput sets the standard logger output.
|
||||
func SetOutput(out io.Writer) {
|
||||
std.mu.Lock()
|
||||
defer std.mu.Unlock()
|
||||
std.Out = out
|
||||
}
|
||||
|
||||
// SetFormatter sets the standard logger formatter.
|
||||
func SetFormatter(formatter Formatter) {
|
||||
std.mu.Lock()
|
||||
defer std.mu.Unlock()
|
||||
std.Formatter = formatter
|
||||
}
|
||||
|
||||
// SetLevel sets the standard logger level.
|
||||
func SetLevel(level Level) {
|
||||
std.mu.Lock()
|
||||
defer std.mu.Unlock()
|
||||
std.Level = level
|
||||
}
|
||||
|
||||
// GetLevel returns the standard logger level.
|
||||
func GetLevel() Level {
|
||||
std.mu.Lock()
|
||||
defer std.mu.Unlock()
|
||||
return std.Level
|
||||
}
|
||||
|
||||
// AddHook adds a hook to the standard logger hooks.
|
||||
func AddHook(hook Hook) {
|
||||
std.mu.Lock()
|
||||
defer std.mu.Unlock()
|
||||
std.Hooks.Add(hook)
|
||||
}
|
||||
|
||||
// WithError creates an entry from the standard logger and adds an error to it, using the value defined in ErrorKey as key.
|
||||
func WithError(err error) *Entry {
|
||||
return std.WithField(ErrorKey, err)
|
||||
}
|
||||
|
||||
// WithField creates an entry from the standard logger and adds a field to
|
||||
// it. If you want multiple fields, use `WithFields`.
|
||||
//
|
||||
// Note that it doesn't log until you call Debug, Print, Info, Warn, Fatal
|
||||
// or Panic on the Entry it returns.
|
||||
func WithField(key string, value interface{}) *Entry {
|
||||
return std.WithField(key, value)
|
||||
}
|
||||
|
||||
// WithFields creates an entry from the standard logger and adds multiple
|
||||
// fields to it. This is simply a helper for `WithField`, invoking it
|
||||
// once for each field.
|
||||
//
|
||||
// Note that it doesn't log until you call Debug, Print, Info, Warn, Fatal
|
||||
// or Panic on the Entry it returns.
|
||||
func WithFields(fields Fields) *Entry {
|
||||
return std.WithFields(fields)
|
||||
}
|
||||
|
||||
// Debug logs a message at level Debug on the standard logger.
|
||||
func Debug(args ...interface{}) {
|
||||
std.Debug(args...)
|
||||
}
|
||||
|
||||
// Print logs a message at level Info on the standard logger.
|
||||
func Print(args ...interface{}) {
|
||||
std.Print(args...)
|
||||
}
|
||||
|
||||
// Info logs a message at level Info on the standard logger.
|
||||
func Info(args ...interface{}) {
|
||||
std.Info(args...)
|
||||
}
|
||||
|
||||
// Warn logs a message at level Warn on the standard logger.
|
||||
func Warn(args ...interface{}) {
|
||||
std.Warn(args...)
|
||||
}
|
||||
|
||||
// Warning logs a message at level Warn on the standard logger.
|
||||
func Warning(args ...interface{}) {
|
||||
std.Warning(args...)
|
||||
}
|
||||
|
||||
// Error logs a message at level Error on the standard logger.
|
||||
func Error(args ...interface{}) {
|
||||
std.Error(args...)
|
||||
}
|
||||
|
||||
// Panic logs a message at level Panic on the standard logger.
|
||||
func Panic(args ...interface{}) {
|
||||
std.Panic(args...)
|
||||
}
|
||||
|
||||
// Fatal logs a message at level Fatal on the standard logger.
|
||||
func Fatal(args ...interface{}) {
|
||||
std.Fatal(args...)
|
||||
}
|
||||
|
||||
// Debugf logs a message at level Debug on the standard logger.
|
||||
func Debugf(format string, args ...interface{}) {
|
||||
std.Debugf(format, args...)
|
||||
}
|
||||
|
||||
// Printf logs a message at level Info on the standard logger.
|
||||
func Printf(format string, args ...interface{}) {
|
||||
std.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Infof logs a message at level Info on the standard logger.
|
||||
func Infof(format string, args ...interface{}) {
|
||||
std.Infof(format, args...)
|
||||
}
|
||||
|
||||
// Warnf logs a message at level Warn on the standard logger.
|
||||
func Warnf(format string, args ...interface{}) {
|
||||
std.Warnf(format, args...)
|
||||
}
|
||||
|
||||
// Warningf logs a message at level Warn on the standard logger.
|
||||
func Warningf(format string, args ...interface{}) {
|
||||
std.Warningf(format, args...)
|
||||
}
|
||||
|
||||
// Errorf logs a message at level Error on the standard logger.
|
||||
func Errorf(format string, args ...interface{}) {
|
||||
std.Errorf(format, args...)
|
||||
}
|
||||
|
||||
// Panicf logs a message at level Panic on the standard logger.
|
||||
func Panicf(format string, args ...interface{}) {
|
||||
std.Panicf(format, args...)
|
||||
}
|
||||
|
||||
// Fatalf logs a message at level Fatal on the standard logger.
|
||||
func Fatalf(format string, args ...interface{}) {
|
||||
std.Fatalf(format, args...)
|
||||
}
|
||||
|
||||
// Debugln logs a message at level Debug on the standard logger.
|
||||
func Debugln(args ...interface{}) {
|
||||
std.Debugln(args...)
|
||||
}
|
||||
|
||||
// Println logs a message at level Info on the standard logger.
|
||||
func Println(args ...interface{}) {
|
||||
std.Println(args...)
|
||||
}
|
||||
|
||||
// Infoln logs a message at level Info on the standard logger.
|
||||
func Infoln(args ...interface{}) {
|
||||
std.Infoln(args...)
|
||||
}
|
||||
|
||||
// Warnln logs a message at level Warn on the standard logger.
|
||||
func Warnln(args ...interface{}) {
|
||||
std.Warnln(args...)
|
||||
}
|
||||
|
||||
// Warningln logs a message at level Warn on the standard logger.
|
||||
func Warningln(args ...interface{}) {
|
||||
std.Warningln(args...)
|
||||
}
|
||||
|
||||
// Errorln logs a message at level Error on the standard logger.
|
||||
func Errorln(args ...interface{}) {
|
||||
std.Errorln(args...)
|
||||
}
|
||||
|
||||
// Panicln logs a message at level Panic on the standard logger.
|
||||
func Panicln(args ...interface{}) {
|
||||
std.Panicln(args...)
|
||||
}
|
||||
|
||||
// Fatalln logs a message at level Fatal on the standard logger.
|
||||
func Fatalln(args ...interface{}) {
|
||||
std.Fatalln(args...)
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
package logrus
|
||||
|
||||
import "time"
|
||||
|
||||
const DefaultTimestampFormat = time.RFC3339
|
||||
|
||||
// The Formatter interface is used to implement a custom Formatter. It takes an
|
||||
// `Entry`. It exposes all the fields, including the default ones:
|
||||
//
|
||||
// * `entry.Data["msg"]`. The message passed from Info, Warn, Error ..
|
||||
// * `entry.Data["time"]`. The timestamp.
|
||||
// * `entry.Data["level"]. The level the entry was logged at.
|
||||
//
|
||||
// Any additional fields added with `WithField` or `WithFields` are also in
|
||||
// `entry.Data`. Format is expected to return an array of bytes which are then
|
||||
// logged to `logger.Out`.
|
||||
type Formatter interface {
|
||||
Format(*Entry) ([]byte, error)
|
||||
}
|
||||
|
||||
// This is to not silently overwrite `time`, `msg` and `level` fields when
|
||||
// dumping it. If this code wasn't there doing:
|
||||
//
|
||||
// logrus.WithField("level", 1).Info("hello")
|
||||
//
|
||||
// Would just silently drop the user provided level. Instead with this code
|
||||
// it'll logged as:
|
||||
//
|
||||
// {"level": "info", "fields.level": 1, "msg": "hello", "time": "..."}
|
||||
//
|
||||
// It's not exported because it's still using Data in an opinionated way. It's to
|
||||
// avoid code duplication between the two default formatters.
|
||||
func prefixFieldClashes(data Fields) {
|
||||
if t, ok := data["time"]; ok {
|
||||
data["fields.time"] = t
|
||||
}
|
||||
|
||||
if m, ok := data["msg"]; ok {
|
||||
data["fields.msg"] = m
|
||||
}
|
||||
|
||||
if l, ok := data["level"]; ok {
|
||||
data["fields.level"] = l
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
package logrus
|
||||
|
||||
// A hook to be fired when logging on the logging levels returned from
|
||||
// `Levels()` on your implementation of the interface. Note that this is not
|
||||
// fired in a goroutine or a channel with workers, you should handle such
|
||||
// functionality yourself if your call is non-blocking and you don't wish for
|
||||
// the logging calls for levels returned from `Levels()` to block.
|
||||
type Hook interface {
|
||||
Levels() []Level
|
||||
Fire(*Entry) error
|
||||
}
|
||||
|
||||
// Internal type for storing the hooks on a logger instance.
|
||||
type LevelHooks map[Level][]Hook
|
||||
|
||||
// Add a hook to an instance of logger. This is called with
|
||||
// `log.Hooks.Add(new(MyHook))` where `MyHook` implements the `Hook` interface.
|
||||
func (hooks LevelHooks) Add(hook Hook) {
|
||||
for _, level := range hook.Levels() {
|
||||
hooks[level] = append(hooks[level], hook)
|
||||
}
|
||||
}
|
||||
|
||||
// Fire all the hooks for the passed level. Used by `entry.log` to fire
|
||||
// appropriate hooks for a log entry.
|
||||
func (hooks LevelHooks) Fire(level Level, entry *Entry) error {
|
||||
for _, hook := range hooks[level] {
|
||||
if err := hook.Fire(entry); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type fieldKey string
|
||||
type FieldMap map[fieldKey]string
|
||||
|
||||
const (
|
||||
FieldKeyMsg = "msg"
|
||||
FieldKeyLevel = "level"
|
||||
FieldKeyTime = "time"
|
||||
)
|
||||
|
||||
func (f FieldMap) resolve(key fieldKey) string {
|
||||
if k, ok := f[key]; ok {
|
||||
return k
|
||||
}
|
||||
|
||||
return string(key)
|
||||
}
|
||||
|
||||
type JSONFormatter struct {
|
||||
// TimestampFormat sets the format used for marshaling timestamps.
|
||||
TimestampFormat string
|
||||
|
||||
// DisableTimestamp allows disabling automatic timestamps in output
|
||||
DisableTimestamp bool
|
||||
|
||||
// FieldMap allows users to customize the names of keys for various fields.
|
||||
// As an example:
|
||||
// formatter := &JSONFormatter{
|
||||
// FieldMap: FieldMap{
|
||||
// FieldKeyTime: "@timestamp",
|
||||
// FieldKeyLevel: "@level",
|
||||
// FieldKeyLevel: "@message",
|
||||
// },
|
||||
// }
|
||||
FieldMap FieldMap
|
||||
}
|
||||
|
||||
func (f *JSONFormatter) Format(entry *Entry) ([]byte, error) {
|
||||
data := make(Fields, len(entry.Data)+3)
|
||||
for k, v := range entry.Data {
|
||||
switch v := v.(type) {
|
||||
case error:
|
||||
// Otherwise errors are ignored by `encoding/json`
|
||||
// https://github.com/Sirupsen/logrus/issues/137
|
||||
data[k] = v.Error()
|
||||
default:
|
||||
data[k] = v
|
||||
}
|
||||
}
|
||||
prefixFieldClashes(data)
|
||||
|
||||
timestampFormat := f.TimestampFormat
|
||||
if timestampFormat == "" {
|
||||
timestampFormat = DefaultTimestampFormat
|
||||
}
|
||||
|
||||
if !f.DisableTimestamp {
|
||||
data[f.FieldMap.resolve(FieldKeyTime)] = entry.Time.Format(timestampFormat)
|
||||
}
|
||||
data[f.FieldMap.resolve(FieldKeyMsg)] = entry.Message
|
||||
data[f.FieldMap.resolve(FieldKeyLevel)] = entry.Level.String()
|
||||
|
||||
serialized, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to marshal fields to JSON, %v", err)
|
||||
}
|
||||
return append(serialized, '\n'), nil
|
||||
}
|
|
@ -0,0 +1,308 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
// The logs are `io.Copy`'d to this in a mutex. It's common to set this to a
|
||||
// file, or leave it default which is `os.Stderr`. You can also set this to
|
||||
// something more adventorous, such as logging to Kafka.
|
||||
Out io.Writer
|
||||
// Hooks for the logger instance. These allow firing events based on logging
|
||||
// levels and log entries. For example, to send errors to an error tracking
|
||||
// service, log to StatsD or dump the core on fatal errors.
|
||||
Hooks LevelHooks
|
||||
// All log entries pass through the formatter before logged to Out. The
|
||||
// included formatters are `TextFormatter` and `JSONFormatter` for which
|
||||
// TextFormatter is the default. In development (when a TTY is attached) it
|
||||
// logs with colors, but to a file it wouldn't. You can easily implement your
|
||||
// own that implements the `Formatter` interface, see the `README` or included
|
||||
// formatters for examples.
|
||||
Formatter Formatter
|
||||
// The logging level the logger should log at. This is typically (and defaults
|
||||
// to) `logrus.Info`, which allows Info(), Warn(), Error() and Fatal() to be
|
||||
// logged. `logrus.Debug` is useful in
|
||||
Level Level
|
||||
// Used to sync writing to the log. Locking is enabled by Default
|
||||
mu MutexWrap
|
||||
// Reusable empty entry
|
||||
entryPool sync.Pool
|
||||
}
|
||||
|
||||
type MutexWrap struct {
|
||||
lock sync.Mutex
|
||||
disabled bool
|
||||
}
|
||||
|
||||
func (mw *MutexWrap) Lock() {
|
||||
if !mw.disabled {
|
||||
mw.lock.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
func (mw *MutexWrap) Unlock() {
|
||||
if !mw.disabled {
|
||||
mw.lock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (mw *MutexWrap) Disable() {
|
||||
mw.disabled = true
|
||||
}
|
||||
|
||||
// Creates a new logger. Configuration should be set by changing `Formatter`,
|
||||
// `Out` and `Hooks` directly on the default logger instance. You can also just
|
||||
// instantiate your own:
|
||||
//
|
||||
// var log = &Logger{
|
||||
// Out: os.Stderr,
|
||||
// Formatter: new(JSONFormatter),
|
||||
// Hooks: make(LevelHooks),
|
||||
// Level: logrus.DebugLevel,
|
||||
// }
|
||||
//
|
||||
// It's recommended to make this a global instance called `log`.
|
||||
func New() *Logger {
|
||||
return &Logger{
|
||||
Out: os.Stderr,
|
||||
Formatter: new(TextFormatter),
|
||||
Hooks: make(LevelHooks),
|
||||
Level: InfoLevel,
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) newEntry() *Entry {
|
||||
entry, ok := logger.entryPool.Get().(*Entry)
|
||||
if ok {
|
||||
return entry
|
||||
}
|
||||
return NewEntry(logger)
|
||||
}
|
||||
|
||||
func (logger *Logger) releaseEntry(entry *Entry) {
|
||||
logger.entryPool.Put(entry)
|
||||
}
|
||||
|
||||
// Adds a field to the log entry, note that it doesn't log until you call
|
||||
// Debug, Print, Info, Warn, Fatal or Panic. It only creates a log entry.
|
||||
// If you want multiple fields, use `WithFields`.
|
||||
func (logger *Logger) WithField(key string, value interface{}) *Entry {
|
||||
entry := logger.newEntry()
|
||||
defer logger.releaseEntry(entry)
|
||||
return entry.WithField(key, value)
|
||||
}
|
||||
|
||||
// Adds a struct of fields to the log entry. All it does is call `WithField` for
|
||||
// each `Field`.
|
||||
func (logger *Logger) WithFields(fields Fields) *Entry {
|
||||
entry := logger.newEntry()
|
||||
defer logger.releaseEntry(entry)
|
||||
return entry.WithFields(fields)
|
||||
}
|
||||
|
||||
// Add an error as single field to the log entry. All it does is call
|
||||
// `WithError` for the given `error`.
|
||||
func (logger *Logger) WithError(err error) *Entry {
|
||||
entry := logger.newEntry()
|
||||
defer logger.releaseEntry(entry)
|
||||
return entry.WithError(err)
|
||||
}
|
||||
|
||||
func (logger *Logger) Debugf(format string, args ...interface{}) {
|
||||
if logger.Level >= DebugLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Debugf(format, args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Infof(format string, args ...interface{}) {
|
||||
if logger.Level >= InfoLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Infof(format, args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Printf(format string, args ...interface{}) {
|
||||
entry := logger.newEntry()
|
||||
entry.Printf(format, args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
|
||||
func (logger *Logger) Warnf(format string, args ...interface{}) {
|
||||
if logger.Level >= WarnLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Warnf(format, args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Warningf(format string, args ...interface{}) {
|
||||
if logger.Level >= WarnLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Warnf(format, args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Errorf(format string, args ...interface{}) {
|
||||
if logger.Level >= ErrorLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Errorf(format, args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Fatalf(format string, args ...interface{}) {
|
||||
if logger.Level >= FatalLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Fatalf(format, args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
Exit(1)
|
||||
}
|
||||
|
||||
func (logger *Logger) Panicf(format string, args ...interface{}) {
|
||||
if logger.Level >= PanicLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Panicf(format, args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Debug(args ...interface{}) {
|
||||
if logger.Level >= DebugLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Debug(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Info(args ...interface{}) {
|
||||
if logger.Level >= InfoLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Info(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Print(args ...interface{}) {
|
||||
entry := logger.newEntry()
|
||||
entry.Info(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
|
||||
func (logger *Logger) Warn(args ...interface{}) {
|
||||
if logger.Level >= WarnLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Warn(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Warning(args ...interface{}) {
|
||||
if logger.Level >= WarnLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Warn(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Error(args ...interface{}) {
|
||||
if logger.Level >= ErrorLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Error(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Fatal(args ...interface{}) {
|
||||
if logger.Level >= FatalLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Fatal(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
Exit(1)
|
||||
}
|
||||
|
||||
func (logger *Logger) Panic(args ...interface{}) {
|
||||
if logger.Level >= PanicLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Panic(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Debugln(args ...interface{}) {
|
||||
if logger.Level >= DebugLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Debugln(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Infoln(args ...interface{}) {
|
||||
if logger.Level >= InfoLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Infoln(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Println(args ...interface{}) {
|
||||
entry := logger.newEntry()
|
||||
entry.Println(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
|
||||
func (logger *Logger) Warnln(args ...interface{}) {
|
||||
if logger.Level >= WarnLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Warnln(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Warningln(args ...interface{}) {
|
||||
if logger.Level >= WarnLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Warnln(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Errorln(args ...interface{}) {
|
||||
if logger.Level >= ErrorLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Errorln(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Fatalln(args ...interface{}) {
|
||||
if logger.Level >= FatalLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Fatalln(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
Exit(1)
|
||||
}
|
||||
|
||||
func (logger *Logger) Panicln(args ...interface{}) {
|
||||
if logger.Level >= PanicLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Panicln(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
//When file is opened with appending mode, it's safe to
|
||||
//write concurrently to a file (within 4k message on Linux).
|
||||
//In these cases user can choose to disable the lock.
|
||||
func (logger *Logger) SetNoLock() {
|
||||
logger.mu.Disable()
|
||||
}
|
|
@ -0,0 +1,143 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Fields type, used to pass to `WithFields`.
|
||||
type Fields map[string]interface{}
|
||||
|
||||
// Level type
|
||||
type Level uint8
|
||||
|
||||
// Convert the Level to a string. E.g. PanicLevel becomes "panic".
|
||||
func (level Level) String() string {
|
||||
switch level {
|
||||
case DebugLevel:
|
||||
return "debug"
|
||||
case InfoLevel:
|
||||
return "info"
|
||||
case WarnLevel:
|
||||
return "warning"
|
||||
case ErrorLevel:
|
||||
return "error"
|
||||
case FatalLevel:
|
||||
return "fatal"
|
||||
case PanicLevel:
|
||||
return "panic"
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// ParseLevel takes a string level and returns the Logrus log level constant.
|
||||
func ParseLevel(lvl string) (Level, error) {
|
||||
switch strings.ToLower(lvl) {
|
||||
case "panic":
|
||||
return PanicLevel, nil
|
||||
case "fatal":
|
||||
return FatalLevel, nil
|
||||
case "error":
|
||||
return ErrorLevel, nil
|
||||
case "warn", "warning":
|
||||
return WarnLevel, nil
|
||||
case "info":
|
||||
return InfoLevel, nil
|
||||
case "debug":
|
||||
return DebugLevel, nil
|
||||
}
|
||||
|
||||
var l Level
|
||||
return l, fmt.Errorf("not a valid logrus Level: %q", lvl)
|
||||
}
|
||||
|
||||
// A constant exposing all logging levels
|
||||
var AllLevels = []Level{
|
||||
PanicLevel,
|
||||
FatalLevel,
|
||||
ErrorLevel,
|
||||
WarnLevel,
|
||||
InfoLevel,
|
||||
DebugLevel,
|
||||
}
|
||||
|
||||
// These are the different logging levels. You can set the logging level to log
|
||||
// on your instance of logger, obtained with `logrus.New()`.
|
||||
const (
|
||||
// PanicLevel level, highest level of severity. Logs and then calls panic with the
|
||||
// message passed to Debug, Info, ...
|
||||
PanicLevel Level = iota
|
||||
// FatalLevel level. Logs and then calls `os.Exit(1)`. It will exit even if the
|
||||
// logging level is set to Panic.
|
||||
FatalLevel
|
||||
// ErrorLevel level. Logs. Used for errors that should definitely be noted.
|
||||
// Commonly used for hooks to send errors to an error tracking service.
|
||||
ErrorLevel
|
||||
// WarnLevel level. Non-critical entries that deserve eyes.
|
||||
WarnLevel
|
||||
// InfoLevel level. General operational entries about what's going on inside the
|
||||
// application.
|
||||
InfoLevel
|
||||
// DebugLevel level. Usually only enabled when debugging. Very verbose logging.
|
||||
DebugLevel
|
||||
)
|
||||
|
||||
// Won't compile if StdLogger can't be realized by a log.Logger
|
||||
var (
|
||||
_ StdLogger = &log.Logger{}
|
||||
_ StdLogger = &Entry{}
|
||||
_ StdLogger = &Logger{}
|
||||
)
|
||||
|
||||
// StdLogger is what your logrus-enabled library should take, that way
|
||||
// it'll accept a stdlib logger and a logrus logger. There's no standard
|
||||
// interface, this is the closest we get, unfortunately.
|
||||
type StdLogger interface {
|
||||
Print(...interface{})
|
||||
Printf(string, ...interface{})
|
||||
Println(...interface{})
|
||||
|
||||
Fatal(...interface{})
|
||||
Fatalf(string, ...interface{})
|
||||
Fatalln(...interface{})
|
||||
|
||||
Panic(...interface{})
|
||||
Panicf(string, ...interface{})
|
||||
Panicln(...interface{})
|
||||
}
|
||||
|
||||
// The FieldLogger interface generalizes the Entry and Logger types
|
||||
type FieldLogger interface {
|
||||
WithField(key string, value interface{}) *Entry
|
||||
WithFields(fields Fields) *Entry
|
||||
WithError(err error) *Entry
|
||||
|
||||
Debugf(format string, args ...interface{})
|
||||
Infof(format string, args ...interface{})
|
||||
Printf(format string, args ...interface{})
|
||||
Warnf(format string, args ...interface{})
|
||||
Warningf(format string, args ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
Fatalf(format string, args ...interface{})
|
||||
Panicf(format string, args ...interface{})
|
||||
|
||||
Debug(args ...interface{})
|
||||
Info(args ...interface{})
|
||||
Print(args ...interface{})
|
||||
Warn(args ...interface{})
|
||||
Warning(args ...interface{})
|
||||
Error(args ...interface{})
|
||||
Fatal(args ...interface{})
|
||||
Panic(args ...interface{})
|
||||
|
||||
Debugln(args ...interface{})
|
||||
Infoln(args ...interface{})
|
||||
Println(args ...interface{})
|
||||
Warnln(args ...interface{})
|
||||
Warningln(args ...interface{})
|
||||
Errorln(args ...interface{})
|
||||
Fatalln(args ...interface{})
|
||||
Panicln(args ...interface{})
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
// +build appengine
|
||||
|
||||
package logrus
|
||||
|
||||
// IsTerminal returns true if stderr's file descriptor is a terminal.
|
||||
func IsTerminal() bool {
|
||||
return true
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
// +build darwin freebsd openbsd netbsd dragonfly
|
||||
// +build !appengine
|
||||
|
||||
package logrus
|
||||
|
||||
import "syscall"
|
||||
|
||||
const ioctlReadTermios = syscall.TIOCGETA
|
||||
|
||||
type Termios syscall.Termios
|
|
@ -0,0 +1,14 @@
|
|||
// Based on ssh/terminal:
|
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !appengine
|
||||
|
||||
package logrus
|
||||
|
||||
import "syscall"
|
||||
|
||||
const ioctlReadTermios = syscall.TCGETS
|
||||
|
||||
type Termios syscall.Termios
|
|
@ -0,0 +1,22 @@
|
|||
// Based on ssh/terminal:
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build linux darwin freebsd openbsd netbsd dragonfly
|
||||
// +build !appengine
|
||||
|
||||
package logrus
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// IsTerminal returns true if stderr's file descriptor is a terminal.
|
||||
func IsTerminal() bool {
|
||||
fd := syscall.Stderr
|
||||
var termios Termios
|
||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&termios)), 0, 0, 0)
|
||||
return err == 0
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
// +build solaris,!appengine
|
||||
|
||||
package logrus
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// IsTerminal returns true if the given file descriptor is a terminal.
|
||||
func IsTerminal() bool {
|
||||
_, err := unix.IoctlGetTermios(int(os.Stdout.Fd()), unix.TCGETA)
|
||||
return err == nil
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
// Based on ssh/terminal:
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build windows,!appengine
|
||||
|
||||
package logrus
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
|
||||
var (
|
||||
procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
|
||||
)
|
||||
|
||||
// IsTerminal returns true if stderr's file descriptor is a terminal.
|
||||
func IsTerminal() bool {
|
||||
fd := syscall.Stderr
|
||||
var st uint32
|
||||
r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
|
||||
return r != 0 && e == 0
|
||||
}
|
|
@ -0,0 +1,170 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
nocolor = 0
|
||||
red = 31
|
||||
green = 32
|
||||
yellow = 33
|
||||
blue = 34
|
||||
gray = 37
|
||||
)
|
||||
|
||||
var (
|
||||
baseTimestamp time.Time
|
||||
isTerminal bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
baseTimestamp = time.Now()
|
||||
isTerminal = IsTerminal()
|
||||
}
|
||||
|
||||
func miniTS() int {
|
||||
return int(time.Since(baseTimestamp) / time.Second)
|
||||
}
|
||||
|
||||
type TextFormatter struct {
|
||||
// Set to true to bypass checking for a TTY before outputting colors.
|
||||
ForceColors bool
|
||||
|
||||
// Force disabling colors.
|
||||
DisableColors bool
|
||||
|
||||
// Disable timestamp logging. useful when output is redirected to logging
|
||||
// system that already adds timestamps.
|
||||
DisableTimestamp bool
|
||||
|
||||
// Enable logging the full timestamp when a TTY is attached instead of just
|
||||
// the time passed since beginning of execution.
|
||||
FullTimestamp bool
|
||||
|
||||
// TimestampFormat to use for display when a full timestamp is printed
|
||||
TimestampFormat string
|
||||
|
||||
// The fields are sorted by default for a consistent output. For applications
|
||||
// that log extremely frequently and don't use the JSON formatter this may not
|
||||
// be desired.
|
||||
DisableSorting bool
|
||||
}
|
||||
|
||||
func (f *TextFormatter) Format(entry *Entry) ([]byte, error) {
|
||||
var b *bytes.Buffer
|
||||
var keys []string = make([]string, 0, len(entry.Data))
|
||||
for k := range entry.Data {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
if !f.DisableSorting {
|
||||
sort.Strings(keys)
|
||||
}
|
||||
if entry.Buffer != nil {
|
||||
b = entry.Buffer
|
||||
} else {
|
||||
b = &bytes.Buffer{}
|
||||
}
|
||||
|
||||
prefixFieldClashes(entry.Data)
|
||||
|
||||
isColorTerminal := isTerminal && (runtime.GOOS != "windows")
|
||||
isColored := (f.ForceColors || isColorTerminal) && !f.DisableColors
|
||||
|
||||
timestampFormat := f.TimestampFormat
|
||||
if timestampFormat == "" {
|
||||
timestampFormat = DefaultTimestampFormat
|
||||
}
|
||||
if isColored {
|
||||
f.printColored(b, entry, keys, timestampFormat)
|
||||
} else {
|
||||
if !f.DisableTimestamp {
|
||||
f.appendKeyValue(b, "time", entry.Time.Format(timestampFormat))
|
||||
}
|
||||
f.appendKeyValue(b, "level", entry.Level.String())
|
||||
if entry.Message != "" {
|
||||
f.appendKeyValue(b, "msg", entry.Message)
|
||||
}
|
||||
for _, key := range keys {
|
||||
f.appendKeyValue(b, key, entry.Data[key])
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteByte('\n')
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
func (f *TextFormatter) printColored(b *bytes.Buffer, entry *Entry, keys []string, timestampFormat string) {
|
||||
var levelColor int
|
||||
switch entry.Level {
|
||||
case DebugLevel:
|
||||
levelColor = gray
|
||||
case WarnLevel:
|
||||
levelColor = yellow
|
||||
case ErrorLevel, FatalLevel, PanicLevel:
|
||||
levelColor = red
|
||||
default:
|
||||
levelColor = blue
|
||||
}
|
||||
|
||||
levelText := strings.ToUpper(entry.Level.String())[0:4]
|
||||
|
||||
if f.DisableTimestamp {
|
||||
fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m %-44s ", levelColor, levelText, entry.Message)
|
||||
} else if !f.FullTimestamp {
|
||||
fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m[%04d] %-44s ", levelColor, levelText, miniTS(), entry.Message)
|
||||
} else {
|
||||
fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m[%s] %-44s ", levelColor, levelText, entry.Time.Format(timestampFormat), entry.Message)
|
||||
}
|
||||
for _, k := range keys {
|
||||
v := entry.Data[k]
|
||||
fmt.Fprintf(b, " \x1b[%dm%s\x1b[0m=", levelColor, k)
|
||||
f.appendValue(b, v)
|
||||
}
|
||||
}
|
||||
|
||||
func needsQuoting(text string) bool {
|
||||
for _, ch := range text {
|
||||
if !((ch >= 'a' && ch <= 'z') ||
|
||||
(ch >= 'A' && ch <= 'Z') ||
|
||||
(ch >= '0' && ch <= '9') ||
|
||||
ch == '-' || ch == '.') {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (f *TextFormatter) appendKeyValue(b *bytes.Buffer, key string, value interface{}) {
|
||||
|
||||
b.WriteString(key)
|
||||
b.WriteByte('=')
|
||||
f.appendValue(b, value)
|
||||
b.WriteByte(' ')
|
||||
}
|
||||
|
||||
func (f *TextFormatter) appendValue(b *bytes.Buffer, value interface{}) {
|
||||
switch value := value.(type) {
|
||||
case string:
|
||||
if !needsQuoting(value) {
|
||||
b.WriteString(value)
|
||||
} else {
|
||||
fmt.Fprintf(b, "%q", value)
|
||||
}
|
||||
case error:
|
||||
errmsg := value.Error()
|
||||
if !needsQuoting(errmsg) {
|
||||
b.WriteString(errmsg)
|
||||
} else {
|
||||
fmt.Fprintf(b, "%q", errmsg)
|
||||
}
|
||||
default:
|
||||
fmt.Fprint(b, value)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
func (logger *Logger) Writer() *io.PipeWriter {
|
||||
return logger.WriterLevel(InfoLevel)
|
||||
}
|
||||
|
||||
func (logger *Logger) WriterLevel(level Level) *io.PipeWriter {
|
||||
reader, writer := io.Pipe()
|
||||
|
||||
var printFunc func(args ...interface{})
|
||||
switch level {
|
||||
case DebugLevel:
|
||||
printFunc = logger.Debug
|
||||
case InfoLevel:
|
||||
printFunc = logger.Info
|
||||
case WarnLevel:
|
||||
printFunc = logger.Warn
|
||||
case ErrorLevel:
|
||||
printFunc = logger.Error
|
||||
case FatalLevel:
|
||||
printFunc = logger.Fatal
|
||||
case PanicLevel:
|
||||
printFunc = logger.Panic
|
||||
default:
|
||||
printFunc = logger.Print
|
||||
}
|
||||
|
||||
go logger.writerScanner(reader, printFunc)
|
||||
runtime.SetFinalizer(writer, writerFinalizer)
|
||||
|
||||
return writer
|
||||
}
|
||||
|
||||
func (logger *Logger) writerScanner(reader *io.PipeReader, printFunc func(args ...interface{})) {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
for scanner.Scan() {
|
||||
printFunc(scanner.Text())
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.Errorf("Error while reading from Writer: %s", err)
|
||||
}
|
||||
reader.Close()
|
||||
}
|
||||
|
||||
func writerFinalizer(writer *io.PipeWriter) {
|
||||
writer.Close()
|
||||
}
|
|
@ -0,0 +1,84 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package uuid
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// A Domain represents a Version 2 domain
|
||||
type Domain byte
|
||||
|
||||
// Domain constants for DCE Security (Version 2) UUIDs.
|
||||
const (
|
||||
Person = Domain(0)
|
||||
Group = Domain(1)
|
||||
Org = Domain(2)
|
||||
)
|
||||
|
||||
// NewDCESecurity returns a DCE Security (Version 2) UUID.
|
||||
//
|
||||
// The domain should be one of Person, Group or Org.
|
||||
// On a POSIX system the id should be the users UID for the Person
|
||||
// domain and the users GID for the Group. The meaning of id for
|
||||
// the domain Org or on non-POSIX systems is site defined.
|
||||
//
|
||||
// For a given domain/id pair the same token may be returned for up to
|
||||
// 7 minutes and 10 seconds.
|
||||
func NewDCESecurity(domain Domain, id uint32) UUID {
|
||||
uuid := NewUUID()
|
||||
if uuid != nil {
|
||||
uuid[6] = (uuid[6] & 0x0f) | 0x20 // Version 2
|
||||
uuid[9] = byte(domain)
|
||||
binary.BigEndian.PutUint32(uuid[0:], id)
|
||||
}
|
||||
return uuid
|
||||
}
|
||||
|
||||
// NewDCEPerson returns a DCE Security (Version 2) UUID in the person
|
||||
// domain with the id returned by os.Getuid.
|
||||
//
|
||||
// NewDCEPerson(Person, uint32(os.Getuid()))
|
||||
func NewDCEPerson() UUID {
|
||||
return NewDCESecurity(Person, uint32(os.Getuid()))
|
||||
}
|
||||
|
||||
// NewDCEGroup returns a DCE Security (Version 2) UUID in the group
|
||||
// domain with the id returned by os.Getgid.
|
||||
//
|
||||
// NewDCEGroup(Group, uint32(os.Getgid()))
|
||||
func NewDCEGroup() UUID {
|
||||
return NewDCESecurity(Group, uint32(os.Getgid()))
|
||||
}
|
||||
|
||||
// Domain returns the domain for a Version 2 UUID or false.
|
||||
func (uuid UUID) Domain() (Domain, bool) {
|
||||
if v, _ := uuid.Version(); v != 2 {
|
||||
return 0, false
|
||||
}
|
||||
return Domain(uuid[9]), true
|
||||
}
|
||||
|
||||
// Id returns the id for a Version 2 UUID or false.
|
||||
func (uuid UUID) Id() (uint32, bool) {
|
||||
if v, _ := uuid.Version(); v != 2 {
|
||||
return 0, false
|
||||
}
|
||||
return binary.BigEndian.Uint32(uuid[0:4]), true
|
||||
}
|
||||
|
||||
func (d Domain) String() string {
|
||||
switch d {
|
||||
case Person:
|
||||
return "Person"
|
||||
case Group:
|
||||
return "Group"
|
||||
case Org:
|
||||
return "Org"
|
||||
}
|
||||
return fmt.Sprintf("Domain%d", int(d))
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// The uuid package generates and inspects UUIDs.
|
||||
//
|
||||
// UUIDs are based on RFC 4122 and DCE 1.1: Authentication and Security Services.
|
||||
package uuid
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package uuid
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
"hash"
|
||||
)
|
||||
|
||||
// Well known Name Space IDs and UUIDs
|
||||
var (
|
||||
NameSpace_DNS = Parse("6ba7b810-9dad-11d1-80b4-00c04fd430c8")
|
||||
NameSpace_URL = Parse("6ba7b811-9dad-11d1-80b4-00c04fd430c8")
|
||||
NameSpace_OID = Parse("6ba7b812-9dad-11d1-80b4-00c04fd430c8")
|
||||
NameSpace_X500 = Parse("6ba7b814-9dad-11d1-80b4-00c04fd430c8")
|
||||
NIL = Parse("00000000-0000-0000-0000-000000000000")
|
||||
)
|
||||
|
||||
// NewHash returns a new UUID dervied from the hash of space concatenated with
|
||||
// data generated by h. The hash should be at least 16 byte in length. The
|
||||
// first 16 bytes of the hash are used to form the UUID. The version of the
|
||||
// UUID will be the lower 4 bits of version. NewHash is used to implement
|
||||
// NewMD5 and NewSHA1.
|
||||
func NewHash(h hash.Hash, space UUID, data []byte, version int) UUID {
|
||||
h.Reset()
|
||||
h.Write(space)
|
||||
h.Write([]byte(data))
|
||||
s := h.Sum(nil)
|
||||
uuid := make([]byte, 16)
|
||||
copy(uuid, s)
|
||||
uuid[6] = (uuid[6] & 0x0f) | uint8((version&0xf)<<4)
|
||||
uuid[8] = (uuid[8] & 0x3f) | 0x80 // RFC 4122 variant
|
||||
return uuid
|
||||
}
|
||||
|
||||
// NewMD5 returns a new MD5 (Version 3) UUID based on the
|
||||
// supplied name space and data.
|
||||
//
|
||||
// NewHash(md5.New(), space, data, 3)
|
||||
func NewMD5(space UUID, data []byte) UUID {
|
||||
return NewHash(md5.New(), space, data, 3)
|
||||
}
|
||||
|
||||
// NewSHA1 returns a new SHA1 (Version 5) UUID based on the
|
||||
// supplied name space and data.
|
||||
//
|
||||
// NewHash(sha1.New(), space, data, 5)
|
||||
func NewSHA1(space UUID, data []byte) UUID {
|
||||
return NewHash(sha1.New(), space, data, 5)
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package uuid
|
||||
|
||||
import "net"
|
||||
|
||||
var (
|
||||
interfaces []net.Interface // cached list of interfaces
|
||||
ifname string // name of interface being used
|
||||
nodeID []byte // hardware for version 1 UUIDs
|
||||
)
|
||||
|
||||
// NodeInterface returns the name of the interface from which the NodeID was
|
||||
// derived. The interface "user" is returned if the NodeID was set by
|
||||
// SetNodeID.
|
||||
func NodeInterface() string {
|
||||
return ifname
|
||||
}
|
||||
|
||||
// SetNodeInterface selects the hardware address to be used for Version 1 UUIDs.
|
||||
// If name is "" then the first usable interface found will be used or a random
|
||||
// Node ID will be generated. If a named interface cannot be found then false
|
||||
// is returned.
|
||||
//
|
||||
// SetNodeInterface never fails when name is "".
|
||||
func SetNodeInterface(name string) bool {
|
||||
if interfaces == nil {
|
||||
var err error
|
||||
interfaces, err = net.Interfaces()
|
||||
if err != nil && name != "" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
for _, ifs := range interfaces {
|
||||
if len(ifs.HardwareAddr) >= 6 && (name == "" || name == ifs.Name) {
|
||||
if setNodeID(ifs.HardwareAddr) {
|
||||
ifname = ifs.Name
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We found no interfaces with a valid hardware address. If name
|
||||
// does not specify a specific interface generate a random Node ID
|
||||
// (section 4.1.6)
|
||||
if name == "" {
|
||||
if nodeID == nil {
|
||||
nodeID = make([]byte, 6)
|
||||
}
|
||||
randomBits(nodeID)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// NodeID returns a slice of a copy of the current Node ID, setting the Node ID
|
||||
// if not already set.
|
||||
func NodeID() []byte {
|
||||
if nodeID == nil {
|
||||
SetNodeInterface("")
|
||||
}
|
||||
nid := make([]byte, 6)
|
||||
copy(nid, nodeID)
|
||||
return nid
|
||||
}
|
||||
|
||||
// SetNodeID sets the Node ID to be used for Version 1 UUIDs. The first 6 bytes
|
||||
// of id are used. If id is less than 6 bytes then false is returned and the
|
||||
// Node ID is not set.
|
||||
func SetNodeID(id []byte) bool {
|
||||
if setNodeID(id) {
|
||||
ifname = "user"
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func setNodeID(id []byte) bool {
|
||||
if len(id) < 6 {
|
||||
return false
|
||||
}
|
||||
if nodeID == nil {
|
||||
nodeID = make([]byte, 6)
|
||||
}
|
||||
copy(nodeID, id)
|
||||
return true
|
||||
}
|
||||
|
||||
// NodeID returns the 6 byte node id encoded in uuid. It returns nil if uuid is
|
||||
// not valid. The NodeID is only well defined for version 1 and 2 UUIDs.
|
||||
func (uuid UUID) NodeID() []byte {
|
||||
if len(uuid) != 16 {
|
||||
return nil
|
||||
}
|
||||
node := make([]byte, 6)
|
||||
copy(node, uuid[10:])
|
||||
return node
|
||||
}
|
|
@ -0,0 +1,132 @@
|
|||
// Copyright 2014 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package uuid
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// A Time represents a time as the number of 100's of nanoseconds since 15 Oct
|
||||
// 1582.
|
||||
type Time int64
|
||||
|
||||
const (
|
||||
lillian = 2299160 // Julian day of 15 Oct 1582
|
||||
unix = 2440587 // Julian day of 1 Jan 1970
|
||||
epoch = unix - lillian // Days between epochs
|
||||
g1582 = epoch * 86400 // seconds between epochs
|
||||
g1582ns100 = g1582 * 10000000 // 100s of a nanoseconds between epochs
|
||||
)
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
lasttime uint64 // last time we returned
|
||||
clock_seq uint16 // clock sequence for this run
|
||||
|
||||
timeNow = time.Now // for testing
|
||||
)
|
||||
|
||||
// UnixTime converts t the number of seconds and nanoseconds using the Unix
|
||||
// epoch of 1 Jan 1970.
|
||||
func (t Time) UnixTime() (sec, nsec int64) {
|
||||
sec = int64(t - g1582ns100)
|
||||
nsec = (sec % 10000000) * 100
|
||||
sec /= 10000000
|
||||
return sec, nsec
|
||||
}
|
||||
|
||||
// GetTime returns the current Time (100s of nanoseconds since 15 Oct 1582) and
|
||||
// adjusts the clock sequence as needed. An error is returned if the current
|
||||
// time cannot be determined.
|
||||
func GetTime() (Time, error) {
|
||||
defer mu.Unlock()
|
||||
mu.Lock()
|
||||
return getTime()
|
||||
}
|
||||
|
||||
func getTime() (Time, error) {
|
||||
t := timeNow()
|
||||
|
||||
// If we don't have a clock sequence already, set one.
|
||||
if clock_seq == 0 {
|
||||
setClockSequence(-1)
|
||||
}
|
||||
now := uint64(t.UnixNano()/100) + g1582ns100
|
||||
|
||||
// If time has gone backwards with this clock sequence then we
|
||||
// increment the clock sequence
|
||||
if now <= lasttime {
|
||||
clock_seq = ((clock_seq + 1) & 0x3fff) | 0x8000
|
||||
}
|
||||
lasttime = now
|
||||
return Time(now), nil
|
||||
}
|
||||
|
||||
// ClockSequence returns the current clock sequence, generating one if not
|
||||
// already set. The clock sequence is only used for Version 1 UUIDs.
|
||||
//
|
||||
// The uuid package does not use global static storage for the clock sequence or
|
||||
// the last time a UUID was generated. Unless SetClockSequence a new random
|
||||
// clock sequence is generated the first time a clock sequence is requested by
|
||||
// ClockSequence, GetTime, or NewUUID. (section 4.2.1.1) sequence is generated
|
||||
// for
|
||||
func ClockSequence() int {
|
||||
defer mu.Unlock()
|
||||
mu.Lock()
|
||||
return clockSequence()
|
||||
}
|
||||
|
||||
func clockSequence() int {
|
||||
if clock_seq == 0 {
|
||||
setClockSequence(-1)
|
||||
}
|
||||
return int(clock_seq & 0x3fff)
|
||||
}
|
||||
|
||||
// SetClockSeq sets the clock sequence to the lower 14 bits of seq. Setting to
|
||||
// -1 causes a new sequence to be generated.
|
||||
func SetClockSequence(seq int) {
|
||||
defer mu.Unlock()
|
||||
mu.Lock()
|
||||
setClockSequence(seq)
|
||||
}
|
||||
|
||||
func setClockSequence(seq int) {
|
||||
if seq == -1 {
|
||||
var b [2]byte
|
||||
randomBits(b[:]) // clock sequence
|
||||
seq = int(b[0])<<8 | int(b[1])
|
||||
}
|
||||
old_seq := clock_seq
|
||||
clock_seq = uint16(seq&0x3fff) | 0x8000 // Set our variant
|
||||
if old_seq != clock_seq {
|
||||
lasttime = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Time returns the time in 100s of nanoseconds since 15 Oct 1582 encoded in
|
||||
// uuid. It returns false if uuid is not valid. The time is only well defined
|
||||
// for version 1 and 2 UUIDs.
|
||||
func (uuid UUID) Time() (Time, bool) {
|
||||
if len(uuid) != 16 {
|
||||
return 0, false
|
||||
}
|
||||
time := int64(binary.BigEndian.Uint32(uuid[0:4]))
|
||||
time |= int64(binary.BigEndian.Uint16(uuid[4:6])) << 32
|
||||
time |= int64(binary.BigEndian.Uint16(uuid[6:8])&0xfff) << 48
|
||||
return Time(time), true
|
||||
}
|
||||
|
||||
// ClockSequence returns the clock sequence encoded in uuid. It returns false
|
||||
// if uuid is not valid. The clock sequence is only well defined for version 1
|
||||
// and 2 UUIDs.
|
||||
func (uuid UUID) ClockSequence() (int, bool) {
|
||||
if len(uuid) != 16 {
|
||||
return 0, false
|
||||
}
|
||||
return int(binary.BigEndian.Uint16(uuid[8:10])) & 0x3fff, true
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package uuid
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
// randomBits completely fills slice b with random data.
|
||||
func randomBits(b []byte) {
|
||||
if _, err := io.ReadFull(rander, b); err != nil {
|
||||
panic(err.Error()) // rand should never fail
|
||||
}
|
||||
}
|
||||
|
||||
// xvalues returns the value of a byte as a hexadecimal digit or 255.
|
||||
var xvalues = []byte{
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 255, 255, 255, 255, 255, 255,
|
||||
255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
}
|
||||
|
||||
// xtob converts the the first two hex bytes of x into a byte.
|
||||
func xtob(x string) (byte, bool) {
|
||||
b1 := xvalues[x[0]]
|
||||
b2 := xvalues[x[1]]
|
||||
return (b1 << 4) | b2, b1 != 255 && b2 != 255
|
||||
}
|
|
@ -0,0 +1,163 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package uuid
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// A UUID is a 128 bit (16 byte) Universal Unique IDentifier as defined in RFC
|
||||
// 4122.
|
||||
type UUID []byte
|
||||
|
||||
// A Version represents a UUIDs version.
|
||||
type Version byte
|
||||
|
||||
// A Variant represents a UUIDs variant.
|
||||
type Variant byte
|
||||
|
||||
// Constants returned by Variant.
|
||||
const (
|
||||
Invalid = Variant(iota) // Invalid UUID
|
||||
RFC4122 // The variant specified in RFC4122
|
||||
Reserved // Reserved, NCS backward compatibility.
|
||||
Microsoft // Reserved, Microsoft Corporation backward compatibility.
|
||||
Future // Reserved for future definition.
|
||||
)
|
||||
|
||||
var rander = rand.Reader // random function
|
||||
|
||||
// New returns a new random (version 4) UUID as a string. It is a convenience
|
||||
// function for NewRandom().String().
|
||||
func New() string {
|
||||
return NewRandom().String()
|
||||
}
|
||||
|
||||
// Parse decodes s into a UUID or returns nil. Both the UUID form of
|
||||
// xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx and
|
||||
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx are decoded.
|
||||
func Parse(s string) UUID {
|
||||
if len(s) == 36+9 {
|
||||
if strings.ToLower(s[:9]) != "urn:uuid:" {
|
||||
return nil
|
||||
}
|
||||
s = s[9:]
|
||||
} else if len(s) != 36 {
|
||||
return nil
|
||||
}
|
||||
if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
|
||||
return nil
|
||||
}
|
||||
uuid := make([]byte, 16)
|
||||
for i, x := range []int{
|
||||
0, 2, 4, 6,
|
||||
9, 11,
|
||||
14, 16,
|
||||
19, 21,
|
||||
24, 26, 28, 30, 32, 34} {
|
||||
if v, ok := xtob(s[x:]); !ok {
|
||||
return nil
|
||||
} else {
|
||||
uuid[i] = v
|
||||
}
|
||||
}
|
||||
return uuid
|
||||
}
|
||||
|
||||
// Equal returns true if uuid1 and uuid2 are equal.
|
||||
func Equal(uuid1, uuid2 UUID) bool {
|
||||
return bytes.Equal(uuid1, uuid2)
|
||||
}
|
||||
|
||||
// String returns the string form of uuid, xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
|
||||
// , or "" if uuid is invalid.
|
||||
func (uuid UUID) String() string {
|
||||
if uuid == nil || len(uuid) != 16 {
|
||||
return ""
|
||||
}
|
||||
b := []byte(uuid)
|
||||
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
|
||||
b[:4], b[4:6], b[6:8], b[8:10], b[10:])
|
||||
}
|
||||
|
||||
// URN returns the RFC 2141 URN form of uuid,
|
||||
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx, or "" if uuid is invalid.
|
||||
func (uuid UUID) URN() string {
|
||||
if uuid == nil || len(uuid) != 16 {
|
||||
return ""
|
||||
}
|
||||
b := []byte(uuid)
|
||||
return fmt.Sprintf("urn:uuid:%08x-%04x-%04x-%04x-%012x",
|
||||
b[:4], b[4:6], b[6:8], b[8:10], b[10:])
|
||||
}
|
||||
|
||||
// Variant returns the variant encoded in uuid. It returns Invalid if
|
||||
// uuid is invalid.
|
||||
func (uuid UUID) Variant() Variant {
|
||||
if len(uuid) != 16 {
|
||||
return Invalid
|
||||
}
|
||||
switch {
|
||||
case (uuid[8] & 0xc0) == 0x80:
|
||||
return RFC4122
|
||||
case (uuid[8] & 0xe0) == 0xc0:
|
||||
return Microsoft
|
||||
case (uuid[8] & 0xe0) == 0xe0:
|
||||
return Future
|
||||
default:
|
||||
return Reserved
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
// Version returns the verison of uuid. It returns false if uuid is not
|
||||
// valid.
|
||||
func (uuid UUID) Version() (Version, bool) {
|
||||
if len(uuid) != 16 {
|
||||
return 0, false
|
||||
}
|
||||
return Version(uuid[6] >> 4), true
|
||||
}
|
||||
|
||||
func (v Version) String() string {
|
||||
if v > 15 {
|
||||
return fmt.Sprintf("BAD_VERSION_%d", v)
|
||||
}
|
||||
return fmt.Sprintf("VERSION_%d", v)
|
||||
}
|
||||
|
||||
func (v Variant) String() string {
|
||||
switch v {
|
||||
case RFC4122:
|
||||
return "RFC4122"
|
||||
case Reserved:
|
||||
return "Reserved"
|
||||
case Microsoft:
|
||||
return "Microsoft"
|
||||
case Future:
|
||||
return "Future"
|
||||
case Invalid:
|
||||
return "Invalid"
|
||||
}
|
||||
return fmt.Sprintf("BadVariant%d", int(v))
|
||||
}
|
||||
|
||||
// SetRand sets the random number generator to r, which implents io.Reader.
|
||||
// If r.Read returns an error when the package requests random data then
|
||||
// a panic will be issued.
|
||||
//
|
||||
// Calling SetRand with nil sets the random number generator to the default
|
||||
// generator.
|
||||
func SetRand(r io.Reader) {
|
||||
if r == nil {
|
||||
rander = rand.Reader
|
||||
return
|
||||
}
|
||||
rander = r
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package uuid
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
// NewUUID returns a Version 1 UUID based on the current NodeID and clock
|
||||
// sequence, and the current time. If the NodeID has not been set by SetNodeID
|
||||
// or SetNodeInterface then it will be set automatically. If the NodeID cannot
|
||||
// be set NewUUID returns nil. If clock sequence has not been set by
|
||||
// SetClockSequence then it will be set automatically. If GetTime fails to
|
||||
// return the current NewUUID returns nil.
|
||||
func NewUUID() UUID {
|
||||
if nodeID == nil {
|
||||
SetNodeInterface("")
|
||||
}
|
||||
|
||||
now, err := GetTime()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
uuid := make([]byte, 16)
|
||||
|
||||
time_low := uint32(now & 0xffffffff)
|
||||
time_mid := uint16((now >> 32) & 0xffff)
|
||||
time_hi := uint16((now >> 48) & 0x0fff)
|
||||
time_hi |= 0x1000 // Version 1
|
||||
|
||||
binary.BigEndian.PutUint32(uuid[0:], time_low)
|
||||
binary.BigEndian.PutUint16(uuid[4:], time_mid)
|
||||
binary.BigEndian.PutUint16(uuid[6:], time_hi)
|
||||
binary.BigEndian.PutUint16(uuid[8:], clock_seq)
|
||||
copy(uuid[10:], nodeID)
|
||||
|
||||
return uuid
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package uuid
|
||||
|
||||
// Random returns a Random (Version 4) UUID or panics.
|
||||
//
|
||||
// The strength of the UUIDs is based on the strength of the crypto/rand
|
||||
// package.
|
||||
//
|
||||
// A note about uniqueness derived from from the UUID Wikipedia entry:
|
||||
//
|
||||
// Randomly generated UUIDs have 122 random bits. One's annual risk of being
|
||||
// hit by a meteorite is estimated to be one chance in 17 billion, that
|
||||
// means the probability is about 0.00000000006 (6 × 10−11),
|
||||
// equivalent to the odds of creating a few tens of trillions of UUIDs in a
|
||||
// year and having one duplicate.
|
||||
func NewRandom() UUID {
|
||||
uuid := make([]byte, 16)
|
||||
randomBits([]byte(uuid))
|
||||
uuid[6] = (uuid[6] & 0x0f) | 0x40 // Version 4
|
||||
uuid[8] = (uuid[8] & 0x3f) | 0x80 // Variant is 10
|
||||
return uuid
|
||||
}
|
|
@ -0,0 +1,137 @@
|
|||
// cmd_authenticate.go - AUTHENTICATE/AUTHCHALLENGE commands.
|
||||
//
|
||||
// To the extent possible under law, Yawning Angel waived all copyright
|
||||
// and related or neighboring rights to bulb, using the creative
|
||||
// commons "cc0" public domain dedication. See LICENSE or
|
||||
// <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
|
||||
|
||||
package bulb
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Authenticate authenticates with the Tor instance using the "best" possible
|
||||
// authentication method. The password argument is entirely optional, and will
|
||||
// only be used if the "SAFECOOKE" and "NULL" authentication methods are not
|
||||
// available and "HASHEDPASSWORD" is.
|
||||
func (c *Conn) Authenticate(password string) error {
|
||||
if c.isAuthenticated {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Determine the supported authentication methods, and the cookie path.
|
||||
pi, err := c.ProtocolInfo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// "COOKIE" authentication exists, but anything modern supports
|
||||
// "SAFECOOKIE".
|
||||
const (
|
||||
cmdAuthenticate = "AUTHENTICATE"
|
||||
authMethodNull = "NULL"
|
||||
authMethodPassword = "HASHEDPASSWORD"
|
||||
authMethodSafeCookie = "SAFECOOKIE"
|
||||
)
|
||||
if pi.AuthMethods[authMethodNull] {
|
||||
_, err = c.Request(cmdAuthenticate)
|
||||
c.isAuthenticated = err == nil
|
||||
return err
|
||||
} else if pi.AuthMethods[authMethodSafeCookie] {
|
||||
const (
|
||||
authCookieLength = 32
|
||||
authNonceLength = 32
|
||||
authHashLength = 32
|
||||
|
||||
authServerHashKey = "Tor safe cookie authentication server-to-controller hash"
|
||||
authClientHashKey = "Tor safe cookie authentication controller-to-server hash"
|
||||
)
|
||||
|
||||
if pi.CookieFile == "" {
|
||||
return newProtocolError("invalid (empty) COOKIEFILE")
|
||||
}
|
||||
cookie, err := ioutil.ReadFile(pi.CookieFile)
|
||||
if err != nil {
|
||||
return newProtocolError("failed to read COOKIEFILE: %v", err)
|
||||
} else if len(cookie) != authCookieLength {
|
||||
return newProtocolError("invalid cookie file length: %d", len(cookie))
|
||||
}
|
||||
|
||||
// Send an AUTHCHALLENGE command, and parse the response.
|
||||
var clientNonce [authNonceLength]byte
|
||||
if _, err := rand.Read(clientNonce[:]); err != nil {
|
||||
return newProtocolError("failed to generate clientNonce: %v", err)
|
||||
}
|
||||
clientNonceStr := hex.EncodeToString(clientNonce[:])
|
||||
resp, err := c.Request("AUTHCHALLENGE %s %s", authMethodSafeCookie, clientNonceStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
splitResp := strings.Split(resp.Reply, " ")
|
||||
if len(splitResp) != 3 {
|
||||
return newProtocolError("invalid AUTHCHALLENGE response")
|
||||
}
|
||||
serverHashStr := strings.TrimPrefix(splitResp[1], "SERVERHASH=")
|
||||
if serverHashStr == splitResp[1] {
|
||||
return newProtocolError("missing SERVERHASH")
|
||||
}
|
||||
serverHash, err := hex.DecodeString(serverHashStr)
|
||||
if err != nil {
|
||||
return newProtocolError("failed to decode ServerHash: %v", err)
|
||||
}
|
||||
if len(serverHash) != authHashLength {
|
||||
return newProtocolError("invalid ServerHash length: %d", len(serverHash))
|
||||
}
|
||||
serverNonceStr := strings.TrimPrefix(splitResp[2], "SERVERNONCE=")
|
||||
if serverNonceStr == splitResp[2] {
|
||||
return newProtocolError("missing SERVERNONCE")
|
||||
}
|
||||
serverNonce, err := hex.DecodeString(serverNonceStr)
|
||||
if err != nil {
|
||||
return newProtocolError("failed to decode ServerNonce: %v", err)
|
||||
}
|
||||
if len(serverNonce) != authNonceLength {
|
||||
return newProtocolError("invalid ServerNonce length: %d", len(serverNonce))
|
||||
}
|
||||
|
||||
// Validate the ServerHash.
|
||||
m := hmac.New(sha256.New, []byte(authServerHashKey))
|
||||
m.Write(cookie)
|
||||
m.Write(clientNonce[:])
|
||||
m.Write(serverNonce)
|
||||
dervServerHash := m.Sum(nil)
|
||||
if !hmac.Equal(serverHash, dervServerHash) {
|
||||
return newProtocolError("invalid ServerHash: mismatch")
|
||||
}
|
||||
|
||||
// Calculate the ClientHash, and issue the AUTHENTICATE.
|
||||
m = hmac.New(sha256.New, []byte(authClientHashKey))
|
||||
m.Write(cookie)
|
||||
m.Write(clientNonce[:])
|
||||
m.Write(serverNonce)
|
||||
clientHash := m.Sum(nil)
|
||||
clientHashStr := hex.EncodeToString(clientHash)
|
||||
|
||||
_, err = c.Request("%s %s", cmdAuthenticate, clientHashStr)
|
||||
c.isAuthenticated = err == nil
|
||||
return err
|
||||
} else if pi.AuthMethods[authMethodPassword] {
|
||||
// Despite the name HASHEDPASSWORD, the raw password is actually sent.
|
||||
// According to the code, this can either be a QuotedString, or base16
|
||||
// encoded, so go with the later since it's easier to handle.
|
||||
if password == "" {
|
||||
return newProtocolError("password auth needs a password")
|
||||
}
|
||||
passwordStr := hex.EncodeToString([]byte(password))
|
||||
_, err = c.Request("%s %s", cmdAuthenticate, passwordStr)
|
||||
c.isAuthenticated = err == nil
|
||||
return err
|
||||
}
|
||||
return newProtocolError("no supported authentication methods")
|
||||
}
|
|
@ -0,0 +1,149 @@
|
|||
// cmd_onion.go - various onion service commands: ADD_ONION, DEL_ONION...
|
||||
//
|
||||
// To the extent possible under law, David Stainton waived all copyright
|
||||
// and related or neighboring rights to this module of bulb, using the creative
|
||||
// commons "cc0" public domain dedication. See LICENSE or
|
||||
// <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
|
||||
|
||||
package bulb
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/yawning/bulb/utils/pkcs1"
|
||||
)
|
||||
|
||||
// OnionInfo is the result of the AddOnion command.
|
||||
type OnionInfo struct {
|
||||
OnionID string
|
||||
PrivateKey crypto.PrivateKey
|
||||
|
||||
RawResponse *Response
|
||||
}
|
||||
|
||||
// OnionPrivateKey is a unknown Onion private key (crypto.PublicKey).
|
||||
type OnionPrivateKey struct {
|
||||
KeyType string
|
||||
Key string
|
||||
}
|
||||
|
||||
// OnionPortSpec is a Onion VirtPort/Target pair.
|
||||
type OnionPortSpec struct {
|
||||
VirtPort uint16
|
||||
Target string
|
||||
}
|
||||
|
||||
// AddOnion issues an ADD_ONION command and returns the parsed response.
|
||||
func (c *Conn) AddOnion(ports []OnionPortSpec, key crypto.PrivateKey, oneshot bool) (*OnionInfo, error) {
|
||||
const keyTypeRSA = "RSA1024"
|
||||
var err error
|
||||
|
||||
var portStr string
|
||||
if ports == nil {
|
||||
return nil, newProtocolError("invalid port specification")
|
||||
}
|
||||
for _, v := range ports {
|
||||
portStr += fmt.Sprintf(" Port=%d", v.VirtPort)
|
||||
if v.Target != "" {
|
||||
portStr += "," + v.Target
|
||||
}
|
||||
}
|
||||
|
||||
var hsKeyType, hsKeyStr string
|
||||
if key != nil {
|
||||
switch t := key.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
rsaPK, _ := key.(*rsa.PrivateKey)
|
||||
if rsaPK.N.BitLen() != 1024 {
|
||||
return nil, newProtocolError("invalid RSA key size")
|
||||
}
|
||||
pkDER, err := pkcs1.EncodePrivateKeyDER(rsaPK)
|
||||
if err != nil {
|
||||
return nil, newProtocolError("failed to serialize RSA key: %v", err)
|
||||
}
|
||||
hsKeyType = keyTypeRSA
|
||||
hsKeyStr = base64.StdEncoding.EncodeToString(pkDER)
|
||||
case *OnionPrivateKey:
|
||||
genericPK, _ := key.(*OnionPrivateKey)
|
||||
hsKeyType = genericPK.KeyType
|
||||
hsKeyStr = genericPK.Key
|
||||
default:
|
||||
return nil, newProtocolError("unsupported private key type: %v", t)
|
||||
}
|
||||
}
|
||||
|
||||
var resp *Response
|
||||
if hsKeyStr == "" {
|
||||
flags := " Flags=DiscardPK"
|
||||
if !oneshot {
|
||||
flags = ""
|
||||
}
|
||||
resp, err = c.Request("ADD_ONION NEW:BEST%s%s", portStr, flags)
|
||||
} else {
|
||||
resp, err = c.Request("ADD_ONION %s:%s%s", hsKeyType, hsKeyStr, portStr)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse out the response.
|
||||
var serviceID string
|
||||
var hsPrivateKey crypto.PrivateKey
|
||||
for _, l := range resp.Data {
|
||||
const (
|
||||
serviceIDPrefix = "ServiceID="
|
||||
privateKeyPrefix = "PrivateKey="
|
||||
)
|
||||
|
||||
if strings.HasPrefix(l, serviceIDPrefix) {
|
||||
serviceID = strings.TrimPrefix(l, serviceIDPrefix)
|
||||
} else if strings.HasPrefix(l, privateKeyPrefix) {
|
||||
if oneshot || hsKeyStr != "" {
|
||||
return nil, newProtocolError("received an unexpected private key")
|
||||
}
|
||||
hsKeyStr = strings.TrimPrefix(l, privateKeyPrefix)
|
||||
splitKey := strings.SplitN(hsKeyStr, ":", 2)
|
||||
if len(splitKey) != 2 {
|
||||
return nil, newProtocolError("failed to parse private key type")
|
||||
}
|
||||
|
||||
switch splitKey[0] {
|
||||
case keyTypeRSA:
|
||||
keyBlob, err := base64.StdEncoding.DecodeString(splitKey[1])
|
||||
if err != nil {
|
||||
return nil, newProtocolError("failed to base64 decode RSA key: %v", err)
|
||||
}
|
||||
hsPrivateKey, _, err = pkcs1.DecodePrivateKeyDER(keyBlob)
|
||||
if err != nil {
|
||||
return nil, newProtocolError("failed to deserialize RSA key: %v", err)
|
||||
}
|
||||
default:
|
||||
hsPrivateKey := new(OnionPrivateKey)
|
||||
hsPrivateKey.KeyType = splitKey[0]
|
||||
hsPrivateKey.Key = splitKey[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
if serviceID == "" {
|
||||
// This should *NEVER* happen, since the command succeded, and the spec
|
||||
// guarantees that this will always be present.
|
||||
return nil, newProtocolError("failed to determine service ID")
|
||||
}
|
||||
|
||||
oi := new(OnionInfo)
|
||||
oi.RawResponse = resp
|
||||
oi.OnionID = serviceID
|
||||
oi.PrivateKey = hsPrivateKey
|
||||
|
||||
return oi, nil
|
||||
}
|
||||
|
||||
// DeleteOnion issues a DEL_ONION command and returns the parsed response.
|
||||
func (c *Conn) DeleteOnion(serviceID string) error {
|
||||
_, err := c.Request("DEL_ONION %s", serviceID)
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,95 @@
|
|||
// cmd_protocolinfo.go - PROTOCOLINFO command.
|
||||
//
|
||||
// To the extent possible under law, Yawning Angel waived all copyright
|
||||
// and related or neighboring rights to bulb, using the creative
|
||||
// commons "cc0" public domain dedication. See LICENSE or
|
||||
// <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
|
||||
|
||||
package bulb
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/yawning/bulb/utils"
|
||||
)
|
||||
|
||||
// ProtocolInfo is the result of the ProtocolInfo command.
|
||||
type ProtocolInfo struct {
|
||||
AuthMethods map[string]bool
|
||||
CookieFile string
|
||||
TorVersion string
|
||||
|
||||
RawResponse *Response
|
||||
}
|
||||
|
||||
// ProtocolInfo issues a PROTOCOLINFO command and returns the parsed response.
|
||||
func (c *Conn) ProtocolInfo() (*ProtocolInfo, error) {
|
||||
// In the pre-authentication state, only one PROTOCOLINFO command
|
||||
// may be issued. Cache the value returned so that subsequent
|
||||
// calls continue to work.
|
||||
if !c.isAuthenticated && c.cachedPI != nil {
|
||||
return c.cachedPI, nil
|
||||
}
|
||||
|
||||
resp, err := c.Request("PROTOCOLINFO")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse out the PIVERSION to make sure it speaks something we understand.
|
||||
if len(resp.Data) < 1 {
|
||||
return nil, newProtocolError("missing PIVERSION")
|
||||
}
|
||||
switch resp.Data[0] {
|
||||
case "1":
|
||||
return nil, newProtocolError("invalid PIVERSION: '%s'", resp.Reply)
|
||||
default:
|
||||
}
|
||||
|
||||
// Parse out the rest of the lines.
|
||||
pi := new(ProtocolInfo)
|
||||
pi.RawResponse = resp
|
||||
pi.AuthMethods = make(map[string]bool)
|
||||
for i := 1; i < len(resp.Data); i++ {
|
||||
splitLine := utils.SplitQuoted(resp.Data[i], '"', ' ')
|
||||
switch splitLine[0] {
|
||||
case "AUTH":
|
||||
// Parse an AuthLine detailing how to authenticate.
|
||||
if len(splitLine) < 2 {
|
||||
continue
|
||||
}
|
||||
methods := strings.TrimPrefix(splitLine[1], "METHODS=")
|
||||
if methods == splitLine[1] {
|
||||
continue
|
||||
}
|
||||
for _, meth := range strings.Split(methods, ",") {
|
||||
pi.AuthMethods[meth] = true
|
||||
}
|
||||
|
||||
if len(splitLine) < 3 {
|
||||
continue
|
||||
}
|
||||
cookiePath := strings.TrimPrefix(splitLine[2], "COOKIEFILE=")
|
||||
if cookiePath == splitLine[2] {
|
||||
continue
|
||||
}
|
||||
pi.CookieFile, _ = strconv.Unquote(cookiePath)
|
||||
case "VERSION":
|
||||
// Parse a VersionLine detailing the Tor version.
|
||||
if len(splitLine) < 2 {
|
||||
continue
|
||||
}
|
||||
torVersion := strings.TrimPrefix(splitLine[1], "Tor=")
|
||||
if torVersion == splitLine[1] {
|
||||
continue
|
||||
}
|
||||
pi.TorVersion, _ = strconv.Unquote(torVersion)
|
||||
default: // MUST ignore unsupported InfoLines.
|
||||
}
|
||||
}
|
||||
if !c.isAuthenticated {
|
||||
c.cachedPI = pi
|
||||
}
|
||||
return pi, nil
|
||||
}
|
|
@ -0,0 +1,233 @@
|
|||
// conn.go - Controller connection instance.
|
||||
//
|
||||
// To the extent possible under law, Yawning Angel waived all copyright
|
||||
// and related or neighboring rights to bulb, using the creative
|
||||
// commons "cc0" public domain dedication. See LICENSE or
|
||||
// <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
|
||||
|
||||
// Package bulb is a Go language interface to a Tor control port.
|
||||
package bulb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
gofmt "fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/textproto"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
maxEventBacklog = 16
|
||||
maxResponseBacklog = 16
|
||||
)
|
||||
|
||||
// ErrNoAsyncReader is the error returned when the asynchronous event handling
|
||||
// is requested, but the helper go routine has not been started.
|
||||
var ErrNoAsyncReader = errors.New("event requested without an async reader")
|
||||
|
||||
// Conn is a control port connection instance.
|
||||
type Conn struct {
|
||||
conn *textproto.Conn
|
||||
isAuthenticated bool
|
||||
debugLog bool
|
||||
cachedPI *ProtocolInfo
|
||||
|
||||
asyncReaderLock sync.Mutex
|
||||
asyncReaderRunning bool
|
||||
eventChan chan *Response
|
||||
respChan chan *Response
|
||||
closeWg sync.WaitGroup
|
||||
|
||||
rdErrLock sync.Mutex
|
||||
rdErr error
|
||||
}
|
||||
|
||||
func (c *Conn) setRdErr(err error, force bool) {
|
||||
c.rdErrLock.Lock()
|
||||
defer c.rdErrLock.Unlock()
|
||||
if c.rdErr == nil || force {
|
||||
c.rdErr = err
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) getRdErr() error {
|
||||
c.rdErrLock.Lock()
|
||||
defer c.rdErrLock.Unlock()
|
||||
return c.rdErr
|
||||
}
|
||||
|
||||
func (c *Conn) isAsyncReaderRunning() bool {
|
||||
c.asyncReaderLock.Lock()
|
||||
defer c.asyncReaderLock.Unlock()
|
||||
return c.asyncReaderRunning
|
||||
}
|
||||
|
||||
func (c *Conn) asyncReader() {
|
||||
for {
|
||||
resp, err := c.ReadResponse()
|
||||
if err != nil {
|
||||
c.setRdErr(err, false)
|
||||
break
|
||||
}
|
||||
if resp.IsAsync() {
|
||||
c.eventChan <- resp
|
||||
} else {
|
||||
c.respChan <- resp
|
||||
}
|
||||
}
|
||||
close(c.eventChan)
|
||||
close(c.respChan)
|
||||
c.closeWg.Done()
|
||||
|
||||
// In theory, we would lock and set asyncReaderRunning to false here, but
|
||||
// once it's started, the only way it returns is if there is a catastrophic
|
||||
// failure, or a graceful shutdown. Changing this will require redoing how
|
||||
// Close() works.
|
||||
}
|
||||
|
||||
// Debug enables/disables debug logging of control port chatter.
|
||||
func (c *Conn) Debug(enable bool) {
|
||||
c.debugLog = enable
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (c *Conn) Close() error {
|
||||
c.asyncReaderLock.Lock()
|
||||
defer c.asyncReaderLock.Unlock()
|
||||
|
||||
err := c.conn.Close()
|
||||
if err != nil && c.asyncReaderRunning {
|
||||
c.closeWg.Wait()
|
||||
}
|
||||
c.setRdErr(io.ErrClosedPipe, true)
|
||||
return err
|
||||
}
|
||||
|
||||
// StartAsyncReader starts the asynchronous reader go routine that allows
|
||||
// asynchronous events to be handled. It must not be called simultaniously
|
||||
// with Read, Request, or ReadResponse or undefined behavior will occur.
|
||||
func (c *Conn) StartAsyncReader() {
|
||||
c.asyncReaderLock.Lock()
|
||||
defer c.asyncReaderLock.Unlock()
|
||||
if c.asyncReaderRunning {
|
||||
return
|
||||
}
|
||||
|
||||
// Allocate the channels and kick off the read worker.
|
||||
c.eventChan = make(chan *Response, maxEventBacklog)
|
||||
c.respChan = make(chan *Response, maxResponseBacklog)
|
||||
c.closeWg.Add(1)
|
||||
go c.asyncReader()
|
||||
c.asyncReaderRunning = true
|
||||
}
|
||||
|
||||
// NextEvent returns the next asynchronous event received, blocking if
|
||||
// neccecary. In order to enable asynchronous event handling, StartAsyncReader
|
||||
// must be called first.
|
||||
func (c *Conn) NextEvent() (*Response, error) {
|
||||
if err := c.getRdErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !c.isAsyncReaderRunning() {
|
||||
return nil, ErrNoAsyncReader
|
||||
}
|
||||
|
||||
resp, ok := <-c.eventChan
|
||||
if resp != nil {
|
||||
return resp, nil
|
||||
} else if !ok {
|
||||
return nil, io.ErrClosedPipe
|
||||
}
|
||||
panic("BUG: NextEvent() returned a nil response and error")
|
||||
}
|
||||
|
||||
// Request sends a raw control port request and returns the response.
|
||||
// If the async. reader is not currently running, events received while waiting
|
||||
// for the response will be silently dropped. Calling Request simultaniously
|
||||
// with StartAsyncReader, Read, Write, or ReadResponse will lead to undefined
|
||||
// behavior.
|
||||
func (c *Conn) Request(fmt string, args ...interface{}) (*Response, error) {
|
||||
if err := c.getRdErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
asyncResp := c.isAsyncReaderRunning()
|
||||
|
||||
if c.debugLog {
|
||||
log.Printf("C: %s", gofmt.Sprintf(fmt, args...))
|
||||
}
|
||||
|
||||
id, err := c.conn.Cmd(fmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.conn.StartResponse(id)
|
||||
defer c.conn.EndResponse(id)
|
||||
var resp *Response
|
||||
if asyncResp {
|
||||
var ok bool
|
||||
resp, ok = <-c.respChan
|
||||
if resp == nil && !ok {
|
||||
return nil, io.ErrClosedPipe
|
||||
}
|
||||
} else {
|
||||
// Event handing requires the asyncReader() goroutine, try to get a
|
||||
// response, while silently swallowing events.
|
||||
for resp == nil || resp.IsAsync() {
|
||||
resp, err = c.ReadResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
if resp == nil {
|
||||
panic("BUG: Request() returned a nil response and error")
|
||||
}
|
||||
if resp.IsOk() {
|
||||
return resp, nil
|
||||
}
|
||||
return resp, resp.Err
|
||||
}
|
||||
|
||||
// Read reads directly from the control port connection. Mixing this call
|
||||
// with Request, ReadResponse, or asynchronous events will lead to undefined
|
||||
// behavior.
|
||||
func (c *Conn) Read(p []byte) (int, error) {
|
||||
return c.conn.R.Read(p)
|
||||
}
|
||||
|
||||
// Write writes directly from the control port connection. Mixing this call
|
||||
// with Request will lead to undefined behavior.
|
||||
func (c *Conn) Write(p []byte) (int, error) {
|
||||
n, err := c.conn.W.Write(p)
|
||||
if err == nil {
|
||||
// If the write succeeds, but the flush fails, n will be incorrect...
|
||||
return n, c.conn.W.Flush()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Dial connects to a given network/address and returns a new Conn for the
|
||||
// connection.
|
||||
func Dial(network, addr string) (*Conn, error) {
|
||||
c, err := net.Dial(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewConn(c), nil
|
||||
}
|
||||
|
||||
// NewConn returns a new Conn using c for I/O.
|
||||
func NewConn(c io.ReadWriteCloser) *Conn {
|
||||
conn := new(Conn)
|
||||
conn.conn = textproto.NewConn(c)
|
||||
return conn
|
||||
}
|
||||
|
||||
func newProtocolError(fmt string, args ...interface{}) textproto.ProtocolError {
|
||||
return textproto.ProtocolError(gofmt.Sprintf(fmt, args...))
|
||||
}
|
||||
|
||||
var _ io.ReadWriteCloser = (*Conn)(nil)
|
|
@ -0,0 +1,54 @@
|
|||
// dialer.go - Tor backed proxy.Dialer.
|
||||
//
|
||||
// To the extent possible under law, Yawning Angel waived all copyright
|
||||
// and related or neighboring rights to bulb, using the creative
|
||||
// commons "cc0" public domain dedication. See LICENSE or
|
||||
// <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
|
||||
|
||||
package bulb
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// Dialer returns a proxy.Dialer for the given Tor instance.
|
||||
func (c *Conn) Dialer(auth *proxy.Auth) (proxy.Dialer, error) {
|
||||
const (
|
||||
cmdGetInfo = "GETINFO"
|
||||
socksListeners = "net/listeners/socks"
|
||||
unixPrefix = "unix:"
|
||||
)
|
||||
|
||||
// Query for the SOCKS listeners via a GETINFO request.
|
||||
resp, err := c.Request("%s %s", cmdGetInfo, socksListeners)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(resp.Data) != 1 {
|
||||
return nil, newProtocolError("no SOCKS listeners configured")
|
||||
}
|
||||
splitResp := strings.Split(resp.Data[0], " ")
|
||||
if len(splitResp) < 1 {
|
||||
return nil, newProtocolError("no SOCKS listeners configured")
|
||||
}
|
||||
|
||||
// The first listener will have a "net/listeners/socks=" prefix, and all
|
||||
// entries are QuotedStrings.
|
||||
laddrStr := strings.TrimPrefix(splitResp[0], socksListeners+"=")
|
||||
if laddrStr == splitResp[0] {
|
||||
return nil, newProtocolError("failed to parse SOCKS listener")
|
||||
}
|
||||
laddrStr, _ = strconv.Unquote(laddrStr)
|
||||
|
||||
// Construct the proxyDialer.
|
||||
if strings.HasPrefix(laddrStr, unixPrefix) {
|
||||
unixPath := strings.TrimPrefix(laddrStr, unixPrefix)
|
||||
return proxy.SOCKS5("unix", unixPath, auth, proxy.Direct)
|
||||
}
|
||||
|
||||
return proxy.SOCKS5("tcp", laddrStr, auth, proxy.Direct)
|
||||
}
|
|
@ -0,0 +1,87 @@
|
|||
// listener.go - Tor backed net.Listener.
|
||||
//
|
||||
// To the extent possible under law, Yawning Angel waived all copyright
|
||||
// and related or neighboring rights to bulb, using the creative
|
||||
// commons "cc0" public domain dedication. See LICENSE or
|
||||
// <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
|
||||
|
||||
package bulb
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type onionAddr struct {
|
||||
info *OnionInfo
|
||||
port uint16
|
||||
}
|
||||
|
||||
func (a *onionAddr) Network() string {
|
||||
return "tcp"
|
||||
}
|
||||
|
||||
func (a *onionAddr) String() string {
|
||||
return fmt.Sprintf("%s.onion:%d", a.info.OnionID, a.port)
|
||||
}
|
||||
|
||||
type onionListener struct {
|
||||
addr *onionAddr
|
||||
ctrlConn *Conn
|
||||
listener net.Listener
|
||||
}
|
||||
|
||||
func (l *onionListener) Accept() (net.Conn, error) {
|
||||
return l.listener.Accept()
|
||||
}
|
||||
|
||||
func (l *onionListener) Close() (err error) {
|
||||
if err = l.listener.Close(); err == nil {
|
||||
// Only delete the onion once.
|
||||
err = l.ctrlConn.DeleteOnion(l.addr.info.OnionID)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *onionListener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
// Listener returns a net.Listener backed by a Onion Service, optionally
|
||||
// having Tor generate an ephemeral private key. Regardless of the status of
|
||||
// the returned Listener, the Onion Service will be torn down when the control
|
||||
// connection is closed.
|
||||
//
|
||||
// WARNING: Only one port can be listened to per PrivateKey if this interface
|
||||
// is used. To bind to more ports, use the AddOnion call directly.
|
||||
func (c *Conn) Listener(port uint16, key crypto.PrivateKey) (net.Listener, error) {
|
||||
const (
|
||||
loopbackAddr = "127.0.0.1:0"
|
||||
)
|
||||
|
||||
// Listen on the loopback interface.
|
||||
tcpListener, err := net.Listen("tcp4", loopbackAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tAddr, ok := tcpListener.Addr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
tcpListener.Close()
|
||||
return nil, newProtocolError("failed to extract local port")
|
||||
}
|
||||
|
||||
// Create the onion.
|
||||
ports := []OnionPortSpec{{port, strconv.FormatUint((uint64)(tAddr.Port), 10)}}
|
||||
oi, err := c.AddOnion(ports, key, key == nil)
|
||||
if err != nil {
|
||||
tcpListener.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oa := &onionAddr{info: oi, port: port}
|
||||
ol := &onionListener{addr: oa, ctrlConn: c, listener: tcpListener}
|
||||
|
||||
return ol, nil
|
||||
}
|
|
@ -0,0 +1,125 @@
|
|||
// response.go - Generic response handler
|
||||
//
|
||||
// To the extent possible under law, Yawning Angel waived all copyright
|
||||
// and related or neighboring rights to bulb, using the creative
|
||||
// commons "cc0" public domain dedication. See LICENSE or
|
||||
// <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
|
||||
|
||||
package bulb
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Response is a response to a control port command, or an asyncrhonous event.
|
||||
type Response struct {
|
||||
// Err is the status code and string representation associated with a
|
||||
// response. Responses that have completed successfully will also have
|
||||
// Err set to indicate such.
|
||||
Err *textproto.Error
|
||||
|
||||
// Reply is the text on the EndReplyLine of the response.
|
||||
Reply string
|
||||
|
||||
// Data is the MidReplyLines/DataReplyLines of the response. Dot encoded
|
||||
// data is "decoded" and presented as a single string (terminal ".CRLF"
|
||||
// removed, all intervening CRs stripped).
|
||||
Data []string
|
||||
|
||||
// RawLines is all of the lines of a response, without CRLFs.
|
||||
RawLines []string
|
||||
}
|
||||
|
||||
// IsOk returns true if the response status code indicates success or
|
||||
// an asynchronous event.
|
||||
func (r *Response) IsOk() bool {
|
||||
switch r.Err.Code {
|
||||
case StatusOk, StatusOkUnneccecary, StatusAsyncEvent:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsAsync returns true if the response is an asyncrhonous event.
|
||||
func (r *Response) IsAsync() bool {
|
||||
return r.Err.Code == StatusAsyncEvent
|
||||
}
|
||||
|
||||
// ReadResponse returns the next response object. Calling this
|
||||
// simultaniously with Read, Request, or StartAsyncReader will lead to
|
||||
// undefined behavior
|
||||
func (c *Conn) ReadResponse() (*Response, error) {
|
||||
var resp *Response
|
||||
var statusCode int
|
||||
for {
|
||||
line, err := c.conn.ReadLine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c.debugLog {
|
||||
log.Printf("S: %s", line)
|
||||
}
|
||||
|
||||
// Parse the line that was just read.
|
||||
if len(line) < 4 {
|
||||
return nil, newProtocolError("truncated response: '%s'", line)
|
||||
}
|
||||
if code, err := strconv.Atoi(line[0:3]); err != nil {
|
||||
return nil, newProtocolError("invalid status code: '%s'", line[0:3])
|
||||
} else if code < 100 {
|
||||
return nil, newProtocolError("invalid status code: '%s'", line[0:3])
|
||||
} else if resp == nil {
|
||||
resp = new(Response)
|
||||
statusCode = code
|
||||
} else if code != statusCode {
|
||||
// The status code should stay fixed for all lines of the
|
||||
// response, since events can't be interleaved with response
|
||||
// lines.
|
||||
return nil, newProtocolError("status code changed: %03d != %03d", code, statusCode)
|
||||
}
|
||||
if resp.RawLines == nil {
|
||||
resp.RawLines = make([]string, 0, 1)
|
||||
}
|
||||
|
||||
if line[3] == ' ' {
|
||||
// Final line in the response.
|
||||
resp.Reply = line[4:]
|
||||
resp.Err = statusCodeToError(statusCode, resp.Reply)
|
||||
resp.RawLines = append(resp.RawLines, line)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if resp.Data == nil {
|
||||
resp.Data = make([]string, 0, 1)
|
||||
}
|
||||
switch line[3] {
|
||||
case '-':
|
||||
// Continuation, keep reading.
|
||||
resp.Data = append(resp.Data, line[4:])
|
||||
resp.RawLines = append(resp.RawLines, line)
|
||||
case '+':
|
||||
// A "dot-encoded" payload follows.
|
||||
resp.Data = append(resp.Data, line[4:])
|
||||
resp.RawLines = append(resp.RawLines, line)
|
||||
dotBody, err := c.conn.ReadDotBytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c.debugLog {
|
||||
log.Printf("S: [dot encoded data]")
|
||||
}
|
||||
resp.Data = append(resp.Data, string(dotBody))
|
||||
dotLines := strings.Split(string(dotBody), "\n")
|
||||
for _, dotLine := range dotLines[:len(dotLines)-1] {
|
||||
resp.RawLines = append(resp.RawLines, dotLine)
|
||||
}
|
||||
resp.RawLines = append(resp.RawLines, ".")
|
||||
default:
|
||||
return nil, newProtocolError("invalid separator: '%c'", line[3])
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
// status.go - Status codes.
|
||||
//
|
||||
// To the extent possible under law, Yawning Angel waived all copyright
|
||||
// and related or neighboring rights to bulb, using the creative
|
||||
// commons "cc0" public domain dedication. See LICENSE or
|
||||
// <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
|
||||
|
||||
package bulb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"net/textproto"
|
||||
)
|
||||
|
||||
// The various control port StatusCode constants.
|
||||
const (
|
||||
StatusOk = 250
|
||||
StatusOkUnneccecary = 251
|
||||
|
||||
StatusErrResourceExhausted = 451
|
||||
StatusErrSyntaxError = 500
|
||||
StatusErrUnrecognizedCmd = 510
|
||||
StatusErrUnimplementedCmd = 511
|
||||
StatusErrSyntaxErrorArg = 512
|
||||
StatusErrUnrecognizedCmdArg = 513
|
||||
StatusErrAuthenticationRequired = 514
|
||||
StatusErrBadAuthentication = 515
|
||||
StatusErrUnspecifiedTorError = 550
|
||||
StatusErrInternalError = 551
|
||||
StatusErrUnrecognizedEntity = 552
|
||||
StatusErrInvalidConfigValue = 553
|
||||
StatusErrInvalidDescriptor = 554
|
||||
StatusErrUnmanagedEntity = 555
|
||||
|
||||
StatusAsyncEvent = 650
|
||||
)
|
||||
|
||||
var statusCodeStringMap = map[int]string{
|
||||
StatusOk: "OK",
|
||||
StatusOkUnneccecary: "Operation was unnecessary",
|
||||
|
||||
StatusErrResourceExhausted: "Resource exhausted",
|
||||
StatusErrSyntaxError: "Syntax error: protocol",
|
||||
StatusErrUnrecognizedCmd: "Unrecognized command",
|
||||
StatusErrUnimplementedCmd: "Unimplemented command",
|
||||
StatusErrSyntaxErrorArg: "Syntax error in command argument",
|
||||
StatusErrUnrecognizedCmdArg: "Unrecognized command argument",
|
||||
StatusErrAuthenticationRequired: "Authentication required",
|
||||
StatusErrBadAuthentication: "Bad authentication",
|
||||
StatusErrUnspecifiedTorError: "Unspecified Tor error",
|
||||
StatusErrInternalError: "Internal error",
|
||||
StatusErrUnrecognizedEntity: "Unrecognized entity",
|
||||
StatusErrInvalidConfigValue: "Invalid configuration value",
|
||||
StatusErrInvalidDescriptor: "Invalid descriptor",
|
||||
StatusErrUnmanagedEntity: "Unmanaged entity",
|
||||
|
||||
StatusAsyncEvent: "Asynchronous event notification",
|
||||
}
|
||||
|
||||
func statusCodeToError(code int, reply string) *textproto.Error {
|
||||
err := new(textproto.Error)
|
||||
err.Code = code
|
||||
if msg, ok := statusCodeStringMap[code]; ok {
|
||||
trimmedReply := strings.TrimSpace(strings.TrimPrefix(reply, msg))
|
||||
err.Msg = fmt.Sprintf("%s: %s", msg, trimmedReply)
|
||||
} else {
|
||||
err.Msg = fmt.Sprintf("Unknown status code (%03d): %s", code, reply)
|
||||
}
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
// Package backoff implements backoff algorithms for retrying operations.
|
||||
//
|
||||
// Use Retry function for retrying operations that may fail.
|
||||
// If Retry does not meet your needs,
|
||||
// copy/paste the function into your project and modify as you wish.
|
||||
//
|
||||
// There is also Ticker type similar to time.Ticker.
|
||||
// You can use it if you need to work with channels.
|
||||
//
|
||||
// See Examples section below for usage examples.
|
||||
package backoff
|
||||
|
||||
import "time"
|
||||
|
||||
// BackOff is a backoff policy for retrying an operation.
|
||||
type BackOff interface {
|
||||
// NextBackOff returns the duration to wait before retrying the operation,
|
||||
// or backoff.Stop to indicate that no more retries should be made.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// duration := backoff.NextBackOff();
|
||||
// if (duration == backoff.Stop) {
|
||||
// // Do not retry operation.
|
||||
// } else {
|
||||
// // Sleep for duration and retry operation.
|
||||
// }
|
||||
//
|
||||
NextBackOff() time.Duration
|
||||
|
||||
// Reset to initial state.
|
||||
Reset()
|
||||
}
|
||||
|
||||
// Stop indicates that no more retries should be made for use in NextBackOff().
|
||||
const Stop time.Duration = -1
|
||||
|
||||
// ZeroBackOff is a fixed backoff policy whose backoff time is always zero,
|
||||
// meaning that the operation is retried immediately without waiting, indefinitely.
|
||||
type ZeroBackOff struct{}
|
||||
|
||||
func (b *ZeroBackOff) Reset() {}
|
||||
|
||||
func (b *ZeroBackOff) NextBackOff() time.Duration { return 0 }
|
||||
|
||||
// StopBackOff is a fixed backoff policy that always returns backoff.Stop for
|
||||
// NextBackOff(), meaning that the operation should never be retried.
|
||||
type StopBackOff struct{}
|
||||
|
||||
func (b *StopBackOff) Reset() {}
|
||||
|
||||
func (b *StopBackOff) NextBackOff() time.Duration { return Stop }
|
||||
|
||||
// ConstantBackOff is a backoff policy that always returns the same backoff delay.
|
||||
// This is in contrast to an exponential backoff policy,
|
||||
// which returns a delay that grows longer as you call NextBackOff() over and over again.
|
||||
type ConstantBackOff struct {
|
||||
Interval time.Duration
|
||||
}
|
||||
|
||||
func (b *ConstantBackOff) Reset() {}
|
||||
func (b *ConstantBackOff) NextBackOff() time.Duration { return b.Interval }
|
||||
|
||||
func NewConstantBackOff(d time.Duration) *ConstantBackOff {
|
||||
return &ConstantBackOff{Interval: d}
|
||||
}
|
|
@ -0,0 +1,156 @@
|
|||
package backoff
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
/*
|
||||
ExponentialBackOff is a backoff implementation that increases the backoff
|
||||
period for each retry attempt using a randomization function that grows exponentially.
|
||||
|
||||
NextBackOff() is calculated using the following formula:
|
||||
|
||||
randomized interval =
|
||||
RetryInterval * (random value in range [1 - RandomizationFactor, 1 + RandomizationFactor])
|
||||
|
||||
In other words NextBackOff() will range between the randomization factor
|
||||
percentage below and above the retry interval.
|
||||
|
||||
For example, given the following parameters:
|
||||
|
||||
RetryInterval = 2
|
||||
RandomizationFactor = 0.5
|
||||
Multiplier = 2
|
||||
|
||||
the actual backoff period used in the next retry attempt will range between 1 and 3 seconds,
|
||||
multiplied by the exponential, that is, between 2 and 6 seconds.
|
||||
|
||||
Note: MaxInterval caps the RetryInterval and not the randomized interval.
|
||||
|
||||
If the time elapsed since an ExponentialBackOff instance is created goes past the
|
||||
MaxElapsedTime, then the method NextBackOff() starts returning backoff.Stop.
|
||||
|
||||
The elapsed time can be reset by calling Reset().
|
||||
|
||||
Example: Given the following default arguments, for 10 tries the sequence will be,
|
||||
and assuming we go over the MaxElapsedTime on the 10th try:
|
||||
|
||||
Request # RetryInterval (seconds) Randomized Interval (seconds)
|
||||
|
||||
1 0.5 [0.25, 0.75]
|
||||
2 0.75 [0.375, 1.125]
|
||||
3 1.125 [0.562, 1.687]
|
||||
4 1.687 [0.8435, 2.53]
|
||||
5 2.53 [1.265, 3.795]
|
||||
6 3.795 [1.897, 5.692]
|
||||
7 5.692 [2.846, 8.538]
|
||||
8 8.538 [4.269, 12.807]
|
||||
9 12.807 [6.403, 19.210]
|
||||
10 19.210 backoff.Stop
|
||||
|
||||
Note: Implementation is not thread-safe.
|
||||
*/
|
||||
type ExponentialBackOff struct {
|
||||
InitialInterval time.Duration
|
||||
RandomizationFactor float64
|
||||
Multiplier float64
|
||||
MaxInterval time.Duration
|
||||
// After MaxElapsedTime the ExponentialBackOff stops.
|
||||
// It never stops if MaxElapsedTime == 0.
|
||||
MaxElapsedTime time.Duration
|
||||
Clock Clock
|
||||
|
||||
currentInterval time.Duration
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
// Clock is an interface that returns current time for BackOff.
|
||||
type Clock interface {
|
||||
Now() time.Time
|
||||
}
|
||||
|
||||
// Default values for ExponentialBackOff.
|
||||
const (
|
||||
DefaultInitialInterval = 500 * time.Millisecond
|
||||
DefaultRandomizationFactor = 0.5
|
||||
DefaultMultiplier = 1.5
|
||||
DefaultMaxInterval = 60 * time.Second
|
||||
DefaultMaxElapsedTime = 15 * time.Minute
|
||||
)
|
||||
|
||||
// NewExponentialBackOff creates an instance of ExponentialBackOff using default values.
|
||||
func NewExponentialBackOff() *ExponentialBackOff {
|
||||
b := &ExponentialBackOff{
|
||||
InitialInterval: DefaultInitialInterval,
|
||||
RandomizationFactor: DefaultRandomizationFactor,
|
||||
Multiplier: DefaultMultiplier,
|
||||
MaxInterval: DefaultMaxInterval,
|
||||
MaxElapsedTime: DefaultMaxElapsedTime,
|
||||
Clock: SystemClock,
|
||||
}
|
||||
if b.RandomizationFactor < 0 {
|
||||
b.RandomizationFactor = 0
|
||||
} else if b.RandomizationFactor > 1 {
|
||||
b.RandomizationFactor = 1
|
||||
}
|
||||
b.Reset()
|
||||
return b
|
||||
}
|
||||
|
||||
type systemClock struct{}
|
||||
|
||||
func (t systemClock) Now() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
// SystemClock implements Clock interface that uses time.Now().
|
||||
var SystemClock = systemClock{}
|
||||
|
||||
// Reset the interval back to the initial retry interval and restarts the timer.
|
||||
func (b *ExponentialBackOff) Reset() {
|
||||
b.currentInterval = b.InitialInterval
|
||||
b.startTime = b.Clock.Now()
|
||||
}
|
||||
|
||||
// NextBackOff calculates the next backoff interval using the formula:
|
||||
// Randomized interval = RetryInterval +/- (RandomizationFactor * RetryInterval)
|
||||
func (b *ExponentialBackOff) NextBackOff() time.Duration {
|
||||
// Make sure we have not gone over the maximum elapsed time.
|
||||
if b.MaxElapsedTime != 0 && b.GetElapsedTime() > b.MaxElapsedTime {
|
||||
return Stop
|
||||
}
|
||||
defer b.incrementCurrentInterval()
|
||||
return getRandomValueFromInterval(b.RandomizationFactor, rand.Float64(), b.currentInterval)
|
||||
}
|
||||
|
||||
// GetElapsedTime returns the elapsed time since an ExponentialBackOff instance
|
||||
// is created and is reset when Reset() is called.
|
||||
//
|
||||
// The elapsed time is computed using time.Now().UnixNano().
|
||||
func (b *ExponentialBackOff) GetElapsedTime() time.Duration {
|
||||
return b.Clock.Now().Sub(b.startTime)
|
||||
}
|
||||
|
||||
// Increments the current interval by multiplying it with the multiplier.
|
||||
func (b *ExponentialBackOff) incrementCurrentInterval() {
|
||||
// Check for overflow, if overflow is detected set the current interval to the max interval.
|
||||
if float64(b.currentInterval) >= float64(b.MaxInterval)/b.Multiplier {
|
||||
b.currentInterval = b.MaxInterval
|
||||
} else {
|
||||
b.currentInterval = time.Duration(float64(b.currentInterval) * b.Multiplier)
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a random value from the following interval:
|
||||
// [randomizationFactor * currentInterval, randomizationFactor * currentInterval].
|
||||
func getRandomValueFromInterval(randomizationFactor, random float64, currentInterval time.Duration) time.Duration {
|
||||
var delta = randomizationFactor * float64(currentInterval)
|
||||
var minInterval = float64(currentInterval) - delta
|
||||
var maxInterval = float64(currentInterval) + delta
|
||||
|
||||
// Get a random value from the range [minInterval, maxInterval].
|
||||
// The formula used below has a +1 because if the minInterval is 1 and the maxInterval is 3 then
|
||||
// we want a 33% chance for selecting either 1, 2 or 3.
|
||||
return time.Duration(minInterval + (random * (maxInterval - minInterval + 1)))
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
package backoff
|
||||
|
||||
import "time"
|
||||
|
||||
// An Operation is executing by Retry() or RetryNotify().
|
||||
// The operation will be retried using a backoff policy if it returns an error.
|
||||
type Operation func() error
|
||||
|
||||
// Notify is a notify-on-error function. It receives an operation error and
|
||||
// backoff delay if the operation failed (with an error).
|
||||
//
|
||||
// NOTE that if the backoff policy stated to stop retrying,
|
||||
// the notify function isn't called.
|
||||
type Notify func(error, time.Duration)
|
||||
|
||||
// Retry the operation o until it does not return error or BackOff stops.
|
||||
// o is guaranteed to be run at least once.
|
||||
// It is the caller's responsibility to reset b after Retry returns.
|
||||
//
|
||||
// Retry sleeps the goroutine for the duration returned by BackOff after a
|
||||
// failed operation returns.
|
||||
func Retry(o Operation, b BackOff) error { return RetryNotify(o, b, nil) }
|
||||
|
||||
// RetryNotify calls notify function with the error and wait duration
|
||||
// for each failed attempt before sleep.
|
||||
func RetryNotify(operation Operation, b BackOff, notify Notify) error {
|
||||
var err error
|
||||
var next time.Duration
|
||||
|
||||
b.Reset()
|
||||
for {
|
||||
if err = operation(); err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if next = b.NextBackOff(); next == Stop {
|
||||
return err
|
||||
}
|
||||
|
||||
if notify != nil {
|
||||
notify(err, next)
|
||||
}
|
||||
|
||||
time.Sleep(next)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
package backoff
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Ticker holds a channel that delivers `ticks' of a clock at times reported by a BackOff.
|
||||
//
|
||||
// Ticks will continue to arrive when the previous operation is still running,
|
||||
// so operations that take a while to fail could run in quick succession.
|
||||
type Ticker struct {
|
||||
C <-chan time.Time
|
||||
c chan time.Time
|
||||
b BackOff
|
||||
stop chan struct{}
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
// NewTicker returns a new Ticker containing a channel that will send the time at times
|
||||
// specified by the BackOff argument. Ticker is guaranteed to tick at least once.
|
||||
// The channel is closed when Stop method is called or BackOff stops.
|
||||
func NewTicker(b BackOff) *Ticker {
|
||||
c := make(chan time.Time)
|
||||
t := &Ticker{
|
||||
C: c,
|
||||
c: c,
|
||||
b: b,
|
||||
stop: make(chan struct{}),
|
||||
}
|
||||
go t.run()
|
||||
runtime.SetFinalizer(t, (*Ticker).Stop)
|
||||
return t
|
||||
}
|
||||
|
||||
// Stop turns off a ticker. After Stop, no more ticks will be sent.
|
||||
func (t *Ticker) Stop() {
|
||||
t.stopOnce.Do(func() { close(t.stop) })
|
||||
}
|
||||
|
||||
func (t *Ticker) run() {
|
||||
c := t.c
|
||||
defer close(c)
|
||||
t.b.Reset()
|
||||
|
||||
// Ticker is guaranteed to tick at least once.
|
||||
afterC := t.send(time.Now())
|
||||
|
||||
for {
|
||||
if afterC == nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case tick := <-afterC:
|
||||
afterC = t.send(tick)
|
||||
case <-t.stop:
|
||||
t.c = nil // Prevent future ticks from being sent to the channel.
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Ticker) send(tick time.Time) <-chan time.Time {
|
||||
select {
|
||||
case t.c <- tick:
|
||||
case <-t.stop:
|
||||
return nil
|
||||
}
|
||||
|
||||
next := t.b.NextBackOff()
|
||||
if next == Stop {
|
||||
t.Stop()
|
||||
return nil
|
||||
}
|
||||
|
||||
return time.After(next)
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
// Package backoff implements backoff algorithms for retrying operations.
|
||||
//
|
||||
// Use Retry function for retrying operations that may fail.
|
||||
// If Retry does not meet your needs,
|
||||
// copy/paste the function into your project and modify as you wish.
|
||||
//
|
||||
// There is also Ticker type similar to time.Ticker.
|
||||
// You can use it if you need to work with channels.
|
||||
//
|
||||
// See Examples section below for usage examples.
|
||||
package backoff
|
||||
|
||||
import "time"
|
||||
|
||||
// BackOff is a backoff policy for retrying an operation.
|
||||
type BackOff interface {
|
||||
// NextBackOff returns the duration to wait before retrying the operation,
|
||||
// or backoff.Stop to indicate that no more retries should be made.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// duration := backoff.NextBackOff();
|
||||
// if (duration == backoff.Stop) {
|
||||
// // Do not retry operation.
|
||||
// } else {
|
||||
// // Sleep for duration and retry operation.
|
||||
// }
|
||||
//
|
||||
NextBackOff() time.Duration
|
||||
|
||||
// Reset to initial state.
|
||||
Reset()
|
||||
}
|
||||
|
||||
// Stop indicates that no more retries should be made for use in NextBackOff().
|
||||
const Stop time.Duration = -1
|
||||
|
||||
// ZeroBackOff is a fixed backoff policy whose backoff time is always zero,
|
||||
// meaning that the operation is retried immediately without waiting, indefinitely.
|
||||
type ZeroBackOff struct{}
|
||||
|
||||
func (b *ZeroBackOff) Reset() {}
|
||||
|
||||
func (b *ZeroBackOff) NextBackOff() time.Duration { return 0 }
|
||||
|
||||
// StopBackOff is a fixed backoff policy that always returns backoff.Stop for
|
||||
// NextBackOff(), meaning that the operation should never be retried.
|
||||
type StopBackOff struct{}
|
||||
|
||||
func (b *StopBackOff) Reset() {}
|
||||
|
||||
func (b *StopBackOff) NextBackOff() time.Duration { return Stop }
|
||||
|
||||
// ConstantBackOff is a backoff policy that always returns the same backoff delay.
|
||||
// This is in contrast to an exponential backoff policy,
|
||||
// which returns a delay that grows longer as you call NextBackOff() over and over again.
|
||||
type ConstantBackOff struct {
|
||||
Interval time.Duration
|
||||
}
|
||||
|
||||
func (b *ConstantBackOff) Reset() {}
|
||||
func (b *ConstantBackOff) NextBackOff() time.Duration { return b.Interval }
|
||||
|
||||
func NewConstantBackOff(d time.Duration) *ConstantBackOff {
|
||||
return &ConstantBackOff{Interval: d}
|
||||
}
|
|
@ -0,0 +1,156 @@
|
|||
package backoff
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
/*
|
||||
ExponentialBackOff is a backoff implementation that increases the backoff
|
||||
period for each retry attempt using a randomization function that grows exponentially.
|
||||
|
||||
NextBackOff() is calculated using the following formula:
|
||||
|
||||
randomized interval =
|
||||
RetryInterval * (random value in range [1 - RandomizationFactor, 1 + RandomizationFactor])
|
||||
|
||||
In other words NextBackOff() will range between the randomization factor
|
||||
percentage below and above the retry interval.
|
||||
|
||||
For example, given the following parameters:
|
||||
|
||||
RetryInterval = 2
|
||||
RandomizationFactor = 0.5
|
||||
Multiplier = 2
|
||||
|
||||
the actual backoff period used in the next retry attempt will range between 1 and 3 seconds,
|
||||
multiplied by the exponential, that is, between 2 and 6 seconds.
|
||||
|
||||
Note: MaxInterval caps the RetryInterval and not the randomized interval.
|
||||
|
||||
If the time elapsed since an ExponentialBackOff instance is created goes past the
|
||||
MaxElapsedTime, then the method NextBackOff() starts returning backoff.Stop.
|
||||
|
||||
The elapsed time can be reset by calling Reset().
|
||||
|
||||
Example: Given the following default arguments, for 10 tries the sequence will be,
|
||||
and assuming we go over the MaxElapsedTime on the 10th try:
|
||||
|
||||
Request # RetryInterval (seconds) Randomized Interval (seconds)
|
||||
|
||||
1 0.5 [0.25, 0.75]
|
||||
2 0.75 [0.375, 1.125]
|
||||
3 1.125 [0.562, 1.687]
|
||||
4 1.687 [0.8435, 2.53]
|
||||
5 2.53 [1.265, 3.795]
|
||||
6 3.795 [1.897, 5.692]
|
||||
7 5.692 [2.846, 8.538]
|
||||
8 8.538 [4.269, 12.807]
|
||||
9 12.807 [6.403, 19.210]
|
||||
10 19.210 backoff.Stop
|
||||
|
||||
Note: Implementation is not thread-safe.
|
||||
*/
|
||||
type ExponentialBackOff struct {
|
||||
InitialInterval time.Duration
|
||||
RandomizationFactor float64
|
||||
Multiplier float64
|
||||
MaxInterval time.Duration
|
||||
// After MaxElapsedTime the ExponentialBackOff stops.
|
||||
// It never stops if MaxElapsedTime == 0.
|
||||
MaxElapsedTime time.Duration
|
||||
Clock Clock
|
||||
|
||||
currentInterval time.Duration
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
// Clock is an interface that returns current time for BackOff.
|
||||
type Clock interface {
|
||||
Now() time.Time
|
||||
}
|
||||
|
||||
// Default values for ExponentialBackOff.
|
||||
const (
|
||||
DefaultInitialInterval = 500 * time.Millisecond
|
||||
DefaultRandomizationFactor = 0.5
|
||||
DefaultMultiplier = 1.5
|
||||
DefaultMaxInterval = 60 * time.Second
|
||||
DefaultMaxElapsedTime = 15 * time.Minute
|
||||
)
|
||||
|
||||
// NewExponentialBackOff creates an instance of ExponentialBackOff using default values.
|
||||
func NewExponentialBackOff() *ExponentialBackOff {
|
||||
b := &ExponentialBackOff{
|
||||
InitialInterval: DefaultInitialInterval,
|
||||
RandomizationFactor: DefaultRandomizationFactor,
|
||||
Multiplier: DefaultMultiplier,
|
||||
MaxInterval: DefaultMaxInterval,
|
||||
MaxElapsedTime: DefaultMaxElapsedTime,
|
||||
Clock: SystemClock,
|
||||
}
|
||||
if b.RandomizationFactor < 0 {
|
||||
b.RandomizationFactor = 0
|
||||
} else if b.RandomizationFactor > 1 {
|
||||
b.RandomizationFactor = 1
|
||||
}
|
||||
b.Reset()
|
||||
return b
|
||||
}
|
||||
|
||||
type systemClock struct{}
|
||||
|
||||
func (t systemClock) Now() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
// SystemClock implements Clock interface that uses time.Now().
|
||||
var SystemClock = systemClock{}
|
||||
|
||||
// Reset the interval back to the initial retry interval and restarts the timer.
|
||||
func (b *ExponentialBackOff) Reset() {
|
||||
b.currentInterval = b.InitialInterval
|
||||
b.startTime = b.Clock.Now()
|
||||
}
|
||||
|
||||
// NextBackOff calculates the next backoff interval using the formula:
|
||||
// Randomized interval = RetryInterval +/- (RandomizationFactor * RetryInterval)
|
||||
func (b *ExponentialBackOff) NextBackOff() time.Duration {
|
||||
// Make sure we have not gone over the maximum elapsed time.
|
||||
if b.MaxElapsedTime != 0 && b.GetElapsedTime() > b.MaxElapsedTime {
|
||||
return Stop
|
||||
}
|
||||
defer b.incrementCurrentInterval()
|
||||
return getRandomValueFromInterval(b.RandomizationFactor, rand.Float64(), b.currentInterval)
|
||||
}
|
||||
|
||||
// GetElapsedTime returns the elapsed time since an ExponentialBackOff instance
|
||||
// is created and is reset when Reset() is called.
|
||||
//
|
||||
// The elapsed time is computed using time.Now().UnixNano().
|
||||
func (b *ExponentialBackOff) GetElapsedTime() time.Duration {
|
||||
return b.Clock.Now().Sub(b.startTime)
|
||||
}
|
||||
|
||||
// Increments the current interval by multiplying it with the multiplier.
|
||||
func (b *ExponentialBackOff) incrementCurrentInterval() {
|
||||
// Check for overflow, if overflow is detected set the current interval to the max interval.
|
||||
if float64(b.currentInterval) >= float64(b.MaxInterval)/b.Multiplier {
|
||||
b.currentInterval = b.MaxInterval
|
||||
} else {
|
||||
b.currentInterval = time.Duration(float64(b.currentInterval) * b.Multiplier)
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a random value from the following interval:
|
||||
// [randomizationFactor * currentInterval, randomizationFactor * currentInterval].
|
||||
func getRandomValueFromInterval(randomizationFactor, random float64, currentInterval time.Duration) time.Duration {
|
||||
var delta = randomizationFactor * float64(currentInterval)
|
||||
var minInterval = float64(currentInterval) - delta
|
||||
var maxInterval = float64(currentInterval) + delta
|
||||
|
||||
// Get a random value from the range [minInterval, maxInterval].
|
||||
// The formula used below has a +1 because if the minInterval is 1 and the maxInterval is 3 then
|
||||
// we want a 33% chance for selecting either 1, 2 or 3.
|
||||
return time.Duration(minInterval + (random * (maxInterval - minInterval + 1)))
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
package backoff
|
||||
|
||||
import "time"
|
||||
|
||||
// An Operation is executing by Retry() or RetryNotify().
|
||||
// The operation will be retried using a backoff policy if it returns an error.
|
||||
type Operation func() error
|
||||
|
||||
// Notify is a notify-on-error function. It receives an operation error and
|
||||
// backoff delay if the operation failed (with an error).
|
||||
//
|
||||
// NOTE that if the backoff policy stated to stop retrying,
|
||||
// the notify function isn't called.
|
||||
type Notify func(error, time.Duration)
|
||||
|
||||
// Retry the operation o until it does not return error or BackOff stops.
|
||||
// o is guaranteed to be run at least once.
|
||||
// It is the caller's responsibility to reset b after Retry returns.
|
||||
//
|
||||
// Retry sleeps the goroutine for the duration returned by BackOff after a
|
||||
// failed operation returns.
|
||||
func Retry(o Operation, b BackOff) error { return RetryNotify(o, b, nil) }
|
||||
|
||||
// RetryNotify calls notify function with the error and wait duration
|
||||
// for each failed attempt before sleep.
|
||||
func RetryNotify(operation Operation, b BackOff, notify Notify) error {
|
||||
var err error
|
||||
var next time.Duration
|
||||
|
||||
b.Reset()
|
||||
for {
|
||||
if err = operation(); err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if next = b.NextBackOff(); next == Stop {
|
||||
return err
|
||||
}
|
||||
|
||||
if notify != nil {
|
||||
notify(err, next)
|
||||
}
|
||||
|
||||
time.Sleep(next)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
package backoff
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Ticker holds a channel that delivers `ticks' of a clock at times reported by a BackOff.
|
||||
//
|
||||
// Ticks will continue to arrive when the previous operation is still running,
|
||||
// so operations that take a while to fail could run in quick succession.
|
||||
type Ticker struct {
|
||||
C <-chan time.Time
|
||||
c chan time.Time
|
||||
b BackOff
|
||||
stop chan struct{}
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
// NewTicker returns a new Ticker containing a channel that will send the time at times
|
||||
// specified by the BackOff argument. Ticker is guaranteed to tick at least once.
|
||||
// The channel is closed when Stop method is called or BackOff stops.
|
||||
func NewTicker(b BackOff) *Ticker {
|
||||
c := make(chan time.Time)
|
||||
t := &Ticker{
|
||||
C: c,
|
||||
c: c,
|
||||
b: b,
|
||||
stop: make(chan struct{}),
|
||||
}
|
||||
go t.run()
|
||||
runtime.SetFinalizer(t, (*Ticker).Stop)
|
||||
return t
|
||||
}
|
||||
|
||||
// Stop turns off a ticker. After Stop, no more ticks will be sent.
|
||||
func (t *Ticker) Stop() {
|
||||
t.stopOnce.Do(func() { close(t.stop) })
|
||||
}
|
||||
|
||||
func (t *Ticker) run() {
|
||||
c := t.c
|
||||
defer close(c)
|
||||
t.b.Reset()
|
||||
|
||||
// Ticker is guaranteed to tick at least once.
|
||||
afterC := t.send(time.Now())
|
||||
|
||||
for {
|
||||
if afterC == nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case tick := <-afterC:
|
||||
afterC = t.send(tick)
|
||||
case <-t.stop:
|
||||
t.c = nil // Prevent future ticks from being sent to the channel.
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Ticker) send(tick time.Time) <-chan time.Time {
|
||||
select {
|
||||
case t.c <- tick:
|
||||
case <-t.stop:
|
||||
return nil
|
||||
}
|
||||
|
||||
next := t.b.NextBackOff()
|
||||
if next == Stop {
|
||||
t.Stop()
|
||||
return nil
|
||||
}
|
||||
|
||||
return time.After(next)
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
// Package flagenv provides the ability to populate flags from
|
||||
// environment variables.
|
||||
package flagenv
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Specify a prefix for environment variables.
|
||||
var Prefix = ""
|
||||
|
||||
func contains(list []*flag.Flag, f *flag.Flag) bool {
|
||||
for _, i := range list {
|
||||
if i == f {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ParseSet parses the given flagset. The specified prefix will be applied to
|
||||
// the environment variable names.
|
||||
func ParseSet(prefix string, set *flag.FlagSet) error {
|
||||
var explicit []*flag.Flag
|
||||
var all []*flag.Flag
|
||||
set.Visit(func(f *flag.Flag) {
|
||||
explicit = append(explicit, f)
|
||||
})
|
||||
|
||||
var err error
|
||||
set.VisitAll(func(f *flag.Flag) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
all = append(all, f)
|
||||
if !contains(explicit, f) {
|
||||
name := strings.Replace(f.Name, ".", "_", -1)
|
||||
name = strings.Replace(name, "-", "_", -1)
|
||||
if prefix != "" {
|
||||
name = prefix + name
|
||||
}
|
||||
name = strings.ToUpper(name)
|
||||
val := os.Getenv(name)
|
||||
if val != "" {
|
||||
if ferr := f.Value.Set(val); ferr != nil {
|
||||
err = fmt.Errorf("failed to set flag %q with value %q", f.Name, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse will set each defined flag from its corresponding environment
|
||||
// variable . If dots or dash are presents in the flag name, they will be
|
||||
// converted to underscores.
|
||||
//
|
||||
// If Parse fails, a fatal error is issued.
|
||||
func Parse() {
|
||||
if err := ParseSet(Prefix, flag.CommandLine); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,229 @@
|
|||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
// Protocol buffer deep copy and merge.
|
||||
// TODO: RawMessage.
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
"log"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Clone returns a deep copy of a protocol buffer.
|
||||
func Clone(pb Message) Message {
|
||||
in := reflect.ValueOf(pb)
|
||||
if in.IsNil() {
|
||||
return pb
|
||||
}
|
||||
|
||||
out := reflect.New(in.Type().Elem())
|
||||
// out is empty so a merge is a deep copy.
|
||||
mergeStruct(out.Elem(), in.Elem())
|
||||
return out.Interface().(Message)
|
||||
}
|
||||
|
||||
// Merge merges src into dst.
|
||||
// Required and optional fields that are set in src will be set to that value in dst.
|
||||
// Elements of repeated fields will be appended.
|
||||
// Merge panics if src and dst are not the same type, or if dst is nil.
|
||||
func Merge(dst, src Message) {
|
||||
in := reflect.ValueOf(src)
|
||||
out := reflect.ValueOf(dst)
|
||||
if out.IsNil() {
|
||||
panic("proto: nil destination")
|
||||
}
|
||||
if in.Type() != out.Type() {
|
||||
// Explicit test prior to mergeStruct so that mistyped nils will fail
|
||||
panic("proto: type mismatch")
|
||||
}
|
||||
if in.IsNil() {
|
||||
// Merging nil into non-nil is a quiet no-op
|
||||
return
|
||||
}
|
||||
mergeStruct(out.Elem(), in.Elem())
|
||||
}
|
||||
|
||||
func mergeStruct(out, in reflect.Value) {
|
||||
sprop := GetProperties(in.Type())
|
||||
for i := 0; i < in.NumField(); i++ {
|
||||
f := in.Type().Field(i)
|
||||
if strings.HasPrefix(f.Name, "XXX_") {
|
||||
continue
|
||||
}
|
||||
mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i])
|
||||
}
|
||||
|
||||
if emIn, ok := extendable(in.Addr().Interface()); ok {
|
||||
emOut, _ := extendable(out.Addr().Interface())
|
||||
mIn, muIn := emIn.extensionsRead()
|
||||
if mIn != nil {
|
||||
mOut := emOut.extensionsWrite()
|
||||
muIn.Lock()
|
||||
mergeExtension(mOut, mIn)
|
||||
muIn.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
uf := in.FieldByName("XXX_unrecognized")
|
||||
if !uf.IsValid() {
|
||||
return
|
||||
}
|
||||
uin := uf.Bytes()
|
||||
if len(uin) > 0 {
|
||||
out.FieldByName("XXX_unrecognized").SetBytes(append([]byte(nil), uin...))
|
||||
}
|
||||
}
|
||||
|
||||
// mergeAny performs a merge between two values of the same type.
|
||||
// viaPtr indicates whether the values were indirected through a pointer (implying proto2).
|
||||
// prop is set if this is a struct field (it may be nil).
|
||||
func mergeAny(out, in reflect.Value, viaPtr bool, prop *Properties) {
|
||||
if in.Type() == protoMessageType {
|
||||
if !in.IsNil() {
|
||||
if out.IsNil() {
|
||||
out.Set(reflect.ValueOf(Clone(in.Interface().(Message))))
|
||||
} else {
|
||||
Merge(out.Interface().(Message), in.Interface().(Message))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
switch in.Kind() {
|
||||
case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64,
|
||||
reflect.String, reflect.Uint32, reflect.Uint64:
|
||||
if !viaPtr && isProto3Zero(in) {
|
||||
return
|
||||
}
|
||||
out.Set(in)
|
||||
case reflect.Interface:
|
||||
// Probably a oneof field; copy non-nil values.
|
||||
if in.IsNil() {
|
||||
return
|
||||
}
|
||||
// Allocate destination if it is not set, or set to a different type.
|
||||
// Otherwise we will merge as normal.
|
||||
if out.IsNil() || out.Elem().Type() != in.Elem().Type() {
|
||||
out.Set(reflect.New(in.Elem().Elem().Type())) // interface -> *T -> T -> new(T)
|
||||
}
|
||||
mergeAny(out.Elem(), in.Elem(), false, nil)
|
||||
case reflect.Map:
|
||||
if in.Len() == 0 {
|
||||
return
|
||||
}
|
||||
if out.IsNil() {
|
||||
out.Set(reflect.MakeMap(in.Type()))
|
||||
}
|
||||
// For maps with value types of *T or []byte we need to deep copy each value.
|
||||
elemKind := in.Type().Elem().Kind()
|
||||
for _, key := range in.MapKeys() {
|
||||
var val reflect.Value
|
||||
switch elemKind {
|
||||
case reflect.Ptr:
|
||||
val = reflect.New(in.Type().Elem().Elem())
|
||||
mergeAny(val, in.MapIndex(key), false, nil)
|
||||
case reflect.Slice:
|
||||
val = in.MapIndex(key)
|
||||
val = reflect.ValueOf(append([]byte{}, val.Bytes()...))
|
||||
default:
|
||||
val = in.MapIndex(key)
|
||||
}
|
||||
out.SetMapIndex(key, val)
|
||||
}
|
||||
case reflect.Ptr:
|
||||
if in.IsNil() {
|
||||
return
|
||||
}
|
||||
if out.IsNil() {
|
||||
out.Set(reflect.New(in.Elem().Type()))
|
||||
}
|
||||
mergeAny(out.Elem(), in.Elem(), true, nil)
|
||||
case reflect.Slice:
|
||||
if in.IsNil() {
|
||||
return
|
||||
}
|
||||
if in.Type().Elem().Kind() == reflect.Uint8 {
|
||||
// []byte is a scalar bytes field, not a repeated field.
|
||||
|
||||
// Edge case: if this is in a proto3 message, a zero length
|
||||
// bytes field is considered the zero value, and should not
|
||||
// be merged.
|
||||
if prop != nil && prop.proto3 && in.Len() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Make a deep copy.
|
||||
// Append to []byte{} instead of []byte(nil) so that we never end up
|
||||
// with a nil result.
|
||||
out.SetBytes(append([]byte{}, in.Bytes()...))
|
||||
return
|
||||
}
|
||||
n := in.Len()
|
||||
if out.IsNil() {
|
||||
out.Set(reflect.MakeSlice(in.Type(), 0, n))
|
||||
}
|
||||
switch in.Type().Elem().Kind() {
|
||||
case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64,
|
||||
reflect.String, reflect.Uint32, reflect.Uint64:
|
||||
out.Set(reflect.AppendSlice(out, in))
|
||||
default:
|
||||
for i := 0; i < n; i++ {
|
||||
x := reflect.Indirect(reflect.New(in.Type().Elem()))
|
||||
mergeAny(x, in.Index(i), false, nil)
|
||||
out.Set(reflect.Append(out, x))
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
mergeStruct(out, in)
|
||||
default:
|
||||
// unknown type, so not a protocol buffer
|
||||
log.Printf("proto: don't know how to copy %v", in)
|
||||
}
|
||||
}
|
||||
|
||||
func mergeExtension(out, in map[int32]Extension) {
|
||||
for extNum, eIn := range in {
|
||||
eOut := Extension{desc: eIn.desc}
|
||||
if eIn.value != nil {
|
||||
v := reflect.New(reflect.TypeOf(eIn.value)).Elem()
|
||||
mergeAny(v, reflect.ValueOf(eIn.value), false, nil)
|
||||
eOut.value = v.Interface()
|
||||
}
|
||||
if eIn.enc != nil {
|
||||
eOut.enc = make([]byte, len(eIn.enc))
|
||||
copy(eOut.enc, eIn.enc)
|
||||
}
|
||||
|
||||
out[extNum] = eOut
|
||||
}
|
||||
}
|
|
@ -0,0 +1,970 @@
|
|||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package proto
|
||||
|
||||
/*
|
||||
* Routines for decoding protocol buffer data to construct in-memory representations.
|
||||
*/
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// errOverflow is returned when an integer is too large to be represented.
|
||||
var errOverflow = errors.New("proto: integer overflow")
|
||||
|
||||
// ErrInternalBadWireType is returned by generated code when an incorrect
|
||||
// wire type is encountered. It does not get returned to user code.
|
||||
var ErrInternalBadWireType = errors.New("proto: internal error: bad wiretype for oneof")
|
||||
|
||||
// The fundamental decoders that interpret bytes on the wire.
|
||||
// Those that take integer types all return uint64 and are
|
||||
// therefore of type valueDecoder.
|
||||
|
||||
// DecodeVarint reads a varint-encoded integer from the slice.
|
||||
// It returns the integer and the number of bytes consumed, or
|
||||
// zero if there is not enough.
|
||||
// This is the format for the
|
||||
// int32, int64, uint32, uint64, bool, and enum
|
||||
// protocol buffer types.
|
||||
func DecodeVarint(buf []byte) (x uint64, n int) {
|
||||
for shift := uint(0); shift < 64; shift += 7 {
|
||||
if n >= len(buf) {
|
||||
return 0, 0
|
||||
}
|
||||
b := uint64(buf[n])
|
||||
n++
|
||||
x |= (b & 0x7F) << shift
|
||||
if (b & 0x80) == 0 {
|
||||
return x, n
|
||||
}
|
||||
}
|
||||
|
||||
// The number is too large to represent in a 64-bit value.
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
func (p *Buffer) decodeVarintSlow() (x uint64, err error) {
|
||||
i := p.index
|
||||
l := len(p.buf)
|
||||
|
||||
for shift := uint(0); shift < 64; shift += 7 {
|
||||
if i >= l {
|
||||
err = io.ErrUnexpectedEOF
|
||||
return
|
||||
}
|
||||
b := p.buf[i]
|
||||
i++
|
||||
x |= (uint64(b) & 0x7F) << shift
|
||||
if b < 0x80 {
|
||||
p.index = i
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// The number is too large to represent in a 64-bit value.
|
||||
err = errOverflow
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeVarint reads a varint-encoded integer from the Buffer.
|
||||
// This is the format for the
|
||||
// int32, int64, uint32, uint64, bool, and enum
|
||||
// protocol buffer types.
|
||||
func (p *Buffer) DecodeVarint() (x uint64, err error) {
|
||||
i := p.index
|
||||
buf := p.buf
|
||||
|
||||
if i >= len(buf) {
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
} else if buf[i] < 0x80 {
|
||||
p.index++
|
||||
return uint64(buf[i]), nil
|
||||
} else if len(buf)-i < 10 {
|
||||
return p.decodeVarintSlow()
|
||||
}
|
||||
|
||||
var b uint64
|
||||
// we already checked the first byte
|
||||
x = uint64(buf[i]) - 0x80
|
||||
i++
|
||||
|
||||
b = uint64(buf[i])
|
||||
i++
|
||||
x += b << 7
|
||||
if b&0x80 == 0 {
|
||||
goto done
|
||||
}
|
||||
x -= 0x80 << 7
|
||||
|
||||
b = uint64(buf[i])
|
||||
i++
|
||||
x += b << 14
|
||||
if b&0x80 == 0 {
|
||||
goto done
|
||||
}
|
||||
x -= 0x80 << 14
|
||||
|
||||
b = uint64(buf[i])
|
||||
i++
|
||||
x += b << 21
|
||||
if b&0x80 == 0 {
|
||||
goto done
|
||||
}
|
||||
x -= 0x80 << 21
|
||||
|
||||
b = uint64(buf[i])
|
||||
i++
|
||||
x += b << 28
|
||||
if b&0x80 == 0 {
|
||||
goto done
|
||||
}
|
||||
x -= 0x80 << 28
|
||||
|
||||
b = uint64(buf[i])
|
||||
i++
|
||||
x += b << 35
|
||||
if b&0x80 == 0 {
|
||||
goto done
|
||||
}
|
||||
x -= 0x80 << 35
|
||||
|
||||
b = uint64(buf[i])
|
||||
i++
|
||||
x += b << 42
|
||||
if b&0x80 == 0 {
|
||||
goto done
|
||||
}
|
||||
x -= 0x80 << 42
|
||||
|
||||
b = uint64(buf[i])
|
||||
i++
|
||||
x += b << 49
|
||||
if b&0x80 == 0 {
|
||||
goto done
|
||||
}
|
||||
x -= 0x80 << 49
|
||||
|
||||
b = uint64(buf[i])
|
||||
i++
|
||||
x += b << 56
|
||||
if b&0x80 == 0 {
|
||||
goto done
|
||||
}
|
||||
x -= 0x80 << 56
|
||||
|
||||
b = uint64(buf[i])
|
||||
i++
|
||||
x += b << 63
|
||||
if b&0x80 == 0 {
|
||||
goto done
|
||||
}
|
||||
// x -= 0x80 << 63 // Always zero.
|
||||
|
||||
return 0, errOverflow
|
||||
|
||||
done:
|
||||
p.index = i
|
||||
return x, nil
|
||||
}
|
||||
|
||||
// DecodeFixed64 reads a 64-bit integer from the Buffer.
|
||||
// This is the format for the
|
||||
// fixed64, sfixed64, and double protocol buffer types.
|
||||
func (p *Buffer) DecodeFixed64() (x uint64, err error) {
|
||||
// x, err already 0
|
||||
i := p.index + 8
|
||||
if i < 0 || i > len(p.buf) {
|
||||
err = io.ErrUnexpectedEOF
|
||||
return
|
||||
}
|
||||
p.index = i
|
||||
|
||||
x = uint64(p.buf[i-8])
|
||||
x |= uint64(p.buf[i-7]) << 8
|
||||
x |= uint64(p.buf[i-6]) << 16
|
||||
x |= uint64(p.buf[i-5]) << 24
|
||||
x |= uint64(p.buf[i-4]) << 32
|
||||
x |= uint64(p.buf[i-3]) << 40
|
||||
x |= uint64(p.buf[i-2]) << 48
|
||||
x |= uint64(p.buf[i-1]) << 56
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeFixed32 reads a 32-bit integer from the Buffer.
|
||||
// This is the format for the
|
||||
// fixed32, sfixed32, and float protocol buffer types.
|
||||
func (p *Buffer) DecodeFixed32() (x uint64, err error) {
|
||||
// x, err already 0
|
||||
i := p.index + 4
|
||||
if i < 0 || i > len(p.buf) {
|
||||
err = io.ErrUnexpectedEOF
|
||||
return
|
||||
}
|
||||
p.index = i
|
||||
|
||||
x = uint64(p.buf[i-4])
|
||||
x |= uint64(p.buf[i-3]) << 8
|
||||
x |= uint64(p.buf[i-2]) << 16
|
||||
x |= uint64(p.buf[i-1]) << 24
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeZigzag64 reads a zigzag-encoded 64-bit integer
|
||||
// from the Buffer.
|
||||
// This is the format used for the sint64 protocol buffer type.
|
||||
func (p *Buffer) DecodeZigzag64() (x uint64, err error) {
|
||||
x, err = p.DecodeVarint()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
x = (x >> 1) ^ uint64((int64(x&1)<<63)>>63)
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeZigzag32 reads a zigzag-encoded 32-bit integer
|
||||
// from the Buffer.
|
||||
// This is the format used for the sint32 protocol buffer type.
|
||||
func (p *Buffer) DecodeZigzag32() (x uint64, err error) {
|
||||
x, err = p.DecodeVarint()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
x = uint64((uint32(x) >> 1) ^ uint32((int32(x&1)<<31)>>31))
|
||||
return
|
||||
}
|
||||
|
||||
// These are not ValueDecoders: they produce an array of bytes or a string.
|
||||
// bytes, embedded messages
|
||||
|
||||
// DecodeRawBytes reads a count-delimited byte buffer from the Buffer.
|
||||
// This is the format used for the bytes protocol buffer
|
||||
// type and for embedded messages.
|
||||
func (p *Buffer) DecodeRawBytes(alloc bool) (buf []byte, err error) {
|
||||
n, err := p.DecodeVarint()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nb := int(n)
|
||||
if nb < 0 {
|
||||
return nil, fmt.Errorf("proto: bad byte length %d", nb)
|
||||
}
|
||||
end := p.index + nb
|
||||
if end < p.index || end > len(p.buf) {
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
if !alloc {
|
||||
// todo: check if can get more uses of alloc=false
|
||||
buf = p.buf[p.index:end]
|
||||
p.index += nb
|
||||
return
|
||||
}
|
||||
|
||||
buf = make([]byte, nb)
|
||||
copy(buf, p.buf[p.index:])
|
||||
p.index += nb
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeStringBytes reads an encoded string from the Buffer.
|
||||
// This is the format used for the proto2 string type.
|
||||
func (p *Buffer) DecodeStringBytes() (s string, err error) {
|
||||
buf, err := p.DecodeRawBytes(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
// Skip the next item in the buffer. Its wire type is decoded and presented as an argument.
|
||||
// If the protocol buffer has extensions, and the field matches, add it as an extension.
|
||||
// Otherwise, if the XXX_unrecognized field exists, append the skipped data there.
|
||||
func (o *Buffer) skipAndSave(t reflect.Type, tag, wire int, base structPointer, unrecField field) error {
|
||||
oi := o.index
|
||||
|
||||
err := o.skip(t, tag, wire)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !unrecField.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
ptr := structPointer_Bytes(base, unrecField)
|
||||
|
||||
// Add the skipped field to struct field
|
||||
obuf := o.buf
|
||||
|
||||
o.buf = *ptr
|
||||
o.EncodeVarint(uint64(tag<<3 | wire))
|
||||
*ptr = append(o.buf, obuf[oi:o.index]...)
|
||||
|
||||
o.buf = obuf
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip the next item in the buffer. Its wire type is decoded and presented as an argument.
|
||||
func (o *Buffer) skip(t reflect.Type, tag, wire int) error {
|
||||
|
||||
var u uint64
|
||||
var err error
|
||||
|
||||
switch wire {
|
||||
case WireVarint:
|
||||
_, err = o.DecodeVarint()
|
||||
case WireFixed64:
|
||||
_, err = o.DecodeFixed64()
|
||||
case WireBytes:
|
||||
_, err = o.DecodeRawBytes(false)
|
||||
case WireFixed32:
|
||||
_, err = o.DecodeFixed32()
|
||||
case WireStartGroup:
|
||||
for {
|
||||
u, err = o.DecodeVarint()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
fwire := int(u & 0x7)
|
||||
if fwire == WireEndGroup {
|
||||
break
|
||||
}
|
||||
ftag := int(u >> 3)
|
||||
err = o.skip(t, ftag, fwire)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("proto: can't skip unknown wire type %d for %s", wire, t)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Unmarshaler is the interface representing objects that can
|
||||
// unmarshal themselves. The method should reset the receiver before
|
||||
// decoding starts. The argument points to data that may be
|
||||
// overwritten, so implementations should not keep references to the
|
||||
// buffer.
|
||||
type Unmarshaler interface {
|
||||
Unmarshal([]byte) error
|
||||
}
|
||||
|
||||
// Unmarshal parses the protocol buffer representation in buf and places the
|
||||
// decoded result in pb. If the struct underlying pb does not match
|
||||
// the data in buf, the results can be unpredictable.
|
||||
//
|
||||
// Unmarshal resets pb before starting to unmarshal, so any
|
||||
// existing data in pb is always removed. Use UnmarshalMerge
|
||||
// to preserve and append to existing data.
|
||||
func Unmarshal(buf []byte, pb Message) error {
|
||||
pb.Reset()
|
||||
return UnmarshalMerge(buf, pb)
|
||||
}
|
||||
|
||||
// UnmarshalMerge parses the protocol buffer representation in buf and
|
||||
// writes the decoded result to pb. If the struct underlying pb does not match
|
||||
// the data in buf, the results can be unpredictable.
|
||||
//
|
||||
// UnmarshalMerge merges into existing data in pb.
|
||||
// Most code should use Unmarshal instead.
|
||||
func UnmarshalMerge(buf []byte, pb Message) error {
|
||||
// If the object can unmarshal itself, let it.
|
||||
if u, ok := pb.(Unmarshaler); ok {
|
||||
return u.Unmarshal(buf)
|
||||
}
|
||||
return NewBuffer(buf).Unmarshal(pb)
|
||||
}
|
||||
|
||||
// DecodeMessage reads a count-delimited message from the Buffer.
|
||||
func (p *Buffer) DecodeMessage(pb Message) error {
|
||||
enc, err := p.DecodeRawBytes(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return NewBuffer(enc).Unmarshal(pb)
|
||||
}
|
||||
|
||||
// DecodeGroup reads a tag-delimited group from the Buffer.
|
||||
func (p *Buffer) DecodeGroup(pb Message) error {
|
||||
typ, base, err := getbase(pb)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return p.unmarshalType(typ.Elem(), GetProperties(typ.Elem()), true, base)
|
||||
}
|
||||
|
||||
// Unmarshal parses the protocol buffer representation in the
|
||||
// Buffer and places the decoded result in pb. If the struct
|
||||
// underlying pb does not match the data in the buffer, the results can be
|
||||
// unpredictable.
|
||||
//
|
||||
// Unlike proto.Unmarshal, this does not reset pb before starting to unmarshal.
|
||||
func (p *Buffer) Unmarshal(pb Message) error {
|
||||
// If the object can unmarshal itself, let it.
|
||||
if u, ok := pb.(Unmarshaler); ok {
|
||||
err := u.Unmarshal(p.buf[p.index:])
|
||||
p.index = len(p.buf)
|
||||
return err
|
||||
}
|
||||
|
||||
typ, base, err := getbase(pb)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = p.unmarshalType(typ.Elem(), GetProperties(typ.Elem()), false, base)
|
||||
|
||||
if collectStats {
|
||||
stats.Decode++
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// unmarshalType does the work of unmarshaling a structure.
|
||||
func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group bool, base structPointer) error {
|
||||
var state errorState
|
||||
required, reqFields := prop.reqCount, uint64(0)
|
||||
|
||||
var err error
|
||||
for err == nil && o.index < len(o.buf) {
|
||||
oi := o.index
|
||||
var u uint64
|
||||
u, err = o.DecodeVarint()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
wire := int(u & 0x7)
|
||||
if wire == WireEndGroup {
|
||||
if is_group {
|
||||
if required > 0 {
|
||||
// Not enough information to determine the exact field.
|
||||
// (See below.)
|
||||
return &RequiredNotSetError{"{Unknown}"}
|
||||
}
|
||||
return nil // input is satisfied
|
||||
}
|
||||
return fmt.Errorf("proto: %s: wiretype end group for non-group", st)
|
||||
}
|
||||
tag := int(u >> 3)
|
||||
if tag <= 0 {
|
||||
return fmt.Errorf("proto: %s: illegal tag %d (wire type %d)", st, tag, wire)
|
||||
}
|
||||
fieldnum, ok := prop.decoderTags.get(tag)
|
||||
if !ok {
|
||||
// Maybe it's an extension?
|
||||
if prop.extendable {
|
||||
if e, _ := extendable(structPointer_Interface(base, st)); isExtensionField(e, int32(tag)) {
|
||||
if err = o.skip(st, tag, wire); err == nil {
|
||||
extmap := e.extensionsWrite()
|
||||
ext := extmap[int32(tag)] // may be missing
|
||||
ext.enc = append(ext.enc, o.buf[oi:o.index]...)
|
||||
extmap[int32(tag)] = ext
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Maybe it's a oneof?
|
||||
if prop.oneofUnmarshaler != nil {
|
||||
m := structPointer_Interface(base, st).(Message)
|
||||
// First return value indicates whether tag is a oneof field.
|
||||
ok, err = prop.oneofUnmarshaler(m, tag, wire, o)
|
||||
if err == ErrInternalBadWireType {
|
||||
// Map the error to something more descriptive.
|
||||
// Do the formatting here to save generated code space.
|
||||
err = fmt.Errorf("bad wiretype for oneof field in %T", m)
|
||||
}
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
err = o.skipAndSave(st, tag, wire, base, prop.unrecField)
|
||||
continue
|
||||
}
|
||||
p := prop.Prop[fieldnum]
|
||||
|
||||
if p.dec == nil {
|
||||
fmt.Fprintf(os.Stderr, "proto: no protobuf decoder for %s.%s\n", st, st.Field(fieldnum).Name)
|
||||
continue
|
||||
}
|
||||
dec := p.dec
|
||||
if wire != WireStartGroup && wire != p.WireType {
|
||||
if wire == WireBytes && p.packedDec != nil {
|
||||
// a packable field
|
||||
dec = p.packedDec
|
||||
} else {
|
||||
err = fmt.Errorf("proto: bad wiretype for field %s.%s: got wiretype %d, want %d", st, st.Field(fieldnum).Name, wire, p.WireType)
|
||||
continue
|
||||
}
|
||||
}
|
||||
decErr := dec(o, p, base)
|
||||
if decErr != nil && !state.shouldContinue(decErr, p) {
|
||||
err = decErr
|
||||
}
|
||||
if err == nil && p.Required {
|
||||
// Successfully decoded a required field.
|
||||
if tag <= 64 {
|
||||
// use bitmap for fields 1-64 to catch field reuse.
|
||||
var mask uint64 = 1 << uint64(tag-1)
|
||||
if reqFields&mask == 0 {
|
||||
// new required field
|
||||
reqFields |= mask
|
||||
required--
|
||||
}
|
||||
} else {
|
||||
// This is imprecise. It can be fooled by a required field
|
||||
// with a tag > 64 that is encoded twice; that's very rare.
|
||||
// A fully correct implementation would require allocating
|
||||
// a data structure, which we would like to avoid.
|
||||
required--
|
||||
}
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
if is_group {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
if state.err != nil {
|
||||
return state.err
|
||||
}
|
||||
if required > 0 {
|
||||
// Not enough information to determine the exact field. If we use extra
|
||||
// CPU, we could determine the field only if the missing required field
|
||||
// has a tag <= 64 and we check reqFields.
|
||||
return &RequiredNotSetError{"{Unknown}"}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Individual type decoders
|
||||
// For each,
|
||||
// u is the decoded value,
|
||||
// v is a pointer to the field (pointer) in the struct
|
||||
|
||||
// Sizes of the pools to allocate inside the Buffer.
|
||||
// The goal is modest amortization and allocation
|
||||
// on at least 16-byte boundaries.
|
||||
const (
|
||||
boolPoolSize = 16
|
||||
uint32PoolSize = 8
|
||||
uint64PoolSize = 4
|
||||
)
|
||||
|
||||
// Decode a bool.
|
||||
func (o *Buffer) dec_bool(p *Properties, base structPointer) error {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(o.bools) == 0 {
|
||||
o.bools = make([]bool, boolPoolSize)
|
||||
}
|
||||
o.bools[0] = u != 0
|
||||
*structPointer_Bool(base, p.field) = &o.bools[0]
|
||||
o.bools = o.bools[1:]
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Buffer) dec_proto3_bool(p *Properties, base structPointer) error {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*structPointer_BoolVal(base, p.field) = u != 0
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode an int32.
|
||||
func (o *Buffer) dec_int32(p *Properties, base structPointer) error {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
word32_Set(structPointer_Word32(base, p.field), o, uint32(u))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Buffer) dec_proto3_int32(p *Properties, base structPointer) error {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
word32Val_Set(structPointer_Word32Val(base, p.field), uint32(u))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode an int64.
|
||||
func (o *Buffer) dec_int64(p *Properties, base structPointer) error {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
word64_Set(structPointer_Word64(base, p.field), o, u)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Buffer) dec_proto3_int64(p *Properties, base structPointer) error {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
word64Val_Set(structPointer_Word64Val(base, p.field), o, u)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a string.
|
||||
func (o *Buffer) dec_string(p *Properties, base structPointer) error {
|
||||
s, err := o.DecodeStringBytes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*structPointer_String(base, p.field) = &s
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Buffer) dec_proto3_string(p *Properties, base structPointer) error {
|
||||
s, err := o.DecodeStringBytes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*structPointer_StringVal(base, p.field) = s
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a slice of bytes ([]byte).
|
||||
func (o *Buffer) dec_slice_byte(p *Properties, base structPointer) error {
|
||||
b, err := o.DecodeRawBytes(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*structPointer_Bytes(base, p.field) = b
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a slice of bools ([]bool).
|
||||
func (o *Buffer) dec_slice_bool(p *Properties, base structPointer) error {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v := structPointer_BoolSlice(base, p.field)
|
||||
*v = append(*v, u != 0)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a slice of bools ([]bool) in packed format.
|
||||
func (o *Buffer) dec_slice_packed_bool(p *Properties, base structPointer) error {
|
||||
v := structPointer_BoolSlice(base, p.field)
|
||||
|
||||
nn, err := o.DecodeVarint()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nb := int(nn) // number of bytes of encoded bools
|
||||
fin := o.index + nb
|
||||
if fin < o.index {
|
||||
return errOverflow
|
||||
}
|
||||
|
||||
y := *v
|
||||
for o.index < fin {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
y = append(y, u != 0)
|
||||
}
|
||||
|
||||
*v = y
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a slice of int32s ([]int32).
|
||||
func (o *Buffer) dec_slice_int32(p *Properties, base structPointer) error {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
structPointer_Word32Slice(base, p.field).Append(uint32(u))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a slice of int32s ([]int32) in packed format.
|
||||
func (o *Buffer) dec_slice_packed_int32(p *Properties, base structPointer) error {
|
||||
v := structPointer_Word32Slice(base, p.field)
|
||||
|
||||
nn, err := o.DecodeVarint()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nb := int(nn) // number of bytes of encoded int32s
|
||||
|
||||
fin := o.index + nb
|
||||
if fin < o.index {
|
||||
return errOverflow
|
||||
}
|
||||
for o.index < fin {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v.Append(uint32(u))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a slice of int64s ([]int64).
|
||||
func (o *Buffer) dec_slice_int64(p *Properties, base structPointer) error {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
structPointer_Word64Slice(base, p.field).Append(u)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a slice of int64s ([]int64) in packed format.
|
||||
func (o *Buffer) dec_slice_packed_int64(p *Properties, base structPointer) error {
|
||||
v := structPointer_Word64Slice(base, p.field)
|
||||
|
||||
nn, err := o.DecodeVarint()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nb := int(nn) // number of bytes of encoded int64s
|
||||
|
||||
fin := o.index + nb
|
||||
if fin < o.index {
|
||||
return errOverflow
|
||||
}
|
||||
for o.index < fin {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v.Append(u)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a slice of strings ([]string).
|
||||
func (o *Buffer) dec_slice_string(p *Properties, base structPointer) error {
|
||||
s, err := o.DecodeStringBytes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v := structPointer_StringSlice(base, p.field)
|
||||
*v = append(*v, s)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a slice of slice of bytes ([][]byte).
|
||||
func (o *Buffer) dec_slice_slice_byte(p *Properties, base structPointer) error {
|
||||
b, err := o.DecodeRawBytes(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v := structPointer_BytesSlice(base, p.field)
|
||||
*v = append(*v, b)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a map field.
|
||||
func (o *Buffer) dec_new_map(p *Properties, base structPointer) error {
|
||||
raw, err := o.DecodeRawBytes(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
oi := o.index // index at the end of this map entry
|
||||
o.index -= len(raw) // move buffer back to start of map entry
|
||||
|
||||
mptr := structPointer_NewAt(base, p.field, p.mtype) // *map[K]V
|
||||
if mptr.Elem().IsNil() {
|
||||
mptr.Elem().Set(reflect.MakeMap(mptr.Type().Elem()))
|
||||
}
|
||||
v := mptr.Elem() // map[K]V
|
||||
|
||||
// Prepare addressable doubly-indirect placeholders for the key and value types.
|
||||
// See enc_new_map for why.
|
||||
keyptr := reflect.New(reflect.PtrTo(p.mtype.Key())).Elem() // addressable *K
|
||||
keybase := toStructPointer(keyptr.Addr()) // **K
|
||||
|
||||
var valbase structPointer
|
||||
var valptr reflect.Value
|
||||
switch p.mtype.Elem().Kind() {
|
||||
case reflect.Slice:
|
||||
// []byte
|
||||
var dummy []byte
|
||||
valptr = reflect.ValueOf(&dummy) // *[]byte
|
||||
valbase = toStructPointer(valptr) // *[]byte
|
||||
case reflect.Ptr:
|
||||
// message; valptr is **Msg; need to allocate the intermediate pointer
|
||||
valptr = reflect.New(reflect.PtrTo(p.mtype.Elem())).Elem() // addressable *V
|
||||
valptr.Set(reflect.New(valptr.Type().Elem()))
|
||||
valbase = toStructPointer(valptr)
|
||||
default:
|
||||
// everything else
|
||||
valptr = reflect.New(reflect.PtrTo(p.mtype.Elem())).Elem() // addressable *V
|
||||
valbase = toStructPointer(valptr.Addr()) // **V
|
||||
}
|
||||
|
||||
// Decode.
|
||||
// This parses a restricted wire format, namely the encoding of a message
|
||||
// with two fields. See enc_new_map for the format.
|
||||
for o.index < oi {
|
||||
// tagcode for key and value properties are always a single byte
|
||||
// because they have tags 1 and 2.
|
||||
tagcode := o.buf[o.index]
|
||||
o.index++
|
||||
switch tagcode {
|
||||
case p.mkeyprop.tagcode[0]:
|
||||
if err := p.mkeyprop.dec(o, p.mkeyprop, keybase); err != nil {
|
||||
return err
|
||||
}
|
||||
case p.mvalprop.tagcode[0]:
|
||||
if err := p.mvalprop.dec(o, p.mvalprop, valbase); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
// TODO: Should we silently skip this instead?
|
||||
return fmt.Errorf("proto: bad map data tag %d", raw[0])
|
||||
}
|
||||
}
|
||||
keyelem, valelem := keyptr.Elem(), valptr.Elem()
|
||||
if !keyelem.IsValid() {
|
||||
keyelem = reflect.Zero(p.mtype.Key())
|
||||
}
|
||||
if !valelem.IsValid() {
|
||||
valelem = reflect.Zero(p.mtype.Elem())
|
||||
}
|
||||
|
||||
v.SetMapIndex(keyelem, valelem)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a group.
|
||||
func (o *Buffer) dec_struct_group(p *Properties, base structPointer) error {
|
||||
bas := structPointer_GetStructPointer(base, p.field)
|
||||
if structPointer_IsNil(bas) {
|
||||
// allocate new nested message
|
||||
bas = toStructPointer(reflect.New(p.stype))
|
||||
structPointer_SetStructPointer(base, p.field, bas)
|
||||
}
|
||||
return o.unmarshalType(p.stype, p.sprop, true, bas)
|
||||
}
|
||||
|
||||
// Decode an embedded message.
|
||||
func (o *Buffer) dec_struct_message(p *Properties, base structPointer) (err error) {
|
||||
raw, e := o.DecodeRawBytes(false)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
bas := structPointer_GetStructPointer(base, p.field)
|
||||
if structPointer_IsNil(bas) {
|
||||
// allocate new nested message
|
||||
bas = toStructPointer(reflect.New(p.stype))
|
||||
structPointer_SetStructPointer(base, p.field, bas)
|
||||
}
|
||||
|
||||
// If the object can unmarshal itself, let it.
|
||||
if p.isUnmarshaler {
|
||||
iv := structPointer_Interface(bas, p.stype)
|
||||
return iv.(Unmarshaler).Unmarshal(raw)
|
||||
}
|
||||
|
||||
obuf := o.buf
|
||||
oi := o.index
|
||||
o.buf = raw
|
||||
o.index = 0
|
||||
|
||||
err = o.unmarshalType(p.stype, p.sprop, false, bas)
|
||||
o.buf = obuf
|
||||
o.index = oi
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Decode a slice of embedded messages.
|
||||
func (o *Buffer) dec_slice_struct_message(p *Properties, base structPointer) error {
|
||||
return o.dec_slice_struct(p, false, base)
|
||||
}
|
||||
|
||||
// Decode a slice of embedded groups.
|
||||
func (o *Buffer) dec_slice_struct_group(p *Properties, base structPointer) error {
|
||||
return o.dec_slice_struct(p, true, base)
|
||||
}
|
||||
|
||||
// Decode a slice of structs ([]*struct).
|
||||
func (o *Buffer) dec_slice_struct(p *Properties, is_group bool, base structPointer) error {
|
||||
v := reflect.New(p.stype)
|
||||
bas := toStructPointer(v)
|
||||
structPointer_StructPointerSlice(base, p.field).Append(bas)
|
||||
|
||||
if is_group {
|
||||
err := o.unmarshalType(p.stype, p.sprop, is_group, bas)
|
||||
return err
|
||||
}
|
||||
|
||||
raw, err := o.DecodeRawBytes(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the object can unmarshal itself, let it.
|
||||
if p.isUnmarshaler {
|
||||
iv := v.Interface()
|
||||
return iv.(Unmarshaler).Unmarshal(raw)
|
||||
}
|
||||
|
||||
obuf := o.buf
|
||||
oi := o.index
|
||||
o.buf = raw
|
||||
o.index = 0
|
||||
|
||||
err = o.unmarshalType(p.stype, p.sprop, is_group, bas)
|
||||
|
||||
o.buf = obuf
|
||||
o.index = oi
|
||||
|
||||
return err
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,300 @@
|
|||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
// Protocol buffer comparison.
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
/*
|
||||
Equal returns true iff protocol buffers a and b are equal.
|
||||
The arguments must both be pointers to protocol buffer structs.
|
||||
|
||||
Equality is defined in this way:
|
||||
- Two messages are equal iff they are the same type,
|
||||
corresponding fields are equal, unknown field sets
|
||||
are equal, and extensions sets are equal.
|
||||
- Two set scalar fields are equal iff their values are equal.
|
||||
If the fields are of a floating-point type, remember that
|
||||
NaN != x for all x, including NaN. If the message is defined
|
||||
in a proto3 .proto file, fields are not "set"; specifically,
|
||||
zero length proto3 "bytes" fields are equal (nil == {}).
|
||||
- Two repeated fields are equal iff their lengths are the same,
|
||||
and their corresponding elements are equal. Note a "bytes" field,
|
||||
although represented by []byte, is not a repeated field and the
|
||||
rule for the scalar fields described above applies.
|
||||
- Two unset fields are equal.
|
||||
- Two unknown field sets are equal if their current
|
||||
encoded state is equal.
|
||||
- Two extension sets are equal iff they have corresponding
|
||||
elements that are pairwise equal.
|
||||
- Two map fields are equal iff their lengths are the same,
|
||||
and they contain the same set of elements. Zero-length map
|
||||
fields are equal.
|
||||
- Every other combination of things are not equal.
|
||||
|
||||
The return value is undefined if a and b are not protocol buffers.
|
||||
*/
|
||||
func Equal(a, b Message) bool {
|
||||
if a == nil || b == nil {
|
||||
return a == b
|
||||
}
|
||||
v1, v2 := reflect.ValueOf(a), reflect.ValueOf(b)
|
||||
if v1.Type() != v2.Type() {
|
||||
return false
|
||||
}
|
||||
if v1.Kind() == reflect.Ptr {
|
||||
if v1.IsNil() {
|
||||
return v2.IsNil()
|
||||
}
|
||||
if v2.IsNil() {
|
||||
return false
|
||||
}
|
||||
v1, v2 = v1.Elem(), v2.Elem()
|
||||
}
|
||||
if v1.Kind() != reflect.Struct {
|
||||
return false
|
||||
}
|
||||
return equalStruct(v1, v2)
|
||||
}
|
||||
|
||||
// v1 and v2 are known to have the same type.
|
||||
func equalStruct(v1, v2 reflect.Value) bool {
|
||||
sprop := GetProperties(v1.Type())
|
||||
for i := 0; i < v1.NumField(); i++ {
|
||||
f := v1.Type().Field(i)
|
||||
if strings.HasPrefix(f.Name, "XXX_") {
|
||||
continue
|
||||
}
|
||||
f1, f2 := v1.Field(i), v2.Field(i)
|
||||
if f.Type.Kind() == reflect.Ptr {
|
||||
if n1, n2 := f1.IsNil(), f2.IsNil(); n1 && n2 {
|
||||
// both unset
|
||||
continue
|
||||
} else if n1 != n2 {
|
||||
// set/unset mismatch
|
||||
return false
|
||||
}
|
||||
b1, ok := f1.Interface().(raw)
|
||||
if ok {
|
||||
b2 := f2.Interface().(raw)
|
||||
// RawMessage
|
||||
if !bytes.Equal(b1.Bytes(), b2.Bytes()) {
|
||||
return false
|
||||
}
|
||||
continue
|
||||
}
|
||||
f1, f2 = f1.Elem(), f2.Elem()
|
||||
}
|
||||
if !equalAny(f1, f2, sprop.Prop[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if em1 := v1.FieldByName("XXX_InternalExtensions"); em1.IsValid() {
|
||||
em2 := v2.FieldByName("XXX_InternalExtensions")
|
||||
if !equalExtensions(v1.Type(), em1.Interface().(XXX_InternalExtensions), em2.Interface().(XXX_InternalExtensions)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() {
|
||||
em2 := v2.FieldByName("XXX_extensions")
|
||||
if !equalExtMap(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
uf := v1.FieldByName("XXX_unrecognized")
|
||||
if !uf.IsValid() {
|
||||
return true
|
||||
}
|
||||
|
||||
u1 := uf.Bytes()
|
||||
u2 := v2.FieldByName("XXX_unrecognized").Bytes()
|
||||
if !bytes.Equal(u1, u2) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// v1 and v2 are known to have the same type.
|
||||
// prop may be nil.
|
||||
func equalAny(v1, v2 reflect.Value, prop *Properties) bool {
|
||||
if v1.Type() == protoMessageType {
|
||||
m1, _ := v1.Interface().(Message)
|
||||
m2, _ := v2.Interface().(Message)
|
||||
return Equal(m1, m2)
|
||||
}
|
||||
switch v1.Kind() {
|
||||
case reflect.Bool:
|
||||
return v1.Bool() == v2.Bool()
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return v1.Float() == v2.Float()
|
||||
case reflect.Int32, reflect.Int64:
|
||||
return v1.Int() == v2.Int()
|
||||
case reflect.Interface:
|
||||
// Probably a oneof field; compare the inner values.
|
||||
n1, n2 := v1.IsNil(), v2.IsNil()
|
||||
if n1 || n2 {
|
||||
return n1 == n2
|
||||
}
|
||||
e1, e2 := v1.Elem(), v2.Elem()
|
||||
if e1.Type() != e2.Type() {
|
||||
return false
|
||||
}
|
||||
return equalAny(e1, e2, nil)
|
||||
case reflect.Map:
|
||||
if v1.Len() != v2.Len() {
|
||||
return false
|
||||
}
|
||||
for _, key := range v1.MapKeys() {
|
||||
val2 := v2.MapIndex(key)
|
||||
if !val2.IsValid() {
|
||||
// This key was not found in the second map.
|
||||
return false
|
||||
}
|
||||
if !equalAny(v1.MapIndex(key), val2, nil) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case reflect.Ptr:
|
||||
// Maps may have nil values in them, so check for nil.
|
||||
if v1.IsNil() && v2.IsNil() {
|
||||
return true
|
||||
}
|
||||
if v1.IsNil() != v2.IsNil() {
|
||||
return false
|
||||
}
|
||||
return equalAny(v1.Elem(), v2.Elem(), prop)
|
||||
case reflect.Slice:
|
||||
if v1.Type().Elem().Kind() == reflect.Uint8 {
|
||||
// short circuit: []byte
|
||||
|
||||
// Edge case: if this is in a proto3 message, a zero length
|
||||
// bytes field is considered the zero value.
|
||||
if prop != nil && prop.proto3 && v1.Len() == 0 && v2.Len() == 0 {
|
||||
return true
|
||||
}
|
||||
if v1.IsNil() != v2.IsNil() {
|
||||
return false
|
||||
}
|
||||
return bytes.Equal(v1.Interface().([]byte), v2.Interface().([]byte))
|
||||
}
|
||||
|
||||
if v1.Len() != v2.Len() {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < v1.Len(); i++ {
|
||||
if !equalAny(v1.Index(i), v2.Index(i), prop) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case reflect.String:
|
||||
return v1.Interface().(string) == v2.Interface().(string)
|
||||
case reflect.Struct:
|
||||
return equalStruct(v1, v2)
|
||||
case reflect.Uint32, reflect.Uint64:
|
||||
return v1.Uint() == v2.Uint()
|
||||
}
|
||||
|
||||
// unknown type, so not a protocol buffer
|
||||
log.Printf("proto: don't know how to compare %v", v1)
|
||||
return false
|
||||
}
|
||||
|
||||
// base is the struct type that the extensions are based on.
|
||||
// x1 and x2 are InternalExtensions.
|
||||
func equalExtensions(base reflect.Type, x1, x2 XXX_InternalExtensions) bool {
|
||||
em1, _ := x1.extensionsRead()
|
||||
em2, _ := x2.extensionsRead()
|
||||
return equalExtMap(base, em1, em2)
|
||||
}
|
||||
|
||||
func equalExtMap(base reflect.Type, em1, em2 map[int32]Extension) bool {
|
||||
if len(em1) != len(em2) {
|
||||
return false
|
||||
}
|
||||
|
||||
for extNum, e1 := range em1 {
|
||||
e2, ok := em2[extNum]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
m1, m2 := e1.value, e2.value
|
||||
|
||||
if m1 != nil && m2 != nil {
|
||||
// Both are unencoded.
|
||||
if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2), nil) {
|
||||
return false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// At least one is encoded. To do a semantically correct comparison
|
||||
// we need to unmarshal them first.
|
||||
var desc *ExtensionDesc
|
||||
if m := extensionMaps[base]; m != nil {
|
||||
desc = m[extNum]
|
||||
}
|
||||
if desc == nil {
|
||||
log.Printf("proto: don't know how to compare extension %d of %v", extNum, base)
|
||||
continue
|
||||
}
|
||||
var err error
|
||||
if m1 == nil {
|
||||
m1, err = decodeExtension(e1.enc, desc)
|
||||
}
|
||||
if m2 == nil && err == nil {
|
||||
m2, err = decodeExtension(e2.enc, desc)
|
||||
}
|
||||
if err != nil {
|
||||
// The encoded form is invalid.
|
||||
log.Printf("proto: badly encoded extension %d of %v: %v", extNum, base, err)
|
||||
return false
|
||||
}
|
||||
if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2), nil) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
|
@ -0,0 +1,587 @@
|
|||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package proto
|
||||
|
||||
/*
|
||||
* Types and routines for supporting protocol buffer extensions.
|
||||
*/
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message.
|
||||
var ErrMissingExtension = errors.New("proto: missing extension")
|
||||
|
||||
// ExtensionRange represents a range of message extensions for a protocol buffer.
|
||||
// Used in code generated by the protocol compiler.
|
||||
type ExtensionRange struct {
|
||||
Start, End int32 // both inclusive
|
||||
}
|
||||
|
||||
// extendableProto is an interface implemented by any protocol buffer generated by the current
|
||||
// proto compiler that may be extended.
|
||||
type extendableProto interface {
|
||||
Message
|
||||
ExtensionRangeArray() []ExtensionRange
|
||||
extensionsWrite() map[int32]Extension
|
||||
extensionsRead() (map[int32]Extension, sync.Locker)
|
||||
}
|
||||
|
||||
// extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous
|
||||
// version of the proto compiler that may be extended.
|
||||
type extendableProtoV1 interface {
|
||||
Message
|
||||
ExtensionRangeArray() []ExtensionRange
|
||||
ExtensionMap() map[int32]Extension
|
||||
}
|
||||
|
||||
// extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto.
|
||||
type extensionAdapter struct {
|
||||
extendableProtoV1
|
||||
}
|
||||
|
||||
func (e extensionAdapter) extensionsWrite() map[int32]Extension {
|
||||
return e.ExtensionMap()
|
||||
}
|
||||
|
||||
func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) {
|
||||
return e.ExtensionMap(), notLocker{}
|
||||
}
|
||||
|
||||
// notLocker is a sync.Locker whose Lock and Unlock methods are nops.
|
||||
type notLocker struct{}
|
||||
|
||||
func (n notLocker) Lock() {}
|
||||
func (n notLocker) Unlock() {}
|
||||
|
||||
// extendable returns the extendableProto interface for the given generated proto message.
|
||||
// If the proto message has the old extension format, it returns a wrapper that implements
|
||||
// the extendableProto interface.
|
||||
func extendable(p interface{}) (extendableProto, bool) {
|
||||
if ep, ok := p.(extendableProto); ok {
|
||||
return ep, ok
|
||||
}
|
||||
if ep, ok := p.(extendableProtoV1); ok {
|
||||
return extensionAdapter{ep}, ok
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// XXX_InternalExtensions is an internal representation of proto extensions.
|
||||
//
|
||||
// Each generated message struct type embeds an anonymous XXX_InternalExtensions field,
|
||||
// thus gaining the unexported 'extensions' method, which can be called only from the proto package.
|
||||
//
|
||||
// The methods of XXX_InternalExtensions are not concurrency safe in general,
|
||||
// but calls to logically read-only methods such as has and get may be executed concurrently.
|
||||
type XXX_InternalExtensions struct {
|
||||
// The struct must be indirect so that if a user inadvertently copies a
|
||||
// generated message and its embedded XXX_InternalExtensions, they
|
||||
// avoid the mayhem of a copied mutex.
|
||||
//
|
||||
// The mutex serializes all logically read-only operations to p.extensionMap.
|
||||
// It is up to the client to ensure that write operations to p.extensionMap are
|
||||
// mutually exclusive with other accesses.
|
||||
p *struct {
|
||||
mu sync.Mutex
|
||||
extensionMap map[int32]Extension
|
||||
}
|
||||
}
|
||||
|
||||
// extensionsWrite returns the extension map, creating it on first use.
|
||||
func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension {
|
||||
if e.p == nil {
|
||||
e.p = new(struct {
|
||||
mu sync.Mutex
|
||||
extensionMap map[int32]Extension
|
||||
})
|
||||
e.p.extensionMap = make(map[int32]Extension)
|
||||
}
|
||||
return e.p.extensionMap
|
||||
}
|
||||
|
||||
// extensionsRead returns the extensions map for read-only use. It may be nil.
|
||||
// The caller must hold the returned mutex's lock when accessing Elements within the map.
|
||||
func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) {
|
||||
if e.p == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return e.p.extensionMap, &e.p.mu
|
||||
}
|
||||
|
||||
var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem()
|
||||
var extendableProtoV1Type = reflect.TypeOf((*extendableProtoV1)(nil)).Elem()
|
||||
|
||||
// ExtensionDesc represents an extension specification.
|
||||
// Used in generated code from the protocol compiler.
|
||||
type ExtensionDesc struct {
|
||||
ExtendedType Message // nil pointer to the type that is being extended
|
||||
ExtensionType interface{} // nil pointer to the extension type
|
||||
Field int32 // field number
|
||||
Name string // fully-qualified name of extension, for text formatting
|
||||
Tag string // protobuf tag style
|
||||
Filename string // name of the file in which the extension is defined
|
||||
}
|
||||
|
||||
func (ed *ExtensionDesc) repeated() bool {
|
||||
t := reflect.TypeOf(ed.ExtensionType)
|
||||
return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
|
||||
}
|
||||
|
||||
// Extension represents an extension in a message.
|
||||
type Extension struct {
|
||||
// When an extension is stored in a message using SetExtension
|
||||
// only desc and value are set. When the message is marshaled
|
||||
// enc will be set to the encoded form of the message.
|
||||
//
|
||||
// When a message is unmarshaled and contains extensions, each
|
||||
// extension will have only enc set. When such an extension is
|
||||
// accessed using GetExtension (or GetExtensions) desc and value
|
||||
// will be set.
|
||||
desc *ExtensionDesc
|
||||
value interface{}
|
||||
enc []byte
|
||||
}
|
||||
|
||||
// SetRawExtension is for testing only.
|
||||
func SetRawExtension(base Message, id int32, b []byte) {
|
||||
epb, ok := extendable(base)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
extmap := epb.extensionsWrite()
|
||||
extmap[id] = Extension{enc: b}
|
||||
}
|
||||
|
||||
// isExtensionField returns true iff the given field number is in an extension range.
|
||||
func isExtensionField(pb extendableProto, field int32) bool {
|
||||
for _, er := range pb.ExtensionRangeArray() {
|
||||
if er.Start <= field && field <= er.End {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// checkExtensionTypes checks that the given extension is valid for pb.
|
||||
func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error {
|
||||
var pbi interface{} = pb
|
||||
// Check the extended type.
|
||||
if ea, ok := pbi.(extensionAdapter); ok {
|
||||
pbi = ea.extendableProtoV1
|
||||
}
|
||||
if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b {
|
||||
return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String())
|
||||
}
|
||||
// Check the range.
|
||||
if !isExtensionField(pb, extension.Field) {
|
||||
return errors.New("proto: bad extension number; not in declared ranges")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// extPropKey is sufficient to uniquely identify an extension.
|
||||
type extPropKey struct {
|
||||
base reflect.Type
|
||||
field int32
|
||||
}
|
||||
|
||||
var extProp = struct {
|
||||
sync.RWMutex
|
||||
m map[extPropKey]*Properties
|
||||
}{
|
||||
m: make(map[extPropKey]*Properties),
|
||||
}
|
||||
|
||||
func extensionProperties(ed *ExtensionDesc) *Properties {
|
||||
key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field}
|
||||
|
||||
extProp.RLock()
|
||||
if prop, ok := extProp.m[key]; ok {
|
||||
extProp.RUnlock()
|
||||
return prop
|
||||
}
|
||||
extProp.RUnlock()
|
||||
|
||||
extProp.Lock()
|
||||
defer extProp.Unlock()
|
||||
// Check again.
|
||||
if prop, ok := extProp.m[key]; ok {
|
||||
return prop
|
||||
}
|
||||
|
||||
prop := new(Properties)
|
||||
prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil)
|
||||
extProp.m[key] = prop
|
||||
return prop
|
||||
}
|
||||
|
||||
// encode encodes any unmarshaled (unencoded) extensions in e.
|
||||
func encodeExtensions(e *XXX_InternalExtensions) error {
|
||||
m, mu := e.extensionsRead()
|
||||
if m == nil {
|
||||
return nil // fast path
|
||||
}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return encodeExtensionsMap(m)
|
||||
}
|
||||
|
||||
// encode encodes any unmarshaled (unencoded) extensions in e.
|
||||
func encodeExtensionsMap(m map[int32]Extension) error {
|
||||
for k, e := range m {
|
||||
if e.value == nil || e.desc == nil {
|
||||
// Extension is only in its encoded form.
|
||||
continue
|
||||
}
|
||||
|
||||
// We don't skip extensions that have an encoded form set,
|
||||
// because the extension value may have been mutated after
|
||||
// the last time this function was called.
|
||||
|
||||
et := reflect.TypeOf(e.desc.ExtensionType)
|
||||
props := extensionProperties(e.desc)
|
||||
|
||||
p := NewBuffer(nil)
|
||||
// If e.value has type T, the encoder expects a *struct{ X T }.
|
||||
// Pass a *T with a zero field and hope it all works out.
|
||||
x := reflect.New(et)
|
||||
x.Elem().Set(reflect.ValueOf(e.value))
|
||||
if err := props.enc(p, props, toStructPointer(x)); err != nil {
|
||||
return err
|
||||
}
|
||||
e.enc = p.buf
|
||||
m[k] = e
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func extensionsSize(e *XXX_InternalExtensions) (n int) {
|
||||
m, mu := e.extensionsRead()
|
||||
if m == nil {
|
||||
return 0
|
||||
}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return extensionsMapSize(m)
|
||||
}
|
||||
|
||||
func extensionsMapSize(m map[int32]Extension) (n int) {
|
||||
for _, e := range m {
|
||||
if e.value == nil || e.desc == nil {
|
||||
// Extension is only in its encoded form.
|
||||
n += len(e.enc)
|
||||
continue
|
||||
}
|
||||
|
||||
// We don't skip extensions that have an encoded form set,
|
||||
// because the extension value may have been mutated after
|
||||
// the last time this function was called.
|
||||
|
||||
et := reflect.TypeOf(e.desc.ExtensionType)
|
||||
props := extensionProperties(e.desc)
|
||||
|
||||
// If e.value has type T, the encoder expects a *struct{ X T }.
|
||||
// Pass a *T with a zero field and hope it all works out.
|
||||
x := reflect.New(et)
|
||||
x.Elem().Set(reflect.ValueOf(e.value))
|
||||
n += props.size(props, toStructPointer(x))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// HasExtension returns whether the given extension is present in pb.
|
||||
func HasExtension(pb Message, extension *ExtensionDesc) bool {
|
||||
// TODO: Check types, field numbers, etc.?
|
||||
epb, ok := extendable(pb)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
extmap, mu := epb.extensionsRead()
|
||||
if extmap == nil {
|
||||
return false
|
||||
}
|
||||
mu.Lock()
|
||||
_, ok = extmap[extension.Field]
|
||||
mu.Unlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
// ClearExtension removes the given extension from pb.
|
||||
func ClearExtension(pb Message, extension *ExtensionDesc) {
|
||||
epb, ok := extendable(pb)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// TODO: Check types, field numbers, etc.?
|
||||
extmap := epb.extensionsWrite()
|
||||
delete(extmap, extension.Field)
|
||||
}
|
||||
|
||||
// GetExtension parses and returns the given extension of pb.
|
||||
// If the extension is not present and has no default value it returns ErrMissingExtension.
|
||||
func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
|
||||
epb, ok := extendable(pb)
|
||||
if !ok {
|
||||
return nil, errors.New("proto: not an extendable proto")
|
||||
}
|
||||
|
||||
if err := checkExtensionTypes(epb, extension); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
emap, mu := epb.extensionsRead()
|
||||
if emap == nil {
|
||||
return defaultExtensionValue(extension)
|
||||
}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
e, ok := emap[extension.Field]
|
||||
if !ok {
|
||||
// defaultExtensionValue returns the default value or
|
||||
// ErrMissingExtension if there is no default.
|
||||
return defaultExtensionValue(extension)
|
||||
}
|
||||
|
||||
if e.value != nil {
|
||||
// Already decoded. Check the descriptor, though.
|
||||
if e.desc != extension {
|
||||
// This shouldn't happen. If it does, it means that
|
||||
// GetExtension was called twice with two different
|
||||
// descriptors with the same field number.
|
||||
return nil, errors.New("proto: descriptor conflict")
|
||||
}
|
||||
return e.value, nil
|
||||
}
|
||||
|
||||
v, err := decodeExtension(e.enc, extension)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Remember the decoded version and drop the encoded version.
|
||||
// That way it is safe to mutate what we return.
|
||||
e.value = v
|
||||
e.desc = extension
|
||||
e.enc = nil
|
||||
emap[extension.Field] = e
|
||||
return e.value, nil
|
||||
}
|
||||
|
||||
// defaultExtensionValue returns the default value for extension.
|
||||
// If no default for an extension is defined ErrMissingExtension is returned.
|
||||
func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
|
||||
t := reflect.TypeOf(extension.ExtensionType)
|
||||
props := extensionProperties(extension)
|
||||
|
||||
sf, _, err := fieldDefault(t, props)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if sf == nil || sf.value == nil {
|
||||
// There is no default value.
|
||||
return nil, ErrMissingExtension
|
||||
}
|
||||
|
||||
if t.Kind() != reflect.Ptr {
|
||||
// We do not need to return a Ptr, we can directly return sf.value.
|
||||
return sf.value, nil
|
||||
}
|
||||
|
||||
// We need to return an interface{} that is a pointer to sf.value.
|
||||
value := reflect.New(t).Elem()
|
||||
value.Set(reflect.New(value.Type().Elem()))
|
||||
if sf.kind == reflect.Int32 {
|
||||
// We may have an int32 or an enum, but the underlying data is int32.
|
||||
// Since we can't set an int32 into a non int32 reflect.value directly
|
||||
// set it as a int32.
|
||||
value.Elem().SetInt(int64(sf.value.(int32)))
|
||||
} else {
|
||||
value.Elem().Set(reflect.ValueOf(sf.value))
|
||||
}
|
||||
return value.Interface(), nil
|
||||
}
|
||||
|
||||
// decodeExtension decodes an extension encoded in b.
|
||||
func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
|
||||
o := NewBuffer(b)
|
||||
|
||||
t := reflect.TypeOf(extension.ExtensionType)
|
||||
|
||||
props := extensionProperties(extension)
|
||||
|
||||
// t is a pointer to a struct, pointer to basic type or a slice.
|
||||
// Allocate a "field" to store the pointer/slice itself; the
|
||||
// pointer/slice will be stored here. We pass
|
||||
// the address of this field to props.dec.
|
||||
// This passes a zero field and a *t and lets props.dec
|
||||
// interpret it as a *struct{ x t }.
|
||||
value := reflect.New(t).Elem()
|
||||
|
||||
for {
|
||||
// Discard wire type and field number varint. It isn't needed.
|
||||
if _, err := o.DecodeVarint(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := props.dec(o, props, toStructPointer(value.Addr())); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if o.index >= len(o.buf) {
|
||||
break
|
||||
}
|
||||
}
|
||||
return value.Interface(), nil
|
||||
}
|
||||
|
||||
// GetExtensions returns a slice of the extensions present in pb that are also listed in es.
|
||||
// The returned slice has the same length as es; missing extensions will appear as nil elements.
|
||||
func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
|
||||
epb, ok := extendable(pb)
|
||||
if !ok {
|
||||
return nil, errors.New("proto: not an extendable proto")
|
||||
}
|
||||
extensions = make([]interface{}, len(es))
|
||||
for i, e := range es {
|
||||
extensions[i], err = GetExtension(epb, e)
|
||||
if err == ErrMissingExtension {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ExtensionDescs returns a new slice containing pb's extension descriptors, in undefined order.
|
||||
// For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing
|
||||
// just the Field field, which defines the extension's field number.
|
||||
func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) {
|
||||
epb, ok := extendable(pb)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("proto: %T is not an extendable proto.Message", pb)
|
||||
}
|
||||
registeredExtensions := RegisteredExtensions(pb)
|
||||
|
||||
emap, mu := epb.extensionsRead()
|
||||
if emap == nil {
|
||||
return nil, nil
|
||||
}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
extensions := make([]*ExtensionDesc, 0, len(emap))
|
||||
for extid, e := range emap {
|
||||
desc := e.desc
|
||||
if desc == nil {
|
||||
desc = registeredExtensions[extid]
|
||||
if desc == nil {
|
||||
desc = &ExtensionDesc{Field: extid}
|
||||
}
|
||||
}
|
||||
|
||||
extensions = append(extensions, desc)
|
||||
}
|
||||
return extensions, nil
|
||||
}
|
||||
|
||||
// SetExtension sets the specified extension of pb to the specified value.
|
||||
func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error {
|
||||
epb, ok := extendable(pb)
|
||||
if !ok {
|
||||
return errors.New("proto: not an extendable proto")
|
||||
}
|
||||
if err := checkExtensionTypes(epb, extension); err != nil {
|
||||
return err
|
||||
}
|
||||
typ := reflect.TypeOf(extension.ExtensionType)
|
||||
if typ != reflect.TypeOf(value) {
|
||||
return errors.New("proto: bad extension value type")
|
||||
}
|
||||
// nil extension values need to be caught early, because the
|
||||
// encoder can't distinguish an ErrNil due to a nil extension
|
||||
// from an ErrNil due to a missing field. Extensions are
|
||||
// always optional, so the encoder would just swallow the error
|
||||
// and drop all the extensions from the encoded message.
|
||||
if reflect.ValueOf(value).IsNil() {
|
||||
return fmt.Errorf("proto: SetExtension called with nil value of type %T", value)
|
||||
}
|
||||
|
||||
extmap := epb.extensionsWrite()
|
||||
extmap[extension.Field] = Extension{desc: extension, value: value}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearAllExtensions clears all extensions from pb.
|
||||
func ClearAllExtensions(pb Message) {
|
||||
epb, ok := extendable(pb)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
m := epb.extensionsWrite()
|
||||
for k := range m {
|
||||
delete(m, k)
|
||||
}
|
||||
}
|
||||
|
||||
// A global registry of extensions.
|
||||
// The generated code will register the generated descriptors by calling RegisterExtension.
|
||||
|
||||
var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc)
|
||||
|
||||
// RegisterExtension is called from the generated code.
|
||||
func RegisterExtension(desc *ExtensionDesc) {
|
||||
st := reflect.TypeOf(desc.ExtendedType).Elem()
|
||||
m := extensionMaps[st]
|
||||
if m == nil {
|
||||
m = make(map[int32]*ExtensionDesc)
|
||||
extensionMaps[st] = m
|
||||
}
|
||||
if _, ok := m[desc.Field]; ok {
|
||||
panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field)))
|
||||
}
|
||||
m[desc.Field] = desc
|
||||
}
|
||||
|
||||
// RegisteredExtensions returns a map of the registered extensions of a
|
||||
// protocol buffer struct, indexed by the extension number.
|
||||
// The argument pb should be a nil pointer to the struct type.
|
||||
func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc {
|
||||
return extensionMaps[reflect.TypeOf(pb).Elem()]
|
||||
}
|
|
@ -0,0 +1,898 @@
|
|||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
/*
|
||||
Package proto converts data structures to and from the wire format of
|
||||
protocol buffers. It works in concert with the Go source code generated
|
||||
for .proto files by the protocol compiler.
|
||||
|
||||
A summary of the properties of the protocol buffer interface
|
||||
for a protocol buffer variable v:
|
||||
|
||||
- Names are turned from camel_case to CamelCase for export.
|
||||
- There are no methods on v to set fields; just treat
|
||||
them as structure fields.
|
||||
- There are getters that return a field's value if set,
|
||||
and return the field's default value if unset.
|
||||
The getters work even if the receiver is a nil message.
|
||||
- The zero value for a struct is its correct initialization state.
|
||||
All desired fields must be set before marshaling.
|
||||
- A Reset() method will restore a protobuf struct to its zero state.
|
||||
- Non-repeated fields are pointers to the values; nil means unset.
|
||||
That is, optional or required field int32 f becomes F *int32.
|
||||
- Repeated fields are slices.
|
||||
- Helper functions are available to aid the setting of fields.
|
||||
msg.Foo = proto.String("hello") // set field
|
||||
- Constants are defined to hold the default values of all fields that
|
||||
have them. They have the form Default_StructName_FieldName.
|
||||
Because the getter methods handle defaulted values,
|
||||
direct use of these constants should be rare.
|
||||
- Enums are given type names and maps from names to values.
|
||||
Enum values are prefixed by the enclosing message's name, or by the
|
||||
enum's type name if it is a top-level enum. Enum types have a String
|
||||
method, and a Enum method to assist in message construction.
|
||||
- Nested messages, groups and enums have type names prefixed with the name of
|
||||
the surrounding message type.
|
||||
- Extensions are given descriptor names that start with E_,
|
||||
followed by an underscore-delimited list of the nested messages
|
||||
that contain it (if any) followed by the CamelCased name of the
|
||||
extension field itself. HasExtension, ClearExtension, GetExtension
|
||||
and SetExtension are functions for manipulating extensions.
|
||||
- Oneof field sets are given a single field in their message,
|
||||
with distinguished wrapper types for each possible field value.
|
||||
- Marshal and Unmarshal are functions to encode and decode the wire format.
|
||||
|
||||
When the .proto file specifies `syntax="proto3"`, there are some differences:
|
||||
|
||||
- Non-repeated fields of non-message type are values instead of pointers.
|
||||
- Getters are only generated for message and oneof fields.
|
||||
- Enum types do not get an Enum method.
|
||||
|
||||
The simplest way to describe this is to see an example.
|
||||
Given file test.proto, containing
|
||||
|
||||
package example;
|
||||
|
||||
enum FOO { X = 17; }
|
||||
|
||||
message Test {
|
||||
required string label = 1;
|
||||
optional int32 type = 2 [default=77];
|
||||
repeated int64 reps = 3;
|
||||
optional group OptionalGroup = 4 {
|
||||
required string RequiredField = 5;
|
||||
}
|
||||
oneof union {
|
||||
int32 number = 6;
|
||||
string name = 7;
|
||||
}
|
||||
}
|
||||
|
||||
The resulting file, test.pb.go, is:
|
||||
|
||||
package example
|
||||
|
||||
import proto "github.com/golang/protobuf/proto"
|
||||
import math "math"
|
||||
|
||||
type FOO int32
|
||||
const (
|
||||
FOO_X FOO = 17
|
||||
)
|
||||
var FOO_name = map[int32]string{
|
||||
17: "X",
|
||||
}
|
||||
var FOO_value = map[string]int32{
|
||||
"X": 17,
|
||||
}
|
||||
|
||||
func (x FOO) Enum() *FOO {
|
||||
p := new(FOO)
|
||||
*p = x
|
||||
return p
|
||||
}
|
||||
func (x FOO) String() string {
|
||||
return proto.EnumName(FOO_name, int32(x))
|
||||
}
|
||||
func (x *FOO) UnmarshalJSON(data []byte) error {
|
||||
value, err := proto.UnmarshalJSONEnum(FOO_value, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*x = FOO(value)
|
||||
return nil
|
||||
}
|
||||
|
||||
type Test struct {
|
||||
Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"`
|
||||
Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"`
|
||||
Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"`
|
||||
Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"`
|
||||
// Types that are valid to be assigned to Union:
|
||||
// *Test_Number
|
||||
// *Test_Name
|
||||
Union isTest_Union `protobuf_oneof:"union"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
}
|
||||
func (m *Test) Reset() { *m = Test{} }
|
||||
func (m *Test) String() string { return proto.CompactTextString(m) }
|
||||
func (*Test) ProtoMessage() {}
|
||||
|
||||
type isTest_Union interface {
|
||||
isTest_Union()
|
||||
}
|
||||
|
||||
type Test_Number struct {
|
||||
Number int32 `protobuf:"varint,6,opt,name=number"`
|
||||
}
|
||||
type Test_Name struct {
|
||||
Name string `protobuf:"bytes,7,opt,name=name"`
|
||||
}
|
||||
|
||||
func (*Test_Number) isTest_Union() {}
|
||||
func (*Test_Name) isTest_Union() {}
|
||||
|
||||
func (m *Test) GetUnion() isTest_Union {
|
||||
if m != nil {
|
||||
return m.Union
|
||||
}
|
||||
return nil
|
||||
}
|
||||
const Default_Test_Type int32 = 77
|
||||
|
||||
func (m *Test) GetLabel() string {
|
||||
if m != nil && m.Label != nil {
|
||||
return *m.Label
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Test) GetType() int32 {
|
||||
if m != nil && m.Type != nil {
|
||||
return *m.Type
|
||||
}
|
||||
return Default_Test_Type
|
||||
}
|
||||
|
||||
func (m *Test) GetOptionalgroup() *Test_OptionalGroup {
|
||||
if m != nil {
|
||||
return m.Optionalgroup
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Test_OptionalGroup struct {
|
||||
RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"`
|
||||
}
|
||||
func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} }
|
||||
func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) }
|
||||
|
||||
func (m *Test_OptionalGroup) GetRequiredField() string {
|
||||
if m != nil && m.RequiredField != nil {
|
||||
return *m.RequiredField
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Test) GetNumber() int32 {
|
||||
if x, ok := m.GetUnion().(*Test_Number); ok {
|
||||
return x.Number
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *Test) GetName() string {
|
||||
if x, ok := m.GetUnion().(*Test_Name); ok {
|
||||
return x.Name
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterEnum("example.FOO", FOO_name, FOO_value)
|
||||
}
|
||||
|
||||
To create and play with a Test object:
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
pb "./example.pb"
|
||||
)
|
||||
|
||||
func main() {
|
||||
test := &pb.Test{
|
||||
Label: proto.String("hello"),
|
||||
Type: proto.Int32(17),
|
||||
Reps: []int64{1, 2, 3},
|
||||
Optionalgroup: &pb.Test_OptionalGroup{
|
||||
RequiredField: proto.String("good bye"),
|
||||
},
|
||||
Union: &pb.Test_Name{"fred"},
|
||||
}
|
||||
data, err := proto.Marshal(test)
|
||||
if err != nil {
|
||||
log.Fatal("marshaling error: ", err)
|
||||
}
|
||||
newTest := &pb.Test{}
|
||||
err = proto.Unmarshal(data, newTest)
|
||||
if err != nil {
|
||||
log.Fatal("unmarshaling error: ", err)
|
||||
}
|
||||
// Now test and newTest contain the same data.
|
||||
if test.GetLabel() != newTest.GetLabel() {
|
||||
log.Fatalf("data mismatch %q != %q", test.GetLabel(), newTest.GetLabel())
|
||||
}
|
||||
// Use a type switch to determine which oneof was set.
|
||||
switch u := test.Union.(type) {
|
||||
case *pb.Test_Number: // u.Number contains the number.
|
||||
case *pb.Test_Name: // u.Name contains the string.
|
||||
}
|
||||
// etc.
|
||||
}
|
||||
*/
|
||||
package proto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Message is implemented by generated protocol buffer messages.
|
||||
type Message interface {
|
||||
Reset()
|
||||
String() string
|
||||
ProtoMessage()
|
||||
}
|
||||
|
||||
// Stats records allocation details about the protocol buffer encoders
|
||||
// and decoders. Useful for tuning the library itself.
|
||||
type Stats struct {
|
||||
Emalloc uint64 // mallocs in encode
|
||||
Dmalloc uint64 // mallocs in decode
|
||||
Encode uint64 // number of encodes
|
||||
Decode uint64 // number of decodes
|
||||
Chit uint64 // number of cache hits
|
||||
Cmiss uint64 // number of cache misses
|
||||
Size uint64 // number of sizes
|
||||
}
|
||||
|
||||
// Set to true to enable stats collection.
|
||||
const collectStats = false
|
||||
|
||||
var stats Stats
|
||||
|
||||
// GetStats returns a copy of the global Stats structure.
|
||||
func GetStats() Stats { return stats }
|
||||
|
||||
// A Buffer is a buffer manager for marshaling and unmarshaling
|
||||
// protocol buffers. It may be reused between invocations to
|
||||
// reduce memory usage. It is not necessary to use a Buffer;
|
||||
// the global functions Marshal and Unmarshal create a
|
||||
// temporary Buffer and are fine for most applications.
|
||||
type Buffer struct {
|
||||
buf []byte // encode/decode byte stream
|
||||
index int // read point
|
||||
|
||||
// pools of basic types to amortize allocation.
|
||||
bools []bool
|
||||
uint32s []uint32
|
||||
uint64s []uint64
|
||||
|
||||
// extra pools, only used with pointer_reflect.go
|
||||
int32s []int32
|
||||
int64s []int64
|
||||
float32s []float32
|
||||
float64s []float64
|
||||
}
|
||||
|
||||
// NewBuffer allocates a new Buffer and initializes its internal data to
|
||||
// the contents of the argument slice.
|
||||
func NewBuffer(e []byte) *Buffer {
|
||||
return &Buffer{buf: e}
|
||||
}
|
||||
|
||||
// Reset resets the Buffer, ready for marshaling a new protocol buffer.
|
||||
func (p *Buffer) Reset() {
|
||||
p.buf = p.buf[0:0] // for reading/writing
|
||||
p.index = 0 // for reading
|
||||
}
|
||||
|
||||
// SetBuf replaces the internal buffer with the slice,
|
||||
// ready for unmarshaling the contents of the slice.
|
||||
func (p *Buffer) SetBuf(s []byte) {
|
||||
p.buf = s
|
||||
p.index = 0
|
||||
}
|
||||
|
||||
// Bytes returns the contents of the Buffer.
|
||||
func (p *Buffer) Bytes() []byte { return p.buf }
|
||||
|
||||
/*
|
||||
* Helper routines for simplifying the creation of optional fields of basic type.
|
||||
*/
|
||||
|
||||
// Bool is a helper routine that allocates a new bool value
|
||||
// to store v and returns a pointer to it.
|
||||
func Bool(v bool) *bool {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Int32 is a helper routine that allocates a new int32 value
|
||||
// to store v and returns a pointer to it.
|
||||
func Int32(v int32) *int32 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Int is a helper routine that allocates a new int32 value
|
||||
// to store v and returns a pointer to it, but unlike Int32
|
||||
// its argument value is an int.
|
||||
func Int(v int) *int32 {
|
||||
p := new(int32)
|
||||
*p = int32(v)
|
||||
return p
|
||||
}
|
||||
|
||||
// Int64 is a helper routine that allocates a new int64 value
|
||||
// to store v and returns a pointer to it.
|
||||
func Int64(v int64) *int64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Float32 is a helper routine that allocates a new float32 value
|
||||
// to store v and returns a pointer to it.
|
||||
func Float32(v float32) *float32 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Float64 is a helper routine that allocates a new float64 value
|
||||
// to store v and returns a pointer to it.
|
||||
func Float64(v float64) *float64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Uint32 is a helper routine that allocates a new uint32 value
|
||||
// to store v and returns a pointer to it.
|
||||
func Uint32(v uint32) *uint32 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Uint64 is a helper routine that allocates a new uint64 value
|
||||
// to store v and returns a pointer to it.
|
||||
func Uint64(v uint64) *uint64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// String is a helper routine that allocates a new string value
|
||||
// to store v and returns a pointer to it.
|
||||
func String(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
// EnumName is a helper function to simplify printing protocol buffer enums
|
||||
// by name. Given an enum map and a value, it returns a useful string.
|
||||
func EnumName(m map[int32]string, v int32) string {
|
||||
s, ok := m[v]
|
||||
if ok {
|
||||
return s
|
||||
}
|
||||
return strconv.Itoa(int(v))
|
||||
}
|
||||
|
||||
// UnmarshalJSONEnum is a helper function to simplify recovering enum int values
|
||||
// from their JSON-encoded representation. Given a map from the enum's symbolic
|
||||
// names to its int values, and a byte buffer containing the JSON-encoded
|
||||
// value, it returns an int32 that can be cast to the enum type by the caller.
|
||||
//
|
||||
// The function can deal with both JSON representations, numeric and symbolic.
|
||||
func UnmarshalJSONEnum(m map[string]int32, data []byte, enumName string) (int32, error) {
|
||||
if data[0] == '"' {
|
||||
// New style: enums are strings.
|
||||
var repr string
|
||||
if err := json.Unmarshal(data, &repr); err != nil {
|
||||
return -1, err
|
||||
}
|
||||
val, ok := m[repr]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("unrecognized enum %s value %q", enumName, repr)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
// Old style: enums are ints.
|
||||
var val int32
|
||||
if err := json.Unmarshal(data, &val); err != nil {
|
||||
return 0, fmt.Errorf("cannot unmarshal %#q into enum %s", data, enumName)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// DebugPrint dumps the encoded data in b in a debugging format with a header
|
||||
// including the string s. Used in testing but made available for general debugging.
|
||||
func (p *Buffer) DebugPrint(s string, b []byte) {
|
||||
var u uint64
|
||||
|
||||
obuf := p.buf
|
||||
index := p.index
|
||||
p.buf = b
|
||||
p.index = 0
|
||||
depth := 0
|
||||
|
||||
fmt.Printf("\n--- %s ---\n", s)
|
||||
|
||||
out:
|
||||
for {
|
||||
for i := 0; i < depth; i++ {
|
||||
fmt.Print(" ")
|
||||
}
|
||||
|
||||
index := p.index
|
||||
if index == len(p.buf) {
|
||||
break
|
||||
}
|
||||
|
||||
op, err := p.DecodeVarint()
|
||||
if err != nil {
|
||||
fmt.Printf("%3d: fetching op err %v\n", index, err)
|
||||
break out
|
||||
}
|
||||
tag := op >> 3
|
||||
wire := op & 7
|
||||
|
||||
switch wire {
|
||||
default:
|
||||
fmt.Printf("%3d: t=%3d unknown wire=%d\n",
|
||||
index, tag, wire)
|
||||
break out
|
||||
|
||||
case WireBytes:
|
||||
var r []byte
|
||||
|
||||
r, err = p.DecodeRawBytes(false)
|
||||
if err != nil {
|
||||
break out
|
||||
}
|
||||
fmt.Printf("%3d: t=%3d bytes [%d]", index, tag, len(r))
|
||||
if len(r) <= 6 {
|
||||
for i := 0; i < len(r); i++ {
|
||||
fmt.Printf(" %.2x", r[i])
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < 3; i++ {
|
||||
fmt.Printf(" %.2x", r[i])
|
||||
}
|
||||
fmt.Printf(" ..")
|
||||
for i := len(r) - 3; i < len(r); i++ {
|
||||
fmt.Printf(" %.2x", r[i])
|
||||
}
|
||||
}
|
||||
fmt.Printf("\n")
|
||||
|
||||
case WireFixed32:
|
||||
u, err = p.DecodeFixed32()
|
||||
if err != nil {
|
||||
fmt.Printf("%3d: t=%3d fix32 err %v\n", index, tag, err)
|
||||
break out
|
||||
}
|
||||
fmt.Printf("%3d: t=%3d fix32 %d\n", index, tag, u)
|
||||
|
||||
case WireFixed64:
|
||||
u, err = p.DecodeFixed64()
|
||||
if err != nil {
|
||||
fmt.Printf("%3d: t=%3d fix64 err %v\n", index, tag, err)
|
||||
break out
|
||||
}
|
||||
fmt.Printf("%3d: t=%3d fix64 %d\n", index, tag, u)
|
||||
|
||||
case WireVarint:
|
||||
u, err = p.DecodeVarint()
|
||||
if err != nil {
|
||||
fmt.Printf("%3d: t=%3d varint err %v\n", index, tag, err)
|
||||
break out
|
||||
}
|
||||
fmt.Printf("%3d: t=%3d varint %d\n", index, tag, u)
|
||||
|
||||
case WireStartGroup:
|
||||
fmt.Printf("%3d: t=%3d start\n", index, tag)
|
||||
depth++
|
||||
|
||||
case WireEndGroup:
|
||||
depth--
|
||||
fmt.Printf("%3d: t=%3d end\n", index, tag)
|
||||
}
|
||||
}
|
||||
|
||||
if depth != 0 {
|
||||
fmt.Printf("%3d: start-end not balanced %d\n", p.index, depth)
|
||||
}
|
||||
fmt.Printf("\n")
|
||||
|
||||
p.buf = obuf
|
||||
p.index = index
|
||||
}
|
||||
|
||||
// SetDefaults sets unset protocol buffer fields to their default values.
|
||||
// It only modifies fields that are both unset and have defined defaults.
|
||||
// It recursively sets default values in any non-nil sub-messages.
|
||||
func SetDefaults(pb Message) {
|
||||
setDefaults(reflect.ValueOf(pb), true, false)
|
||||
}
|
||||
|
||||
// v is a pointer to a struct.
|
||||
func setDefaults(v reflect.Value, recur, zeros bool) {
|
||||
v = v.Elem()
|
||||
|
||||
defaultMu.RLock()
|
||||
dm, ok := defaults[v.Type()]
|
||||
defaultMu.RUnlock()
|
||||
if !ok {
|
||||
dm = buildDefaultMessage(v.Type())
|
||||
defaultMu.Lock()
|
||||
defaults[v.Type()] = dm
|
||||
defaultMu.Unlock()
|
||||
}
|
||||
|
||||
for _, sf := range dm.scalars {
|
||||
f := v.Field(sf.index)
|
||||
if !f.IsNil() {
|
||||
// field already set
|
||||
continue
|
||||
}
|
||||
dv := sf.value
|
||||
if dv == nil && !zeros {
|
||||
// no explicit default, and don't want to set zeros
|
||||
continue
|
||||
}
|
||||
fptr := f.Addr().Interface() // **T
|
||||
// TODO: Consider batching the allocations we do here.
|
||||
switch sf.kind {
|
||||
case reflect.Bool:
|
||||
b := new(bool)
|
||||
if dv != nil {
|
||||
*b = dv.(bool)
|
||||
}
|
||||
*(fptr.(**bool)) = b
|
||||
case reflect.Float32:
|
||||
f := new(float32)
|
||||
if dv != nil {
|
||||
*f = dv.(float32)
|
||||
}
|
||||
*(fptr.(**float32)) = f
|
||||
case reflect.Float64:
|
||||
f := new(float64)
|
||||
if dv != nil {
|
||||
*f = dv.(float64)
|
||||
}
|
||||
*(fptr.(**float64)) = f
|
||||
case reflect.Int32:
|
||||
// might be an enum
|
||||
if ft := f.Type(); ft != int32PtrType {
|
||||
// enum
|
||||
f.Set(reflect.New(ft.Elem()))
|
||||
if dv != nil {
|
||||
f.Elem().SetInt(int64(dv.(int32)))
|
||||
}
|
||||
} else {
|
||||
// int32 field
|
||||
i := new(int32)
|
||||
if dv != nil {
|
||||
*i = dv.(int32)
|
||||
}
|
||||
*(fptr.(**int32)) = i
|
||||
}
|
||||
case reflect.Int64:
|
||||
i := new(int64)
|
||||
if dv != nil {
|
||||
*i = dv.(int64)
|
||||
}
|
||||
*(fptr.(**int64)) = i
|
||||
case reflect.String:
|
||||
s := new(string)
|
||||
if dv != nil {
|
||||
*s = dv.(string)
|
||||
}
|
||||
*(fptr.(**string)) = s
|
||||
case reflect.Uint8:
|
||||
// exceptional case: []byte
|
||||
var b []byte
|
||||
if dv != nil {
|
||||
db := dv.([]byte)
|
||||
b = make([]byte, len(db))
|
||||
copy(b, db)
|
||||
} else {
|
||||
b = []byte{}
|
||||
}
|
||||
*(fptr.(*[]byte)) = b
|
||||
case reflect.Uint32:
|
||||
u := new(uint32)
|
||||
if dv != nil {
|
||||
*u = dv.(uint32)
|
||||
}
|
||||
*(fptr.(**uint32)) = u
|
||||
case reflect.Uint64:
|
||||
u := new(uint64)
|
||||
if dv != nil {
|
||||
*u = dv.(uint64)
|
||||
}
|
||||
*(fptr.(**uint64)) = u
|
||||
default:
|
||||
log.Printf("proto: can't set default for field %v (sf.kind=%v)", f, sf.kind)
|
||||
}
|
||||
}
|
||||
|
||||
for _, ni := range dm.nested {
|
||||
f := v.Field(ni)
|
||||
// f is *T or []*T or map[T]*T
|
||||
switch f.Kind() {
|
||||
case reflect.Ptr:
|
||||
if f.IsNil() {
|
||||
continue
|
||||
}
|
||||
setDefaults(f, recur, zeros)
|
||||
|
||||
case reflect.Slice:
|
||||
for i := 0; i < f.Len(); i++ {
|
||||
e := f.Index(i)
|
||||
if e.IsNil() {
|
||||
continue
|
||||
}
|
||||
setDefaults(e, recur, zeros)
|
||||
}
|
||||
|
||||
case reflect.Map:
|
||||
for _, k := range f.MapKeys() {
|
||||
e := f.MapIndex(k)
|
||||
if e.IsNil() {
|
||||
continue
|
||||
}
|
||||
setDefaults(e, recur, zeros)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
// defaults maps a protocol buffer struct type to a slice of the fields,
|
||||
// with its scalar fields set to their proto-declared non-zero default values.
|
||||
defaultMu sync.RWMutex
|
||||
defaults = make(map[reflect.Type]defaultMessage)
|
||||
|
||||
int32PtrType = reflect.TypeOf((*int32)(nil))
|
||||
)
|
||||
|
||||
// defaultMessage represents information about the default values of a message.
|
||||
type defaultMessage struct {
|
||||
scalars []scalarField
|
||||
nested []int // struct field index of nested messages
|
||||
}
|
||||
|
||||
type scalarField struct {
|
||||
index int // struct field index
|
||||
kind reflect.Kind // element type (the T in *T or []T)
|
||||
value interface{} // the proto-declared default value, or nil
|
||||
}
|
||||
|
||||
// t is a struct type.
|
||||
func buildDefaultMessage(t reflect.Type) (dm defaultMessage) {
|
||||
sprop := GetProperties(t)
|
||||
for _, prop := range sprop.Prop {
|
||||
fi, ok := sprop.decoderTags.get(prop.Tag)
|
||||
if !ok {
|
||||
// XXX_unrecognized
|
||||
continue
|
||||
}
|
||||
ft := t.Field(fi).Type
|
||||
|
||||
sf, nested, err := fieldDefault(ft, prop)
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Print(err)
|
||||
case nested:
|
||||
dm.nested = append(dm.nested, fi)
|
||||
case sf != nil:
|
||||
sf.index = fi
|
||||
dm.scalars = append(dm.scalars, *sf)
|
||||
}
|
||||
}
|
||||
|
||||
return dm
|
||||
}
|
||||
|
||||
// fieldDefault returns the scalarField for field type ft.
|
||||
// sf will be nil if the field can not have a default.
|
||||
// nestedMessage will be true if this is a nested message.
|
||||
// Note that sf.index is not set on return.
|
||||
func fieldDefault(ft reflect.Type, prop *Properties) (sf *scalarField, nestedMessage bool, err error) {
|
||||
var canHaveDefault bool
|
||||
switch ft.Kind() {
|
||||
case reflect.Ptr:
|
||||
if ft.Elem().Kind() == reflect.Struct {
|
||||
nestedMessage = true
|
||||
} else {
|
||||
canHaveDefault = true // proto2 scalar field
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
switch ft.Elem().Kind() {
|
||||
case reflect.Ptr:
|
||||
nestedMessage = true // repeated message
|
||||
case reflect.Uint8:
|
||||
canHaveDefault = true // bytes field
|
||||
}
|
||||
|
||||
case reflect.Map:
|
||||
if ft.Elem().Kind() == reflect.Ptr {
|
||||
nestedMessage = true // map with message values
|
||||
}
|
||||
}
|
||||
|
||||
if !canHaveDefault {
|
||||
if nestedMessage {
|
||||
return nil, true, nil
|
||||
}
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// We now know that ft is a pointer or slice.
|
||||
sf = &scalarField{kind: ft.Elem().Kind()}
|
||||
|
||||
// scalar fields without defaults
|
||||
if !prop.HasDefault {
|
||||
return sf, false, nil
|
||||
}
|
||||
|
||||
// a scalar field: either *T or []byte
|
||||
switch ft.Elem().Kind() {
|
||||
case reflect.Bool:
|
||||
x, err := strconv.ParseBool(prop.Default)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default bool %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = x
|
||||
case reflect.Float32:
|
||||
x, err := strconv.ParseFloat(prop.Default, 32)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default float32 %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = float32(x)
|
||||
case reflect.Float64:
|
||||
x, err := strconv.ParseFloat(prop.Default, 64)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default float64 %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = x
|
||||
case reflect.Int32:
|
||||
x, err := strconv.ParseInt(prop.Default, 10, 32)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default int32 %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = int32(x)
|
||||
case reflect.Int64:
|
||||
x, err := strconv.ParseInt(prop.Default, 10, 64)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default int64 %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = x
|
||||
case reflect.String:
|
||||
sf.value = prop.Default
|
||||
case reflect.Uint8:
|
||||
// []byte (not *uint8)
|
||||
sf.value = []byte(prop.Default)
|
||||
case reflect.Uint32:
|
||||
x, err := strconv.ParseUint(prop.Default, 10, 32)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default uint32 %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = uint32(x)
|
||||
case reflect.Uint64:
|
||||
x, err := strconv.ParseUint(prop.Default, 10, 64)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default uint64 %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = x
|
||||
default:
|
||||
return nil, false, fmt.Errorf("proto: unhandled def kind %v", ft.Elem().Kind())
|
||||
}
|
||||
|
||||
return sf, false, nil
|
||||
}
|
||||
|
||||
// Map fields may have key types of non-float scalars, strings and enums.
|
||||
// The easiest way to sort them in some deterministic order is to use fmt.
|
||||
// If this turns out to be inefficient we can always consider other options,
|
||||
// such as doing a Schwartzian transform.
|
||||
|
||||
func mapKeys(vs []reflect.Value) sort.Interface {
|
||||
s := mapKeySorter{
|
||||
vs: vs,
|
||||
// default Less function: textual comparison
|
||||
less: func(a, b reflect.Value) bool {
|
||||
return fmt.Sprint(a.Interface()) < fmt.Sprint(b.Interface())
|
||||
},
|
||||
}
|
||||
|
||||
// Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps;
|
||||
// numeric keys are sorted numerically.
|
||||
if len(vs) == 0 {
|
||||
return s
|
||||
}
|
||||
switch vs[0].Kind() {
|
||||
case reflect.Int32, reflect.Int64:
|
||||
s.less = func(a, b reflect.Value) bool { return a.Int() < b.Int() }
|
||||
case reflect.Uint32, reflect.Uint64:
|
||||
s.less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() }
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
type mapKeySorter struct {
|
||||
vs []reflect.Value
|
||||
less func(a, b reflect.Value) bool
|
||||
}
|
||||
|
||||
func (s mapKeySorter) Len() int { return len(s.vs) }
|
||||
func (s mapKeySorter) Swap(i, j int) { s.vs[i], s.vs[j] = s.vs[j], s.vs[i] }
|
||||
func (s mapKeySorter) Less(i, j int) bool {
|
||||
return s.less(s.vs[i], s.vs[j])
|
||||
}
|
||||
|
||||
// isProto3Zero reports whether v is a zero proto3 value.
|
||||
func isProto3Zero(v reflect.Value) bool {
|
||||
switch v.Kind() {
|
||||
case reflect.Bool:
|
||||
return !v.Bool()
|
||||
case reflect.Int32, reflect.Int64:
|
||||
return v.Int() == 0
|
||||
case reflect.Uint32, reflect.Uint64:
|
||||
return v.Uint() == 0
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return v.Float() == 0
|
||||
case reflect.String:
|
||||
return v.String() == ""
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ProtoPackageIsVersion2 is referenced from generated protocol buffer files
|
||||
// to assert that that code is compatible with this version of the proto package.
|
||||
const ProtoPackageIsVersion2 = true
|
||||
|
||||
// ProtoPackageIsVersion1 is referenced from generated protocol buffer files
|
||||
// to assert that that code is compatible with this version of the proto package.
|
||||
const ProtoPackageIsVersion1 = true
|
|
@ -0,0 +1,311 @@
|
|||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package proto
|
||||
|
||||
/*
|
||||
* Support for message sets.
|
||||
*/
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// errNoMessageTypeID occurs when a protocol buffer does not have a message type ID.
|
||||
// A message type ID is required for storing a protocol buffer in a message set.
|
||||
var errNoMessageTypeID = errors.New("proto does not have a message type ID")
|
||||
|
||||
// The first two types (_MessageSet_Item and messageSet)
|
||||
// model what the protocol compiler produces for the following protocol message:
|
||||
// message MessageSet {
|
||||
// repeated group Item = 1 {
|
||||
// required int32 type_id = 2;
|
||||
// required string message = 3;
|
||||
// };
|
||||
// }
|
||||
// That is the MessageSet wire format. We can't use a proto to generate these
|
||||
// because that would introduce a circular dependency between it and this package.
|
||||
|
||||
type _MessageSet_Item struct {
|
||||
TypeId *int32 `protobuf:"varint,2,req,name=type_id"`
|
||||
Message []byte `protobuf:"bytes,3,req,name=message"`
|
||||
}
|
||||
|
||||
type messageSet struct {
|
||||
Item []*_MessageSet_Item `protobuf:"group,1,rep"`
|
||||
XXX_unrecognized []byte
|
||||
// TODO: caching?
|
||||
}
|
||||
|
||||
// Make sure messageSet is a Message.
|
||||
var _ Message = (*messageSet)(nil)
|
||||
|
||||
// messageTypeIder is an interface satisfied by a protocol buffer type
|
||||
// that may be stored in a MessageSet.
|
||||
type messageTypeIder interface {
|
||||
MessageTypeId() int32
|
||||
}
|
||||
|
||||
func (ms *messageSet) find(pb Message) *_MessageSet_Item {
|
||||
mti, ok := pb.(messageTypeIder)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
id := mti.MessageTypeId()
|
||||
for _, item := range ms.Item {
|
||||
if *item.TypeId == id {
|
||||
return item
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ms *messageSet) Has(pb Message) bool {
|
||||
if ms.find(pb) != nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ms *messageSet) Unmarshal(pb Message) error {
|
||||
if item := ms.find(pb); item != nil {
|
||||
return Unmarshal(item.Message, pb)
|
||||
}
|
||||
if _, ok := pb.(messageTypeIder); !ok {
|
||||
return errNoMessageTypeID
|
||||
}
|
||||
return nil // TODO: return error instead?
|
||||
}
|
||||
|
||||
func (ms *messageSet) Marshal(pb Message) error {
|
||||
msg, err := Marshal(pb)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if item := ms.find(pb); item != nil {
|
||||
// reuse existing item
|
||||
item.Message = msg
|
||||
return nil
|
||||
}
|
||||
|
||||
mti, ok := pb.(messageTypeIder)
|
||||
if !ok {
|
||||
return errNoMessageTypeID
|
||||
}
|
||||
|
||||
mtid := mti.MessageTypeId()
|
||||
ms.Item = append(ms.Item, &_MessageSet_Item{
|
||||
TypeId: &mtid,
|
||||
Message: msg,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ms *messageSet) Reset() { *ms = messageSet{} }
|
||||
func (ms *messageSet) String() string { return CompactTextString(ms) }
|
||||
func (*messageSet) ProtoMessage() {}
|
||||
|
||||
// Support for the message_set_wire_format message option.
|
||||
|
||||
func skipVarint(buf []byte) []byte {
|
||||
i := 0
|
||||
for ; buf[i]&0x80 != 0; i++ {
|
||||
}
|
||||
return buf[i+1:]
|
||||
}
|
||||
|
||||
// MarshalMessageSet encodes the extension map represented by m in the message set wire format.
|
||||
// It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option.
|
||||
func MarshalMessageSet(exts interface{}) ([]byte, error) {
|
||||
var m map[int32]Extension
|
||||
switch exts := exts.(type) {
|
||||
case *XXX_InternalExtensions:
|
||||
if err := encodeExtensions(exts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m, _ = exts.extensionsRead()
|
||||
case map[int32]Extension:
|
||||
if err := encodeExtensionsMap(exts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m = exts
|
||||
default:
|
||||
return nil, errors.New("proto: not an extension map")
|
||||
}
|
||||
|
||||
// Sort extension IDs to provide a deterministic encoding.
|
||||
// See also enc_map in encode.go.
|
||||
ids := make([]int, 0, len(m))
|
||||
for id := range m {
|
||||
ids = append(ids, int(id))
|
||||
}
|
||||
sort.Ints(ids)
|
||||
|
||||
ms := &messageSet{Item: make([]*_MessageSet_Item, 0, len(m))}
|
||||
for _, id := range ids {
|
||||
e := m[int32(id)]
|
||||
// Remove the wire type and field number varint, as well as the length varint.
|
||||
msg := skipVarint(skipVarint(e.enc))
|
||||
|
||||
ms.Item = append(ms.Item, &_MessageSet_Item{
|
||||
TypeId: Int32(int32(id)),
|
||||
Message: msg,
|
||||
})
|
||||
}
|
||||
return Marshal(ms)
|
||||
}
|
||||
|
||||
// UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format.
|
||||
// It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
|
||||
func UnmarshalMessageSet(buf []byte, exts interface{}) error {
|
||||
var m map[int32]Extension
|
||||
switch exts := exts.(type) {
|
||||
case *XXX_InternalExtensions:
|
||||
m = exts.extensionsWrite()
|
||||
case map[int32]Extension:
|
||||
m = exts
|
||||
default:
|
||||
return errors.New("proto: not an extension map")
|
||||
}
|
||||
|
||||
ms := new(messageSet)
|
||||
if err := Unmarshal(buf, ms); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, item := range ms.Item {
|
||||
id := *item.TypeId
|
||||
msg := item.Message
|
||||
|
||||
// Restore wire type and field number varint, plus length varint.
|
||||
// Be careful to preserve duplicate items.
|
||||
b := EncodeVarint(uint64(id)<<3 | WireBytes)
|
||||
if ext, ok := m[id]; ok {
|
||||
// Existing data; rip off the tag and length varint
|
||||
// so we join the new data correctly.
|
||||
// We can assume that ext.enc is set because we are unmarshaling.
|
||||
o := ext.enc[len(b):] // skip wire type and field number
|
||||
_, n := DecodeVarint(o) // calculate length of length varint
|
||||
o = o[n:] // skip length varint
|
||||
msg = append(o, msg...) // join old data and new data
|
||||
}
|
||||
b = append(b, EncodeVarint(uint64(len(msg)))...)
|
||||
b = append(b, msg...)
|
||||
|
||||
m[id] = Extension{enc: b}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalMessageSetJSON encodes the extension map represented by m in JSON format.
|
||||
// It is called by generated MarshalJSON methods on protocol buffer messages with the message_set_wire_format option.
|
||||
func MarshalMessageSetJSON(exts interface{}) ([]byte, error) {
|
||||
var m map[int32]Extension
|
||||
switch exts := exts.(type) {
|
||||
case *XXX_InternalExtensions:
|
||||
m, _ = exts.extensionsRead()
|
||||
case map[int32]Extension:
|
||||
m = exts
|
||||
default:
|
||||
return nil, errors.New("proto: not an extension map")
|
||||
}
|
||||
var b bytes.Buffer
|
||||
b.WriteByte('{')
|
||||
|
||||
// Process the map in key order for deterministic output.
|
||||
ids := make([]int32, 0, len(m))
|
||||
for id := range m {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
sort.Sort(int32Slice(ids)) // int32Slice defined in text.go
|
||||
|
||||
for i, id := range ids {
|
||||
ext := m[id]
|
||||
if i > 0 {
|
||||
b.WriteByte(',')
|
||||
}
|
||||
|
||||
msd, ok := messageSetMap[id]
|
||||
if !ok {
|
||||
// Unknown type; we can't render it, so skip it.
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(&b, `"[%s]":`, msd.name)
|
||||
|
||||
x := ext.value
|
||||
if x == nil {
|
||||
x = reflect.New(msd.t.Elem()).Interface()
|
||||
if err := Unmarshal(ext.enc, x.(Message)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
d, err := json.Marshal(x)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b.Write(d)
|
||||
}
|
||||
b.WriteByte('}')
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
// UnmarshalMessageSetJSON decodes the extension map encoded in buf in JSON format.
|
||||
// It is called by generated UnmarshalJSON methods on protocol buffer messages with the message_set_wire_format option.
|
||||
func UnmarshalMessageSetJSON(buf []byte, exts interface{}) error {
|
||||
// Common-case fast path.
|
||||
if len(buf) == 0 || bytes.Equal(buf, []byte("{}")) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// This is fairly tricky, and it's not clear that it is needed.
|
||||
return errors.New("TODO: UnmarshalMessageSetJSON not yet implemented")
|
||||
}
|
||||
|
||||
// A global registry of types that can be used in a MessageSet.
|
||||
|
||||
var messageSetMap = make(map[int32]messageSetDesc)
|
||||
|
||||
type messageSetDesc struct {
|
||||
t reflect.Type // pointer to struct
|
||||
name string
|
||||
}
|
||||
|
||||
// RegisterMessageSetType is called from the generated code.
|
||||
func RegisterMessageSetType(m Message, fieldNum int32, name string) {
|
||||
messageSetMap[fieldNum] = messageSetDesc{
|
||||
t: reflect.TypeOf(m),
|
||||
name: name,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,484 @@
|
|||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
// +build appengine js
|
||||
|
||||
// This file contains an implementation of proto field accesses using package reflect.
|
||||
// It is slower than the code in pointer_unsafe.go but it avoids package unsafe and can
|
||||
// be used on App Engine.
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
"math"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// A structPointer is a pointer to a struct.
|
||||
type structPointer struct {
|
||||
v reflect.Value
|
||||
}
|
||||
|
||||
// toStructPointer returns a structPointer equivalent to the given reflect value.
|
||||
// The reflect value must itself be a pointer to a struct.
|
||||
func toStructPointer(v reflect.Value) structPointer {
|
||||
return structPointer{v}
|
||||
}
|
||||
|
||||
// IsNil reports whether p is nil.
|
||||
func structPointer_IsNil(p structPointer) bool {
|
||||
return p.v.IsNil()
|
||||
}
|
||||
|
||||
// Interface returns the struct pointer as an interface value.
|
||||
func structPointer_Interface(p structPointer, _ reflect.Type) interface{} {
|
||||
return p.v.Interface()
|
||||
}
|
||||
|
||||
// A field identifies a field in a struct, accessible from a structPointer.
|
||||
// In this implementation, a field is identified by the sequence of field indices
|
||||
// passed to reflect's FieldByIndex.
|
||||
type field []int
|
||||
|
||||
// toField returns a field equivalent to the given reflect field.
|
||||
func toField(f *reflect.StructField) field {
|
||||
return f.Index
|
||||
}
|
||||
|
||||
// invalidField is an invalid field identifier.
|
||||
var invalidField = field(nil)
|
||||
|
||||
// IsValid reports whether the field identifier is valid.
|
||||
func (f field) IsValid() bool { return f != nil }
|
||||
|
||||
// field returns the given field in the struct as a reflect value.
|
||||
func structPointer_field(p structPointer, f field) reflect.Value {
|
||||
// Special case: an extension map entry with a value of type T
|
||||
// passes a *T to the struct-handling code with a zero field,
|
||||
// expecting that it will be treated as equivalent to *struct{ X T },
|
||||
// which has the same memory layout. We have to handle that case
|
||||
// specially, because reflect will panic if we call FieldByIndex on a
|
||||
// non-struct.
|
||||
if f == nil {
|
||||
return p.v.Elem()
|
||||
}
|
||||
|
||||
return p.v.Elem().FieldByIndex(f)
|
||||
}
|
||||
|
||||
// ifield returns the given field in the struct as an interface value.
|
||||
func structPointer_ifield(p structPointer, f field) interface{} {
|
||||
return structPointer_field(p, f).Addr().Interface()
|
||||
}
|
||||
|
||||
// Bytes returns the address of a []byte field in the struct.
|
||||
func structPointer_Bytes(p structPointer, f field) *[]byte {
|
||||
return structPointer_ifield(p, f).(*[]byte)
|
||||
}
|
||||
|
||||
// BytesSlice returns the address of a [][]byte field in the struct.
|
||||
func structPointer_BytesSlice(p structPointer, f field) *[][]byte {
|
||||
return structPointer_ifield(p, f).(*[][]byte)
|
||||
}
|
||||
|
||||
// Bool returns the address of a *bool field in the struct.
|
||||
func structPointer_Bool(p structPointer, f field) **bool {
|
||||
return structPointer_ifield(p, f).(**bool)
|
||||
}
|
||||
|
||||
// BoolVal returns the address of a bool field in the struct.
|
||||
func structPointer_BoolVal(p structPointer, f field) *bool {
|
||||
return structPointer_ifield(p, f).(*bool)
|
||||
}
|
||||
|
||||
// BoolSlice returns the address of a []bool field in the struct.
|
||||
func structPointer_BoolSlice(p structPointer, f field) *[]bool {
|
||||
return structPointer_ifield(p, f).(*[]bool)
|
||||
}
|
||||
|
||||
// String returns the address of a *string field in the struct.
|
||||
func structPointer_String(p structPointer, f field) **string {
|
||||
return structPointer_ifield(p, f).(**string)
|
||||
}
|
||||
|
||||
// StringVal returns the address of a string field in the struct.
|
||||
func structPointer_StringVal(p structPointer, f field) *string {
|
||||
return structPointer_ifield(p, f).(*string)
|
||||
}
|
||||
|
||||
// StringSlice returns the address of a []string field in the struct.
|
||||
func structPointer_StringSlice(p structPointer, f field) *[]string {
|
||||
return structPointer_ifield(p, f).(*[]string)
|
||||
}
|
||||
|
||||
// Extensions returns the address of an extension map field in the struct.
|
||||
func structPointer_Extensions(p structPointer, f field) *XXX_InternalExtensions {
|
||||
return structPointer_ifield(p, f).(*XXX_InternalExtensions)
|
||||
}
|
||||
|
||||
// ExtMap returns the address of an extension map field in the struct.
|
||||
func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
|
||||
return structPointer_ifield(p, f).(*map[int32]Extension)
|
||||
}
|
||||
|
||||
// NewAt returns the reflect.Value for a pointer to a field in the struct.
|
||||
func structPointer_NewAt(p structPointer, f field, typ reflect.Type) reflect.Value {
|
||||
return structPointer_field(p, f).Addr()
|
||||
}
|
||||
|
||||
// SetStructPointer writes a *struct field in the struct.
|
||||
func structPointer_SetStructPointer(p structPointer, f field, q structPointer) {
|
||||
structPointer_field(p, f).Set(q.v)
|
||||
}
|
||||
|
||||
// GetStructPointer reads a *struct field in the struct.
|
||||
func structPointer_GetStructPointer(p structPointer, f field) structPointer {
|
||||
return structPointer{structPointer_field(p, f)}
|
||||
}
|
||||
|
||||
// StructPointerSlice the address of a []*struct field in the struct.
|
||||
func structPointer_StructPointerSlice(p structPointer, f field) structPointerSlice {
|
||||
return structPointerSlice{structPointer_field(p, f)}
|
||||
}
|
||||
|
||||
// A structPointerSlice represents the address of a slice of pointers to structs
|
||||
// (themselves messages or groups). That is, v.Type() is *[]*struct{...}.
|
||||
type structPointerSlice struct {
|
||||
v reflect.Value
|
||||
}
|
||||
|
||||
func (p structPointerSlice) Len() int { return p.v.Len() }
|
||||
func (p structPointerSlice) Index(i int) structPointer { return structPointer{p.v.Index(i)} }
|
||||
func (p structPointerSlice) Append(q structPointer) {
|
||||
p.v.Set(reflect.Append(p.v, q.v))
|
||||
}
|
||||
|
||||
var (
|
||||
int32Type = reflect.TypeOf(int32(0))
|
||||
uint32Type = reflect.TypeOf(uint32(0))
|
||||
float32Type = reflect.TypeOf(float32(0))
|
||||
int64Type = reflect.TypeOf(int64(0))
|
||||
uint64Type = reflect.TypeOf(uint64(0))
|
||||
float64Type = reflect.TypeOf(float64(0))
|
||||
)
|
||||
|
||||
// A word32 represents a field of type *int32, *uint32, *float32, or *enum.
|
||||
// That is, v.Type() is *int32, *uint32, *float32, or *enum and v is assignable.
|
||||
type word32 struct {
|
||||
v reflect.Value
|
||||
}
|
||||
|
||||
// IsNil reports whether p is nil.
|
||||
func word32_IsNil(p word32) bool {
|
||||
return p.v.IsNil()
|
||||
}
|
||||
|
||||
// Set sets p to point at a newly allocated word with bits set to x.
|
||||
func word32_Set(p word32, o *Buffer, x uint32) {
|
||||
t := p.v.Type().Elem()
|
||||
switch t {
|
||||
case int32Type:
|
||||
if len(o.int32s) == 0 {
|
||||
o.int32s = make([]int32, uint32PoolSize)
|
||||
}
|
||||
o.int32s[0] = int32(x)
|
||||
p.v.Set(reflect.ValueOf(&o.int32s[0]))
|
||||
o.int32s = o.int32s[1:]
|
||||
return
|
||||
case uint32Type:
|
||||
if len(o.uint32s) == 0 {
|
||||
o.uint32s = make([]uint32, uint32PoolSize)
|
||||
}
|
||||
o.uint32s[0] = x
|
||||
p.v.Set(reflect.ValueOf(&o.uint32s[0]))
|
||||
o.uint32s = o.uint32s[1:]
|
||||
return
|
||||
case float32Type:
|
||||
if len(o.float32s) == 0 {
|
||||
o.float32s = make([]float32, uint32PoolSize)
|
||||
}
|
||||
o.float32s[0] = math.Float32frombits(x)
|
||||
p.v.Set(reflect.ValueOf(&o.float32s[0]))
|
||||
o.float32s = o.float32s[1:]
|
||||
return
|
||||
}
|
||||
|
||||
// must be enum
|
||||
p.v.Set(reflect.New(t))
|
||||
p.v.Elem().SetInt(int64(int32(x)))
|
||||
}
|
||||
|
||||
// Get gets the bits pointed at by p, as a uint32.
|
||||
func word32_Get(p word32) uint32 {
|
||||
elem := p.v.Elem()
|
||||
switch elem.Kind() {
|
||||
case reflect.Int32:
|
||||
return uint32(elem.Int())
|
||||
case reflect.Uint32:
|
||||
return uint32(elem.Uint())
|
||||
case reflect.Float32:
|
||||
return math.Float32bits(float32(elem.Float()))
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
// Word32 returns a reference to a *int32, *uint32, *float32, or *enum field in the struct.
|
||||
func structPointer_Word32(p structPointer, f field) word32 {
|
||||
return word32{structPointer_field(p, f)}
|
||||
}
|
||||
|
||||
// A word32Val represents a field of type int32, uint32, float32, or enum.
|
||||
// That is, v.Type() is int32, uint32, float32, or enum and v is assignable.
|
||||
type word32Val struct {
|
||||
v reflect.Value
|
||||
}
|
||||
|
||||
// Set sets *p to x.
|
||||
func word32Val_Set(p word32Val, x uint32) {
|
||||
switch p.v.Type() {
|
||||
case int32Type:
|
||||
p.v.SetInt(int64(x))
|
||||
return
|
||||
case uint32Type:
|
||||
p.v.SetUint(uint64(x))
|
||||
return
|
||||
case float32Type:
|
||||
p.v.SetFloat(float64(math.Float32frombits(x)))
|
||||
return
|
||||
}
|
||||
|
||||
// must be enum
|
||||
p.v.SetInt(int64(int32(x)))
|
||||
}
|
||||
|
||||
// Get gets the bits pointed at by p, as a uint32.
|
||||
func word32Val_Get(p word32Val) uint32 {
|
||||
elem := p.v
|
||||
switch elem.Kind() {
|
||||
case reflect.Int32:
|
||||
return uint32(elem.Int())
|
||||
case reflect.Uint32:
|
||||
return uint32(elem.Uint())
|
||||
case reflect.Float32:
|
||||
return math.Float32bits(float32(elem.Float()))
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
// Word32Val returns a reference to a int32, uint32, float32, or enum field in the struct.
|
||||
func structPointer_Word32Val(p structPointer, f field) word32Val {
|
||||
return word32Val{structPointer_field(p, f)}
|
||||
}
|
||||
|
||||
// A word32Slice is a slice of 32-bit values.
|
||||
// That is, v.Type() is []int32, []uint32, []float32, or []enum.
|
||||
type word32Slice struct {
|
||||
v reflect.Value
|
||||
}
|
||||
|
||||
func (p word32Slice) Append(x uint32) {
|
||||
n, m := p.v.Len(), p.v.Cap()
|
||||
if n < m {
|
||||
p.v.SetLen(n + 1)
|
||||
} else {
|
||||
t := p.v.Type().Elem()
|
||||
p.v.Set(reflect.Append(p.v, reflect.Zero(t)))
|
||||
}
|
||||
elem := p.v.Index(n)
|
||||
switch elem.Kind() {
|
||||
case reflect.Int32:
|
||||
elem.SetInt(int64(int32(x)))
|
||||
case reflect.Uint32:
|
||||
elem.SetUint(uint64(x))
|
||||
case reflect.Float32:
|
||||
elem.SetFloat(float64(math.Float32frombits(x)))
|
||||
}
|
||||
}
|
||||
|
||||
func (p word32Slice) Len() int {
|
||||
return p.v.Len()
|
||||
}
|
||||
|
||||
func (p word32Slice) Index(i int) uint32 {
|
||||
elem := p.v.Index(i)
|
||||
switch elem.Kind() {
|
||||
case reflect.Int32:
|
||||
return uint32(elem.Int())
|
||||
case reflect.Uint32:
|
||||
return uint32(elem.Uint())
|
||||
case reflect.Float32:
|
||||
return math.Float32bits(float32(elem.Float()))
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
// Word32Slice returns a reference to a []int32, []uint32, []float32, or []enum field in the struct.
|
||||
func structPointer_Word32Slice(p structPointer, f field) word32Slice {
|
||||
return word32Slice{structPointer_field(p, f)}
|
||||
}
|
||||
|
||||
// word64 is like word32 but for 64-bit values.
|
||||
type word64 struct {
|
||||
v reflect.Value
|
||||
}
|
||||
|
||||
func word64_Set(p word64, o *Buffer, x uint64) {
|
||||
t := p.v.Type().Elem()
|
||||
switch t {
|
||||
case int64Type:
|
||||
if len(o.int64s) == 0 {
|
||||
o.int64s = make([]int64, uint64PoolSize)
|
||||
}
|
||||
o.int64s[0] = int64(x)
|
||||
p.v.Set(reflect.ValueOf(&o.int64s[0]))
|
||||
o.int64s = o.int64s[1:]
|
||||
return
|
||||
case uint64Type:
|
||||
if len(o.uint64s) == 0 {
|
||||
o.uint64s = make([]uint64, uint64PoolSize)
|
||||
}
|
||||
o.uint64s[0] = x
|
||||
p.v.Set(reflect.ValueOf(&o.uint64s[0]))
|
||||
o.uint64s = o.uint64s[1:]
|
||||
return
|
||||
case float64Type:
|
||||
if len(o.float64s) == 0 {
|
||||
o.float64s = make([]float64, uint64PoolSize)
|
||||
}
|
||||
o.float64s[0] = math.Float64frombits(x)
|
||||
p.v.Set(reflect.ValueOf(&o.float64s[0]))
|
||||
o.float64s = o.float64s[1:]
|
||||
return
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
func word64_IsNil(p word64) bool {
|
||||
return p.v.IsNil()
|
||||
}
|
||||
|
||||
func word64_Get(p word64) uint64 {
|
||||
elem := p.v.Elem()
|
||||
switch elem.Kind() {
|
||||
case reflect.Int64:
|
||||
return uint64(elem.Int())
|
||||
case reflect.Uint64:
|
||||
return elem.Uint()
|
||||
case reflect.Float64:
|
||||
return math.Float64bits(elem.Float())
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
func structPointer_Word64(p structPointer, f field) word64 {
|
||||
return word64{structPointer_field(p, f)}
|
||||
}
|
||||
|
||||
// word64Val is like word32Val but for 64-bit values.
|
||||
type word64Val struct {
|
||||
v reflect.Value
|
||||
}
|
||||
|
||||
func word64Val_Set(p word64Val, o *Buffer, x uint64) {
|
||||
switch p.v.Type() {
|
||||
case int64Type:
|
||||
p.v.SetInt(int64(x))
|
||||
return
|
||||
case uint64Type:
|
||||
p.v.SetUint(x)
|
||||
return
|
||||
case float64Type:
|
||||
p.v.SetFloat(math.Float64frombits(x))
|
||||
return
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
func word64Val_Get(p word64Val) uint64 {
|
||||
elem := p.v
|
||||
switch elem.Kind() {
|
||||
case reflect.Int64:
|
||||
return uint64(elem.Int())
|
||||
case reflect.Uint64:
|
||||
return elem.Uint()
|
||||
case reflect.Float64:
|
||||
return math.Float64bits(elem.Float())
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
func structPointer_Word64Val(p structPointer, f field) word64Val {
|
||||
return word64Val{structPointer_field(p, f)}
|
||||
}
|
||||
|
||||
type word64Slice struct {
|
||||
v reflect.Value
|
||||
}
|
||||
|
||||
func (p word64Slice) Append(x uint64) {
|
||||
n, m := p.v.Len(), p.v.Cap()
|
||||
if n < m {
|
||||
p.v.SetLen(n + 1)
|
||||
} else {
|
||||
t := p.v.Type().Elem()
|
||||
p.v.Set(reflect.Append(p.v, reflect.Zero(t)))
|
||||
}
|
||||
elem := p.v.Index(n)
|
||||
switch elem.Kind() {
|
||||
case reflect.Int64:
|
||||
elem.SetInt(int64(int64(x)))
|
||||
case reflect.Uint64:
|
||||
elem.SetUint(uint64(x))
|
||||
case reflect.Float64:
|
||||
elem.SetFloat(float64(math.Float64frombits(x)))
|
||||
}
|
||||
}
|
||||
|
||||
func (p word64Slice) Len() int {
|
||||
return p.v.Len()
|
||||
}
|
||||
|
||||
func (p word64Slice) Index(i int) uint64 {
|
||||
elem := p.v.Index(i)
|
||||
switch elem.Kind() {
|
||||
case reflect.Int64:
|
||||
return uint64(elem.Int())
|
||||
case reflect.Uint64:
|
||||
return uint64(elem.Uint())
|
||||
case reflect.Float64:
|
||||
return math.Float64bits(float64(elem.Float()))
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
func structPointer_Word64Slice(p structPointer, f field) word64Slice {
|
||||
return word64Slice{structPointer_field(p, f)}
|
||||
}
|
|
@ -0,0 +1,270 @@
|
|||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
// +build !appengine,!js
|
||||
|
||||
// This file contains the implementation of the proto field accesses using package unsafe.
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// NOTE: These type_Foo functions would more idiomatically be methods,
|
||||
// but Go does not allow methods on pointer types, and we must preserve
|
||||
// some pointer type for the garbage collector. We use these
|
||||
// funcs with clunky names as our poor approximation to methods.
|
||||
//
|
||||
// An alternative would be
|
||||
// type structPointer struct { p unsafe.Pointer }
|
||||
// but that does not registerize as well.
|
||||
|
||||
// A structPointer is a pointer to a struct.
|
||||
type structPointer unsafe.Pointer
|
||||
|
||||
// toStructPointer returns a structPointer equivalent to the given reflect value.
|
||||
func toStructPointer(v reflect.Value) structPointer {
|
||||
return structPointer(unsafe.Pointer(v.Pointer()))
|
||||
}
|
||||
|
||||
// IsNil reports whether p is nil.
|
||||
func structPointer_IsNil(p structPointer) bool {
|
||||
return p == nil
|
||||
}
|
||||
|
||||
// Interface returns the struct pointer, assumed to have element type t,
|
||||
// as an interface value.
|
||||
func structPointer_Interface(p structPointer, t reflect.Type) interface{} {
|
||||
return reflect.NewAt(t, unsafe.Pointer(p)).Interface()
|
||||
}
|
||||
|
||||
// A field identifies a field in a struct, accessible from a structPointer.
|
||||
// In this implementation, a field is identified by its byte offset from the start of the struct.
|
||||
type field uintptr
|
||||
|
||||
// toField returns a field equivalent to the given reflect field.
|
||||
func toField(f *reflect.StructField) field {
|
||||
return field(f.Offset)
|
||||
}
|
||||
|
||||
// invalidField is an invalid field identifier.
|
||||
const invalidField = ^field(0)
|
||||
|
||||
// IsValid reports whether the field identifier is valid.
|
||||
func (f field) IsValid() bool {
|
||||
return f != ^field(0)
|
||||
}
|
||||
|
||||
// Bytes returns the address of a []byte field in the struct.
|
||||
func structPointer_Bytes(p structPointer, f field) *[]byte {
|
||||
return (*[]byte)(unsafe.Pointer(uintptr(p) + uintptr(f)))
|
||||
}
|
||||
|
||||
// BytesSlice returns the address of a [][]byte field in the struct.
|
||||
func structPointer_BytesSlice(p structPointer, f field) *[][]byte {
|
||||
return (*[][]byte)(unsafe.Pointer(uintptr(p) + uintptr(f)))
|
||||
}
|
||||
|
||||
// Bool returns the address of a *bool field in the struct.
|
||||
func structPointer_Bool(p structPointer, f field) **bool {
|
||||
return (**bool)(unsafe.Pointer(uintptr(p) + uintptr(f)))
|
||||
}
|
||||
|
||||
// BoolVal returns the address of a bool field in the struct.
|
||||
func structPointer_BoolVal(p structPointer, f field) *bool {
|
||||
return (*bool)(unsafe.Pointer(uintptr(p) + uintptr(f)))
|
||||
}
|
||||
|
||||
// BoolSlice returns the address of a []bool field in the struct.
|
||||
func structPointer_BoolSlice(p structPointer, f field) *[]bool {
|
||||
return (*[]bool)(unsafe.Pointer(uintptr(p) + uintptr(f)))
|
||||
}
|
||||
|
||||
// String returns the address of a *string field in the struct.
|
||||
func structPointer_String(p structPointer, f field) **string {
|
||||
return (**string)(unsafe.Pointer(uintptr(p) + uintptr(f)))
|
||||
}
|
||||
|
||||
// StringVal returns the address of a string field in the struct.
|
||||
func structPointer_StringVal(p structPointer, f field) *string {
|
||||
return (*string)(unsafe.Pointer(uintptr(p) + uintptr(f)))
|
||||
}
|
||||
|
||||
// StringSlice returns the address of a []string field in the struct.
|
||||
func structPointer_StringSlice(p structPointer, f field) *[]string {
|
||||
return (*[]string)(unsafe.Pointer(uintptr(p) + uintptr(f)))
|
||||
}
|
||||
|
||||
// ExtMap returns the address of an extension map field in the struct.
|
||||
func structPointer_Extensions(p structPointer, f field) *XXX_InternalExtensions {
|
||||
return (*XXX_InternalExtensions)(unsafe.Pointer(uintptr(p) + uintptr(f)))
|
||||
}
|
||||
|
||||
func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
|
||||
return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f)))
|
||||
}
|
||||
|
||||
// NewAt returns the reflect.Value for a pointer to a field in the struct.
|
||||
func structPointer_NewAt(p structPointer, f field, typ reflect.Type) reflect.Value {
|
||||
return reflect.NewAt(typ, unsafe.Pointer(uintptr(p)+uintptr(f)))
|
||||
}
|
||||
|
||||
// SetStructPointer writes a *struct field in the struct.
|
||||
func structPointer_SetStructPointer(p structPointer, f field, q structPointer) {
|
||||
*(*structPointer)(unsafe.Pointer(uintptr(p) + uintptr(f))) = q
|
||||
}
|
||||
|
||||
// GetStructPointer reads a *struct field in the struct.
|
||||
func structPointer_GetStructPointer(p structPointer, f field) structPointer {
|
||||
return *(*structPointer)(unsafe.Pointer(uintptr(p) + uintptr(f)))
|
||||
}
|
||||
|
||||
// StructPointerSlice the address of a []*struct field in the struct.
|
||||
func structPointer_StructPointerSlice(p structPointer, f field) *structPointerSlice {
|
||||
return (*structPointerSlice)(unsafe.Pointer(uintptr(p) + uintptr(f)))
|
||||
}
|
||||
|
||||
// A structPointerSlice represents a slice of pointers to structs (themselves submessages or groups).
|
||||
type structPointerSlice []structPointer
|
||||
|
||||
func (v *structPointerSlice) Len() int { return len(*v) }
|
||||
func (v *structPointerSlice) Index(i int) structPointer { return (*v)[i] }
|
||||
func (v *structPointerSlice) Append(p structPointer) { *v = append(*v, p) }
|
||||
|
||||
// A word32 is the address of a "pointer to 32-bit value" field.
|
||||
type word32 **uint32
|
||||
|
||||
// IsNil reports whether *v is nil.
|
||||
func word32_IsNil(p word32) bool {
|
||||
return *p == nil
|
||||
}
|
||||
|
||||
// Set sets *v to point at a newly allocated word set to x.
|
||||
func word32_Set(p word32, o *Buffer, x uint32) {
|
||||
if len(o.uint32s) == 0 {
|
||||
o.uint32s = make([]uint32, uint32PoolSize)
|
||||
}
|
||||
o.uint32s[0] = x
|
||||
*p = &o.uint32s[0]
|
||||
o.uint32s = o.uint32s[1:]
|
||||
}
|
||||
|
||||
// Get gets the value pointed at by *v.
|
||||
func word32_Get(p word32) uint32 {
|
||||
return **p
|
||||
}
|
||||
|
||||
// Word32 returns the address of a *int32, *uint32, *float32, or *enum field in the struct.
|
||||
func structPointer_Word32(p structPointer, f field) word32 {
|
||||
return word32((**uint32)(unsafe.Pointer(uintptr(p) + uintptr(f))))
|
||||
}
|
||||
|
||||
// A word32Val is the address of a 32-bit value field.
|
||||
type word32Val *uint32
|
||||
|
||||
// Set sets *p to x.
|
||||
func word32Val_Set(p word32Val, x uint32) {
|
||||
*p = x
|
||||
}
|
||||
|
||||
// Get gets the value pointed at by p.
|
||||
func word32Val_Get(p word32Val) uint32 {
|
||||
return *p
|
||||
}
|
||||
|
||||
// Word32Val returns the address of a *int32, *uint32, *float32, or *enum field in the struct.
|
||||
func structPointer_Word32Val(p structPointer, f field) word32Val {
|
||||
return word32Val((*uint32)(unsafe.Pointer(uintptr(p) + uintptr(f))))
|
||||
}
|
||||
|
||||
// A word32Slice is a slice of 32-bit values.
|
||||
type word32Slice []uint32
|
||||
|
||||
func (v *word32Slice) Append(x uint32) { *v = append(*v, x) }
|
||||
func (v *word32Slice) Len() int { return len(*v) }
|
||||
func (v *word32Slice) Index(i int) uint32 { return (*v)[i] }
|
||||
|
||||
// Word32Slice returns the address of a []int32, []uint32, []float32, or []enum field in the struct.
|
||||
func structPointer_Word32Slice(p structPointer, f field) *word32Slice {
|
||||
return (*word32Slice)(unsafe.Pointer(uintptr(p) + uintptr(f)))
|
||||
}
|
||||
|
||||
// word64 is like word32 but for 64-bit values.
|
||||
type word64 **uint64
|
||||
|
||||
func word64_Set(p word64, o *Buffer, x uint64) {
|
||||
if len(o.uint64s) == 0 {
|
||||
o.uint64s = make([]uint64, uint64PoolSize)
|
||||
}
|
||||
o.uint64s[0] = x
|
||||
*p = &o.uint64s[0]
|
||||
o.uint64s = o.uint64s[1:]
|
||||
}
|
||||
|
||||
func word64_IsNil(p word64) bool {
|
||||
return *p == nil
|
||||
}
|
||||
|
||||
func word64_Get(p word64) uint64 {
|
||||
return **p
|
||||
}
|
||||
|
||||
func structPointer_Word64(p structPointer, f field) word64 {
|
||||
return word64((**uint64)(unsafe.Pointer(uintptr(p) + uintptr(f))))
|
||||
}
|
||||
|
||||
// word64Val is like word32Val but for 64-bit values.
|
||||
type word64Val *uint64
|
||||
|
||||
func word64Val_Set(p word64Val, o *Buffer, x uint64) {
|
||||
*p = x
|
||||
}
|
||||
|
||||
func word64Val_Get(p word64Val) uint64 {
|
||||
return *p
|
||||
}
|
||||
|
||||
func structPointer_Word64Val(p structPointer, f field) word64Val {
|
||||
return word64Val((*uint64)(unsafe.Pointer(uintptr(p) + uintptr(f))))
|
||||
}
|
||||
|
||||
// word64Slice is like word32Slice but for 64-bit values.
|
||||
type word64Slice []uint64
|
||||
|
||||
func (v *word64Slice) Append(x uint64) { *v = append(*v, x) }
|
||||
func (v *word64Slice) Len() int { return len(*v) }
|
||||
func (v *word64Slice) Index(i int) uint64 { return (*v)[i] }
|
||||
|
||||
func structPointer_Word64Slice(p structPointer, f field) *word64Slice {
|
||||
return (*word64Slice)(unsafe.Pointer(uintptr(p) + uintptr(f)))
|
||||
}
|
|
@ -0,0 +1,872 @@
|
|||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package proto
|
||||
|
||||
/*
|
||||
* Routines for encoding data into the wire format for protocol buffers.
|
||||
*/
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const debug bool = false
|
||||
|
||||
// Constants that identify the encoding of a value on the wire.
|
||||
const (
|
||||
WireVarint = 0
|
||||
WireFixed64 = 1
|
||||
WireBytes = 2
|
||||
WireStartGroup = 3
|
||||
WireEndGroup = 4
|
||||
WireFixed32 = 5
|
||||
)
|
||||
|
||||
const startSize = 10 // initial slice/string sizes
|
||||
|
||||
// Encoders are defined in encode.go
|
||||
// An encoder outputs the full representation of a field, including its
|
||||
// tag and encoder type.
|
||||
type encoder func(p *Buffer, prop *Properties, base structPointer) error
|
||||
|
||||
// A valueEncoder encodes a single integer in a particular encoding.
|
||||
type valueEncoder func(o *Buffer, x uint64) error
|
||||
|
||||
// Sizers are defined in encode.go
|
||||
// A sizer returns the encoded size of a field, including its tag and encoder
|
||||
// type.
|
||||
type sizer func(prop *Properties, base structPointer) int
|
||||
|
||||
// A valueSizer returns the encoded size of a single integer in a particular
|
||||
// encoding.
|
||||
type valueSizer func(x uint64) int
|
||||
|
||||
// Decoders are defined in decode.go
|
||||
// A decoder creates a value from its wire representation.
|
||||
// Unrecognized subelements are saved in unrec.
|
||||
type decoder func(p *Buffer, prop *Properties, base structPointer) error
|
||||
|
||||
// A valueDecoder decodes a single integer in a particular encoding.
|
||||
type valueDecoder func(o *Buffer) (x uint64, err error)
|
||||
|
||||
// A oneofMarshaler does the marshaling for all oneof fields in a message.
|
||||
type oneofMarshaler func(Message, *Buffer) error
|
||||
|
||||
// A oneofUnmarshaler does the unmarshaling for a oneof field in a message.
|
||||
type oneofUnmarshaler func(Message, int, int, *Buffer) (bool, error)
|
||||
|
||||
// A oneofSizer does the sizing for all oneof fields in a message.
|
||||
type oneofSizer func(Message) int
|
||||
|
||||
// tagMap is an optimization over map[int]int for typical protocol buffer
|
||||
// use-cases. Encoded protocol buffers are often in tag order with small tag
|
||||
// numbers.
|
||||
type tagMap struct {
|
||||
fastTags []int
|
||||
slowTags map[int]int
|
||||
}
|
||||
|
||||
// tagMapFastLimit is the upper bound on the tag number that will be stored in
|
||||
// the tagMap slice rather than its map.
|
||||
const tagMapFastLimit = 1024
|
||||
|
||||
func (p *tagMap) get(t int) (int, bool) {
|
||||
if t > 0 && t < tagMapFastLimit {
|
||||
if t >= len(p.fastTags) {
|
||||
return 0, false
|
||||
}
|
||||
fi := p.fastTags[t]
|
||||
return fi, fi >= 0
|
||||
}
|
||||
fi, ok := p.slowTags[t]
|
||||
return fi, ok
|
||||
}
|
||||
|
||||
func (p *tagMap) put(t int, fi int) {
|
||||
if t > 0 && t < tagMapFastLimit {
|
||||
for len(p.fastTags) < t+1 {
|
||||
p.fastTags = append(p.fastTags, -1)
|
||||
}
|
||||
p.fastTags[t] = fi
|
||||
return
|
||||
}
|
||||
if p.slowTags == nil {
|
||||
p.slowTags = make(map[int]int)
|
||||
}
|
||||
p.slowTags[t] = fi
|
||||
}
|
||||
|
||||
// StructProperties represents properties for all the fields of a struct.
|
||||
// decoderTags and decoderOrigNames should only be used by the decoder.
|
||||
type StructProperties struct {
|
||||
Prop []*Properties // properties for each field
|
||||
reqCount int // required count
|
||||
decoderTags tagMap // map from proto tag to struct field number
|
||||
decoderOrigNames map[string]int // map from original name to struct field number
|
||||
order []int // list of struct field numbers in tag order
|
||||
unrecField field // field id of the XXX_unrecognized []byte field
|
||||
extendable bool // is this an extendable proto
|
||||
|
||||
oneofMarshaler oneofMarshaler
|
||||
oneofUnmarshaler oneofUnmarshaler
|
||||
oneofSizer oneofSizer
|
||||
stype reflect.Type
|
||||
|
||||
// OneofTypes contains information about the oneof fields in this message.
|
||||
// It is keyed by the original name of a field.
|
||||
OneofTypes map[string]*OneofProperties
|
||||
}
|
||||
|
||||
// OneofProperties represents information about a specific field in a oneof.
|
||||
type OneofProperties struct {
|
||||
Type reflect.Type // pointer to generated struct type for this oneof field
|
||||
Field int // struct field number of the containing oneof in the message
|
||||
Prop *Properties
|
||||
}
|
||||
|
||||
// Implement the sorting interface so we can sort the fields in tag order, as recommended by the spec.
|
||||
// See encode.go, (*Buffer).enc_struct.
|
||||
|
||||
func (sp *StructProperties) Len() int { return len(sp.order) }
|
||||
func (sp *StructProperties) Less(i, j int) bool {
|
||||
return sp.Prop[sp.order[i]].Tag < sp.Prop[sp.order[j]].Tag
|
||||
}
|
||||
func (sp *StructProperties) Swap(i, j int) { sp.order[i], sp.order[j] = sp.order[j], sp.order[i] }
|
||||
|
||||
// Properties represents the protocol-specific behavior of a single struct field.
|
||||
type Properties struct {
|
||||
Name string // name of the field, for error messages
|
||||
OrigName string // original name before protocol compiler (always set)
|
||||
JSONName string // name to use for JSON; determined by protoc
|
||||
Wire string
|
||||
WireType int
|
||||
Tag int
|
||||
Required bool
|
||||
Optional bool
|
||||
Repeated bool
|
||||
Packed bool // relevant for repeated primitives only
|
||||
Enum string // set for enum types only
|
||||
proto3 bool // whether this is known to be a proto3 field; set for []byte only
|
||||
oneof bool // whether this is a oneof field
|
||||
|
||||
Default string // default value
|
||||
HasDefault bool // whether an explicit default was provided
|
||||
def_uint64 uint64
|
||||
|
||||
enc encoder
|
||||
valEnc valueEncoder // set for bool and numeric types only
|
||||
field field
|
||||
tagcode []byte // encoding of EncodeVarint((Tag<<3)|WireType)
|
||||
tagbuf [8]byte
|
||||
stype reflect.Type // set for struct types only
|
||||
sprop *StructProperties // set for struct types only
|
||||
isMarshaler bool
|
||||
isUnmarshaler bool
|
||||
|
||||
mtype reflect.Type // set for map types only
|
||||
mkeyprop *Properties // set for map types only
|
||||
mvalprop *Properties // set for map types only
|
||||
|
||||
size sizer
|
||||
valSize valueSizer // set for bool and numeric types only
|
||||
|
||||
dec decoder
|
||||
valDec valueDecoder // set for bool and numeric types only
|
||||
|
||||
// If this is a packable field, this will be the decoder for the packed version of the field.
|
||||
packedDec decoder
|
||||
}
|
||||
|
||||
// String formats the properties in the protobuf struct field tag style.
|
||||
func (p *Properties) String() string {
|
||||
s := p.Wire
|
||||
s = ","
|
||||
s += strconv.Itoa(p.Tag)
|
||||
if p.Required {
|
||||
s += ",req"
|
||||
}
|
||||
if p.Optional {
|
||||
s += ",opt"
|
||||
}
|
||||
if p.Repeated {
|
||||
s += ",rep"
|
||||
}
|
||||
if p.Packed {
|
||||
s += ",packed"
|
||||
}
|
||||
s += ",name=" + p.OrigName
|
||||
if p.JSONName != p.OrigName {
|
||||
s += ",json=" + p.JSONName
|
||||
}
|
||||
if p.proto3 {
|
||||
s += ",proto3"
|
||||
}
|
||||
if p.oneof {
|
||||
s += ",oneof"
|
||||
}
|
||||
if len(p.Enum) > 0 {
|
||||
s += ",enum=" + p.Enum
|
||||
}
|
||||
if p.HasDefault {
|
||||
s += ",def=" + p.Default
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Parse populates p by parsing a string in the protobuf struct field tag style.
|
||||
func (p *Properties) Parse(s string) {
|
||||
// "bytes,49,opt,name=foo,def=hello!"
|
||||
fields := strings.Split(s, ",") // breaks def=, but handled below.
|
||||
if len(fields) < 2 {
|
||||
fmt.Fprintf(os.Stderr, "proto: tag has too few fields: %q\n", s)
|
||||
return
|
||||
}
|
||||
|
||||
p.Wire = fields[0]
|
||||
switch p.Wire {
|
||||
case "varint":
|
||||
p.WireType = WireVarint
|
||||
p.valEnc = (*Buffer).EncodeVarint
|
||||
p.valDec = (*Buffer).DecodeVarint
|
||||
p.valSize = sizeVarint
|
||||
case "fixed32":
|
||||
p.WireType = WireFixed32
|
||||
p.valEnc = (*Buffer).EncodeFixed32
|
||||
p.valDec = (*Buffer).DecodeFixed32
|
||||
p.valSize = sizeFixed32
|
||||
case "fixed64":
|
||||
p.WireType = WireFixed64
|
||||
p.valEnc = (*Buffer).EncodeFixed64
|
||||
p.valDec = (*Buffer).DecodeFixed64
|
||||
p.valSize = sizeFixed64
|
||||
case "zigzag32":
|
||||
p.WireType = WireVarint
|
||||
p.valEnc = (*Buffer).EncodeZigzag32
|
||||
p.valDec = (*Buffer).DecodeZigzag32
|
||||
p.valSize = sizeZigzag32
|
||||
case "zigzag64":
|
||||
p.WireType = WireVarint
|
||||
p.valEnc = (*Buffer).EncodeZigzag64
|
||||
p.valDec = (*Buffer).DecodeZigzag64
|
||||
p.valSize = sizeZigzag64
|
||||
case "bytes", "group":
|
||||
p.WireType = WireBytes
|
||||
// no numeric converter for non-numeric types
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "proto: tag has unknown wire type: %q\n", s)
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
p.Tag, err = strconv.Atoi(fields[1])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for i := 2; i < len(fields); i++ {
|
||||
f := fields[i]
|
||||
switch {
|
||||
case f == "req":
|
||||
p.Required = true
|
||||
case f == "opt":
|
||||
p.Optional = true
|
||||
case f == "rep":
|
||||
p.Repeated = true
|
||||
case f == "packed":
|
||||
p.Packed = true
|
||||
case strings.HasPrefix(f, "name="):
|
||||
p.OrigName = f[5:]
|
||||
case strings.HasPrefix(f, "json="):
|
||||
p.JSONName = f[5:]
|
||||
case strings.HasPrefix(f, "enum="):
|
||||
p.Enum = f[5:]
|
||||
case f == "proto3":
|
||||
p.proto3 = true
|
||||
case f == "oneof":
|
||||
p.oneof = true
|
||||
case strings.HasPrefix(f, "def="):
|
||||
p.HasDefault = true
|
||||
p.Default = f[4:] // rest of string
|
||||
if i+1 < len(fields) {
|
||||
// Commas aren't escaped, and def is always last.
|
||||
p.Default += "," + strings.Join(fields[i+1:], ",")
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func logNoSliceEnc(t1, t2 reflect.Type) {
|
||||
fmt.Fprintf(os.Stderr, "proto: no slice oenc for %T = []%T\n", t1, t2)
|
||||
}
|
||||
|
||||
var protoMessageType = reflect.TypeOf((*Message)(nil)).Elem()
|
||||
|
||||
// Initialize the fields for encoding and decoding.
|
||||
func (p *Properties) setEncAndDec(typ reflect.Type, f *reflect.StructField, lockGetProp bool) {
|
||||
p.enc = nil
|
||||
p.dec = nil
|
||||
p.size = nil
|
||||
|
||||
switch t1 := typ; t1.Kind() {
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "proto: no coders for %v\n", t1)
|
||||
|
||||
// proto3 scalar types
|
||||
|
||||
case reflect.Bool:
|
||||
p.enc = (*Buffer).enc_proto3_bool
|
||||
p.dec = (*Buffer).dec_proto3_bool
|
||||
p.size = size_proto3_bool
|
||||
case reflect.Int32:
|
||||
p.enc = (*Buffer).enc_proto3_int32
|
||||
p.dec = (*Buffer).dec_proto3_int32
|
||||
p.size = size_proto3_int32
|
||||
case reflect.Uint32:
|
||||
p.enc = (*Buffer).enc_proto3_uint32
|
||||
p.dec = (*Buffer).dec_proto3_int32 // can reuse
|
||||
p.size = size_proto3_uint32
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
p.enc = (*Buffer).enc_proto3_int64
|
||||
p.dec = (*Buffer).dec_proto3_int64
|
||||
p.size = size_proto3_int64
|
||||
case reflect.Float32:
|
||||
p.enc = (*Buffer).enc_proto3_uint32 // can just treat them as bits
|
||||
p.dec = (*Buffer).dec_proto3_int32
|
||||
p.size = size_proto3_uint32
|
||||
case reflect.Float64:
|
||||
p.enc = (*Buffer).enc_proto3_int64 // can just treat them as bits
|
||||
p.dec = (*Buffer).dec_proto3_int64
|
||||
p.size = size_proto3_int64
|
||||
case reflect.String:
|
||||
p.enc = (*Buffer).enc_proto3_string
|
||||
p.dec = (*Buffer).dec_proto3_string
|
||||
p.size = size_proto3_string
|
||||
|
||||
case reflect.Ptr:
|
||||
switch t2 := t1.Elem(); t2.Kind() {
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "proto: no encoder function for %v -> %v\n", t1, t2)
|
||||
break
|
||||
case reflect.Bool:
|
||||
p.enc = (*Buffer).enc_bool
|
||||
p.dec = (*Buffer).dec_bool
|
||||
p.size = size_bool
|
||||
case reflect.Int32:
|
||||
p.enc = (*Buffer).enc_int32
|
||||
p.dec = (*Buffer).dec_int32
|
||||
p.size = size_int32
|
||||
case reflect.Uint32:
|
||||
p.enc = (*Buffer).enc_uint32
|
||||
p.dec = (*Buffer).dec_int32 // can reuse
|
||||
p.size = size_uint32
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
p.enc = (*Buffer).enc_int64
|
||||
p.dec = (*Buffer).dec_int64
|
||||
p.size = size_int64
|
||||
case reflect.Float32:
|
||||
p.enc = (*Buffer).enc_uint32 // can just treat them as bits
|
||||
p.dec = (*Buffer).dec_int32
|
||||
p.size = size_uint32
|
||||
case reflect.Float64:
|
||||
p.enc = (*Buffer).enc_int64 // can just treat them as bits
|
||||
p.dec = (*Buffer).dec_int64
|
||||
p.size = size_int64
|
||||
case reflect.String:
|
||||
p.enc = (*Buffer).enc_string
|
||||
p.dec = (*Buffer).dec_string
|
||||
p.size = size_string
|
||||
case reflect.Struct:
|
||||
p.stype = t1.Elem()
|
||||
p.isMarshaler = isMarshaler(t1)
|
||||
p.isUnmarshaler = isUnmarshaler(t1)
|
||||
if p.Wire == "bytes" {
|
||||
p.enc = (*Buffer).enc_struct_message
|
||||
p.dec = (*Buffer).dec_struct_message
|
||||
p.size = size_struct_message
|
||||
} else {
|
||||
p.enc = (*Buffer).enc_struct_group
|
||||
p.dec = (*Buffer).dec_struct_group
|
||||
p.size = size_struct_group
|
||||
}
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
switch t2 := t1.Elem(); t2.Kind() {
|
||||
default:
|
||||
logNoSliceEnc(t1, t2)
|
||||
break
|
||||
case reflect.Bool:
|
||||
if p.Packed {
|
||||
p.enc = (*Buffer).enc_slice_packed_bool
|
||||
p.size = size_slice_packed_bool
|
||||
} else {
|
||||
p.enc = (*Buffer).enc_slice_bool
|
||||
p.size = size_slice_bool
|
||||
}
|
||||
p.dec = (*Buffer).dec_slice_bool
|
||||
p.packedDec = (*Buffer).dec_slice_packed_bool
|
||||
case reflect.Int32:
|
||||
if p.Packed {
|
||||
p.enc = (*Buffer).enc_slice_packed_int32
|
||||
p.size = size_slice_packed_int32
|
||||
} else {
|
||||
p.enc = (*Buffer).enc_slice_int32
|
||||
p.size = size_slice_int32
|
||||
}
|
||||
p.dec = (*Buffer).dec_slice_int32
|
||||
p.packedDec = (*Buffer).dec_slice_packed_int32
|
||||
case reflect.Uint32:
|
||||
if p.Packed {
|
||||
p.enc = (*Buffer).enc_slice_packed_uint32
|
||||
p.size = size_slice_packed_uint32
|
||||
} else {
|
||||
p.enc = (*Buffer).enc_slice_uint32
|
||||
p.size = size_slice_uint32
|
||||
}
|
||||
p.dec = (*Buffer).dec_slice_int32
|
||||
p.packedDec = (*Buffer).dec_slice_packed_int32
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if p.Packed {
|
||||
p.enc = (*Buffer).enc_slice_packed_int64
|
||||
p.size = size_slice_packed_int64
|
||||
} else {
|
||||
p.enc = (*Buffer).enc_slice_int64
|
||||
p.size = size_slice_int64
|
||||
}
|
||||
p.dec = (*Buffer).dec_slice_int64
|
||||
p.packedDec = (*Buffer).dec_slice_packed_int64
|
||||
case reflect.Uint8:
|
||||
p.dec = (*Buffer).dec_slice_byte
|
||||
if p.proto3 {
|
||||
p.enc = (*Buffer).enc_proto3_slice_byte
|
||||
p.size = size_proto3_slice_byte
|
||||
} else {
|
||||
p.enc = (*Buffer).enc_slice_byte
|
||||
p.size = size_slice_byte
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
switch t2.Bits() {
|
||||
case 32:
|
||||
// can just treat them as bits
|
||||
if p.Packed {
|
||||
p.enc = (*Buffer).enc_slice_packed_uint32
|
||||
p.size = size_slice_packed_uint32
|
||||
} else {
|
||||
p.enc = (*Buffer).enc_slice_uint32
|
||||
p.size = size_slice_uint32
|
||||
}
|
||||
p.dec = (*Buffer).dec_slice_int32
|
||||
p.packedDec = (*Buffer).dec_slice_packed_int32
|
||||
case 64:
|
||||
// can just treat them as bits
|
||||
if p.Packed {
|
||||
p.enc = (*Buffer).enc_slice_packed_int64
|
||||
p.size = size_slice_packed_int64
|
||||
} else {
|
||||
p.enc = (*Buffer).enc_slice_int64
|
||||
p.size = size_slice_int64
|
||||
}
|
||||
p.dec = (*Buffer).dec_slice_int64
|
||||
p.packedDec = (*Buffer).dec_slice_packed_int64
|
||||
default:
|
||||
logNoSliceEnc(t1, t2)
|
||||
break
|
||||
}
|
||||
case reflect.String:
|
||||
p.enc = (*Buffer).enc_slice_string
|
||||
p.dec = (*Buffer).dec_slice_string
|
||||
p.size = size_slice_string
|
||||
case reflect.Ptr:
|
||||
switch t3 := t2.Elem(); t3.Kind() {
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "proto: no ptr oenc for %T -> %T -> %T\n", t1, t2, t3)
|
||||
break
|
||||
case reflect.Struct:
|
||||
p.stype = t2.Elem()
|
||||
p.isMarshaler = isMarshaler(t2)
|
||||
p.isUnmarshaler = isUnmarshaler(t2)
|
||||
if p.Wire == "bytes" {
|
||||
p.enc = (*Buffer).enc_slice_struct_message
|
||||
p.dec = (*Buffer).dec_slice_struct_message
|
||||
p.size = size_slice_struct_message
|
||||
} else {
|
||||
p.enc = (*Buffer).enc_slice_struct_group
|
||||
p.dec = (*Buffer).dec_slice_struct_group
|
||||
p.size = size_slice_struct_group
|
||||
}
|
||||
}
|
||||
case reflect.Slice:
|
||||
switch t2.Elem().Kind() {
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "proto: no slice elem oenc for %T -> %T -> %T\n", t1, t2, t2.Elem())
|
||||
break
|
||||
case reflect.Uint8:
|
||||
p.enc = (*Buffer).enc_slice_slice_byte
|
||||
p.dec = (*Buffer).dec_slice_slice_byte
|
||||
p.size = size_slice_slice_byte
|
||||
}
|
||||
}
|
||||
|
||||
case reflect.Map:
|
||||
p.enc = (*Buffer).enc_new_map
|
||||
p.dec = (*Buffer).dec_new_map
|
||||
p.size = size_new_map
|
||||
|
||||
p.mtype = t1
|
||||
p.mkeyprop = &Properties{}
|
||||
p.mkeyprop.init(reflect.PtrTo(p.mtype.Key()), "Key", f.Tag.Get("protobuf_key"), nil, lockGetProp)
|
||||
p.mvalprop = &Properties{}
|
||||
vtype := p.mtype.Elem()
|
||||
if vtype.Kind() != reflect.Ptr && vtype.Kind() != reflect.Slice {
|
||||
// The value type is not a message (*T) or bytes ([]byte),
|
||||
// so we need encoders for the pointer to this type.
|
||||
vtype = reflect.PtrTo(vtype)
|
||||
}
|
||||
p.mvalprop.init(vtype, "Value", f.Tag.Get("protobuf_val"), nil, lockGetProp)
|
||||
}
|
||||
|
||||
// precalculate tag code
|
||||
wire := p.WireType
|
||||
if p.Packed {
|
||||
wire = WireBytes
|
||||
}
|
||||
x := uint32(p.Tag)<<3 | uint32(wire)
|
||||
i := 0
|
||||
for i = 0; x > 127; i++ {
|
||||
p.tagbuf[i] = 0x80 | uint8(x&0x7F)
|
||||
x >>= 7
|
||||
}
|
||||
p.tagbuf[i] = uint8(x)
|
||||
p.tagcode = p.tagbuf[0 : i+1]
|
||||
|
||||
if p.stype != nil {
|
||||
if lockGetProp {
|
||||
p.sprop = GetProperties(p.stype)
|
||||
} else {
|
||||
p.sprop = getPropertiesLocked(p.stype)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
|
||||
unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
|
||||
)
|
||||
|
||||
// isMarshaler reports whether type t implements Marshaler.
|
||||
func isMarshaler(t reflect.Type) bool {
|
||||
// We're checking for (likely) pointer-receiver methods
|
||||
// so if t is not a pointer, something is very wrong.
|
||||
// The calls above only invoke isMarshaler on pointer types.
|
||||
if t.Kind() != reflect.Ptr {
|
||||
panic("proto: misuse of isMarshaler")
|
||||
}
|
||||
return t.Implements(marshalerType)
|
||||
}
|
||||
|
||||
// isUnmarshaler reports whether type t implements Unmarshaler.
|
||||
func isUnmarshaler(t reflect.Type) bool {
|
||||
// We're checking for (likely) pointer-receiver methods
|
||||
// so if t is not a pointer, something is very wrong.
|
||||
// The calls above only invoke isUnmarshaler on pointer types.
|
||||
if t.Kind() != reflect.Ptr {
|
||||
panic("proto: misuse of isUnmarshaler")
|
||||
}
|
||||
return t.Implements(unmarshalerType)
|
||||
}
|
||||
|
||||
// Init populates the properties from a protocol buffer struct tag.
|
||||
func (p *Properties) Init(typ reflect.Type, name, tag string, f *reflect.StructField) {
|
||||
p.init(typ, name, tag, f, true)
|
||||
}
|
||||
|
||||
func (p *Properties) init(typ reflect.Type, name, tag string, f *reflect.StructField, lockGetProp bool) {
|
||||
// "bytes,49,opt,def=hello!"
|
||||
p.Name = name
|
||||
p.OrigName = name
|
||||
if f != nil {
|
||||
p.field = toField(f)
|
||||
}
|
||||
if tag == "" {
|
||||
return
|
||||
}
|
||||
p.Parse(tag)
|
||||
p.setEncAndDec(typ, f, lockGetProp)
|
||||
}
|
||||
|
||||
var (
|
||||
propertiesMu sync.RWMutex
|
||||
propertiesMap = make(map[reflect.Type]*StructProperties)
|
||||
)
|
||||
|
||||
// GetProperties returns the list of properties for the type represented by t.
|
||||
// t must represent a generated struct type of a protocol message.
|
||||
func GetProperties(t reflect.Type) *StructProperties {
|
||||
if t.Kind() != reflect.Struct {
|
||||
panic("proto: type must have kind struct")
|
||||
}
|
||||
|
||||
// Most calls to GetProperties in a long-running program will be
|
||||
// retrieving details for types we have seen before.
|
||||
propertiesMu.RLock()
|
||||
sprop, ok := propertiesMap[t]
|
||||
propertiesMu.RUnlock()
|
||||
if ok {
|
||||
if collectStats {
|
||||
stats.Chit++
|
||||
}
|
||||
return sprop
|
||||
}
|
||||
|
||||
propertiesMu.Lock()
|
||||
sprop = getPropertiesLocked(t)
|
||||
propertiesMu.Unlock()
|
||||
return sprop
|
||||
}
|
||||
|
||||
// getPropertiesLocked requires that propertiesMu is held.
|
||||
func getPropertiesLocked(t reflect.Type) *StructProperties {
|
||||
if prop, ok := propertiesMap[t]; ok {
|
||||
if collectStats {
|
||||
stats.Chit++
|
||||
}
|
||||
return prop
|
||||
}
|
||||
if collectStats {
|
||||
stats.Cmiss++
|
||||
}
|
||||
|
||||
prop := new(StructProperties)
|
||||
// in case of recursive protos, fill this in now.
|
||||
propertiesMap[t] = prop
|
||||
|
||||
// build properties
|
||||
prop.extendable = reflect.PtrTo(t).Implements(extendableProtoType) ||
|
||||
reflect.PtrTo(t).Implements(extendableProtoV1Type)
|
||||
prop.unrecField = invalidField
|
||||
prop.Prop = make([]*Properties, t.NumField())
|
||||
prop.order = make([]int, t.NumField())
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
p := new(Properties)
|
||||
name := f.Name
|
||||
p.init(f.Type, name, f.Tag.Get("protobuf"), &f, false)
|
||||
|
||||
if f.Name == "XXX_InternalExtensions" { // special case
|
||||
p.enc = (*Buffer).enc_exts
|
||||
p.dec = nil // not needed
|
||||
p.size = size_exts
|
||||
} else if f.Name == "XXX_extensions" { // special case
|
||||
p.enc = (*Buffer).enc_map
|
||||
p.dec = nil // not needed
|
||||
p.size = size_map
|
||||
} else if f.Name == "XXX_unrecognized" { // special case
|
||||
prop.unrecField = toField(&f)
|
||||
}
|
||||
oneof := f.Tag.Get("protobuf_oneof") // special case
|
||||
if oneof != "" {
|
||||
// Oneof fields don't use the traditional protobuf tag.
|
||||
p.OrigName = oneof
|
||||
}
|
||||
prop.Prop[i] = p
|
||||
prop.order[i] = i
|
||||
if debug {
|
||||
print(i, " ", f.Name, " ", t.String(), " ")
|
||||
if p.Tag > 0 {
|
||||
print(p.String())
|
||||
}
|
||||
print("\n")
|
||||
}
|
||||
if p.enc == nil && !strings.HasPrefix(f.Name, "XXX_") && oneof == "" {
|
||||
fmt.Fprintln(os.Stderr, "proto: no encoder for", f.Name, f.Type.String(), "[GetProperties]")
|
||||
}
|
||||
}
|
||||
|
||||
// Re-order prop.order.
|
||||
sort.Sort(prop)
|
||||
|
||||
type oneofMessage interface {
|
||||
XXX_OneofFuncs() (func(Message, *Buffer) error, func(Message, int, int, *Buffer) (bool, error), func(Message) int, []interface{})
|
||||
}
|
||||
if om, ok := reflect.Zero(reflect.PtrTo(t)).Interface().(oneofMessage); ok {
|
||||
var oots []interface{}
|
||||
prop.oneofMarshaler, prop.oneofUnmarshaler, prop.oneofSizer, oots = om.XXX_OneofFuncs()
|
||||
prop.stype = t
|
||||
|
||||
// Interpret oneof metadata.
|
||||
prop.OneofTypes = make(map[string]*OneofProperties)
|
||||
for _, oot := range oots {
|
||||
oop := &OneofProperties{
|
||||
Type: reflect.ValueOf(oot).Type(), // *T
|
||||
Prop: new(Properties),
|
||||
}
|
||||
sft := oop.Type.Elem().Field(0)
|
||||
oop.Prop.Name = sft.Name
|
||||
oop.Prop.Parse(sft.Tag.Get("protobuf"))
|
||||
// There will be exactly one interface field that
|
||||
// this new value is assignable to.
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
if f.Type.Kind() != reflect.Interface {
|
||||
continue
|
||||
}
|
||||
if !oop.Type.AssignableTo(f.Type) {
|
||||
continue
|
||||
}
|
||||
oop.Field = i
|
||||
break
|
||||
}
|
||||
prop.OneofTypes[oop.Prop.OrigName] = oop
|
||||
}
|
||||
}
|
||||
|
||||
// build required counts
|
||||
// build tags
|
||||
reqCount := 0
|
||||
prop.decoderOrigNames = make(map[string]int)
|
||||
for i, p := range prop.Prop {
|
||||
if strings.HasPrefix(p.Name, "XXX_") {
|
||||
// Internal fields should not appear in tags/origNames maps.
|
||||
// They are handled specially when encoding and decoding.
|
||||
continue
|
||||
}
|
||||
if p.Required {
|
||||
reqCount++
|
||||
}
|
||||
prop.decoderTags.put(p.Tag, i)
|
||||
prop.decoderOrigNames[p.OrigName] = i
|
||||
}
|
||||
prop.reqCount = reqCount
|
||||
|
||||
return prop
|
||||
}
|
||||
|
||||
// Return the Properties object for the x[0]'th field of the structure.
|
||||
func propByIndex(t reflect.Type, x []int) *Properties {
|
||||
if len(x) != 1 {
|
||||
fmt.Fprintf(os.Stderr, "proto: field index dimension %d (not 1) for type %s\n", len(x), t)
|
||||
return nil
|
||||
}
|
||||
prop := GetProperties(t)
|
||||
return prop.Prop[x[0]]
|
||||
}
|
||||
|
||||
// Get the address and type of a pointer to a struct from an interface.
|
||||
func getbase(pb Message) (t reflect.Type, b structPointer, err error) {
|
||||
if pb == nil {
|
||||
err = ErrNil
|
||||
return
|
||||
}
|
||||
// get the reflect type of the pointer to the struct.
|
||||
t = reflect.TypeOf(pb)
|
||||
// get the address of the struct.
|
||||
value := reflect.ValueOf(pb)
|
||||
b = toStructPointer(value)
|
||||
return
|
||||
}
|
||||
|
||||
// A global registry of enum types.
|
||||
// The generated code will register the generated maps by calling RegisterEnum.
|
||||
|
||||
var enumValueMaps = make(map[string]map[string]int32)
|
||||
|
||||
// RegisterEnum is called from the generated code to install the enum descriptor
|
||||
// maps into the global table to aid parsing text format protocol buffers.
|
||||
func RegisterEnum(typeName string, unusedNameMap map[int32]string, valueMap map[string]int32) {
|
||||
if _, ok := enumValueMaps[typeName]; ok {
|
||||
panic("proto: duplicate enum registered: " + typeName)
|
||||
}
|
||||
enumValueMaps[typeName] = valueMap
|
||||
}
|
||||
|
||||
// EnumValueMap returns the mapping from names to integers of the
|
||||
// enum type enumType, or a nil if not found.
|
||||
func EnumValueMap(enumType string) map[string]int32 {
|
||||
return enumValueMaps[enumType]
|
||||
}
|
||||
|
||||
// A registry of all linked message types.
|
||||
// The string is a fully-qualified proto name ("pkg.Message").
|
||||
var (
|
||||
protoTypes = make(map[string]reflect.Type)
|
||||
revProtoTypes = make(map[reflect.Type]string)
|
||||
)
|
||||
|
||||
// RegisterType is called from generated code and maps from the fully qualified
|
||||
// proto name to the type (pointer to struct) of the protocol buffer.
|
||||
func RegisterType(x Message, name string) {
|
||||
if _, ok := protoTypes[name]; ok {
|
||||
// TODO: Some day, make this a panic.
|
||||
log.Printf("proto: duplicate proto type registered: %s", name)
|
||||
return
|
||||
}
|
||||
t := reflect.TypeOf(x)
|
||||
protoTypes[name] = t
|
||||
revProtoTypes[t] = name
|
||||
}
|
||||
|
||||
// MessageName returns the fully-qualified proto name for the given message type.
|
||||
func MessageName(x Message) string {
|
||||
type xname interface {
|
||||
XXX_MessageName() string
|
||||
}
|
||||
if m, ok := x.(xname); ok {
|
||||
return m.XXX_MessageName()
|
||||
}
|
||||
return revProtoTypes[reflect.TypeOf(x)]
|
||||
}
|
||||
|
||||
// MessageType returns the message type (pointer to struct) for a named message.
|
||||
func MessageType(name string) reflect.Type { return protoTypes[name] }
|
||||
|
||||
// A registry of all linked proto files.
|
||||
var (
|
||||
protoFiles = make(map[string][]byte) // file name => fileDescriptor
|
||||
)
|
||||
|
||||
// RegisterFile is called from generated code and maps from the
|
||||
// full file name of a .proto file to its compressed FileDescriptorProto.
|
||||
func RegisterFile(filename string, fileDescriptor []byte) {
|
||||
protoFiles[filename] = fileDescriptor
|
||||
}
|
||||
|
||||
// FileDescriptor returns the compressed FileDescriptorProto for a .proto file.
|
||||
func FileDescriptor(filename string) []byte { return protoFiles[filename] }
|
|
@ -0,0 +1,854 @@
|
|||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package proto
|
||||
|
||||
// Functions for writing the text protocol buffer format.
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
newline = []byte("\n")
|
||||
spaces = []byte(" ")
|
||||
gtNewline = []byte(">\n")
|
||||
endBraceNewline = []byte("}\n")
|
||||
backslashN = []byte{'\\', 'n'}
|
||||
backslashR = []byte{'\\', 'r'}
|
||||
backslashT = []byte{'\\', 't'}
|
||||
backslashDQ = []byte{'\\', '"'}
|
||||
backslashBS = []byte{'\\', '\\'}
|
||||
posInf = []byte("inf")
|
||||
negInf = []byte("-inf")
|
||||
nan = []byte("nan")
|
||||
)
|
||||
|
||||
type writer interface {
|
||||
io.Writer
|
||||
WriteByte(byte) error
|
||||
}
|
||||
|
||||
// textWriter is an io.Writer that tracks its indentation level.
|
||||
type textWriter struct {
|
||||
ind int
|
||||
complete bool // if the current position is a complete line
|
||||
compact bool // whether to write out as a one-liner
|
||||
w writer
|
||||
}
|
||||
|
||||
func (w *textWriter) WriteString(s string) (n int, err error) {
|
||||
if !strings.Contains(s, "\n") {
|
||||
if !w.compact && w.complete {
|
||||
w.writeIndent()
|
||||
}
|
||||
w.complete = false
|
||||
return io.WriteString(w.w, s)
|
||||
}
|
||||
// WriteString is typically called without newlines, so this
|
||||
// codepath and its copy are rare. We copy to avoid
|
||||
// duplicating all of Write's logic here.
|
||||
return w.Write([]byte(s))
|
||||
}
|
||||
|
||||
func (w *textWriter) Write(p []byte) (n int, err error) {
|
||||
newlines := bytes.Count(p, newline)
|
||||
if newlines == 0 {
|
||||
if !w.compact && w.complete {
|
||||
w.writeIndent()
|
||||
}
|
||||
n, err = w.w.Write(p)
|
||||
w.complete = false
|
||||
return n, err
|
||||
}
|
||||
|
||||
frags := bytes.SplitN(p, newline, newlines+1)
|
||||
if w.compact {
|
||||
for i, frag := range frags {
|
||||
if i > 0 {
|
||||
if err := w.w.WriteByte(' '); err != nil {
|
||||
return n, err
|
||||
}
|
||||
n++
|
||||
}
|
||||
nn, err := w.w.Write(frag)
|
||||
n += nn
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
for i, frag := range frags {
|
||||
if w.complete {
|
||||
w.writeIndent()
|
||||
}
|
||||
nn, err := w.w.Write(frag)
|
||||
n += nn
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
if i+1 < len(frags) {
|
||||
if err := w.w.WriteByte('\n'); err != nil {
|
||||
return n, err
|
||||
}
|
||||
n++
|
||||
}
|
||||
}
|
||||
w.complete = len(frags[len(frags)-1]) == 0
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (w *textWriter) WriteByte(c byte) error {
|
||||
if w.compact && c == '\n' {
|
||||
c = ' '
|
||||
}
|
||||
if !w.compact && w.complete {
|
||||
w.writeIndent()
|
||||
}
|
||||
err := w.w.WriteByte(c)
|
||||
w.complete = c == '\n'
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *textWriter) indent() { w.ind++ }
|
||||
|
||||
func (w *textWriter) unindent() {
|
||||
if w.ind == 0 {
|
||||
log.Print("proto: textWriter unindented too far")
|
||||
return
|
||||
}
|
||||
w.ind--
|
||||
}
|
||||
|
||||
func writeName(w *textWriter, props *Properties) error {
|
||||
if _, err := w.WriteString(props.OrigName); err != nil {
|
||||
return err
|
||||
}
|
||||
if props.Wire != "group" {
|
||||
return w.WriteByte(':')
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// raw is the interface satisfied by RawMessage.
|
||||
type raw interface {
|
||||
Bytes() []byte
|
||||
}
|
||||
|
||||
func requiresQuotes(u string) bool {
|
||||
// When type URL contains any characters except [0-9A-Za-z./\-]*, it must be quoted.
|
||||
for _, ch := range u {
|
||||
switch {
|
||||
case ch == '.' || ch == '/' || ch == '_':
|
||||
continue
|
||||
case '0' <= ch && ch <= '9':
|
||||
continue
|
||||
case 'A' <= ch && ch <= 'Z':
|
||||
continue
|
||||
case 'a' <= ch && ch <= 'z':
|
||||
continue
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isAny reports whether sv is a google.protobuf.Any message
|
||||
func isAny(sv reflect.Value) bool {
|
||||
type wkt interface {
|
||||
XXX_WellKnownType() string
|
||||
}
|
||||
t, ok := sv.Addr().Interface().(wkt)
|
||||
return ok && t.XXX_WellKnownType() == "Any"
|
||||
}
|
||||
|
||||
// writeProto3Any writes an expanded google.protobuf.Any message.
|
||||
//
|
||||
// It returns (false, nil) if sv value can't be unmarshaled (e.g. because
|
||||
// required messages are not linked in).
|
||||
//
|
||||
// It returns (true, error) when sv was written in expanded format or an error
|
||||
// was encountered.
|
||||
func (tm *TextMarshaler) writeProto3Any(w *textWriter, sv reflect.Value) (bool, error) {
|
||||
turl := sv.FieldByName("TypeUrl")
|
||||
val := sv.FieldByName("Value")
|
||||
if !turl.IsValid() || !val.IsValid() {
|
||||
return true, errors.New("proto: invalid google.protobuf.Any message")
|
||||
}
|
||||
|
||||
b, ok := val.Interface().([]byte)
|
||||
if !ok {
|
||||
return true, errors.New("proto: invalid google.protobuf.Any message")
|
||||
}
|
||||
|
||||
parts := strings.Split(turl.String(), "/")
|
||||
mt := MessageType(parts[len(parts)-1])
|
||||
if mt == nil {
|
||||
return false, nil
|
||||
}
|
||||
m := reflect.New(mt.Elem())
|
||||
if err := Unmarshal(b, m.Interface().(Message)); err != nil {
|
||||
return false, nil
|
||||
}
|
||||
w.Write([]byte("["))
|
||||
u := turl.String()
|
||||
if requiresQuotes(u) {
|
||||
writeString(w, u)
|
||||
} else {
|
||||
w.Write([]byte(u))
|
||||
}
|
||||
if w.compact {
|
||||
w.Write([]byte("]:<"))
|
||||
} else {
|
||||
w.Write([]byte("]: <\n"))
|
||||
w.ind++
|
||||
}
|
||||
if err := tm.writeStruct(w, m.Elem()); err != nil {
|
||||
return true, err
|
||||
}
|
||||
if w.compact {
|
||||
w.Write([]byte("> "))
|
||||
} else {
|
||||
w.ind--
|
||||
w.Write([]byte(">\n"))
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (tm *TextMarshaler) writeStruct(w *textWriter, sv reflect.Value) error {
|
||||
if tm.ExpandAny && isAny(sv) {
|
||||
if canExpand, err := tm.writeProto3Any(w, sv); canExpand {
|
||||
return err
|
||||
}
|
||||
}
|
||||
st := sv.Type()
|
||||
sprops := GetProperties(st)
|
||||
for i := 0; i < sv.NumField(); i++ {
|
||||
fv := sv.Field(i)
|
||||
props := sprops.Prop[i]
|
||||
name := st.Field(i).Name
|
||||
|
||||
if strings.HasPrefix(name, "XXX_") {
|
||||
// There are two XXX_ fields:
|
||||
// XXX_unrecognized []byte
|
||||
// XXX_extensions map[int32]proto.Extension
|
||||
// The first is handled here;
|
||||
// the second is handled at the bottom of this function.
|
||||
if name == "XXX_unrecognized" && !fv.IsNil() {
|
||||
if err := writeUnknownStruct(w, fv.Interface().([]byte)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if fv.Kind() == reflect.Ptr && fv.IsNil() {
|
||||
// Field not filled in. This could be an optional field or
|
||||
// a required field that wasn't filled in. Either way, there
|
||||
// isn't anything we can show for it.
|
||||
continue
|
||||
}
|
||||
if fv.Kind() == reflect.Slice && fv.IsNil() {
|
||||
// Repeated field that is empty, or a bytes field that is unused.
|
||||
continue
|
||||
}
|
||||
|
||||
if props.Repeated && fv.Kind() == reflect.Slice {
|
||||
// Repeated field.
|
||||
for j := 0; j < fv.Len(); j++ {
|
||||
if err := writeName(w, props); err != nil {
|
||||
return err
|
||||
}
|
||||
if !w.compact {
|
||||
if err := w.WriteByte(' '); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
v := fv.Index(j)
|
||||
if v.Kind() == reflect.Ptr && v.IsNil() {
|
||||
// A nil message in a repeated field is not valid,
|
||||
// but we can handle that more gracefully than panicking.
|
||||
if _, err := w.Write([]byte("<nil>\n")); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err := tm.writeAny(w, v, props); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := w.WriteByte('\n'); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if fv.Kind() == reflect.Map {
|
||||
// Map fields are rendered as a repeated struct with key/value fields.
|
||||
keys := fv.MapKeys()
|
||||
sort.Sort(mapKeys(keys))
|
||||
for _, key := range keys {
|
||||
val := fv.MapIndex(key)
|
||||
if err := writeName(w, props); err != nil {
|
||||
return err
|
||||
}
|
||||
if !w.compact {
|
||||
if err := w.WriteByte(' '); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// open struct
|
||||
if err := w.WriteByte('<'); err != nil {
|
||||
return err
|
||||
}
|
||||
if !w.compact {
|
||||
if err := w.WriteByte('\n'); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
w.indent()
|
||||
// key
|
||||
if _, err := w.WriteString("key:"); err != nil {
|
||||
return err
|
||||
}
|
||||
if !w.compact {
|
||||
if err := w.WriteByte(' '); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := tm.writeAny(w, key, props.mkeyprop); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := w.WriteByte('\n'); err != nil {
|
||||
return err
|
||||
}
|
||||
// nil values aren't legal, but we can avoid panicking because of them.
|
||||
if val.Kind() != reflect.Ptr || !val.IsNil() {
|
||||
// value
|
||||
if _, err := w.WriteString("value:"); err != nil {
|
||||
return err
|
||||
}
|
||||
if !w.compact {
|
||||
if err := w.WriteByte(' '); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := tm.writeAny(w, val, props.mvalprop); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := w.WriteByte('\n'); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// close struct
|
||||
w.unindent()
|
||||
if err := w.WriteByte('>'); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := w.WriteByte('\n'); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if props.proto3 && fv.Kind() == reflect.Slice && fv.Len() == 0 {
|
||||
// empty bytes field
|
||||
continue
|
||||
}
|
||||
if fv.Kind() != reflect.Ptr && fv.Kind() != reflect.Slice {
|
||||
// proto3 non-repeated scalar field; skip if zero value
|
||||
if isProto3Zero(fv) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if fv.Kind() == reflect.Interface {
|
||||
// Check if it is a oneof.
|
||||
if st.Field(i).Tag.Get("protobuf_oneof") != "" {
|
||||
// fv is nil, or holds a pointer to generated struct.
|
||||
// That generated struct has exactly one field,
|
||||
// which has a protobuf struct tag.
|
||||
if fv.IsNil() {
|
||||
continue
|
||||
}
|
||||
inner := fv.Elem().Elem() // interface -> *T -> T
|
||||
tag := inner.Type().Field(0).Tag.Get("protobuf")
|
||||
props = new(Properties) // Overwrite the outer props var, but not its pointee.
|
||||
props.Parse(tag)
|
||||
// Write the value in the oneof, not the oneof itself.
|
||||
fv = inner.Field(0)
|
||||
|
||||
// Special case to cope with malformed messages gracefully:
|
||||
// If the value in the oneof is a nil pointer, don't panic
|
||||
// in writeAny.
|
||||
if fv.Kind() == reflect.Ptr && fv.IsNil() {
|
||||
// Use errors.New so writeAny won't render quotes.
|
||||
msg := errors.New("/* nil */")
|
||||
fv = reflect.ValueOf(&msg).Elem()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := writeName(w, props); err != nil {
|
||||
return err
|
||||
}
|
||||
if !w.compact {
|
||||
if err := w.WriteByte(' '); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if b, ok := fv.Interface().(raw); ok {
|
||||
if err := writeRaw(w, b.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Enums have a String method, so writeAny will work fine.
|
||||
if err := tm.writeAny(w, fv, props); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := w.WriteByte('\n'); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Extensions (the XXX_extensions field).
|
||||
pv := sv.Addr()
|
||||
if _, ok := extendable(pv.Interface()); ok {
|
||||
if err := tm.writeExtensions(w, pv); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeRaw writes an uninterpreted raw message.
|
||||
func writeRaw(w *textWriter, b []byte) error {
|
||||
if err := w.WriteByte('<'); err != nil {
|
||||
return err
|
||||
}
|
||||
if !w.compact {
|
||||
if err := w.WriteByte('\n'); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
w.indent()
|
||||
if err := writeUnknownStruct(w, b); err != nil {
|
||||
return err
|
||||
}
|
||||
w.unindent()
|
||||
if err := w.WriteByte('>'); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeAny writes an arbitrary field.
|
||||
func (tm *TextMarshaler) writeAny(w *textWriter, v reflect.Value, props *Properties) error {
|
||||
v = reflect.Indirect(v)
|
||||
|
||||
// Floats have special cases.
|
||||
if v.Kind() == reflect.Float32 || v.Kind() == reflect.Float64 {
|
||||
x := v.Float()
|
||||
var b []byte
|
||||
switch {
|
||||
case math.IsInf(x, 1):
|
||||
b = posInf
|
||||
case math.IsInf(x, -1):
|
||||
b = negInf
|
||||
case math.IsNaN(x):
|
||||
b = nan
|
||||
}
|
||||
if b != nil {
|
||||
_, err := w.Write(b)
|
||||
return err
|
||||
}
|
||||
// Other values are handled below.
|
||||
}
|
||||
|
||||
// We don't attempt to serialise every possible value type; only those
|
||||
// that can occur in protocol buffers.
|
||||
switch v.Kind() {
|
||||
case reflect.Slice:
|
||||
// Should only be a []byte; repeated fields are handled in writeStruct.
|
||||
if err := writeString(w, string(v.Bytes())); err != nil {
|
||||
return err
|
||||
}
|
||||
case reflect.String:
|
||||
if err := writeString(w, v.String()); err != nil {
|
||||
return err
|
||||
}
|
||||
case reflect.Struct:
|
||||
// Required/optional group/message.
|
||||
var bra, ket byte = '<', '>'
|
||||
if props != nil && props.Wire == "group" {
|
||||
bra, ket = '{', '}'
|
||||
}
|
||||
if err := w.WriteByte(bra); err != nil {
|
||||
return err
|
||||
}
|
||||
if !w.compact {
|
||||
if err := w.WriteByte('\n'); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
w.indent()
|
||||
if etm, ok := v.Interface().(encoding.TextMarshaler); ok {
|
||||
text, err := etm.MarshalText()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = w.Write(text); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := tm.writeStruct(w, v); err != nil {
|
||||
return err
|
||||
}
|
||||
w.unindent()
|
||||
if err := w.WriteByte(ket); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
_, err := fmt.Fprint(w, v.Interface())
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// equivalent to C's isprint.
|
||||
func isprint(c byte) bool {
|
||||
return c >= 0x20 && c < 0x7f
|
||||
}
|
||||
|
||||
// writeString writes a string in the protocol buffer text format.
|
||||
// It is similar to strconv.Quote except we don't use Go escape sequences,
|
||||
// we treat the string as a byte sequence, and we use octal escapes.
|
||||
// These differences are to maintain interoperability with the other
|
||||
// languages' implementations of the text format.
|
||||
func writeString(w *textWriter, s string) error {
|
||||
// use WriteByte here to get any needed indent
|
||||
if err := w.WriteByte('"'); err != nil {
|
||||
return err
|
||||
}
|
||||
// Loop over the bytes, not the runes.
|
||||
for i := 0; i < len(s); i++ {
|
||||
var err error
|
||||
// Divergence from C++: we don't escape apostrophes.
|
||||
// There's no need to escape them, and the C++ parser
|
||||
// copes with a naked apostrophe.
|
||||
switch c := s[i]; c {
|
||||
case '\n':
|
||||
_, err = w.w.Write(backslashN)
|
||||
case '\r':
|
||||
_, err = w.w.Write(backslashR)
|
||||
case '\t':
|
||||
_, err = w.w.Write(backslashT)
|
||||
case '"':
|
||||
_, err = w.w.Write(backslashDQ)
|
||||
case '\\':
|
||||
_, err = w.w.Write(backslashBS)
|
||||
default:
|
||||
if isprint(c) {
|
||||
err = w.w.WriteByte(c)
|
||||
} else {
|
||||
_, err = fmt.Fprintf(w.w, "\\%03o", c)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return w.WriteByte('"')
|
||||
}
|
||||
|
||||
func writeUnknownStruct(w *textWriter, data []byte) (err error) {
|
||||
if !w.compact {
|
||||
if _, err := fmt.Fprintf(w, "/* %d unknown bytes */\n", len(data)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
b := NewBuffer(data)
|
||||
for b.index < len(b.buf) {
|
||||
x, err := b.DecodeVarint()
|
||||
if err != nil {
|
||||
_, err := fmt.Fprintf(w, "/* %v */\n", err)
|
||||
return err
|
||||
}
|
||||
wire, tag := x&7, x>>3
|
||||
if wire == WireEndGroup {
|
||||
w.unindent()
|
||||
if _, err := w.Write(endBraceNewline); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(w, tag); err != nil {
|
||||
return err
|
||||
}
|
||||
if wire != WireStartGroup {
|
||||
if err := w.WriteByte(':'); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if !w.compact || wire == WireStartGroup {
|
||||
if err := w.WriteByte(' '); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
switch wire {
|
||||
case WireBytes:
|
||||
buf, e := b.DecodeRawBytes(false)
|
||||
if e == nil {
|
||||
_, err = fmt.Fprintf(w, "%q", buf)
|
||||
} else {
|
||||
_, err = fmt.Fprintf(w, "/* %v */", e)
|
||||
}
|
||||
case WireFixed32:
|
||||
x, err = b.DecodeFixed32()
|
||||
err = writeUnknownInt(w, x, err)
|
||||
case WireFixed64:
|
||||
x, err = b.DecodeFixed64()
|
||||
err = writeUnknownInt(w, x, err)
|
||||
case WireStartGroup:
|
||||
err = w.WriteByte('{')
|
||||
w.indent()
|
||||
case WireVarint:
|
||||
x, err = b.DecodeVarint()
|
||||
err = writeUnknownInt(w, x, err)
|
||||
default:
|
||||
_, err = fmt.Fprintf(w, "/* unknown wire type %d */", wire)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = w.WriteByte('\n'); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeUnknownInt(w *textWriter, x uint64, err error) error {
|
||||
if err == nil {
|
||||
_, err = fmt.Fprint(w, x)
|
||||
} else {
|
||||
_, err = fmt.Fprintf(w, "/* %v */", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type int32Slice []int32
|
||||
|
||||
func (s int32Slice) Len() int { return len(s) }
|
||||
func (s int32Slice) Less(i, j int) bool { return s[i] < s[j] }
|
||||
func (s int32Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
|
||||
// writeExtensions writes all the extensions in pv.
|
||||
// pv is assumed to be a pointer to a protocol message struct that is extendable.
|
||||
func (tm *TextMarshaler) writeExtensions(w *textWriter, pv reflect.Value) error {
|
||||
emap := extensionMaps[pv.Type().Elem()]
|
||||
ep, _ := extendable(pv.Interface())
|
||||
|
||||
// Order the extensions by ID.
|
||||
// This isn't strictly necessary, but it will give us
|
||||
// canonical output, which will also make testing easier.
|
||||
m, mu := ep.extensionsRead()
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
mu.Lock()
|
||||
ids := make([]int32, 0, len(m))
|
||||
for id := range m {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
sort.Sort(int32Slice(ids))
|
||||
mu.Unlock()
|
||||
|
||||
for _, extNum := range ids {
|
||||
ext := m[extNum]
|
||||
var desc *ExtensionDesc
|
||||
if emap != nil {
|
||||
desc = emap[extNum]
|
||||
}
|
||||
if desc == nil {
|
||||
// Unknown extension.
|
||||
if err := writeUnknownStruct(w, ext.enc); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
pb, err := GetExtension(ep, desc)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed getting extension: %v", err)
|
||||
}
|
||||
|
||||
// Repeated extensions will appear as a slice.
|
||||
if !desc.repeated() {
|
||||
if err := tm.writeExtension(w, desc.Name, pb); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
v := reflect.ValueOf(pb)
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
if err := tm.writeExtension(w, desc.Name, v.Index(i).Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tm *TextMarshaler) writeExtension(w *textWriter, name string, pb interface{}) error {
|
||||
if _, err := fmt.Fprintf(w, "[%s]:", name); err != nil {
|
||||
return err
|
||||
}
|
||||
if !w.compact {
|
||||
if err := w.WriteByte(' '); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := tm.writeAny(w, reflect.ValueOf(pb), nil); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := w.WriteByte('\n'); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *textWriter) writeIndent() {
|
||||
if !w.complete {
|
||||
return
|
||||
}
|
||||
remain := w.ind * 2
|
||||
for remain > 0 {
|
||||
n := remain
|
||||
if n > len(spaces) {
|
||||
n = len(spaces)
|
||||
}
|
||||
w.w.Write(spaces[:n])
|
||||
remain -= n
|
||||
}
|
||||
w.complete = false
|
||||
}
|
||||
|
||||
// TextMarshaler is a configurable text format marshaler.
|
||||
type TextMarshaler struct {
|
||||
Compact bool // use compact text format (one line).
|
||||
ExpandAny bool // expand google.protobuf.Any messages of known types
|
||||
}
|
||||
|
||||
// Marshal writes a given protocol buffer in text format.
|
||||
// The only errors returned are from w.
|
||||
func (tm *TextMarshaler) Marshal(w io.Writer, pb Message) error {
|
||||
val := reflect.ValueOf(pb)
|
||||
if pb == nil || val.IsNil() {
|
||||
w.Write([]byte("<nil>"))
|
||||
return nil
|
||||
}
|
||||
var bw *bufio.Writer
|
||||
ww, ok := w.(writer)
|
||||
if !ok {
|
||||
bw = bufio.NewWriter(w)
|
||||
ww = bw
|
||||
}
|
||||
aw := &textWriter{
|
||||
w: ww,
|
||||
complete: true,
|
||||
compact: tm.Compact,
|
||||
}
|
||||
|
||||
if etm, ok := pb.(encoding.TextMarshaler); ok {
|
||||
text, err := etm.MarshalText()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = aw.Write(text); err != nil {
|
||||
return err
|
||||
}
|
||||
if bw != nil {
|
||||
return bw.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// Dereference the received pointer so we don't have outer < and >.
|
||||
v := reflect.Indirect(val)
|
||||
if err := tm.writeStruct(aw, v); err != nil {
|
||||
return err
|
||||
}
|
||||
if bw != nil {
|
||||
return bw.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Text is the same as Marshal, but returns the string directly.
|
||||
func (tm *TextMarshaler) Text(pb Message) string {
|
||||
var buf bytes.Buffer
|
||||
tm.Marshal(&buf, pb)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
var (
|
||||
defaultTextMarshaler = TextMarshaler{}
|
||||
compactTextMarshaler = TextMarshaler{Compact: true}
|
||||
)
|
||||
|
||||
// TODO: consider removing some of the Marshal functions below.
|
||||
|
||||
// MarshalText writes a given protocol buffer in text format.
|
||||
// The only errors returned are from w.
|
||||
func MarshalText(w io.Writer, pb Message) error { return defaultTextMarshaler.Marshal(w, pb) }
|
||||
|
||||
// MarshalTextString is the same as MarshalText, but returns the string directly.
|
||||
func MarshalTextString(pb Message) string { return defaultTextMarshaler.Text(pb) }
|
||||
|
||||
// CompactText writes a given protocol buffer in compact text format (one line).
|
||||
func CompactText(w io.Writer, pb Message) error { return compactTextMarshaler.Marshal(w, pb) }
|
||||
|
||||
// CompactTextString is the same as CompactText, but returns the string directly.
|
||||
func CompactTextString(pb Message) string { return compactTextMarshaler.Text(pb) }
|
|
@ -0,0 +1,895 @@
|
|||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package proto
|
||||
|
||||
// Functions for parsing the Text protocol buffer format.
|
||||
// TODO: message sets.
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// Error string emitted when deserializing Any and fields are already set
|
||||
const anyRepeatedlyUnpacked = "Any message unpacked multiple times, or %q already set"
|
||||
|
||||
type ParseError struct {
|
||||
Message string
|
||||
Line int // 1-based line number
|
||||
Offset int // 0-based byte offset from start of input
|
||||
}
|
||||
|
||||
func (p *ParseError) Error() string {
|
||||
if p.Line == 1 {
|
||||
// show offset only for first line
|
||||
return fmt.Sprintf("line 1.%d: %v", p.Offset, p.Message)
|
||||
}
|
||||
return fmt.Sprintf("line %d: %v", p.Line, p.Message)
|
||||
}
|
||||
|
||||
type token struct {
|
||||
value string
|
||||
err *ParseError
|
||||
line int // line number
|
||||
offset int // byte number from start of input, not start of line
|
||||
unquoted string // the unquoted version of value, if it was a quoted string
|
||||
}
|
||||
|
||||
func (t *token) String() string {
|
||||
if t.err == nil {
|
||||
return fmt.Sprintf("%q (line=%d, offset=%d)", t.value, t.line, t.offset)
|
||||
}
|
||||
return fmt.Sprintf("parse error: %v", t.err)
|
||||
}
|
||||
|
||||
type textParser struct {
|
||||
s string // remaining input
|
||||
done bool // whether the parsing is finished (success or error)
|
||||
backed bool // whether back() was called
|
||||
offset, line int
|
||||
cur token
|
||||
}
|
||||
|
||||
func newTextParser(s string) *textParser {
|
||||
p := new(textParser)
|
||||
p.s = s
|
||||
p.line = 1
|
||||
p.cur.line = 1
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *textParser) errorf(format string, a ...interface{}) *ParseError {
|
||||
pe := &ParseError{fmt.Sprintf(format, a...), p.cur.line, p.cur.offset}
|
||||
p.cur.err = pe
|
||||
p.done = true
|
||||
return pe
|
||||
}
|
||||
|
||||
// Numbers and identifiers are matched by [-+._A-Za-z0-9]
|
||||
func isIdentOrNumberChar(c byte) bool {
|
||||
switch {
|
||||
case 'A' <= c && c <= 'Z', 'a' <= c && c <= 'z':
|
||||
return true
|
||||
case '0' <= c && c <= '9':
|
||||
return true
|
||||
}
|
||||
switch c {
|
||||
case '-', '+', '.', '_':
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isWhitespace(c byte) bool {
|
||||
switch c {
|
||||
case ' ', '\t', '\n', '\r':
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isQuote(c byte) bool {
|
||||
switch c {
|
||||
case '"', '\'':
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *textParser) skipWhitespace() {
|
||||
i := 0
|
||||
for i < len(p.s) && (isWhitespace(p.s[i]) || p.s[i] == '#') {
|
||||
if p.s[i] == '#' {
|
||||
// comment; skip to end of line or input
|
||||
for i < len(p.s) && p.s[i] != '\n' {
|
||||
i++
|
||||
}
|
||||
if i == len(p.s) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if p.s[i] == '\n' {
|
||||
p.line++
|
||||
}
|
||||
i++
|
||||
}
|
||||
p.offset += i
|
||||
p.s = p.s[i:len(p.s)]
|
||||
if len(p.s) == 0 {
|
||||
p.done = true
|
||||
}
|
||||
}
|
||||
|
||||
func (p *textParser) advance() {
|
||||
// Skip whitespace
|
||||
p.skipWhitespace()
|
||||
if p.done {
|
||||
return
|
||||
}
|
||||
|
||||
// Start of non-whitespace
|
||||
p.cur.err = nil
|
||||
p.cur.offset, p.cur.line = p.offset, p.line
|
||||
p.cur.unquoted = ""
|
||||
switch p.s[0] {
|
||||
case '<', '>', '{', '}', ':', '[', ']', ';', ',', '/':
|
||||
// Single symbol
|
||||
p.cur.value, p.s = p.s[0:1], p.s[1:len(p.s)]
|
||||
case '"', '\'':
|
||||
// Quoted string
|
||||
i := 1
|
||||
for i < len(p.s) && p.s[i] != p.s[0] && p.s[i] != '\n' {
|
||||
if p.s[i] == '\\' && i+1 < len(p.s) {
|
||||
// skip escaped char
|
||||
i++
|
||||
}
|
||||
i++
|
||||
}
|
||||
if i >= len(p.s) || p.s[i] != p.s[0] {
|
||||
p.errorf("unmatched quote")
|
||||
return
|
||||
}
|
||||
unq, err := unquoteC(p.s[1:i], rune(p.s[0]))
|
||||
if err != nil {
|
||||
p.errorf("invalid quoted string %s: %v", p.s[0:i+1], err)
|
||||
return
|
||||
}
|
||||
p.cur.value, p.s = p.s[0:i+1], p.s[i+1:len(p.s)]
|
||||
p.cur.unquoted = unq
|
||||
default:
|
||||
i := 0
|
||||
for i < len(p.s) && isIdentOrNumberChar(p.s[i]) {
|
||||
i++
|
||||
}
|
||||
if i == 0 {
|
||||
p.errorf("unexpected byte %#x", p.s[0])
|
||||
return
|
||||
}
|
||||
p.cur.value, p.s = p.s[0:i], p.s[i:len(p.s)]
|
||||
}
|
||||
p.offset += len(p.cur.value)
|
||||
}
|
||||
|
||||
var (
|
||||
errBadUTF8 = errors.New("proto: bad UTF-8")
|
||||
errBadHex = errors.New("proto: bad hexadecimal")
|
||||
)
|
||||
|
||||
func unquoteC(s string, quote rune) (string, error) {
|
||||
// This is based on C++'s tokenizer.cc.
|
||||
// Despite its name, this is *not* parsing C syntax.
|
||||
// For instance, "\0" is an invalid quoted string.
|
||||
|
||||
// Avoid allocation in trivial cases.
|
||||
simple := true
|
||||
for _, r := range s {
|
||||
if r == '\\' || r == quote {
|
||||
simple = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if simple {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
buf := make([]byte, 0, 3*len(s)/2)
|
||||
for len(s) > 0 {
|
||||
r, n := utf8.DecodeRuneInString(s)
|
||||
if r == utf8.RuneError && n == 1 {
|
||||
return "", errBadUTF8
|
||||
}
|
||||
s = s[n:]
|
||||
if r != '\\' {
|
||||
if r < utf8.RuneSelf {
|
||||
buf = append(buf, byte(r))
|
||||
} else {
|
||||
buf = append(buf, string(r)...)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
ch, tail, err := unescape(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
buf = append(buf, ch...)
|
||||
s = tail
|
||||
}
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
func unescape(s string) (ch string, tail string, err error) {
|
||||
r, n := utf8.DecodeRuneInString(s)
|
||||
if r == utf8.RuneError && n == 1 {
|
||||
return "", "", errBadUTF8
|
||||
}
|
||||
s = s[n:]
|
||||
switch r {
|
||||
case 'a':
|
||||
return "\a", s, nil
|
||||
case 'b':
|
||||
return "\b", s, nil
|
||||
case 'f':
|
||||
return "\f", s, nil
|
||||
case 'n':
|
||||
return "\n", s, nil
|
||||
case 'r':
|
||||
return "\r", s, nil
|
||||
case 't':
|
||||
return "\t", s, nil
|
||||
case 'v':
|
||||
return "\v", s, nil
|
||||
case '?':
|
||||
return "?", s, nil // trigraph workaround
|
||||
case '\'', '"', '\\':
|
||||
return string(r), s, nil
|
||||
case '0', '1', '2', '3', '4', '5', '6', '7', 'x', 'X':
|
||||
if len(s) < 2 {
|
||||
return "", "", fmt.Errorf(`\%c requires 2 following digits`, r)
|
||||
}
|
||||
base := 8
|
||||
ss := s[:2]
|
||||
s = s[2:]
|
||||
if r == 'x' || r == 'X' {
|
||||
base = 16
|
||||
} else {
|
||||
ss = string(r) + ss
|
||||
}
|
||||
i, err := strconv.ParseUint(ss, base, 8)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return string([]byte{byte(i)}), s, nil
|
||||
case 'u', 'U':
|
||||
n := 4
|
||||
if r == 'U' {
|
||||
n = 8
|
||||
}
|
||||
if len(s) < n {
|
||||
return "", "", fmt.Errorf(`\%c requires %d digits`, r, n)
|
||||
}
|
||||
|
||||
bs := make([]byte, n/2)
|
||||
for i := 0; i < n; i += 2 {
|
||||
a, ok1 := unhex(s[i])
|
||||
b, ok2 := unhex(s[i+1])
|
||||
if !ok1 || !ok2 {
|
||||
return "", "", errBadHex
|
||||
}
|
||||
bs[i/2] = a<<4 | b
|
||||
}
|
||||
s = s[n:]
|
||||
return string(bs), s, nil
|
||||
}
|
||||
return "", "", fmt.Errorf(`unknown escape \%c`, r)
|
||||
}
|
||||
|
||||
// Adapted from src/pkg/strconv/quote.go.
|
||||
func unhex(b byte) (v byte, ok bool) {
|
||||
switch {
|
||||
case '0' <= b && b <= '9':
|
||||
return b - '0', true
|
||||
case 'a' <= b && b <= 'f':
|
||||
return b - 'a' + 10, true
|
||||
case 'A' <= b && b <= 'F':
|
||||
return b - 'A' + 10, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// Back off the parser by one token. Can only be done between calls to next().
|
||||
// It makes the next advance() a no-op.
|
||||
func (p *textParser) back() { p.backed = true }
|
||||
|
||||
// Advances the parser and returns the new current token.
|
||||
func (p *textParser) next() *token {
|
||||
if p.backed || p.done {
|
||||
p.backed = false
|
||||
return &p.cur
|
||||
}
|
||||
p.advance()
|
||||
if p.done {
|
||||
p.cur.value = ""
|
||||
} else if len(p.cur.value) > 0 && isQuote(p.cur.value[0]) {
|
||||
// Look for multiple quoted strings separated by whitespace,
|
||||
// and concatenate them.
|
||||
cat := p.cur
|
||||
for {
|
||||
p.skipWhitespace()
|
||||
if p.done || !isQuote(p.s[0]) {
|
||||
break
|
||||
}
|
||||
p.advance()
|
||||
if p.cur.err != nil {
|
||||
return &p.cur
|
||||
}
|
||||
cat.value += " " + p.cur.value
|
||||
cat.unquoted += p.cur.unquoted
|
||||
}
|
||||
p.done = false // parser may have seen EOF, but we want to return cat
|
||||
p.cur = cat
|
||||
}
|
||||
return &p.cur
|
||||
}
|
||||
|
||||
func (p *textParser) consumeToken(s string) error {
|
||||
tok := p.next()
|
||||
if tok.err != nil {
|
||||
return tok.err
|
||||
}
|
||||
if tok.value != s {
|
||||
p.back()
|
||||
return p.errorf("expected %q, found %q", s, tok.value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return a RequiredNotSetError indicating which required field was not set.
|
||||
func (p *textParser) missingRequiredFieldError(sv reflect.Value) *RequiredNotSetError {
|
||||
st := sv.Type()
|
||||
sprops := GetProperties(st)
|
||||
for i := 0; i < st.NumField(); i++ {
|
||||
if !isNil(sv.Field(i)) {
|
||||
continue
|
||||
}
|
||||
|
||||
props := sprops.Prop[i]
|
||||
if props.Required {
|
||||
return &RequiredNotSetError{fmt.Sprintf("%v.%v", st, props.OrigName)}
|
||||
}
|
||||
}
|
||||
return &RequiredNotSetError{fmt.Sprintf("%v.<unknown field name>", st)} // should not happen
|
||||
}
|
||||
|
||||
// Returns the index in the struct for the named field, as well as the parsed tag properties.
|
||||
func structFieldByName(sprops *StructProperties, name string) (int, *Properties, bool) {
|
||||
i, ok := sprops.decoderOrigNames[name]
|
||||
if ok {
|
||||
return i, sprops.Prop[i], true
|
||||
}
|
||||
return -1, nil, false
|
||||
}
|
||||
|
||||
// Consume a ':' from the input stream (if the next token is a colon),
|
||||
// returning an error if a colon is needed but not present.
|
||||
func (p *textParser) checkForColon(props *Properties, typ reflect.Type) *ParseError {
|
||||
tok := p.next()
|
||||
if tok.err != nil {
|
||||
return tok.err
|
||||
}
|
||||
if tok.value != ":" {
|
||||
// Colon is optional when the field is a group or message.
|
||||
needColon := true
|
||||
switch props.Wire {
|
||||
case "group":
|
||||
needColon = false
|
||||
case "bytes":
|
||||
// A "bytes" field is either a message, a string, or a repeated field;
|
||||
// those three become *T, *string and []T respectively, so we can check for
|
||||
// this field being a pointer to a non-string.
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
// *T or *string
|
||||
if typ.Elem().Kind() == reflect.String {
|
||||
break
|
||||
}
|
||||
} else if typ.Kind() == reflect.Slice {
|
||||
// []T or []*T
|
||||
if typ.Elem().Kind() != reflect.Ptr {
|
||||
break
|
||||
}
|
||||
} else if typ.Kind() == reflect.String {
|
||||
// The proto3 exception is for a string field,
|
||||
// which requires a colon.
|
||||
break
|
||||
}
|
||||
needColon = false
|
||||
}
|
||||
if needColon {
|
||||
return p.errorf("expected ':', found %q", tok.value)
|
||||
}
|
||||
p.back()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
|
||||
st := sv.Type()
|
||||
sprops := GetProperties(st)
|
||||
reqCount := sprops.reqCount
|
||||
var reqFieldErr error
|
||||
fieldSet := make(map[string]bool)
|
||||
// A struct is a sequence of "name: value", terminated by one of
|
||||
// '>' or '}', or the end of the input. A name may also be
|
||||
// "[extension]" or "[type/url]".
|
||||
//
|
||||
// The whole struct can also be an expanded Any message, like:
|
||||
// [type/url] < ... struct contents ... >
|
||||
for {
|
||||
tok := p.next()
|
||||
if tok.err != nil {
|
||||
return tok.err
|
||||
}
|
||||
if tok.value == terminator {
|
||||
break
|
||||
}
|
||||
if tok.value == "[" {
|
||||
// Looks like an extension or an Any.
|
||||
//
|
||||
// TODO: Check whether we need to handle
|
||||
// namespace rooted names (e.g. ".something.Foo").
|
||||
extName, err := p.consumeExtName()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s := strings.LastIndex(extName, "/"); s >= 0 {
|
||||
// If it contains a slash, it's an Any type URL.
|
||||
messageName := extName[s+1:]
|
||||
mt := MessageType(messageName)
|
||||
if mt == nil {
|
||||
return p.errorf("unrecognized message %q in google.protobuf.Any", messageName)
|
||||
}
|
||||
tok = p.next()
|
||||
if tok.err != nil {
|
||||
return tok.err
|
||||
}
|
||||
// consume an optional colon
|
||||
if tok.value == ":" {
|
||||
tok = p.next()
|
||||
if tok.err != nil {
|
||||
return tok.err
|
||||
}
|
||||
}
|
||||
var terminator string
|
||||
switch tok.value {
|
||||
case "<":
|
||||
terminator = ">"
|
||||
case "{":
|
||||
terminator = "}"
|
||||
default:
|
||||
return p.errorf("expected '{' or '<', found %q", tok.value)
|
||||
}
|
||||
v := reflect.New(mt.Elem())
|
||||
if pe := p.readStruct(v.Elem(), terminator); pe != nil {
|
||||
return pe
|
||||
}
|
||||
b, err := Marshal(v.Interface().(Message))
|
||||
if err != nil {
|
||||
return p.errorf("failed to marshal message of type %q: %v", messageName, err)
|
||||
}
|
||||
if fieldSet["type_url"] {
|
||||
return p.errorf(anyRepeatedlyUnpacked, "type_url")
|
||||
}
|
||||
if fieldSet["value"] {
|
||||
return p.errorf(anyRepeatedlyUnpacked, "value")
|
||||
}
|
||||
sv.FieldByName("TypeUrl").SetString(extName)
|
||||
sv.FieldByName("Value").SetBytes(b)
|
||||
fieldSet["type_url"] = true
|
||||
fieldSet["value"] = true
|
||||
continue
|
||||
}
|
||||
|
||||
var desc *ExtensionDesc
|
||||
// This could be faster, but it's functional.
|
||||
// TODO: Do something smarter than a linear scan.
|
||||
for _, d := range RegisteredExtensions(reflect.New(st).Interface().(Message)) {
|
||||
if d.Name == extName {
|
||||
desc = d
|
||||
break
|
||||
}
|
||||
}
|
||||
if desc == nil {
|
||||
return p.errorf("unrecognized extension %q", extName)
|
||||
}
|
||||
|
||||
props := &Properties{}
|
||||
props.Parse(desc.Tag)
|
||||
|
||||
typ := reflect.TypeOf(desc.ExtensionType)
|
||||
if err := p.checkForColon(props, typ); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rep := desc.repeated()
|
||||
|
||||
// Read the extension structure, and set it in
|
||||
// the value we're constructing.
|
||||
var ext reflect.Value
|
||||
if !rep {
|
||||
ext = reflect.New(typ).Elem()
|
||||
} else {
|
||||
ext = reflect.New(typ.Elem()).Elem()
|
||||
}
|
||||
if err := p.readAny(ext, props); err != nil {
|
||||
if _, ok := err.(*RequiredNotSetError); !ok {
|
||||
return err
|
||||
}
|
||||
reqFieldErr = err
|
||||
}
|
||||
ep := sv.Addr().Interface().(Message)
|
||||
if !rep {
|
||||
SetExtension(ep, desc, ext.Interface())
|
||||
} else {
|
||||
old, err := GetExtension(ep, desc)
|
||||
var sl reflect.Value
|
||||
if err == nil {
|
||||
sl = reflect.ValueOf(old) // existing slice
|
||||
} else {
|
||||
sl = reflect.MakeSlice(typ, 0, 1)
|
||||
}
|
||||
sl = reflect.Append(sl, ext)
|
||||
SetExtension(ep, desc, sl.Interface())
|
||||
}
|
||||
if err := p.consumeOptionalSeparator(); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// This is a normal, non-extension field.
|
||||
name := tok.value
|
||||
var dst reflect.Value
|
||||
fi, props, ok := structFieldByName(sprops, name)
|
||||
if ok {
|
||||
dst = sv.Field(fi)
|
||||
} else if oop, ok := sprops.OneofTypes[name]; ok {
|
||||
// It is a oneof.
|
||||
props = oop.Prop
|
||||
nv := reflect.New(oop.Type.Elem())
|
||||
dst = nv.Elem().Field(0)
|
||||
field := sv.Field(oop.Field)
|
||||
if !field.IsNil() {
|
||||
return p.errorf("field '%s' would overwrite already parsed oneof '%s'", name, sv.Type().Field(oop.Field).Name)
|
||||
}
|
||||
field.Set(nv)
|
||||
}
|
||||
if !dst.IsValid() {
|
||||
return p.errorf("unknown field name %q in %v", name, st)
|
||||
}
|
||||
|
||||
if dst.Kind() == reflect.Map {
|
||||
// Consume any colon.
|
||||
if err := p.checkForColon(props, dst.Type()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Construct the map if it doesn't already exist.
|
||||
if dst.IsNil() {
|
||||
dst.Set(reflect.MakeMap(dst.Type()))
|
||||
}
|
||||
key := reflect.New(dst.Type().Key()).Elem()
|
||||
val := reflect.New(dst.Type().Elem()).Elem()
|
||||
|
||||
// The map entry should be this sequence of tokens:
|
||||
// < key : KEY value : VALUE >
|
||||
// However, implementations may omit key or value, and technically
|
||||
// we should support them in any order. See b/28924776 for a time
|
||||
// this went wrong.
|
||||
|
||||
tok := p.next()
|
||||
var terminator string
|
||||
switch tok.value {
|
||||
case "<":
|
||||
terminator = ">"
|
||||
case "{":
|
||||
terminator = "}"
|
||||
default:
|
||||
return p.errorf("expected '{' or '<', found %q", tok.value)
|
||||
}
|
||||
for {
|
||||
tok := p.next()
|
||||
if tok.err != nil {
|
||||
return tok.err
|
||||
}
|
||||
if tok.value == terminator {
|
||||
break
|
||||
}
|
||||
switch tok.value {
|
||||
case "key":
|
||||
if err := p.consumeToken(":"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := p.readAny(key, props.mkeyprop); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := p.consumeOptionalSeparator(); err != nil {
|
||||
return err
|
||||
}
|
||||
case "value":
|
||||
if err := p.checkForColon(props.mvalprop, dst.Type().Elem()); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := p.readAny(val, props.mvalprop); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := p.consumeOptionalSeparator(); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
p.back()
|
||||
return p.errorf(`expected "key", "value", or %q, found %q`, terminator, tok.value)
|
||||
}
|
||||
}
|
||||
|
||||
dst.SetMapIndex(key, val)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check that it's not already set if it's not a repeated field.
|
||||
if !props.Repeated && fieldSet[name] {
|
||||
return p.errorf("non-repeated field %q was repeated", name)
|
||||
}
|
||||
|
||||
if err := p.checkForColon(props, dst.Type()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse into the field.
|
||||
fieldSet[name] = true
|
||||
if err := p.readAny(dst, props); err != nil {
|
||||
if _, ok := err.(*RequiredNotSetError); !ok {
|
||||
return err
|
||||
}
|
||||
reqFieldErr = err
|
||||
}
|
||||
if props.Required {
|
||||
reqCount--
|
||||
}
|
||||
|
||||
if err := p.consumeOptionalSeparator(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if reqCount > 0 {
|
||||
return p.missingRequiredFieldError(sv)
|
||||
}
|
||||
return reqFieldErr
|
||||
}
|
||||
|
||||
// consumeExtName consumes extension name or expanded Any type URL and the
|
||||
// following ']'. It returns the name or URL consumed.
|
||||
func (p *textParser) consumeExtName() (string, error) {
|
||||
tok := p.next()
|
||||
if tok.err != nil {
|
||||
return "", tok.err
|
||||
}
|
||||
|
||||
// If extension name or type url is quoted, it's a single token.
|
||||
if len(tok.value) > 2 && isQuote(tok.value[0]) && tok.value[len(tok.value)-1] == tok.value[0] {
|
||||
name, err := unquoteC(tok.value[1:len(tok.value)-1], rune(tok.value[0]))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return name, p.consumeToken("]")
|
||||
}
|
||||
|
||||
// Consume everything up to "]"
|
||||
var parts []string
|
||||
for tok.value != "]" {
|
||||
parts = append(parts, tok.value)
|
||||
tok = p.next()
|
||||
if tok.err != nil {
|
||||
return "", p.errorf("unrecognized type_url or extension name: %s", tok.err)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, ""), nil
|
||||
}
|
||||
|
||||
// consumeOptionalSeparator consumes an optional semicolon or comma.
|
||||
// It is used in readStruct to provide backward compatibility.
|
||||
func (p *textParser) consumeOptionalSeparator() error {
|
||||
tok := p.next()
|
||||
if tok.err != nil {
|
||||
return tok.err
|
||||
}
|
||||
if tok.value != ";" && tok.value != "," {
|
||||
p.back()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *textParser) readAny(v reflect.Value, props *Properties) error {
|
||||
tok := p.next()
|
||||
if tok.err != nil {
|
||||
return tok.err
|
||||
}
|
||||
if tok.value == "" {
|
||||
return p.errorf("unexpected EOF")
|
||||
}
|
||||
|
||||
switch fv := v; fv.Kind() {
|
||||
case reflect.Slice:
|
||||
at := v.Type()
|
||||
if at.Elem().Kind() == reflect.Uint8 {
|
||||
// Special case for []byte
|
||||
if tok.value[0] != '"' && tok.value[0] != '\'' {
|
||||
// Deliberately written out here, as the error after
|
||||
// this switch statement would write "invalid []byte: ...",
|
||||
// which is not as user-friendly.
|
||||
return p.errorf("invalid string: %v", tok.value)
|
||||
}
|
||||
bytes := []byte(tok.unquoted)
|
||||
fv.Set(reflect.ValueOf(bytes))
|
||||
return nil
|
||||
}
|
||||
// Repeated field.
|
||||
if tok.value == "[" {
|
||||
// Repeated field with list notation, like [1,2,3].
|
||||
for {
|
||||
fv.Set(reflect.Append(fv, reflect.New(at.Elem()).Elem()))
|
||||
err := p.readAny(fv.Index(fv.Len()-1), props)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tok := p.next()
|
||||
if tok.err != nil {
|
||||
return tok.err
|
||||
}
|
||||
if tok.value == "]" {
|
||||
break
|
||||
}
|
||||
if tok.value != "," {
|
||||
return p.errorf("Expected ']' or ',' found %q", tok.value)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// One value of the repeated field.
|
||||
p.back()
|
||||
fv.Set(reflect.Append(fv, reflect.New(at.Elem()).Elem()))
|
||||
return p.readAny(fv.Index(fv.Len()-1), props)
|
||||
case reflect.Bool:
|
||||
// true/1/t/True or false/f/0/False.
|
||||
switch tok.value {
|
||||
case "true", "1", "t", "True":
|
||||
fv.SetBool(true)
|
||||
return nil
|
||||
case "false", "0", "f", "False":
|
||||
fv.SetBool(false)
|
||||
return nil
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
v := tok.value
|
||||
// Ignore 'f' for compatibility with output generated by C++, but don't
|
||||
// remove 'f' when the value is "-inf" or "inf".
|
||||
if strings.HasSuffix(v, "f") && tok.value != "-inf" && tok.value != "inf" {
|
||||
v = v[:len(v)-1]
|
||||
}
|
||||
if f, err := strconv.ParseFloat(v, fv.Type().Bits()); err == nil {
|
||||
fv.SetFloat(f)
|
||||
return nil
|
||||
}
|
||||
case reflect.Int32:
|
||||
if x, err := strconv.ParseInt(tok.value, 0, 32); err == nil {
|
||||
fv.SetInt(x)
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(props.Enum) == 0 {
|
||||
break
|
||||
}
|
||||
m, ok := enumValueMaps[props.Enum]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
x, ok := m[tok.value]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
fv.SetInt(int64(x))
|
||||
return nil
|
||||
case reflect.Int64:
|
||||
if x, err := strconv.ParseInt(tok.value, 0, 64); err == nil {
|
||||
fv.SetInt(x)
|
||||
return nil
|
||||
}
|
||||
|
||||
case reflect.Ptr:
|
||||
// A basic field (indirected through pointer), or a repeated message/group
|
||||
p.back()
|
||||
fv.Set(reflect.New(fv.Type().Elem()))
|
||||
return p.readAny(fv.Elem(), props)
|
||||
case reflect.String:
|
||||
if tok.value[0] == '"' || tok.value[0] == '\'' {
|
||||
fv.SetString(tok.unquoted)
|
||||
return nil
|
||||
}
|
||||
case reflect.Struct:
|
||||
var terminator string
|
||||
switch tok.value {
|
||||
case "{":
|
||||
terminator = "}"
|
||||
case "<":
|
||||
terminator = ">"
|
||||
default:
|
||||
return p.errorf("expected '{' or '<', found %q", tok.value)
|
||||
}
|
||||
// TODO: Handle nested messages which implement encoding.TextUnmarshaler.
|
||||
return p.readStruct(fv, terminator)
|
||||
case reflect.Uint32:
|
||||
if x, err := strconv.ParseUint(tok.value, 0, 32); err == nil {
|
||||
fv.SetUint(uint64(x))
|
||||
return nil
|
||||
}
|
||||
case reflect.Uint64:
|
||||
if x, err := strconv.ParseUint(tok.value, 0, 64); err == nil {
|
||||
fv.SetUint(x)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return p.errorf("invalid %v: %v", v.Type(), tok.value)
|
||||
}
|
||||
|
||||
// UnmarshalText reads a protocol buffer in Text format. UnmarshalText resets pb
|
||||
// before starting to unmarshal, so any existing data in pb is always removed.
|
||||
// If a required field is not set and no other error occurs,
|
||||
// UnmarshalText returns *RequiredNotSetError.
|
||||
func UnmarshalText(s string, pb Message) error {
|
||||
if um, ok := pb.(encoding.TextUnmarshaler); ok {
|
||||
err := um.UnmarshalText([]byte(s))
|
||||
return err
|
||||
}
|
||||
pb.Reset()
|
||||
v := reflect.ValueOf(pb)
|
||||
if pe := newTextParser(s).readStruct(v.Elem(), ""); pe != nil {
|
||||
return pe
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,220 @@
|
|||
package hostpool
|
||||
|
||||
import (
|
||||
"log"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
type epsilonHostPoolResponse struct {
|
||||
standardHostPoolResponse
|
||||
started time.Time
|
||||
ended time.Time
|
||||
}
|
||||
|
||||
func (r *epsilonHostPoolResponse) Mark(err error) {
|
||||
r.Do(func() {
|
||||
r.ended = time.Now()
|
||||
doMark(err, r)
|
||||
})
|
||||
}
|
||||
|
||||
type epsilonGreedyHostPool struct {
|
||||
standardHostPool // TODO - would be nifty if we could embed HostPool and Locker interfaces
|
||||
epsilon float32 // this is our exploration factor
|
||||
decayDuration time.Duration
|
||||
EpsilonValueCalculator // embed the epsilonValueCalculator
|
||||
timer
|
||||
quit chan bool
|
||||
}
|
||||
|
||||
// Construct an Epsilon Greedy HostPool
|
||||
//
|
||||
// Epsilon Greedy is an algorithm that allows HostPool not only to track failure state,
|
||||
// but also to learn about "better" options in terms of speed, and to pick from available hosts
|
||||
// based on how well they perform. This gives a weighted request rate to better
|
||||
// performing hosts, while still distributing requests to all hosts (proportionate to their performance).
|
||||
// The interface is the same as the standard HostPool, but be sure to mark the HostResponse immediately
|
||||
// after executing the request to the host, as that will stop the implicitly running request timer.
|
||||
//
|
||||
// A good overview of Epsilon Greedy is here http://stevehanov.ca/blog/index.php?id=132
|
||||
//
|
||||
// To compute the weighting scores, we perform a weighted average of recent response times, over the course of
|
||||
// `decayDuration`. decayDuration may be set to 0 to use the default value of 5 minutes
|
||||
// We then use the supplied EpsilonValueCalculator to calculate a score from that weighted average response time.
|
||||
func NewEpsilonGreedy(hosts []string, decayDuration time.Duration, calc EpsilonValueCalculator) HostPool {
|
||||
|
||||
if decayDuration <= 0 {
|
||||
decayDuration = defaultDecayDuration
|
||||
}
|
||||
stdHP := New(hosts).(*standardHostPool)
|
||||
p := &epsilonGreedyHostPool{
|
||||
standardHostPool: *stdHP,
|
||||
epsilon: float32(initialEpsilon),
|
||||
decayDuration: decayDuration,
|
||||
EpsilonValueCalculator: calc,
|
||||
timer: &realTimer{},
|
||||
quit: make(chan bool),
|
||||
}
|
||||
|
||||
// allocate structures
|
||||
for _, h := range p.hostList {
|
||||
h.epsilonCounts = make([]int64, epsilonBuckets)
|
||||
h.epsilonValues = make([]int64, epsilonBuckets)
|
||||
}
|
||||
go p.epsilonGreedyDecay()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *epsilonGreedyHostPool) Close() {
|
||||
// No need to do p.quit <- true as close(p.quit) does the trick.
|
||||
close(p.quit)
|
||||
}
|
||||
|
||||
func (p *epsilonGreedyHostPool) SetEpsilon(newEpsilon float32) {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
p.epsilon = newEpsilon
|
||||
}
|
||||
|
||||
func (p *epsilonGreedyHostPool) SetHosts(hosts []string) {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
p.standardHostPool.setHosts(hosts)
|
||||
for _, h := range p.hostList {
|
||||
h.epsilonCounts = make([]int64, epsilonBuckets)
|
||||
h.epsilonValues = make([]int64, epsilonBuckets)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *epsilonGreedyHostPool) epsilonGreedyDecay() {
|
||||
durationPerBucket := p.decayDuration / epsilonBuckets
|
||||
ticker := time.NewTicker(durationPerBucket)
|
||||
for {
|
||||
select {
|
||||
case <-p.quit:
|
||||
ticker.Stop()
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.performEpsilonGreedyDecay()
|
||||
}
|
||||
}
|
||||
}
|
||||
func (p *epsilonGreedyHostPool) performEpsilonGreedyDecay() {
|
||||
p.Lock()
|
||||
for _, h := range p.hostList {
|
||||
h.epsilonIndex += 1
|
||||
h.epsilonIndex = h.epsilonIndex % epsilonBuckets
|
||||
h.epsilonCounts[h.epsilonIndex] = 0
|
||||
h.epsilonValues[h.epsilonIndex] = 0
|
||||
}
|
||||
p.Unlock()
|
||||
}
|
||||
|
||||
func (p *epsilonGreedyHostPool) Get() HostPoolResponse {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
host := p.getEpsilonGreedy()
|
||||
if host == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
started := time.Now()
|
||||
return &epsilonHostPoolResponse{
|
||||
standardHostPoolResponse: standardHostPoolResponse{host: host, pool: p},
|
||||
started: started,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *epsilonGreedyHostPool) getEpsilonGreedy() string {
|
||||
var hostToUse *hostEntry
|
||||
|
||||
// this is our exploration phase
|
||||
if rand.Float32() < p.epsilon {
|
||||
p.epsilon = p.epsilon * epsilonDecay
|
||||
if p.epsilon < minEpsilon {
|
||||
p.epsilon = minEpsilon
|
||||
}
|
||||
return p.getRoundRobin()
|
||||
}
|
||||
|
||||
// calculate values for each host in the 0..1 range (but not ormalized)
|
||||
var possibleHosts []*hostEntry
|
||||
now := time.Now()
|
||||
var sumValues float64
|
||||
for _, h := range p.hostList {
|
||||
if h.canTryHost(now) {
|
||||
v := h.getWeightedAverageResponseTime()
|
||||
if v > 0 {
|
||||
ev := p.CalcValueFromAvgResponseTime(v)
|
||||
h.epsilonValue = ev
|
||||
sumValues += ev
|
||||
possibleHosts = append(possibleHosts, h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(possibleHosts) != 0 {
|
||||
// now normalize to the 0..1 range to get a percentage
|
||||
for _, h := range possibleHosts {
|
||||
h.epsilonPercentage = h.epsilonValue / sumValues
|
||||
}
|
||||
|
||||
// do a weighted random choice among hosts
|
||||
ceiling := 0.0
|
||||
pickPercentage := rand.Float64()
|
||||
for _, h := range possibleHosts {
|
||||
ceiling += h.epsilonPercentage
|
||||
if pickPercentage <= ceiling {
|
||||
hostToUse = h
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hostToUse == nil {
|
||||
if len(possibleHosts) != 0 {
|
||||
log.Println("Failed to randomly choose a host, Dan loses")
|
||||
}
|
||||
|
||||
return p.getRoundRobin()
|
||||
}
|
||||
|
||||
if hostToUse.dead {
|
||||
hostToUse.willRetryHost(p.maxRetryInterval)
|
||||
}
|
||||
return hostToUse.host
|
||||
}
|
||||
|
||||
func (p *epsilonGreedyHostPool) markSuccess(hostR HostPoolResponse) {
|
||||
// first do the base markSuccess - a little redundant with host lookup but cleaner than repeating logic
|
||||
p.standardHostPool.markSuccess(hostR)
|
||||
eHostR, ok := hostR.(*epsilonHostPoolResponse)
|
||||
if !ok {
|
||||
log.Printf("Incorrect type in eps markSuccess!") // TODO reflection to print out offending type
|
||||
return
|
||||
}
|
||||
host := eHostR.host
|
||||
duration := p.between(eHostR.started, eHostR.ended)
|
||||
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
h, ok := p.hosts[host]
|
||||
if !ok {
|
||||
log.Fatalf("host %s not in HostPool %v", host, p.Hosts())
|
||||
}
|
||||
h.epsilonCounts[h.epsilonIndex]++
|
||||
h.epsilonValues[h.epsilonIndex] += int64(duration.Seconds() * 1000)
|
||||
}
|
||||
|
||||
// --- timer: this just exists for testing
|
||||
|
||||
type timer interface {
|
||||
between(time.Time, time.Time) time.Duration
|
||||
}
|
||||
|
||||
type realTimer struct{}
|
||||
|
||||
func (rt *realTimer) between(start time.Time, end time.Time) time.Duration {
|
||||
return end.Sub(start)
|
||||
}
|
40
vendor/github.com/hailocab/go-hostpool/epsilon_value_calculators.go
generated
vendored
Normal file
40
vendor/github.com/hailocab/go-hostpool/epsilon_value_calculators.go
generated
vendored
Normal file
|
@ -0,0 +1,40 @@
|
|||
package hostpool
|
||||
|
||||
// --- Value Calculators -----------------
|
||||
|
||||
import (
|
||||
"math"
|
||||
)
|
||||
|
||||
// --- Definitions -----------------------
|
||||
|
||||
// Structs implementing this interface are used to convert the average response time for a host
|
||||
// into a score that can be used to weight hosts in the epsilon greedy hostpool. Lower response
|
||||
// times should yield higher scores (we want to select the faster hosts more often) The default
|
||||
// LinearEpsilonValueCalculator just uses the reciprocal of the response time. In practice, any
|
||||
// decreasing function from the positive reals to the positive reals should work.
|
||||
type EpsilonValueCalculator interface {
|
||||
CalcValueFromAvgResponseTime(float64) float64
|
||||
}
|
||||
|
||||
type LinearEpsilonValueCalculator struct{}
|
||||
type LogEpsilonValueCalculator struct{ LinearEpsilonValueCalculator }
|
||||
type PolynomialEpsilonValueCalculator struct {
|
||||
LinearEpsilonValueCalculator
|
||||
Exp float64 // the exponent to which we will raise the value to reweight
|
||||
}
|
||||
|
||||
// -------- Methods -----------------------
|
||||
|
||||
func (c *LinearEpsilonValueCalculator) CalcValueFromAvgResponseTime(v float64) float64 {
|
||||
return 1.0 / v
|
||||
}
|
||||
|
||||
func (c *LogEpsilonValueCalculator) CalcValueFromAvgResponseTime(v float64) float64 {
|
||||
// we need to add 1 to v so that this will be defined on all positive floats
|
||||
return c.LinearEpsilonValueCalculator.CalcValueFromAvgResponseTime(math.Log(v + 1.0))
|
||||
}
|
||||
|
||||
func (c *PolynomialEpsilonValueCalculator) CalcValueFromAvgResponseTime(v float64) float64 {
|
||||
return c.LinearEpsilonValueCalculator.CalcValueFromAvgResponseTime(math.Pow(v, c.Exp))
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
package hostpool
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- hostEntry - this is due to get upgraded
|
||||
|
||||
type hostEntry struct {
|
||||
host string
|
||||
nextRetry time.Time
|
||||
retryCount int16
|
||||
retryDelay time.Duration
|
||||
dead bool
|
||||
epsilonCounts []int64
|
||||
epsilonValues []int64
|
||||
epsilonIndex int
|
||||
epsilonValue float64
|
||||
epsilonPercentage float64
|
||||
}
|
||||
|
||||
func (h *hostEntry) canTryHost(now time.Time) bool {
|
||||
if !h.dead {
|
||||
return true
|
||||
}
|
||||
if h.nextRetry.Before(now) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *hostEntry) willRetryHost(maxRetryInterval time.Duration) {
|
||||
h.retryCount += 1
|
||||
newDelay := h.retryDelay * 2
|
||||
if newDelay < maxRetryInterval {
|
||||
h.retryDelay = newDelay
|
||||
} else {
|
||||
h.retryDelay = maxRetryInterval
|
||||
}
|
||||
h.nextRetry = time.Now().Add(h.retryDelay)
|
||||
}
|
||||
|
||||
func (h *hostEntry) getWeightedAverageResponseTime() float64 {
|
||||
var value float64
|
||||
var lastValue float64
|
||||
|
||||
// start at 1 so we start with the oldest entry
|
||||
for i := 1; i <= epsilonBuckets; i += 1 {
|
||||
pos := (h.epsilonIndex + i) % epsilonBuckets
|
||||
bucketCount := h.epsilonCounts[pos]
|
||||
// Changing the line below to what I think it should be to get the weights right
|
||||
weight := float64(i) / float64(epsilonBuckets)
|
||||
if bucketCount > 0 {
|
||||
currentValue := float64(h.epsilonValues[pos]) / float64(bucketCount)
|
||||
value += currentValue * weight
|
||||
lastValue = currentValue
|
||||
} else {
|
||||
value += lastValue * weight
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
|
@ -0,0 +1,243 @@
|
|||
// A Go package to intelligently and flexibly pool among multiple hosts from your Go application.
|
||||
// Host selection can operate in round robin or epsilon greedy mode, and unresponsive hosts are
|
||||
// avoided. A good overview of Epsilon Greedy is here http://stevehanov.ca/blog/index.php?id=132
|
||||
package hostpool
|
||||
|
||||
import (
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Returns current version
|
||||
func Version() string {
|
||||
return "0.1"
|
||||
}
|
||||
|
||||
// --- Response interfaces and structs ----
|
||||
|
||||
// This interface represents the response from HostPool. You can retrieve the
|
||||
// hostname by calling Host(), and after making a request to the host you should
|
||||
// call Mark with any error encountered, which will inform the HostPool issuing
|
||||
// the HostPoolResponse of what happened to the request and allow it to update.
|
||||
type HostPoolResponse interface {
|
||||
Host() string
|
||||
Mark(error)
|
||||
hostPool() HostPool
|
||||
}
|
||||
|
||||
type standardHostPoolResponse struct {
|
||||
host string
|
||||
sync.Once
|
||||
pool HostPool
|
||||
}
|
||||
|
||||
// --- HostPool structs and interfaces ----
|
||||
|
||||
// This is the main HostPool interface. Structs implementing this interface
|
||||
// allow you to Get a HostPoolResponse (which includes a hostname to use),
|
||||
// get the list of all Hosts, and use ResetAll to reset state.
|
||||
type HostPool interface {
|
||||
Get() HostPoolResponse
|
||||
// keep the marks separate so we can override independently
|
||||
markSuccess(HostPoolResponse)
|
||||
markFailed(HostPoolResponse)
|
||||
|
||||
ResetAll()
|
||||
// ReturnUnhealthy when called with true will prevent an unhealthy node from
|
||||
// being returned and will instead return a nil HostPoolResponse. If using
|
||||
// this feature then you should check the result of Get for nil
|
||||
ReturnUnhealthy(v bool)
|
||||
Hosts() []string
|
||||
SetHosts([]string)
|
||||
|
||||
// Close the hostpool and release all resources.
|
||||
Close()
|
||||
}
|
||||
|
||||
type standardHostPool struct {
|
||||
sync.RWMutex
|
||||
hosts map[string]*hostEntry
|
||||
hostList []*hostEntry
|
||||
returnUnhealthy bool
|
||||
initialRetryDelay time.Duration
|
||||
maxRetryInterval time.Duration
|
||||
nextHostIndex int
|
||||
}
|
||||
|
||||
// ------ constants -------------------
|
||||
|
||||
const epsilonBuckets = 120
|
||||
const epsilonDecay = 0.90 // decay the exploration rate
|
||||
const minEpsilon = 0.01 // explore one percent of the time
|
||||
const initialEpsilon = 0.3
|
||||
const defaultDecayDuration = time.Duration(5) * time.Minute
|
||||
|
||||
// Construct a basic HostPool using the hostnames provided
|
||||
func New(hosts []string) HostPool {
|
||||
p := &standardHostPool{
|
||||
returnUnhealthy: true,
|
||||
hosts: make(map[string]*hostEntry, len(hosts)),
|
||||
hostList: make([]*hostEntry, len(hosts)),
|
||||
initialRetryDelay: time.Duration(30) * time.Second,
|
||||
maxRetryInterval: time.Duration(900) * time.Second,
|
||||
}
|
||||
|
||||
for i, h := range hosts {
|
||||
e := &hostEntry{
|
||||
host: h,
|
||||
retryDelay: p.initialRetryDelay,
|
||||
}
|
||||
p.hosts[h] = e
|
||||
p.hostList[i] = e
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func (r *standardHostPoolResponse) Host() string {
|
||||
return r.host
|
||||
}
|
||||
|
||||
func (r *standardHostPoolResponse) hostPool() HostPool {
|
||||
return r.pool
|
||||
}
|
||||
|
||||
func (r *standardHostPoolResponse) Mark(err error) {
|
||||
r.Do(func() {
|
||||
doMark(err, r)
|
||||
})
|
||||
}
|
||||
|
||||
func doMark(err error, r HostPoolResponse) {
|
||||
if err == nil {
|
||||
r.hostPool().markSuccess(r)
|
||||
} else {
|
||||
r.hostPool().markFailed(r)
|
||||
}
|
||||
}
|
||||
|
||||
// return an entry from the HostPool
|
||||
func (p *standardHostPool) Get() HostPoolResponse {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
host := p.getRoundRobin()
|
||||
if host == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &standardHostPoolResponse{host: host, pool: p}
|
||||
}
|
||||
|
||||
func (p *standardHostPool) getRoundRobin() string {
|
||||
now := time.Now()
|
||||
hostCount := len(p.hostList)
|
||||
for i := range p.hostList {
|
||||
// iterate via sequenece from where we last iterated
|
||||
currentIndex := (i + p.nextHostIndex) % hostCount
|
||||
|
||||
h := p.hostList[currentIndex]
|
||||
if !h.dead {
|
||||
p.nextHostIndex = currentIndex + 1
|
||||
return h.host
|
||||
}
|
||||
if h.nextRetry.Before(now) {
|
||||
h.willRetryHost(p.maxRetryInterval)
|
||||
p.nextHostIndex = currentIndex + 1
|
||||
return h.host
|
||||
}
|
||||
}
|
||||
|
||||
// all hosts are down and returnUnhealhy is false then return no host
|
||||
if !p.returnUnhealthy {
|
||||
return ""
|
||||
}
|
||||
|
||||
// all hosts are down. re-add them
|
||||
p.doResetAll()
|
||||
p.nextHostIndex = 0
|
||||
return p.hostList[0].host
|
||||
}
|
||||
|
||||
func (p *standardHostPool) ResetAll() {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
p.doResetAll()
|
||||
}
|
||||
|
||||
func (p *standardHostPool) SetHosts(hosts []string) {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
p.setHosts(hosts)
|
||||
}
|
||||
|
||||
func (p *standardHostPool) ReturnUnhealthy(v bool) {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
p.returnUnhealthy = v
|
||||
}
|
||||
|
||||
func (p *standardHostPool) setHosts(hosts []string) {
|
||||
p.hosts = make(map[string]*hostEntry, len(hosts))
|
||||
p.hostList = make([]*hostEntry, len(hosts))
|
||||
|
||||
for i, h := range hosts {
|
||||
e := &hostEntry{
|
||||
host: h,
|
||||
retryDelay: p.initialRetryDelay,
|
||||
}
|
||||
p.hosts[h] = e
|
||||
p.hostList[i] = e
|
||||
}
|
||||
}
|
||||
|
||||
// this actually performs the logic to reset,
|
||||
// and should only be called when the lock has
|
||||
// already been acquired
|
||||
func (p *standardHostPool) doResetAll() {
|
||||
for _, h := range p.hosts {
|
||||
h.dead = false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *standardHostPool) Close() {
|
||||
for _, h := range p.hosts {
|
||||
h.dead = true
|
||||
}
|
||||
}
|
||||
|
||||
func (p *standardHostPool) markSuccess(hostR HostPoolResponse) {
|
||||
host := hostR.Host()
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
h, ok := p.hosts[host]
|
||||
if !ok {
|
||||
log.Fatalf("host %s not in HostPool %v", host, p.Hosts())
|
||||
}
|
||||
h.dead = false
|
||||
}
|
||||
|
||||
func (p *standardHostPool) markFailed(hostR HostPoolResponse) {
|
||||
host := hostR.Host()
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
h, ok := p.hosts[host]
|
||||
if !ok {
|
||||
log.Fatalf("host %s not in HostPool %v", host, p.Hosts())
|
||||
}
|
||||
if !h.dead {
|
||||
h.dead = true
|
||||
h.retryCount = 0
|
||||
h.retryDelay = p.initialRetryDelay
|
||||
h.nextRetry = time.Now().Add(h.retryDelay)
|
||||
}
|
||||
|
||||
}
|
||||
func (p *standardHostPool) Hosts() []string {
|
||||
hosts := make([]string, 0, len(p.hosts))
|
||||
for host := range p.hosts {
|
||||
hosts = append(hosts, host)
|
||||
}
|
||||
return hosts
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
package yamux
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
// hasAddr is used to get the address from the underlying connection
|
||||
type hasAddr interface {
|
||||
LocalAddr() net.Addr
|
||||
RemoteAddr() net.Addr
|
||||
}
|
||||
|
||||
// yamuxAddr is used when we cannot get the underlying address
|
||||
type yamuxAddr struct {
|
||||
Addr string
|
||||
}
|
||||
|
||||
func (*yamuxAddr) Network() string {
|
||||
return "yamux"
|
||||
}
|
||||
|
||||
func (y *yamuxAddr) String() string {
|
||||
return fmt.Sprintf("yamux:%s", y.Addr)
|
||||
}
|
||||
|
||||
// Addr is used to get the address of the listener.
|
||||
func (s *Session) Addr() net.Addr {
|
||||
return s.LocalAddr()
|
||||
}
|
||||
|
||||
// LocalAddr is used to get the local address of the
|
||||
// underlying connection.
|
||||
func (s *Session) LocalAddr() net.Addr {
|
||||
addr, ok := s.conn.(hasAddr)
|
||||
if !ok {
|
||||
return &yamuxAddr{"local"}
|
||||
}
|
||||
return addr.LocalAddr()
|
||||
}
|
||||
|
||||
// RemoteAddr is used to get the address of remote end
|
||||
// of the underlying connection
|
||||
func (s *Session) RemoteAddr() net.Addr {
|
||||
addr, ok := s.conn.(hasAddr)
|
||||
if !ok {
|
||||
return &yamuxAddr{"remote"}
|
||||
}
|
||||
return addr.RemoteAddr()
|
||||
}
|
||||
|
||||
// LocalAddr returns the local address
|
||||
func (s *Stream) LocalAddr() net.Addr {
|
||||
return s.session.LocalAddr()
|
||||
}
|
||||
|
||||
// LocalAddr returns the remote address
|
||||
func (s *Stream) RemoteAddr() net.Addr {
|
||||
return s.session.RemoteAddr()
|
||||
}
|
|
@ -0,0 +1,157 @@
|
|||
package yamux
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidVersion means we received a frame with an
|
||||
// invalid version
|
||||
ErrInvalidVersion = fmt.Errorf("invalid protocol version")
|
||||
|
||||
// ErrInvalidMsgType means we received a frame with an
|
||||
// invalid message type
|
||||
ErrInvalidMsgType = fmt.Errorf("invalid msg type")
|
||||
|
||||
// ErrSessionShutdown is used if there is a shutdown during
|
||||
// an operation
|
||||
ErrSessionShutdown = fmt.Errorf("session shutdown")
|
||||
|
||||
// ErrStreamsExhausted is returned if we have no more
|
||||
// stream ids to issue
|
||||
ErrStreamsExhausted = fmt.Errorf("streams exhausted")
|
||||
|
||||
// ErrDuplicateStream is used if a duplicate stream is
|
||||
// opened inbound
|
||||
ErrDuplicateStream = fmt.Errorf("duplicate stream initiated")
|
||||
|
||||
// ErrReceiveWindowExceeded indicates the window was exceeded
|
||||
ErrRecvWindowExceeded = fmt.Errorf("recv window exceeded")
|
||||
|
||||
// ErrTimeout is used when we reach an IO deadline
|
||||
ErrTimeout = fmt.Errorf("i/o deadline reached")
|
||||
|
||||
// ErrStreamClosed is returned when using a closed stream
|
||||
ErrStreamClosed = fmt.Errorf("stream closed")
|
||||
|
||||
// ErrUnexpectedFlag is set when we get an unexpected flag
|
||||
ErrUnexpectedFlag = fmt.Errorf("unexpected flag")
|
||||
|
||||
// ErrRemoteGoAway is used when we get a go away from the other side
|
||||
ErrRemoteGoAway = fmt.Errorf("remote end is not accepting connections")
|
||||
|
||||
// ErrConnectionReset is sent if a stream is reset. This can happen
|
||||
// if the backlog is exceeded, or if there was a remote GoAway.
|
||||
ErrConnectionReset = fmt.Errorf("connection reset")
|
||||
|
||||
// ErrConnectionWriteTimeout indicates that we hit the "safety valve"
|
||||
// timeout writing to the underlying stream connection.
|
||||
ErrConnectionWriteTimeout = fmt.Errorf("connection write timeout")
|
||||
|
||||
// ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close
|
||||
ErrKeepAliveTimeout = fmt.Errorf("keepalive timeout")
|
||||
)
|
||||
|
||||
const (
|
||||
// protoVersion is the only version we support
|
||||
protoVersion uint8 = 0
|
||||
)
|
||||
|
||||
const (
|
||||
// Data is used for data frames. They are followed
|
||||
// by length bytes worth of payload.
|
||||
typeData uint8 = iota
|
||||
|
||||
// WindowUpdate is used to change the window of
|
||||
// a given stream. The length indicates the delta
|
||||
// update to the window.
|
||||
typeWindowUpdate
|
||||
|
||||
// Ping is sent as a keep-alive or to measure
|
||||
// the RTT. The StreamID and Length value are echoed
|
||||
// back in the response.
|
||||
typePing
|
||||
|
||||
// GoAway is sent to terminate a session. The StreamID
|
||||
// should be 0 and the length is an error code.
|
||||
typeGoAway
|
||||
)
|
||||
|
||||
const (
|
||||
// SYN is sent to signal a new stream. May
|
||||
// be sent with a data payload
|
||||
flagSYN uint16 = 1 << iota
|
||||
|
||||
// ACK is sent to acknowledge a new stream. May
|
||||
// be sent with a data payload
|
||||
flagACK
|
||||
|
||||
// FIN is sent to half-close the given stream.
|
||||
// May be sent with a data payload.
|
||||
flagFIN
|
||||
|
||||
// RST is used to hard close a given stream.
|
||||
flagRST
|
||||
)
|
||||
|
||||
const (
|
||||
// initialStreamWindow is the initial stream window size
|
||||
initialStreamWindow uint32 = 256 * 1024
|
||||
)
|
||||
|
||||
const (
|
||||
// goAwayNormal is sent on a normal termination
|
||||
goAwayNormal uint32 = iota
|
||||
|
||||
// goAwayProtoErr sent on a protocol error
|
||||
goAwayProtoErr
|
||||
|
||||
// goAwayInternalErr sent on an internal error
|
||||
goAwayInternalErr
|
||||
)
|
||||
|
||||
const (
|
||||
sizeOfVersion = 1
|
||||
sizeOfType = 1
|
||||
sizeOfFlags = 2
|
||||
sizeOfStreamID = 4
|
||||
sizeOfLength = 4
|
||||
headerSize = sizeOfVersion + sizeOfType + sizeOfFlags +
|
||||
sizeOfStreamID + sizeOfLength
|
||||
)
|
||||
|
||||
type header []byte
|
||||
|
||||
func (h header) Version() uint8 {
|
||||
return h[0]
|
||||
}
|
||||
|
||||
func (h header) MsgType() uint8 {
|
||||
return h[1]
|
||||
}
|
||||
|
||||
func (h header) Flags() uint16 {
|
||||
return binary.BigEndian.Uint16(h[2:4])
|
||||
}
|
||||
|
||||
func (h header) StreamID() uint32 {
|
||||
return binary.BigEndian.Uint32(h[4:8])
|
||||
}
|
||||
|
||||
func (h header) Length() uint32 {
|
||||
return binary.BigEndian.Uint32(h[8:12])
|
||||
}
|
||||
|
||||
func (h header) String() string {
|
||||
return fmt.Sprintf("Vsn:%d Type:%d Flags:%d StreamID:%d Length:%d",
|
||||
h.Version(), h.MsgType(), h.Flags(), h.StreamID(), h.Length())
|
||||
}
|
||||
|
||||
func (h header) encode(msgType uint8, flags uint16, streamID uint32, length uint32) {
|
||||
h[0] = protoVersion
|
||||
h[1] = msgType
|
||||
binary.BigEndian.PutUint16(h[2:4], flags)
|
||||
binary.BigEndian.PutUint32(h[4:8], streamID)
|
||||
binary.BigEndian.PutUint32(h[8:12], length)
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue