From 2d46bec18b3d9cee2ce3466231178d8207ccc8d7 Mon Sep 17 00:00:00 2001 From: Christine Dodrill Date: Tue, 29 Aug 2017 13:30:43 -0700 Subject: [PATCH] add logworker --- box.rb | 17 +- cmd/logworker/main.go | 89 +++ cmd/vyvanse/main.go | 102 ++- docker-compose.yml | 10 + internal/dao/logs.go | 81 +++ internal/dao/users.go | 1 - vendor-log | 8 + vendor/github.com/drone/mq/logger/logger.go | 61 ++ vendor/github.com/drone/mq/stomp/client.go | 259 ++++++++ vendor/github.com/drone/mq/stomp/conn.go | 156 +++++ vendor/github.com/drone/mq/stomp/const.go | 76 +++ vendor/github.com/drone/mq/stomp/context.go | 37 ++ .../drone/mq/stomp/dialer/dialer.go | 51 ++ vendor/github.com/drone/mq/stomp/handler.go | 13 + vendor/github.com/drone/mq/stomp/header.go | 109 ++++ vendor/github.com/drone/mq/stomp/message.go | 146 +++++ vendor/github.com/drone/mq/stomp/option.go | 96 +++ vendor/github.com/drone/mq/stomp/peer.go | 86 +++ vendor/github.com/drone/mq/stomp/reader.go | 139 +++++ vendor/github.com/drone/mq/stomp/writer.go | 173 ++++++ vendor/golang.org/x/net/websocket/client.go | 106 ++++ vendor/golang.org/x/net/websocket/dial.go | 24 + vendor/golang.org/x/net/websocket/hybi.go | 583 ++++++++++++++++++ vendor/golang.org/x/net/websocket/server.go | 113 ++++ .../golang.org/x/net/websocket/websocket.go | 448 ++++++++++++++ .../golang.org/x/sync/semaphore/semaphore.go | 131 ++++ .../api/support/bundler/bundler.go | 258 ++++++++ 27 files changed, 3349 insertions(+), 24 deletions(-) create mode 100644 cmd/logworker/main.go create mode 100644 internal/dao/logs.go create mode 100644 vendor/github.com/drone/mq/logger/logger.go create mode 100644 vendor/github.com/drone/mq/stomp/client.go create mode 100644 vendor/github.com/drone/mq/stomp/conn.go create mode 100644 vendor/github.com/drone/mq/stomp/const.go create mode 100644 vendor/github.com/drone/mq/stomp/context.go create mode 100644 vendor/github.com/drone/mq/stomp/dialer/dialer.go create mode 100644 vendor/github.com/drone/mq/stomp/handler.go create mode 100644 vendor/github.com/drone/mq/stomp/header.go create mode 100644 vendor/github.com/drone/mq/stomp/message.go create mode 100644 vendor/github.com/drone/mq/stomp/option.go create mode 100644 vendor/github.com/drone/mq/stomp/peer.go create mode 100644 vendor/github.com/drone/mq/stomp/reader.go create mode 100644 vendor/github.com/drone/mq/stomp/writer.go create mode 100644 vendor/golang.org/x/net/websocket/client.go create mode 100644 vendor/golang.org/x/net/websocket/dial.go create mode 100644 vendor/golang.org/x/net/websocket/hybi.go create mode 100644 vendor/golang.org/x/net/websocket/server.go create mode 100644 vendor/golang.org/x/net/websocket/websocket.go create mode 100644 vendor/golang.org/x/sync/semaphore/semaphore.go create mode 100644 vendor/google.golang.org/api/support/bundler/bundler.go diff --git a/box.rb b/box.rb index a563c88..9ec12a4 100644 --- a/box.rb +++ b/box.rb @@ -9,25 +9,32 @@ def foldercopy dir end def gobuild pkg - run "mkdir -p /root/go/bin && cd /root/go/bin && go#{$gover} build #{$repo}/#{pkg} && go#{$gover} install #{$repo}/#{pkg}" + run "mkdir -p /root/go/bin && cd /root/go/bin && go#{$gover} build -v #{$repo}/#{pkg}" end [ - "bot", - "cmd", - "internal", "vendor", "vendor-log", ].each { |x| foldercopy x } +[ + "bot", + "cmd", + "internal", +].each { |x| foldercopy x } + [ "cmd/vyvanse", + "cmd/logworker", ].each { |x| gobuild x } cmd "/root/go/bin/vyvanse" run "rm -rf $HOME/sdk /root/go/pkg ||:" run "apk del go#{$gover}" + +tag "xena/vyvanse:thick" + flatten -tag "xena/vyvanse" +tag "xena/vyvanse:latest" diff --git a/cmd/logworker/main.go b/cmd/logworker/main.go new file mode 100644 index 0000000..6c135d7 --- /dev/null +++ b/cmd/logworker/main.go @@ -0,0 +1,89 @@ +package main + +import ( + "context" + "encoding/json" + + "git.xeserv.us/xena/gorqlite" + "git.xeserv.us/xena/vyvanse/internal/dao" + + "github.com/Xe/ln" + "github.com/bwmarrin/discordgo" + "github.com/drone/mq/stomp" + "github.com/namsral/flag" + opentracing "github.com/opentracing/opentracing-go" + zipkin "github.com/openzipkin/zipkin-go-opentracing" +) + +var ( + token = flag.String("token", "", "discord bot token") + zipkinURL = flag.String("zipkin-url", "", "URL for Zipkin traces") + databaseURL = flag.String("database-url", "http://", "URL for database (rqlite)") + mqURL = flag.String("mq-url", "tcp://mq:9000", "URL for STOMP server") +) + +func main() { + flag.Parse() + + if *zipkinURL != "" { + collector, err := zipkin.NewHTTPCollector(*zipkinURL) + if err != nil { + ln.FatalErr(context.Background(), err) + } + tracer, err := zipkin.NewTracer( + zipkin.NewRecorder(collector, false, "logworker:5000", "logworker")) + if err != nil { + ln.FatalErr(context.Background(), err) + } + + opentracing.SetGlobalTracer(tracer) + } + + ctx := context.Background() + + db, err := gorqlite.Open(*databaseURL) + if err != nil { + ln.FatalErr(ctx, err) + } + + ls := dao.NewLogs(db) + err = ls.Migrate(ctx) + if err != nil { + ln.FatalErr(ctx, err, ln.F{"action": "migrate logs table"}) + } + + mq, err := stomp.Dial(*mqURL) + if err != nil { + ln.FatalErr(ctx, err, ln.F{"url": *mqURL}) + } + + mq.Subscribe("/topic/message_create", stomp.HandlerFunc(func(m *stomp.Message) { + sp, ctx := opentracing.StartSpanFromContext(m.Context(), "logworker.topic.message.create") + defer sp.Finish() + + msg := &discordgo.Message{} + err := json.Unmarshal(m.Msg, msg) + if err != nil { + ln.Error(ctx, err, ln.F{"action": "can't unmarshal message body to a discordgo message"}) + return + } + + f := ln.F{ + "stomp_id": string(m.ID), + "channel_id": msg.ChannelID, + "message_id": msg.ID, + "message_author": msg.Author.ID, + "message_author_name": msg.Author.Username, + "message_author_is_bot": msg.Author.Bot, + } + + err = ls.Add(ctx, msg) + if err != nil { + ln.Error(ctx, err, f, ln.F{"action": "can't add discordgo message to the database"}) + } + })) + + for { + select {} + } +} diff --git a/cmd/vyvanse/main.go b/cmd/vyvanse/main.go index 45ef6fd..40d1ca6 100644 --- a/cmd/vyvanse/main.go +++ b/cmd/vyvanse/main.go @@ -13,6 +13,7 @@ import ( "github.com/Xe/ln" "github.com/bwmarrin/discordgo" + "github.com/drone/mq/stomp" _ "github.com/joho/godotenv/autoload" "github.com/namsral/flag" xkcd "github.com/nishanths/go-xkcd" @@ -29,11 +30,26 @@ var ( token = flag.String("token", "", "discord bot token") zipkinURL = flag.String("zipkin-url", "", "URL for Zipkin traces") databaseURL = flag.String("database-url", "http://", "URL for database (rqlite)") + mqURL = flag.String("mq-url", "tcp://mq:9000", "URL for STOMP server") ) func main() { flag.Parse() + if *zipkinURL != "" { + collector, err := zipkin.NewHTTPCollector(*zipkinURL) + if err != nil { + ln.FatalErr(context.Background(), err) + } + tracer, err := zipkin.NewTracer( + zipkin.NewRecorder(collector, false, "vyvanse:5000", "vyvanse")) + if err != nil { + ln.FatalErr(context.Background(), err) + } + + opentracing.SetGlobalTracer(tracer) + } + xk := xkcd.NewClient() dg, err := discordgo.New("Bot " + *token) if err != nil { @@ -62,6 +78,14 @@ func main() { } sp.Finish() + ctx = context.Background() + + mq, err := stomp.Dial(*mqURL) + if err != nil { + ln.FatalErr(ctx, err, ln.F{"url": *mqURL}) + } + _ = mq + c := cron.New() comic, err := xk.Latest() @@ -111,20 +135,6 @@ func main() { c.Start() - if *zipkinURL != "" { - collector, err := zipkin.NewHTTPCollector(*zipkinURL) - if err != nil { - ln.FatalErr(context.Background(), err) - } - tracer, err := zipkin.NewTracer( - zipkin.NewRecorder(collector, false, "vyvanse:5000", "vyvanse")) - if err != nil { - ln.FatalErr(context.Background(), err) - } - - opentracing.SetGlobalTracer(tracer) - } - cs := bot.NewCommandSet() cs.Prefix = ">" @@ -134,6 +144,61 @@ func main() { cs.AddCmd("splattus", "splatoon 2 map rotation status", bot.NoPermissions, spla2nMaps) cs.AddCmd("top10", "shows the top 10 chatters on this server", bot.NoPermissions, top10(us)) + dg.AddHandler(func(s *discordgo.Session, m *discordgo.MessageCreate) { + sp, ctx := opentracing.StartSpanFromContext(context.Background(), "message.create.post.stomp") + defer sp.Finish() + + f := ln.F{ + "channel_id": m.ChannelID, + "message_id": m.ID, + "message_author": m.Author.ID, + "message_author_name": m.Author.Username, + "message_author_is_bot": m.Author.Bot, + } + + err := mq.SendJSON("/topic/message_create", m.Message) + if err != nil { + if err.Error() == "EOF" { + mq, err = stomp.Dial(*mqURL) + if err != nil { + ln.Error(ctx, err, f, ln.F{"url": *mqURL, "action": "reconnect to mq"}) + return + } + + err = mq.SendJSON("/topic/message_create", m.Message) + if err != nil { + ln.Error(ctx, err, f, ln.F{"action": "retry message_create post to message queue"}) + return + } + + return + } + + ln.Error(ctx, err, f, ln.F{"action": "send created message to queue"}) + return + } + + ln.Log(ctx, f, ln.F{"action": "message_create"}) + }) + + dg.AddHandler(func(s *discordgo.Session, m *discordgo.GuildMemberAdd) { + sp, ctx := opentracing.StartSpanFromContext(context.Background(), "member.add.post.stomp") + defer sp.Finish() + + f := ln.F{ + "guild_id": m.GuildID, + "user_id": m.User.ID, + "user_name": m.User.Username, + } + + err := mq.SendJSON("/topic/member_add", m.Member) + if err != nil { + ln.Error(ctx, err, f, ln.F{"action": "send added member to queue"}) + } + + ln.Log(ctx, f, ln.F{"action": "member_add"}) + }) + dg.AddHandler(func(s *discordgo.Session, m *discordgo.MessageCreate) { if m.Author.ID == s.State.User.ID { return @@ -165,8 +230,9 @@ func main() { return } - fmt.Println("Bot is now running. Press CTRL-C to exit.") - // Simple way to keep program running until CTRL-C is pressed. - <-make(chan struct{}) - return + ln.Log(ctx, ln.F{"action": "bot is running"}) + + for { + select {} + } } diff --git a/docker-compose.yml b/docker-compose.yml index 4279ee5..99d2d2a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -33,5 +33,15 @@ services: - mq - rqlite + logworker: + restart: always + image: xena/vyvanse + env_file: ./.env + depends_on: + - zipkin + - mq + - rqlite + command: /root/go/bin/logworker + volumes: rqlite: diff --git a/internal/dao/logs.go b/internal/dao/logs.go new file mode 100644 index 0000000..92b9562 --- /dev/null +++ b/internal/dao/logs.go @@ -0,0 +1,81 @@ +package dao + +import ( + "context" + + "git.xeserv.us/xena/gorqlite" + "github.com/Xe/ln" + "github.com/bwmarrin/discordgo" + opentracing "github.com/opentracing/opentracing-go" + splog "github.com/opentracing/opentracing-go/log" + "google.golang.org/api/support/bundler" +) + +type Logs struct { + conn gorqlite.Connection + bdl *bundler.Bundler +} + +func NewLogs(conn gorqlite.Connection) *Logs { + l := &Logs{conn: conn} + l.bdl = bundler.NewBundler("string", l.writeLines) + + return l +} + +func (l *Logs) writeLines(sqlLines interface{}) { + sp, ctx := opentracing.StartSpanFromContext(context.Background(), "logs.write.lines") + defer sp.Finish() + + _, err := l.conn.Write(sqlLines.([]string)) + if err != nil { + ln.Error(ctx, err, ln.F{"action": "write lines that are batched"}) + } +} + +func (l *Logs) Migrate(ctx context.Context) error { + sp, ctx := opentracing.StartSpanFromContext(ctx, "logs.migrate") + defer sp.Finish() + + migrationDDL := []string{ + `CREATE TABLE IF NOT EXISTS logs(id INTEGER PRIMARY KEY, discord_id TEXT UNIQUE, channel_id TEXT, content TEXT, timestamp INTEGER, mention_everyone INTEGER, author_id TEXT, author_username TEXT)`, + } + + res, err := l.conn.Write(migrationDDL) + if err != nil { + sp.LogFields(splog.Error(err)) + } + + for i, re := range res { + if re.Err != nil { + sp.LogFields(splog.Error(re.Err)) + return re.Err + } + + sp.LogFields(splog.Int("migration.step", i), splog.Float64("timing", re.Timing), splog.Int64("rows.affected", re.RowsAffected)) + } + + return nil +} + +func (l *Logs) Add(ctx context.Context, m *discordgo.Message) error { + sp, ctx := opentracing.StartSpanFromContext(ctx, "logs.add") + defer sp.Finish() + + stmt := gorqlite.NewPreparedStatement(`INSERT INTO logs (discord_id, channel_id, content, timestamp, mention_everyone, author_id, author_username) VALUES (%s, %s, %s, %d, %d, %s, %s)`) + ts, err := m.Timestamp.Parse() + if err != nil { + return err + } + + var me int + if m.MentionEveryone { + me = 1 + } + + bd := stmt.Bind(m.ID, m.ChannelID, m.Content, ts, me, m.Author.ID, m.Author.Username) + + l.bdl.Add(bd, len(bd)) + + return nil +} diff --git a/internal/dao/users.go b/internal/dao/users.go index 9d3fbc9..f573854 100644 --- a/internal/dao/users.go +++ b/internal/dao/users.go @@ -28,7 +28,6 @@ func (u *Users) Migrate(ctx context.Context) error { res, err := u.conn.Write(migrationDDL) if err != nil { sp.LogFields(splog.Error(err)) - return err } for i, re := range res { diff --git a/vendor-log b/vendor-log index 38f2061..45e643b 100644 --- a/vendor-log +++ b/vendor-log @@ -42,3 +42,11 @@ ae77be60afb1dcacde03767a8c37337fad28ac14 github.com/kardianos/osext 40a5e952d22c3ef520c6ab7bdb9b1a010ec9a524 git.xeserv.us/xena/gorqlite 97311d9f7767e3d6f422ea06661bc2c7a19e8a5d github.com/mattn/go-runewidth be5337e7b39e64e5f91445ce7e721888dbab7387 github.com/olekukonko/tablewriter +280af2a3b9c7d9ce90d625150dfff972c6c190b8 github.com/drone/mq/logger +280af2a3b9c7d9ce90d625150dfff972c6c190b8 github.com/drone/mq/stomp +280af2a3b9c7d9ce90d625150dfff972c6c190b8 github.com/drone/mq/stomp/dialer +f5079bd7f6f74e23c4d65efa0f4ce14cbd6a3c0f golang.org/x/net/context +f5079bd7f6f74e23c4d65efa0f4ce14cbd6a3c0f golang.org/x/net/websocket +66aacef3dd8a676686c7ae3716979581e8b03c47 golang.org/x/net/context +f52d1811a62927559de87708c8913c1650ce4f26 golang.org/x/sync/semaphore +e0e0e6e500066ff47335c7717e2a090ad127adec google.golang.org/api/support/bundler diff --git a/vendor/github.com/drone/mq/logger/logger.go b/vendor/github.com/drone/mq/logger/logger.go new file mode 100644 index 0000000..1677864 --- /dev/null +++ b/vendor/github.com/drone/mq/logger/logger.go @@ -0,0 +1,61 @@ +package logger + +var std Logger = new(none) + +// Debugf writes a debug message to the standard logger. +func Debugf(format string, args ...interface{}) { + std.Debugf(format, args...) +} + +// Verbosef writes a verbose message to the standard logger. +func Verbosef(format string, args ...interface{}) { + std.Verbosef(format, args...) +} + +// Noticef writes a notice message to the standard logger. +func Noticef(format string, args ...interface{}) { + std.Noticef(format, args...) +} + +// Warningf writes a warning message to the standard logger. +func Warningf(format string, args ...interface{}) { + std.Warningf(format, args...) +} + +// Printf writes a default message to the standard logger. +func Printf(format string, args ...interface{}) { + std.Printf(format, args...) +} + +// SetLogger sets the standard logger. +func SetLogger(logger Logger) { + std = logger +} + +// Logger represents a logger. +type Logger interface { + + // Debugf writes a debug message. + Debugf(string, ...interface{}) + + // Verbosef writes a verbose message. + Verbosef(string, ...interface{}) + + // Noticef writes a notice message. + Noticef(string, ...interface{}) + + // Warningf writes a warning message. + Warningf(string, ...interface{}) + + // Printf writes a default message. + Printf(string, ...interface{}) +} + +// none is a logger that silently ignores all writes. +type none struct{} + +func (*none) Debugf(string, ...interface{}) {} +func (*none) Verbosef(string, ...interface{}) {} +func (*none) Noticef(string, ...interface{}) {} +func (*none) Warningf(string, ...interface{}) {} +func (*none) Printf(string, ...interface{}) {} diff --git a/vendor/github.com/drone/mq/stomp/client.go b/vendor/github.com/drone/mq/stomp/client.go new file mode 100644 index 0000000..8794d3e --- /dev/null +++ b/vendor/github.com/drone/mq/stomp/client.go @@ -0,0 +1,259 @@ +package stomp + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "runtime/debug" + "strconv" + "sync" + "time" + + "github.com/drone/mq/logger" + "github.com/drone/mq/stomp/dialer" +) + +// Client defines a client connection to a STOMP server. +type Client struct { + mu sync.Mutex + + peer Peer + subs map[string]Handler + wait map[string]chan struct{} + done chan error + + seq int64 + + skipVerify bool + readBufferSize int + writeBufferSize int + timeout time.Duration +} + +// New returns a new STOMP client using the given connection. +func New(peer Peer) *Client { + return &Client{ + peer: peer, + subs: make(map[string]Handler), + wait: make(map[string]chan struct{}), + done: make(chan error, 1), + } +} + +// Dial creates a client connection to the given target. +func Dial(target string) (*Client, error) { + conn, err := dialer.Dial(target) + if err != nil { + return nil, err + } + return New(Conn(conn)), nil +} + +// Send sends the data to the given destination. +func (c *Client) Send(dest string, data []byte, opts ...MessageOption) error { + m := NewMessage() + m.Method = MethodSend + m.Dest = []byte(dest) + m.Body = data + m.Apply(opts...) + return c.sendMessage(m) +} + +// SendJSON sends the JSON encoding of v to the given destination. +func (c *Client) SendJSON(dest string, v interface{}, opts ...MessageOption) error { + data, err := json.Marshal(v) + if err != nil { + return err + } + opts = append(opts, + WithHeader("content-type", "application/json"), + ) + return c.Send(dest, data, opts...) +} + +// Subscribe subscribes to the given destination. +func (c *Client) Subscribe(dest string, handler Handler, opts ...MessageOption) (id []byte, err error) { + id = c.incr() + + m := NewMessage() + m.Method = MethodSubscribe + m.ID = id + m.Dest = []byte(dest) + m.Apply(opts...) + + c.mu.Lock() + c.subs[string(id)] = handler + c.mu.Unlock() + + err = c.sendMessage(m) + if err != nil { + c.mu.Lock() + delete(c.subs, string(id)) + c.mu.Unlock() + return + } + return +} + +// Unsubscribe unsubscribes to the destination. +func (c *Client) Unsubscribe(id []byte, opts ...MessageOption) error { + c.mu.Lock() + delete(c.subs, string(id)) + c.mu.Unlock() + + m := NewMessage() + m.Method = MethodUnsubscribe + m.ID = id + m.Apply(opts...) + + return c.sendMessage(m) +} + +// Ack acknowledges the messages with the given id. +func (c *Client) Ack(id []byte, opts ...MessageOption) error { + m := NewMessage() + m.Method = MethodAck + m.ID = id + m.Apply(opts...) + + return c.sendMessage(m) +} + +// Nack negative-acknowledges the messages with the given id. +func (c *Client) Nack(id []byte, opts ...MessageOption) error { + m := NewMessage() + m.Method = MethodNack + m.ID = id + m.Apply(opts...) + + return c.peer.Send(m) +} + +// Connect opens the connection and establishes the session. +func (c *Client) Connect(opts ...MessageOption) error { + m := NewMessage() + m.Proto = STOMP + m.Method = MethodStomp + m.Apply(opts...) + if err := c.sendMessage(m); err != nil { + return err + } + + m, ok := <-c.peer.Receive() + if !ok { + return io.EOF + } + defer m.Release() + + if !bytes.Equal(m.Method, MethodConnected) { + return fmt.Errorf("stomp: inbound message: unexpected method, want connected") + } + go c.listen() + return nil +} + +// Disconnect terminates the session and closes the connection. +func (c *Client) Disconnect() error { + m := NewMessage() + m.Method = MethodDisconnect + c.sendMessage(m) + return c.peer.Close() +} + +// Done returns a channel +func (c *Client) Done() <-chan error { + return c.done +} + +func (c *Client) incr() []byte { + c.mu.Lock() + i := c.seq + c.seq++ + c.mu.Unlock() + return strconv.AppendInt(nil, i, 10) +} + +func (c *Client) listen() { + defer func() { + if r := recover(); r != nil { + logger.Warningf("stomp client: recover panic: %s", r) + err, ok := r.(error) + if !ok { + logger.Warningf("%v: %s", r, debug.Stack()) + c.done <- fmt.Errorf("%v", r) + } else { + logger.Warningf("%s", err) + c.done <- err + } + } + }() + + for { + m, ok := <-c.peer.Receive() + if !ok { + c.done <- io.EOF + return + } + + switch { + case bytes.Equal(m.Method, MethodMessage): + c.handleMessage(m) + case bytes.Equal(m.Method, MethodRecipet): + c.handleReceipt(m) + default: + logger.Noticef("stomp client: unknown message type: %s", + string(m.Method), + ) + } + } +} + +func (c *Client) handleReceipt(m *Message) { + c.mu.Lock() + receiptc, ok := c.wait[string(m.Receipt)] + c.mu.Unlock() + if !ok { + logger.Noticef("stomp client: unknown read receipt: %s", + string(m.Receipt), + ) + return + } + receiptc <- struct{}{} +} + +func (c *Client) handleMessage(m *Message) { + c.mu.Lock() + handler, ok := c.subs[string(m.Subs)] + c.mu.Unlock() + if !ok { + logger.Noticef("stomp client: subscription not found: %s", + string(m.Subs), + ) + return + } + handler.Handle(m) +} + +func (c *Client) sendMessage(m *Message) error { + if len(m.Receipt) == 0 { + return c.peer.Send(m) + } + + receiptc := make(chan struct{}, 1) + c.wait[string(m.Receipt)] = receiptc + + defer func() { + delete(c.wait, string(m.Receipt)) + }() + + err := c.peer.Send(m) + if err != nil { + return err + } + + select { + case <-receiptc: + return nil + } +} diff --git a/vendor/github.com/drone/mq/stomp/conn.go b/vendor/github.com/drone/mq/stomp/conn.go new file mode 100644 index 0000000..a7ed92c --- /dev/null +++ b/vendor/github.com/drone/mq/stomp/conn.go @@ -0,0 +1,156 @@ +package stomp + +import ( + "bufio" + "io" + "net" + "time" + + "github.com/drone/mq/logger" +) + +const ( + bufferSize = 32 << 10 // default buffer size 32KB + bufferLimit = 32 << 15 // default buffer limit 1MB +) + +var ( + never time.Time + deadline = time.Second * 5 + + heartbeatTime = time.Second * 30 + heartbeatWait = time.Second * 60 +) + +type connPeer struct { + conn net.Conn + done chan bool + + reader *bufio.Reader + writer *bufio.Writer + incoming chan *Message + outgoing chan *Message +} + +// Conn creates a network-connected peer that reads and writes +// messages using net.Conn c. +func Conn(c net.Conn) Peer { + p := &connPeer{ + reader: bufio.NewReaderSize(c, bufferSize), + writer: bufio.NewWriterSize(c, bufferSize), + incoming: make(chan *Message), + outgoing: make(chan *Message), + done: make(chan bool), + conn: c, + } + + go p.readInto(p.incoming) + go p.writeFrom(p.outgoing) + return p +} + +func (c *connPeer) Receive() <-chan *Message { + return c.incoming +} + +func (c *connPeer) Send(message *Message) error { + select { + case <-c.done: + return io.EOF + default: + c.outgoing <- message + return nil + } +} + +func (c *connPeer) Addr() string { + return c.conn.RemoteAddr().String() +} + +func (c *connPeer) Close() error { + return c.close() +} + +func (c *connPeer) close() error { + select { + case <-c.done: + return io.EOF + default: + close(c.done) + close(c.incoming) + close(c.outgoing) + return nil + } +} + +func (c *connPeer) readInto(messages chan<- *Message) { + defer c.close() + + for { + // lim := io.LimitReader(c.conn, bufferLimit) + // buf := bufio.NewReaderSize(lim, bufferSize) + + buf, err := c.reader.ReadBytes(0) + if err != nil { + break + } + if len(buf) == 1 { + c.conn.SetReadDeadline(time.Now().Add(heartbeatWait)) + logger.Verbosef("stomp: received heart-beat") + continue + } + + msg := NewMessage() + msg.Parse(buf[:len(buf)-1]) + + select { + case <-c.done: + break + default: + messages <- msg + } + } +} + +func (c *connPeer) writeFrom(messages <-chan *Message) { + tick := time.NewTicker(time.Millisecond * 100).C + heartbeat := time.NewTicker(heartbeatTime).C + +loop: + for { + select { + case <-c.done: + break loop + case <-heartbeat: + logger.Verbosef("stomp: send heart-beat.") + c.writer.WriteByte(0) + case <-tick: + c.conn.SetWriteDeadline(time.Now().Add(deadline)) + if err := c.writer.Flush(); err != nil { + break loop + } + c.conn.SetWriteDeadline(never) + case msg, ok := <-messages: + if !ok { + break loop + } + writeTo(c.writer, msg) + c.writer.WriteByte(0) + msg.Release() + } + } + + c.drain() +} + +func (c *connPeer) drain() error { + c.conn.SetWriteDeadline(time.Now().Add(deadline)) + for msg := range c.outgoing { + writeTo(c.writer, msg) + c.writer.WriteByte(0) + msg.Release() + } + c.conn.SetWriteDeadline(never) + c.writer.Flush() + return c.conn.Close() +} diff --git a/vendor/github.com/drone/mq/stomp/const.go b/vendor/github.com/drone/mq/stomp/const.go new file mode 100644 index 0000000..55748ad --- /dev/null +++ b/vendor/github.com/drone/mq/stomp/const.go @@ -0,0 +1,76 @@ +package stomp + +// STOMP protocol version. +var STOMP = []byte("1.2") + +// STOMP protocol methods. +var ( + MethodStomp = []byte("STOMP") + MethodConnect = []byte("CONNECT") + MethodConnected = []byte("CONNECTED") + MethodSend = []byte("SEND") + MethodSubscribe = []byte("SUBSCRIBE") + MethodUnsubscribe = []byte("UNSUBSCRIBE") + MethodAck = []byte("ACK") + MethodNack = []byte("NACK") + MethodDisconnect = []byte("DISCONNECT") + MethodMessage = []byte("MESSAGE") + MethodRecipet = []byte("RECEIPT") + MethodError = []byte("ERROR") +) + +// STOMP protocol headers. +var ( + HeaderAccept = []byte("accept-version") + HeaderAck = []byte("ack") + HeaderExpires = []byte("expires") + HeaderDest = []byte("destination") + HeaderHost = []byte("host") + HeaderLogin = []byte("login") + HeaderPass = []byte("passcode") + HeaderID = []byte("id") + HeaderMessageID = []byte("message-id") + HeaderPersist = []byte("persist") + HeaderPrefetch = []byte("prefetch-count") + HeaderReceipt = []byte("receipt") + HeaderReceiptID = []byte("receipt-id") + HeaderRetain = []byte("retain") + HeaderSelector = []byte("selector") + HeaderServer = []byte("server") + HeaderSession = []byte("session") + HeaderSubscription = []byte("subscription") + HeaderVersion = []byte("version") +) + +// Common STOMP header values. +var ( + AckAuto = []byte("auto") + AckClient = []byte("client") + PersistTrue = []byte("true") + RetainTrue = []byte("true") + RetainLast = []byte("last") + RetainAll = []byte("all") + RetainRemove = []byte("remove") +) + +var headerLookup = map[string]struct{}{ + "accept-version": struct{}{}, + "ack": struct{}{}, + "expires": struct{}{}, + "destination": struct{}{}, + "host": struct{}{}, + "login": struct{}{}, + "passcode": struct{}{}, + "id": struct{}{}, + "message-id": struct{}{}, + "persist": struct{}{}, + "prefetch-count": struct{}{}, + "receipt": struct{}{}, + "receipt-id": struct{}{}, + "retain": struct{}{}, + "selector": struct{}{}, + "server": struct{}{}, + "session": struct{}{}, + "subscription": struct{}{}, + "version": struct{}{}, +} diff --git a/vendor/github.com/drone/mq/stomp/context.go b/vendor/github.com/drone/mq/stomp/context.go new file mode 100644 index 0000000..d2218db --- /dev/null +++ b/vendor/github.com/drone/mq/stomp/context.go @@ -0,0 +1,37 @@ +package stomp + +import "golang.org/x/net/context" + +const clientKey = "stomp.client" + +// NewContext adds the client to the context. +func (c *Client) NewContext(ctx context.Context, client *Client) context.Context { + // HACK for use with gin and echo + if s, ok := ctx.(setter); ok { + s.Set(clientKey, clientKey) + return ctx + } + return context.WithValue(ctx, clientKey, client) +} + +// FromContext retrieves the client from context +func FromContext(ctx context.Context) (*Client, bool) { + client, ok := ctx.Value(clientKey).(*Client) + return client, ok +} + +// MustFromContext retrieves the client from context. Panics if not found +func MustFromContext(ctx context.Context) *Client { + client, ok := FromContext(ctx) + if !ok { + panic("stomp.Client not found in context") + } + return client +} + +// HACK setter defines a context that enables setting values. This is a +// temporary workaround for use with gin and echo and will eventually +// be removed. DO NOT depend on this. +type setter interface { + Set(string, interface{}) +} diff --git a/vendor/github.com/drone/mq/stomp/dialer/dialer.go b/vendor/github.com/drone/mq/stomp/dialer/dialer.go new file mode 100644 index 0000000..899ad09 --- /dev/null +++ b/vendor/github.com/drone/mq/stomp/dialer/dialer.go @@ -0,0 +1,51 @@ +package dialer + +import ( + "net" + "net/url" + + "golang.org/x/net/websocket" +) + +const ( + protoHTTP = "http" + protoHTTPS = "https" + protoWS = "ws" + protoWSS = "wss" + protoTCP = "tcp" +) + +// Dial creates a client connection to the given target. +func Dial(target string) (net.Conn, error) { + u, err := url.Parse(target) + if err != nil { + return nil, err + } + + switch u.Scheme { + case protoHTTP, protoHTTPS, protoWS, protoWSS: + return dialWebsocket(u) + case protoTCP: + return dialSocket(u) + default: + panic("stomp: invalid protocol") + } +} + +func dialWebsocket(target *url.URL) (net.Conn, error) { + origin, err := target.Parse("/") + if err != nil { + return nil, err + } + switch origin.Scheme { + case protoWS: + origin.Scheme = protoHTTP + case protoWSS: + origin.Scheme = protoHTTPS + } + return websocket.Dial(target.String(), "", origin.String()) +} + +func dialSocket(target *url.URL) (net.Conn, error) { + return net.Dial(protoTCP, target.Host) +} diff --git a/vendor/github.com/drone/mq/stomp/handler.go b/vendor/github.com/drone/mq/stomp/handler.go new file mode 100644 index 0000000..44c5041 --- /dev/null +++ b/vendor/github.com/drone/mq/stomp/handler.go @@ -0,0 +1,13 @@ +package stomp + +// Handler handles a STOMP message. +type Handler interface { + Handle(*Message) +} + +// The HandlerFunc type is an adapter to allow the use of an ordinary +// function as a STOMP message handler. +type HandlerFunc func(*Message) + +// Handle calls f(m). +func (f HandlerFunc) Handle(m *Message) { f(m) } diff --git a/vendor/github.com/drone/mq/stomp/header.go b/vendor/github.com/drone/mq/stomp/header.go new file mode 100644 index 0000000..ac83a04 --- /dev/null +++ b/vendor/github.com/drone/mq/stomp/header.go @@ -0,0 +1,109 @@ +package stomp + +import ( + "bytes" + "strconv" +) + +const defaultHeaderLen = 5 + +type item struct { + name []byte + data []byte +} + +// Header represents the header section of the STOMP message. +type Header struct { + items []item + itemc int +} + +func newHeader() *Header { + return &Header{ + items: make([]item, defaultHeaderLen), + } +} + +// Get returns the named header value. +func (h *Header) Get(name []byte) (b []byte) { + for i := 0; i < h.itemc; i++ { + if v := h.items[i]; bytes.Equal(v.name, name) { + return v.data + } + } + return +} + +// GetString returns the named header value. +func (h *Header) GetString(name string) string { + k := []byte(name) + v := h.Get(k) + return string(v) +} + +// GetBool returns the named header value. +func (h *Header) GetBool(name string) bool { + s := h.GetString(name) + b, _ := strconv.ParseBool(s) + return b +} + +// GetInt returns the named header value. +func (h *Header) GetInt(name string) int { + s := h.GetString(name) + i, _ := strconv.Atoi(s) + return i +} + +// GetInt64 returns the named header value. +func (h *Header) GetInt64(name string) int64 { + s := h.GetString(name) + i, _ := strconv.ParseInt(s, 10, 64) + return i +} + +// Field returns the named header value in string format. This is used to +// provide compatibility with the SQL expression evaluation package. +func (h *Header) Field(name []byte) []byte { + return h.Get(name) +} + +// Add appens the key value pair to the header. +func (h *Header) Add(name, data []byte) { + h.grow() + h.items[h.itemc].name = name + h.items[h.itemc].data = data + h.itemc++ +} + +// Index returns the keypair at index i. +func (h *Header) Index(i int) (k, v []byte) { + if i > h.itemc { + return + } + k = h.items[i].name + v = h.items[i].data + return +} + +// Len returns the header length. +func (h *Header) Len() int { + return h.itemc +} + +func (h *Header) grow() { + if h.itemc > defaultHeaderLen-1 { + h.items = append(h.items, item{}) + } +} + +func (h *Header) reset() { + h.itemc = 0 + h.items = h.items[:defaultHeaderLen] + for i := range h.items { + h.items[i].name = zeroBytes + h.items[i].data = zeroBytes + } +} + +var zeroBytes []byte diff --git a/vendor/github.com/drone/mq/stomp/message.go b/vendor/github.com/drone/mq/stomp/message.go new file mode 100644 index 0000000..68118e7 --- /dev/null +++ b/vendor/github.com/drone/mq/stomp/message.go @@ -0,0 +1,146 @@ +package stomp + +import ( + "bytes" + "encoding/json" + "math/rand" + "strconv" + "sync" + + "golang.org/x/net/context" +) + +// Message represents a parsed STOMP message. +type Message struct { + ID []byte // id header + Proto []byte // stomp version + Method []byte // stomp method + User []byte // username header + Pass []byte // password header + Dest []byte // destination header + Subs []byte // subscription id + Ack []byte // ack id + Msg []byte // message-id header + Persist []byte // persist header + Retain []byte // retain header + Prefetch []byte // prefetch count + Expires []byte // expires header + Receipt []byte // receipt header + Selector []byte // selector header + Body []byte + Header *Header // custom headers + + ctx context.Context +} + +// Copy returns a copy of the Message. +func (m *Message) Copy() *Message { + c := NewMessage() + c.ID = m.ID + c.Proto = m.Proto + c.Method = m.Method + c.User = m.User + c.Pass = m.Pass + c.Dest = m.Dest + c.Subs = m.Subs + c.Ack = m.Ack + c.Prefetch = m.Prefetch + c.Selector = m.Selector + c.Persist = m.Persist + c.Retain = m.Retain + c.Receipt = m.Receipt + c.Expires = m.Expires + c.Body = m.Body + c.ctx = m.ctx + c.Header.itemc = m.Header.itemc + copy(c.Header.items, m.Header.items) + return c +} + +// Apply applies the options to the message. +func (m *Message) Apply(opts ...MessageOption) { + for _, opt := range opts { + opt(m) + } +} + +// Parse parses the raw bytes into the message. +func (m *Message) Parse(b []byte) error { + return read(b, m) +} + +// Bytes returns the Message in raw byte format. +func (m *Message) Bytes() []byte { + var buf bytes.Buffer + writeTo(&buf, m) + return buf.Bytes() +} + +// String returns the Message in string format. +func (m *Message) String() string { + return string(m.Bytes()) +} + +// Release releases the message back to the message pool. +func (m *Message) Release() { + m.Reset() + pool.Put(m) +} + +// Reset resets the meesage fields to their zero values. +func (m *Message) Reset() { + m.ID = m.ID[:0] + m.Proto = m.Proto[:0] + m.Method = m.Method[:0] + m.User = m.User[:0] + m.Pass = m.Pass[:0] + m.Dest = m.Dest[:0] + m.Subs = m.Subs[:0] + m.Ack = m.Ack[:0] + m.Prefetch = m.Prefetch[:0] + m.Selector = m.Selector[:0] + m.Persist = m.Persist[:0] + m.Retain = m.Retain[:0] + m.Receipt = m.Receipt[:0] + m.Expires = m.Expires[:0] + m.Body = m.Body[:0] + m.ctx = nil + m.Header.reset() +} + +// Context returns the request's context. +func (m *Message) Context() context.Context { + if m.ctx != nil { + return m.ctx + } + return context.Background() +} + +// WithContext returns a shallow copy of m with its context changed +// to ctx. The provided ctx must be non-nil. +func (m *Message) WithContext(ctx context.Context) *Message { + c := m.Copy() + c.ctx = ctx + return c +} + +// Unmarshal parses the JSON-encoded body of the message and +// stores the result in the value pointed to by v. +func (m *Message) Unmarshal(v interface{}) error { + return json.Unmarshal(m.Body, v) +} + +// NewMessage returns an empty message from the message pool. +func NewMessage() *Message { + return pool.Get().(*Message) +} + +var pool = sync.Pool{New: func() interface{} { + return &Message{Header: newHeader()} +}} + +// Rand returns a random int64 number as a []byte of +// ascii characters. +func Rand() []byte { + return strconv.AppendInt(nil, rand.Int63(), 10) +} diff --git a/vendor/github.com/drone/mq/stomp/option.go b/vendor/github.com/drone/mq/stomp/option.go new file mode 100644 index 0000000..a82c903 --- /dev/null +++ b/vendor/github.com/drone/mq/stomp/option.go @@ -0,0 +1,96 @@ +package stomp + +import ( + "math/rand" + "strconv" + "strings" +) + +// MessageOption configures message options. +type MessageOption func(*Message) + +// WithCredentials returns a MessageOption which sets credentials. +func WithCredentials(username, password string) MessageOption { + return func(m *Message) { + m.User = []byte(username) + m.Pass = []byte(password) + } +} + +// WithHeader returns a MessageOption which sets a header. +func WithHeader(key, value string) MessageOption { + return func(m *Message) { + _, ok := headerLookup[strings.ToLower(key)] + if !ok { + m.Header.Add( + []byte(key), + []byte(value), + ) + } + } +} + +// WithHeaders returns a MessageOption which sets headers. +func WithHeaders(headers map[string]string) MessageOption { + return func(m *Message) { + for key, value := range headers { + _, ok := headerLookup[strings.ToLower(key)] + if !ok { + m.Header.Add( + []byte(key), + []byte(value), + ) + } + } + } +} + +// WithExpires returns a MessageOption configured with an expiration. +func WithExpires(exp int64) MessageOption { + return func(m *Message) { + m.Expires = strconv.AppendInt(nil, exp, 10) + } +} + +// WithPrefetch returns a MessageOption configured with a prefetch count. +func WithPrefetch(prefetch int) MessageOption { + return func(m *Message) { + m.Prefetch = strconv.AppendInt(nil, int64(prefetch), 10) + } +} + +// WithReceipt returns a MessageOption configured with a receipt request. +func WithReceipt() MessageOption { + return func(m *Message) { + m.Receipt = strconv.AppendInt(nil, rand.Int63(), 10) + } +} + +// WithPersistence returns a MessageOption configured to persist. +func WithPersistence() MessageOption { + return func(m *Message) { + m.Persist = PersistTrue + } +} + +// WithRetain returns a MessageOption configured to retain the message. +func WithRetain(retain string) MessageOption { + return func(m *Message) { + m.Retain = []byte(retain) + } +} + +// WithSelector returns a MessageOption configured to filter messages +// using a sql-like evaluation string. +func WithSelector(selector string) MessageOption { + return func(m *Message) { + m.Selector = []byte(selector) + } +} + +// WithAck returns a MessageOption configured with an ack policy. +func WithAck(ack string) MessageOption { + return func(m *Message) { + m.Ack = []byte(ack) + } +} diff --git a/vendor/github.com/drone/mq/stomp/peer.go b/vendor/github.com/drone/mq/stomp/peer.go new file mode 100644 index 0000000..fd07a1f --- /dev/null +++ b/vendor/github.com/drone/mq/stomp/peer.go @@ -0,0 +1,86 @@ +package stomp + +import ( + "io" + "net" + "sync" +) + +// Peer defines a peer-to-peer connection. +type Peer interface { + // Send sends a message. + Send(*Message) error + + // Receive returns a channel of inbound messages. + Receive() <-chan *Message + + // Close closes the connection. + Close() error + + // Addr returns the peer address. + Addr() string +} + +// Pipe creates a synchronous in-memory pipe, where reads on one end are +// matched with writes on the other. This is useful for direct, in-memory +// client-server communication. +func Pipe() (Peer, Peer) { + atob := make(chan *Message, 10) + btoa := make(chan *Message, 10) + + a := &localPeer{ + incoming: btoa, + outgoing: atob, + finished: make(chan bool), + } + b := &localPeer{ + incoming: atob, + outgoing: btoa, + finished: make(chan bool), + } + + return a, b +} + +type localPeer struct { + finished chan bool + outgoing chan<- *Message + incoming <-chan *Message +} + +func (p *localPeer) Receive() <-chan *Message { + return p.incoming +} + +func (p *localPeer) Send(m *Message) error { + select { + case <-p.finished: + return io.EOF + default: + p.outgoing <- m + return nil + } +} + +func (p *localPeer) Close() error { + close(p.finished) + close(p.outgoing) + return nil +} + +func (p *localPeer) Addr() string { + peerAddrOnce.Do(func() { + // get the local address list + addr, _ := net.InterfaceAddrs() + if len(addr) != 0 { + // use the last address in the list + peerAddr = addr[len(addr)-1].String() + } + }) + return peerAddr +} + +var peerAddrOnce sync.Once + +// default address displayed for local pipes +var peerAddr = "127.0.0.1/8" diff --git a/vendor/github.com/drone/mq/stomp/reader.go b/vendor/github.com/drone/mq/stomp/reader.go new file mode 100644 index 0000000..ec88d55 --- /dev/null +++ b/vendor/github.com/drone/mq/stomp/reader.go @@ -0,0 +1,139 @@ +package stomp + +import ( + "bytes" + "fmt" +) + +func read(input []byte, m *Message) (err error) { + var ( + pos int + off int + tot = len(input) + ) + + // parse the stomp message + for ; ; off++ { + if off == tot { + return fmt.Errorf("stomp: invalid method") + } + if input[off] == '\n' { + m.Method = input[pos:off] + off++ + pos = off + break + } + } + + // parse the stomp headers + for { + if off == tot { + return fmt.Errorf("stomp: unexpected eof") + } + if input[off] == '\n' { + off++ + pos = off + break + } + + var ( + name []byte + value []byte + ) + + loop: + // parse each individual header + for ; ; off++ { + if off >= tot { + return fmt.Errorf("stomp: unexpected eof") + } + + switch input[off] { + case '\n': + value = input[pos:off] + off++ + pos = off + break loop + case ':': + name = input[pos:off] + off++ + pos = off + } + } + + switch { + case bytes.Equal(name, HeaderAccept): + m.Proto = value + case bytes.Equal(name, HeaderAck): + m.Ack = value + case bytes.Equal(name, HeaderDest): + m.Dest = value + case bytes.Equal(name, HeaderExpires): + m.Expires = value + case bytes.Equal(name, HeaderLogin): + m.User = value + case bytes.Equal(name, HeaderPass): + m.Pass = value + case bytes.Equal(name, HeaderID): + m.ID = value + case bytes.Equal(name, HeaderMessageID): + m.ID = value + case bytes.Equal(name, HeaderPersist): + m.Persist = value + case bytes.Equal(name, HeaderPrefetch): + m.Prefetch = value + case bytes.Equal(name, HeaderReceipt): + m.Receipt = value + case bytes.Equal(name, HeaderReceiptID): + m.Receipt = value + case bytes.Equal(name, HeaderRetain): + m.Retain = value + case bytes.Equal(name, HeaderSelector): + m.Selector = value + case bytes.Equal(name, HeaderSubscription): + m.Subs = value + case bytes.Equal(name, HeaderVersion): + m.Proto = value + default: + m.Header.Add(name, value) + } + } + + if tot > pos { + m.Body = input[pos:] + } + return +} + +const ( + asciiZero = 48 + asciiNine = 57 +) + +// ParseInt returns the ascii integer value. +func ParseInt(d []byte) (n int) { + if len(d) == 0 { + return 0 + } + for _, dec := range d { + if dec < asciiZero || dec > asciiNine { + return 0 + } + n = n*10 + (int(dec) - asciiZero) + } + return n +} + +// ParseInt64 returns the ascii integer value. +func ParseInt64(d []byte) (n int64) { + if len(d) == 0 { + return 0 + } + for _, dec := range d { + if dec < asciiZero || dec > asciiNine { + return 0 + } + n = n*10 + (int64(dec) - asciiZero) + } + return n +} diff --git a/vendor/github.com/drone/mq/stomp/writer.go b/vendor/github.com/drone/mq/stomp/writer.go new file mode 100644 index 0000000..ca09324 --- /dev/null +++ b/vendor/github.com/drone/mq/stomp/writer.go @@ -0,0 +1,173 @@ +package stomp + +import ( + "bytes" + "io" +) + +var ( + crlf = []byte{'\r', '\n'} + newline = []byte{'\n'} + separator = []byte{':'} + terminator = []byte{0} +) + +func writeTo(w io.Writer, m *Message) { + w.Write(m.Method) + w.Write(newline) + + switch { + case bytes.Equal(m.Method, MethodStomp): + // version + w.Write(HeaderAccept) + w.Write(separator) + w.Write(m.Proto) + w.Write(newline) + // login + if len(m.User) != 0 { + w.Write(HeaderLogin) + w.Write(separator) + w.Write(m.User) + w.Write(newline) + } + // passcode + if len(m.Pass) != 0 { + w.Write(HeaderPass) + w.Write(separator) + w.Write(m.Pass) + w.Write(newline) + } + case bytes.Equal(m.Method, MethodConnected): + // version + w.Write(HeaderVersion) + w.Write(separator) + w.Write(m.Proto) + w.Write(newline) + case bytes.Equal(m.Method, MethodSend): + // dest + w.Write(HeaderDest) + w.Write(separator) + w.Write(m.Dest) + w.Write(newline) + if len(m.Expires) != 0 { + w.Write(HeaderExpires) + w.Write(separator) + w.Write(m.Expires) + w.Write(newline) + } + if len(m.Retain) != 0 { + w.Write(HeaderRetain) + w.Write(separator) + w.Write(m.Retain) + w.Write(newline) + } + if len(m.Persist) != 0 { + w.Write(HeaderPersist) + w.Write(separator) + w.Write(m.Persist) + w.Write(newline) + } + case bytes.Equal(m.Method, MethodSubscribe): + // id + w.Write(HeaderID) + w.Write(separator) + w.Write(m.ID) + w.Write(newline) + // destination + w.Write(HeaderDest) + w.Write(separator) + w.Write(m.Dest) + w.Write(newline) + // selector + if len(m.Selector) != 0 { + w.Write(HeaderSelector) + w.Write(separator) + w.Write(m.Selector) + w.Write(newline) + } + // prefetch + if len(m.Prefetch) != 0 { + w.Write(HeaderPrefetch) + w.Write(separator) + w.Write(m.Prefetch) + w.Write(newline) + } + if len(m.Ack) != 0 { + w.Write(HeaderAck) + w.Write(separator) + w.Write(m.Ack) + w.Write(newline) + } + case bytes.Equal(m.Method, MethodUnsubscribe): + // id + w.Write(HeaderID) + w.Write(separator) + w.Write(m.ID) + w.Write(newline) + case bytes.Equal(m.Method, MethodAck): + // id + w.Write(HeaderID) + w.Write(separator) + w.Write(m.ID) + w.Write(newline) + case bytes.Equal(m.Method, MethodNack): + // id + w.Write(HeaderID) + w.Write(separator) + w.Write(m.ID) + w.Write(newline) + case bytes.Equal(m.Method, MethodMessage): + // message-id + w.Write(HeaderMessageID) + w.Write(separator) + w.Write(m.ID) + w.Write(newline) + // destination + w.Write(HeaderDest) + w.Write(separator) + w.Write(m.Dest) + w.Write(newline) + // subscription + w.Write(HeaderSubscription) + w.Write(separator) + w.Write(m.Subs) + w.Write(newline) + // ack + if len(m.Ack) != 0 { + w.Write(HeaderAck) + w.Write(separator) + w.Write(m.Ack) + w.Write(newline) + } + case bytes.Equal(m.Method, MethodRecipet): + // receipt-id + w.Write(HeaderReceiptID) + w.Write(separator) + w.Write(m.Receipt) + w.Write(newline) + } + + // receipt header + if includeReceiptHeader(m) { + w.Write(HeaderReceipt) + w.Write(separator) + w.Write(m.Receipt) + w.Write(newline) + } + + for i, item := range m.Header.items { + if m.Header.itemc == i { + break + } + w.Write(item.name) + w.Write(separator) + w.Write(item.data) + w.Write(newline) + } + w.Write(newline) + w.Write(m.Body) +} + +func includeReceiptHeader(m *Message) bool { + return len(m.Receipt) != 0 && !bytes.Equal(m.Method, MethodRecipet) +} diff --git a/vendor/golang.org/x/net/websocket/client.go b/vendor/golang.org/x/net/websocket/client.go new file mode 100644 index 0000000..69a4ac7 --- /dev/null +++ b/vendor/golang.org/x/net/websocket/client.go @@ -0,0 +1,106 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "io" + "net" + "net/http" + "net/url" +) + +// DialError is an error that occurs while dialling a websocket server. +type DialError struct { + *Config + Err error +} + +func (e *DialError) Error() string { + return "websocket.Dial " + e.Config.Location.String() + ": " + e.Err.Error() +} + +// NewConfig creates a new WebSocket config for client connection. +func NewConfig(server, origin string) (config *Config, err error) { + config = new(Config) + config.Version = ProtocolVersionHybi13 + config.Location, err = url.ParseRequestURI(server) + if err != nil { + return + } + config.Origin, err = url.ParseRequestURI(origin) + if err != nil { + return + } + config.Header = http.Header(make(map[string][]string)) + return +} + +// NewClient creates a new WebSocket client connection over rwc. +func NewClient(config *Config, rwc io.ReadWriteCloser) (ws *Conn, err error) { + br := bufio.NewReader(rwc) + bw := bufio.NewWriter(rwc) + err = hybiClientHandshake(config, br, bw) + if err != nil { + return + } + buf := bufio.NewReadWriter(br, bw) + ws = newHybiClientConn(config, buf, rwc) + return +} + +// Dial opens a new client connection to a WebSocket. +func Dial(url_, protocol, origin string) (ws *Conn, err error) { + config, err := NewConfig(url_, origin) + if err != nil { + return nil, err + } + if protocol != "" { + config.Protocol = []string{protocol} + } + return DialConfig(config) +} + +var portMap = map[string]string{ + "ws": "80", + "wss": "443", +} + +func parseAuthority(location *url.URL) string { + if _, ok := portMap[location.Scheme]; ok { + if _, _, err := net.SplitHostPort(location.Host); err != nil { + return net.JoinHostPort(location.Host, portMap[location.Scheme]) + } + } + return location.Host +} + +// DialConfig opens a new client connection to a WebSocket with a config. +func DialConfig(config *Config) (ws *Conn, err error) { + var client net.Conn + if config.Location == nil { + return nil, &DialError{config, ErrBadWebSocketLocation} + } + if config.Origin == nil { + return nil, &DialError{config, ErrBadWebSocketOrigin} + } + dialer := config.Dialer + if dialer == nil { + dialer = &net.Dialer{} + } + client, err = dialWithDialer(dialer, config) + if err != nil { + goto Error + } + ws, err = NewClient(config, client) + if err != nil { + client.Close() + goto Error + } + return + +Error: + return nil, &DialError{config, err} +} diff --git a/vendor/golang.org/x/net/websocket/dial.go b/vendor/golang.org/x/net/websocket/dial.go new file mode 100644 index 0000000..2dab943 --- /dev/null +++ b/vendor/golang.org/x/net/websocket/dial.go @@ -0,0 +1,24 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "crypto/tls" + "net" +) + +func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) { + switch config.Location.Scheme { + case "ws": + conn, err = dialer.Dial("tcp", parseAuthority(config.Location)) + + case "wss": + conn, err = tls.DialWithDialer(dialer, "tcp", parseAuthority(config.Location), config.TlsConfig) + + default: + err = ErrBadScheme + } + return +} diff --git a/vendor/golang.org/x/net/websocket/hybi.go b/vendor/golang.org/x/net/websocket/hybi.go new file mode 100644 index 0000000..8cffdd1 --- /dev/null +++ b/vendor/golang.org/x/net/websocket/hybi.go @@ -0,0 +1,583 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +// This file implements a protocol of hybi draft. +// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17 + +import ( + "bufio" + "bytes" + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "strings" +) + +const ( + websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + closeStatusNormal = 1000 + closeStatusGoingAway = 1001 + closeStatusProtocolError = 1002 + closeStatusUnsupportedData = 1003 + closeStatusFrameTooLarge = 1004 + closeStatusNoStatusRcvd = 1005 + closeStatusAbnormalClosure = 1006 + closeStatusBadMessageData = 1007 + closeStatusPolicyViolation = 1008 + closeStatusTooBigData = 1009 + closeStatusExtensionMismatch = 1010 + + maxControlFramePayloadLength = 125 +) + +var ( + ErrBadMaskingKey = &ProtocolError{"bad masking key"} + ErrBadPongMessage = &ProtocolError{"bad pong message"} + ErrBadClosingStatus = &ProtocolError{"bad closing status"} + ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"} + ErrNotImplemented = &ProtocolError{"not implemented"} + + handshakeHeader = map[string]bool{ + "Host": true, + "Upgrade": true, + "Connection": true, + "Sec-Websocket-Key": true, + "Sec-Websocket-Origin": true, + "Sec-Websocket-Version": true, + "Sec-Websocket-Protocol": true, + "Sec-Websocket-Accept": true, + } +) + +// A hybiFrameHeader is a frame header as defined in hybi draft. +type hybiFrameHeader struct { + Fin bool + Rsv [3]bool + OpCode byte + Length int64 + MaskingKey []byte + + data *bytes.Buffer +} + +// A hybiFrameReader is a reader for hybi frame. +type hybiFrameReader struct { + reader io.Reader + + header hybiFrameHeader + pos int64 + length int +} + +func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) { + n, err = frame.reader.Read(msg) + if frame.header.MaskingKey != nil { + for i := 0; i < n; i++ { + msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4] + frame.pos++ + } + } + return n, err +} + +func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode } + +func (frame *hybiFrameReader) HeaderReader() io.Reader { + if frame.header.data == nil { + return nil + } + if frame.header.data.Len() == 0 { + return nil + } + return frame.header.data +} + +func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil } + +func (frame *hybiFrameReader) Len() (n int) { return frame.length } + +// A hybiFrameReaderFactory creates new frame reader based on its frame type. +type hybiFrameReaderFactory struct { + *bufio.Reader +} + +// NewFrameReader reads a frame header from the connection, and creates new reader for the frame. +// See Section 5.2 Base Framing protocol for detail. +// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2 +func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) { + hybiFrame := new(hybiFrameReader) + frame = hybiFrame + var header []byte + var b byte + // First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits) + b, err = buf.ReadByte() + if err != nil { + return + } + header = append(header, b) + hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0 + for i := 0; i < 3; i++ { + j := uint(6 - i) + hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0 + } + hybiFrame.header.OpCode = header[0] & 0x0f + + // Second byte. Mask/Payload len(7bits) + b, err = buf.ReadByte() + if err != nil { + return + } + header = append(header, b) + mask := (b & 0x80) != 0 + b &= 0x7f + lengthFields := 0 + switch { + case b <= 125: // Payload length 7bits. + hybiFrame.header.Length = int64(b) + case b == 126: // Payload length 7+16bits + lengthFields = 2 + case b == 127: // Payload length 7+64bits + lengthFields = 8 + } + for i := 0; i < lengthFields; i++ { + b, err = buf.ReadByte() + if err != nil { + return + } + if lengthFields == 8 && i == 0 { // MSB must be zero when 7+64 bits + b &= 0x7f + } + header = append(header, b) + hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b) + } + if mask { + // Masking key. 4 bytes. + for i := 0; i < 4; i++ { + b, err = buf.ReadByte() + if err != nil { + return + } + header = append(header, b) + hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b) + } + } + hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length) + hybiFrame.header.data = bytes.NewBuffer(header) + hybiFrame.length = len(header) + int(hybiFrame.header.Length) + return +} + +// A HybiFrameWriter is a writer for hybi frame. +type hybiFrameWriter struct { + writer *bufio.Writer + + header *hybiFrameHeader +} + +func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) { + var header []byte + var b byte + if frame.header.Fin { + b |= 0x80 + } + for i := 0; i < 3; i++ { + if frame.header.Rsv[i] { + j := uint(6 - i) + b |= 1 << j + } + } + b |= frame.header.OpCode + header = append(header, b) + if frame.header.MaskingKey != nil { + b = 0x80 + } else { + b = 0 + } + lengthFields := 0 + length := len(msg) + switch { + case length <= 125: + b |= byte(length) + case length < 65536: + b |= 126 + lengthFields = 2 + default: + b |= 127 + lengthFields = 8 + } + header = append(header, b) + for i := 0; i < lengthFields; i++ { + j := uint((lengthFields - i - 1) * 8) + b = byte((length >> j) & 0xff) + header = append(header, b) + } + if frame.header.MaskingKey != nil { + if len(frame.header.MaskingKey) != 4 { + return 0, ErrBadMaskingKey + } + header = append(header, frame.header.MaskingKey...) + frame.writer.Write(header) + data := make([]byte, length) + for i := range data { + data[i] = msg[i] ^ frame.header.MaskingKey[i%4] + } + frame.writer.Write(data) + err = frame.writer.Flush() + return length, err + } + frame.writer.Write(header) + frame.writer.Write(msg) + err = frame.writer.Flush() + return length, err +} + +func (frame *hybiFrameWriter) Close() error { return nil } + +type hybiFrameWriterFactory struct { + *bufio.Writer + needMaskingKey bool +} + +func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) { + frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType} + if buf.needMaskingKey { + frameHeader.MaskingKey, err = generateMaskingKey() + if err != nil { + return nil, err + } + } + return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil +} + +type hybiFrameHandler struct { + conn *Conn + payloadType byte +} + +func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, error) { + if handler.conn.IsServerConn() { + // The client MUST mask all frames sent to the server. + if frame.(*hybiFrameReader).header.MaskingKey == nil { + handler.WriteClose(closeStatusProtocolError) + return nil, io.EOF + } + } else { + // The server MUST NOT mask all frames. + if frame.(*hybiFrameReader).header.MaskingKey != nil { + handler.WriteClose(closeStatusProtocolError) + return nil, io.EOF + } + } + if header := frame.HeaderReader(); header != nil { + io.Copy(ioutil.Discard, header) + } + switch frame.PayloadType() { + case ContinuationFrame: + frame.(*hybiFrameReader).header.OpCode = handler.payloadType + case TextFrame, BinaryFrame: + handler.payloadType = frame.PayloadType() + case CloseFrame: + return nil, io.EOF + case PingFrame, PongFrame: + b := make([]byte, maxControlFramePayloadLength) + n, err := io.ReadFull(frame, b) + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { + return nil, err + } + io.Copy(ioutil.Discard, frame) + if frame.PayloadType() == PingFrame { + if _, err := handler.WritePong(b[:n]); err != nil { + return nil, err + } + } + return nil, nil + } + return frame, nil +} + +func (handler *hybiFrameHandler) WriteClose(status int) (err error) { + handler.conn.wio.Lock() + defer handler.conn.wio.Unlock() + w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame) + if err != nil { + return err + } + msg := make([]byte, 2) + binary.BigEndian.PutUint16(msg, uint16(status)) + _, err = w.Write(msg) + w.Close() + return err +} + +func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) { + handler.conn.wio.Lock() + defer handler.conn.wio.Unlock() + w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame) + if err != nil { + return 0, err + } + n, err = w.Write(msg) + w.Close() + return n, err +} + +// newHybiConn creates a new WebSocket connection speaking hybi draft protocol. +func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { + if buf == nil { + br := bufio.NewReader(rwc) + bw := bufio.NewWriter(rwc) + buf = bufio.NewReadWriter(br, bw) + } + ws := &Conn{config: config, request: request, buf: buf, rwc: rwc, + frameReaderFactory: hybiFrameReaderFactory{buf.Reader}, + frameWriterFactory: hybiFrameWriterFactory{ + buf.Writer, request == nil}, + PayloadType: TextFrame, + defaultCloseStatus: closeStatusNormal} + ws.frameHandler = &hybiFrameHandler{conn: ws} + return ws +} + +// generateMaskingKey generates a masking key for a frame. +func generateMaskingKey() (maskingKey []byte, err error) { + maskingKey = make([]byte, 4) + if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil { + return + } + return +} + +// generateNonce generates a nonce consisting of a randomly selected 16-byte +// value that has been base64-encoded. +func generateNonce() (nonce []byte) { + key := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + panic(err) + } + nonce = make([]byte, 24) + base64.StdEncoding.Encode(nonce, key) + return +} + +// removeZone removes IPv6 zone identifer from host. +// E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080" +func removeZone(host string) string { + if !strings.HasPrefix(host, "[") { + return host + } + i := strings.LastIndex(host, "]") + if i < 0 { + return host + } + j := strings.LastIndex(host[:i], "%") + if j < 0 { + return host + } + return host[:j] + host[i:] +} + +// getNonceAccept computes the base64-encoded SHA-1 of the concatenation of +// the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string. +func getNonceAccept(nonce []byte) (expected []byte, err error) { + h := sha1.New() + if _, err = h.Write(nonce); err != nil { + return + } + if _, err = h.Write([]byte(websocketGUID)); err != nil { + return + } + expected = make([]byte, 28) + base64.StdEncoding.Encode(expected, h.Sum(nil)) + return +} + +// Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17 +func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) { + bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n") + + // According to RFC 6874, an HTTP client, proxy, or other + // intermediary must remove any IPv6 zone identifier attached + // to an outgoing URI. + bw.WriteString("Host: " + removeZone(config.Location.Host) + "\r\n") + bw.WriteString("Upgrade: websocket\r\n") + bw.WriteString("Connection: Upgrade\r\n") + nonce := generateNonce() + if config.handshakeData != nil { + nonce = []byte(config.handshakeData["key"]) + } + bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n") + bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n") + + if config.Version != ProtocolVersionHybi13 { + return ErrBadProtocolVersion + } + + bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n") + if len(config.Protocol) > 0 { + bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n") + } + // TODO(ukai): send Sec-WebSocket-Extensions. + err = config.Header.WriteSubset(bw, handshakeHeader) + if err != nil { + return err + } + + bw.WriteString("\r\n") + if err = bw.Flush(); err != nil { + return err + } + + resp, err := http.ReadResponse(br, &http.Request{Method: "GET"}) + if err != nil { + return err + } + if resp.StatusCode != 101 { + return ErrBadStatus + } + if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" || + strings.ToLower(resp.Header.Get("Connection")) != "upgrade" { + return ErrBadUpgrade + } + expectedAccept, err := getNonceAccept(nonce) + if err != nil { + return err + } + if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) { + return ErrChallengeResponse + } + if resp.Header.Get("Sec-WebSocket-Extensions") != "" { + return ErrUnsupportedExtensions + } + offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol") + if offeredProtocol != "" { + protocolMatched := false + for i := 0; i < len(config.Protocol); i++ { + if config.Protocol[i] == offeredProtocol { + protocolMatched = true + break + } + } + if !protocolMatched { + return ErrBadWebSocketProtocol + } + config.Protocol = []string{offeredProtocol} + } + + return nil +} + +// newHybiClientConn creates a client WebSocket connection after handshake. +func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn { + return newHybiConn(config, buf, rwc, nil) +} + +// A HybiServerHandshaker performs a server handshake using hybi draft protocol. +type hybiServerHandshaker struct { + *Config + accept []byte +} + +func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) { + c.Version = ProtocolVersionHybi13 + if req.Method != "GET" { + return http.StatusMethodNotAllowed, ErrBadRequestMethod + } + // HTTP version can be safely ignored. + + if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || + !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") { + return http.StatusBadRequest, ErrNotWebSocket + } + + key := req.Header.Get("Sec-Websocket-Key") + if key == "" { + return http.StatusBadRequest, ErrChallengeResponse + } + version := req.Header.Get("Sec-Websocket-Version") + switch version { + case "13": + c.Version = ProtocolVersionHybi13 + default: + return http.StatusBadRequest, ErrBadWebSocketVersion + } + var scheme string + if req.TLS != nil { + scheme = "wss" + } else { + scheme = "ws" + } + c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI()) + if err != nil { + return http.StatusBadRequest, err + } + protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol")) + if protocol != "" { + protocols := strings.Split(protocol, ",") + for i := 0; i < len(protocols); i++ { + c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i])) + } + } + c.accept, err = getNonceAccept([]byte(key)) + if err != nil { + return http.StatusInternalServerError, err + } + return http.StatusSwitchingProtocols, nil +} + +// Origin parses the Origin header in req. +// If the Origin header is not set, it returns nil and nil. +func Origin(config *Config, req *http.Request) (*url.URL, error) { + var origin string + switch config.Version { + case ProtocolVersionHybi13: + origin = req.Header.Get("Origin") + } + if origin == "" { + return nil, nil + } + return url.ParseRequestURI(origin) +} + +func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) { + if len(c.Protocol) > 0 { + if len(c.Protocol) != 1 { + // You need choose a Protocol in Handshake func in Server. + return ErrBadWebSocketProtocol + } + } + buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n") + buf.WriteString("Upgrade: websocket\r\n") + buf.WriteString("Connection: Upgrade\r\n") + buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n") + if len(c.Protocol) > 0 { + buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n") + } + // TODO(ukai): send Sec-WebSocket-Extensions. + if c.Header != nil { + err := c.Header.WriteSubset(buf, handshakeHeader) + if err != nil { + return err + } + } + buf.WriteString("\r\n") + return buf.Flush() +} + +func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { + return newHybiServerConn(c.Config, buf, rwc, request) +} + +// newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol. +func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { + return newHybiConn(config, buf, rwc, request) +} diff --git a/vendor/golang.org/x/net/websocket/server.go b/vendor/golang.org/x/net/websocket/server.go new file mode 100644 index 0000000..0895dea --- /dev/null +++ b/vendor/golang.org/x/net/websocket/server.go @@ -0,0 +1,113 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "fmt" + "io" + "net/http" +) + +func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) { + var hs serverHandshaker = &hybiServerHandshaker{Config: config} + code, err := hs.ReadHandshake(buf.Reader, req) + if err == ErrBadWebSocketVersion { + fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) + fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion) + buf.WriteString("\r\n") + buf.WriteString(err.Error()) + buf.Flush() + return + } + if err != nil { + fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) + buf.WriteString("\r\n") + buf.WriteString(err.Error()) + buf.Flush() + return + } + if handshake != nil { + err = handshake(config, req) + if err != nil { + code = http.StatusForbidden + fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) + buf.WriteString("\r\n") + buf.Flush() + return + } + } + err = hs.AcceptHandshake(buf.Writer) + if err != nil { + code = http.StatusBadRequest + fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) + buf.WriteString("\r\n") + buf.Flush() + return + } + conn = hs.NewServerConn(buf, rwc, req) + return +} + +// Server represents a server of a WebSocket. +type Server struct { + // Config is a WebSocket configuration for new WebSocket connection. + Config + + // Handshake is an optional function in WebSocket handshake. + // For example, you can check, or don't check Origin header. + // Another example, you can select config.Protocol. + Handshake func(*Config, *http.Request) error + + // Handler handles a WebSocket connection. + Handler +} + +// ServeHTTP implements the http.Handler interface for a WebSocket +func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { + s.serveWebSocket(w, req) +} + +func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) { + rwc, buf, err := w.(http.Hijacker).Hijack() + if err != nil { + panic("Hijack failed: " + err.Error()) + } + // The server should abort the WebSocket connection if it finds + // the client did not send a handshake that matches with protocol + // specification. + defer rwc.Close() + conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake) + if err != nil { + return + } + if conn == nil { + panic("unexpected nil conn") + } + s.Handler(conn) +} + +// Handler is a simple interface to a WebSocket browser client. +// It checks if Origin header is valid URL by default. +// You might want to verify websocket.Conn.Config().Origin in the func. +// If you use Server instead of Handler, you could call websocket.Origin and +// check the origin in your Handshake func. So, if you want to accept +// non-browser clients, which do not send an Origin header, set a +// Server.Handshake that does not check the origin. +type Handler func(*Conn) + +func checkOrigin(config *Config, req *http.Request) (err error) { + config.Origin, err = Origin(config, req) + if err == nil && config.Origin == nil { + return fmt.Errorf("null origin") + } + return err +} + +// ServeHTTP implements the http.Handler interface for a WebSocket +func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + s := Server{Handler: h, Handshake: checkOrigin} + s.serveWebSocket(w, req) +} diff --git a/vendor/golang.org/x/net/websocket/websocket.go b/vendor/golang.org/x/net/websocket/websocket.go new file mode 100644 index 0000000..e242c89 --- /dev/null +++ b/vendor/golang.org/x/net/websocket/websocket.go @@ -0,0 +1,448 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package websocket implements a client and server for the WebSocket protocol +// as specified in RFC 6455. +// +// This package currently lacks some features found in an alternative +// and more actively maintained WebSocket package: +// +// https://godoc.org/github.com/gorilla/websocket +// +package websocket // import "golang.org/x/net/websocket" + +import ( + "bufio" + "crypto/tls" + "encoding/json" + "errors" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "sync" + "time" +) + +const ( + ProtocolVersionHybi13 = 13 + ProtocolVersionHybi = ProtocolVersionHybi13 + SupportedProtocolVersion = "13" + + ContinuationFrame = 0 + TextFrame = 1 + BinaryFrame = 2 + CloseFrame = 8 + PingFrame = 9 + PongFrame = 10 + UnknownFrame = 255 + + DefaultMaxPayloadBytes = 32 << 20 // 32MB +) + +// ProtocolError represents WebSocket protocol errors. +type ProtocolError struct { + ErrorString string +} + +func (err *ProtocolError) Error() string { return err.ErrorString } + +var ( + ErrBadProtocolVersion = &ProtocolError{"bad protocol version"} + ErrBadScheme = &ProtocolError{"bad scheme"} + ErrBadStatus = &ProtocolError{"bad status"} + ErrBadUpgrade = &ProtocolError{"missing or bad upgrade"} + ErrBadWebSocketOrigin = &ProtocolError{"missing or bad WebSocket-Origin"} + ErrBadWebSocketLocation = &ProtocolError{"missing or bad WebSocket-Location"} + ErrBadWebSocketProtocol = &ProtocolError{"missing or bad WebSocket-Protocol"} + ErrBadWebSocketVersion = &ProtocolError{"missing or bad WebSocket Version"} + ErrChallengeResponse = &ProtocolError{"mismatch challenge/response"} + ErrBadFrame = &ProtocolError{"bad frame"} + ErrBadFrameBoundary = &ProtocolError{"not on frame boundary"} + ErrNotWebSocket = &ProtocolError{"not websocket protocol"} + ErrBadRequestMethod = &ProtocolError{"bad method"} + ErrNotSupported = &ProtocolError{"not supported"} +) + +// ErrFrameTooLarge is returned by Codec's Receive method if payload size +// exceeds limit set by Conn.MaxPayloadBytes +var ErrFrameTooLarge = errors.New("websocket: frame payload size exceeds limit") + +// Addr is an implementation of net.Addr for WebSocket. +type Addr struct { + *url.URL +} + +// Network returns the network type for a WebSocket, "websocket". +func (addr *Addr) Network() string { return "websocket" } + +// Config is a WebSocket configuration +type Config struct { + // A WebSocket server address. + Location *url.URL + + // A Websocket client origin. + Origin *url.URL + + // WebSocket subprotocols. + Protocol []string + + // WebSocket protocol version. + Version int + + // TLS config for secure WebSocket (wss). + TlsConfig *tls.Config + + // Additional header fields to be sent in WebSocket opening handshake. + Header http.Header + + // Dialer used when opening websocket connections. + Dialer *net.Dialer + + handshakeData map[string]string +} + +// serverHandshaker is an interface to handle WebSocket server side handshake. +type serverHandshaker interface { + // ReadHandshake reads handshake request message from client. + // Returns http response code and error if any. + ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) + + // AcceptHandshake accepts the client handshake request and sends + // handshake response back to client. + AcceptHandshake(buf *bufio.Writer) (err error) + + // NewServerConn creates a new WebSocket connection. + NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) (conn *Conn) +} + +// frameReader is an interface to read a WebSocket frame. +type frameReader interface { + // Reader is to read payload of the frame. + io.Reader + + // PayloadType returns payload type. + PayloadType() byte + + // HeaderReader returns a reader to read header of the frame. + HeaderReader() io.Reader + + // TrailerReader returns a reader to read trailer of the frame. + // If it returns nil, there is no trailer in the frame. + TrailerReader() io.Reader + + // Len returns total length of the frame, including header and trailer. + Len() int +} + +// frameReaderFactory is an interface to creates new frame reader. +type frameReaderFactory interface { + NewFrameReader() (r frameReader, err error) +} + +// frameWriter is an interface to write a WebSocket frame. +type frameWriter interface { + // Writer is to write payload of the frame. + io.WriteCloser +} + +// frameWriterFactory is an interface to create new frame writer. +type frameWriterFactory interface { + NewFrameWriter(payloadType byte) (w frameWriter, err error) +} + +type frameHandler interface { + HandleFrame(frame frameReader) (r frameReader, err error) + WriteClose(status int) (err error) +} + +// Conn represents a WebSocket connection. +// +// Multiple goroutines may invoke methods on a Conn simultaneously. +type Conn struct { + config *Config + request *http.Request + + buf *bufio.ReadWriter + rwc io.ReadWriteCloser + + rio sync.Mutex + frameReaderFactory + frameReader + + wio sync.Mutex + frameWriterFactory + + frameHandler + PayloadType byte + defaultCloseStatus int + + // MaxPayloadBytes limits the size of frame payload received over Conn + // by Codec's Receive method. If zero, DefaultMaxPayloadBytes is used. + MaxPayloadBytes int +} + +// Read implements the io.Reader interface: +// it reads data of a frame from the WebSocket connection. +// if msg is not large enough for the frame data, it fills the msg and next Read +// will read the rest of the frame data. +// it reads Text frame or Binary frame. +func (ws *Conn) Read(msg []byte) (n int, err error) { + ws.rio.Lock() + defer ws.rio.Unlock() +again: + if ws.frameReader == nil { + frame, err := ws.frameReaderFactory.NewFrameReader() + if err != nil { + return 0, err + } + ws.frameReader, err = ws.frameHandler.HandleFrame(frame) + if err != nil { + return 0, err + } + if ws.frameReader == nil { + goto again + } + } + n, err = ws.frameReader.Read(msg) + if err == io.EOF { + if trailer := ws.frameReader.TrailerReader(); trailer != nil { + io.Copy(ioutil.Discard, trailer) + } + ws.frameReader = nil + goto again + } + return n, err +} + +// Write implements the io.Writer interface: +// it writes data as a frame to the WebSocket connection. +func (ws *Conn) Write(msg []byte) (n int, err error) { + ws.wio.Lock() + defer ws.wio.Unlock() + w, err := ws.frameWriterFactory.NewFrameWriter(ws.PayloadType) + if err != nil { + return 0, err + } + n, err = w.Write(msg) + w.Close() + return n, err +} + +// Close implements the io.Closer interface. +func (ws *Conn) Close() error { + err := ws.frameHandler.WriteClose(ws.defaultCloseStatus) + err1 := ws.rwc.Close() + if err != nil { + return err + } + return err1 +} + +func (ws *Conn) IsClientConn() bool { return ws.request == nil } +func (ws *Conn) IsServerConn() bool { return ws.request != nil } + +// LocalAddr returns the WebSocket Origin for the connection for client, or +// the WebSocket location for server. +func (ws *Conn) LocalAddr() net.Addr { + if ws.IsClientConn() { + return &Addr{ws.config.Origin} + } + return &Addr{ws.config.Location} +} + +// RemoteAddr returns the WebSocket location for the connection for client, or +// the Websocket Origin for server. +func (ws *Conn) RemoteAddr() net.Addr { + if ws.IsClientConn() { + return &Addr{ws.config.Location} + } + return &Addr{ws.config.Origin} +} + +var errSetDeadline = errors.New("websocket: cannot set deadline: not using a net.Conn") + +// SetDeadline sets the connection's network read & write deadlines. +func (ws *Conn) SetDeadline(t time.Time) error { + if conn, ok := ws.rwc.(net.Conn); ok { + return conn.SetDeadline(t) + } + return errSetDeadline +} + +// SetReadDeadline sets the connection's network read deadline. +func (ws *Conn) SetReadDeadline(t time.Time) error { + if conn, ok := ws.rwc.(net.Conn); ok { + return conn.SetReadDeadline(t) + } + return errSetDeadline +} + +// SetWriteDeadline sets the connection's network write deadline. +func (ws *Conn) SetWriteDeadline(t time.Time) error { + if conn, ok := ws.rwc.(net.Conn); ok { + return conn.SetWriteDeadline(t) + } + return errSetDeadline +} + +// Config returns the WebSocket config. +func (ws *Conn) Config() *Config { return ws.config } + +// Request returns the http request upgraded to the WebSocket. +// It is nil for client side. +func (ws *Conn) Request() *http.Request { return ws.request } + +// Codec represents a symmetric pair of functions that implement a codec. +type Codec struct { + Marshal func(v interface{}) (data []byte, payloadType byte, err error) + Unmarshal func(data []byte, payloadType byte, v interface{}) (err error) +} + +// Send sends v marshaled by cd.Marshal as single frame to ws. +func (cd Codec) Send(ws *Conn, v interface{}) (err error) { + data, payloadType, err := cd.Marshal(v) + if err != nil { + return err + } + ws.wio.Lock() + defer ws.wio.Unlock() + w, err := ws.frameWriterFactory.NewFrameWriter(payloadType) + if err != nil { + return err + } + _, err = w.Write(data) + w.Close() + return err +} + +// Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores +// in v. The whole frame payload is read to an in-memory buffer; max size of +// payload is defined by ws.MaxPayloadBytes. If frame payload size exceeds +// limit, ErrFrameTooLarge is returned; in this case frame is not read off wire +// completely. The next call to Receive would read and discard leftover data of +// previous oversized frame before processing next frame. +func (cd Codec) Receive(ws *Conn, v interface{}) (err error) { + ws.rio.Lock() + defer ws.rio.Unlock() + if ws.frameReader != nil { + _, err = io.Copy(ioutil.Discard, ws.frameReader) + if err != nil { + return err + } + ws.frameReader = nil + } +again: + frame, err := ws.frameReaderFactory.NewFrameReader() + if err != nil { + return err + } + frame, err = ws.frameHandler.HandleFrame(frame) + if err != nil { + return err + } + if frame == nil { + goto again + } + maxPayloadBytes := ws.MaxPayloadBytes + if maxPayloadBytes == 0 { + maxPayloadBytes = DefaultMaxPayloadBytes + } + if hf, ok := frame.(*hybiFrameReader); ok && hf.header.Length > int64(maxPayloadBytes) { + // payload size exceeds limit, no need to call Unmarshal + // + // set frameReader to current oversized frame so that + // the next call to this function can drain leftover + // data before processing the next frame + ws.frameReader = frame + return ErrFrameTooLarge + } + payloadType := frame.PayloadType() + data, err := ioutil.ReadAll(frame) + if err != nil { + return err + } + return cd.Unmarshal(data, payloadType, v) +} + +func marshal(v interface{}) (msg []byte, payloadType byte, err error) { + switch data := v.(type) { + case string: + return []byte(data), TextFrame, nil + case []byte: + return data, BinaryFrame, nil + } + return nil, UnknownFrame, ErrNotSupported +} + +func unmarshal(msg []byte, payloadType byte, v interface{}) (err error) { + switch data := v.(type) { + case *string: + *data = string(msg) + return nil + case *[]byte: + *data = msg + return nil + } + return ErrNotSupported +} + +/* +Message is a codec to send/receive text/binary data in a frame on WebSocket connection. +To send/receive text frame, use string type. +To send/receive binary frame, use []byte type. + +Trivial usage: + + import "websocket" + + // receive text frame + var message string + websocket.Message.Receive(ws, &message) + + // send text frame + message = "hello" + websocket.Message.Send(ws, message) + + // receive binary frame + var data []byte + websocket.Message.Receive(ws, &data) + + // send binary frame + data = []byte{0, 1, 2} + websocket.Message.Send(ws, data) + +*/ +var Message = Codec{marshal, unmarshal} + +func jsonMarshal(v interface{}) (msg []byte, payloadType byte, err error) { + msg, err = json.Marshal(v) + return msg, TextFrame, err +} + +func jsonUnmarshal(msg []byte, payloadType byte, v interface{}) (err error) { + return json.Unmarshal(msg, v) +} + +/* +JSON is a codec to send/receive JSON data in a frame from a WebSocket connection. + +Trivial usage: + + import "websocket" + + type T struct { + Msg string + Count int + } + + // receive JSON type T + var data T + websocket.JSON.Receive(ws, &data) + + // send JSON type T + websocket.JSON.Send(ws, data) +*/ +var JSON = Codec{jsonMarshal, jsonUnmarshal} diff --git a/vendor/golang.org/x/sync/semaphore/semaphore.go b/vendor/golang.org/x/sync/semaphore/semaphore.go new file mode 100644 index 0000000..e9d2d79 --- /dev/null +++ b/vendor/golang.org/x/sync/semaphore/semaphore.go @@ -0,0 +1,131 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package semaphore provides a weighted semaphore implementation. +package semaphore // import "golang.org/x/sync/semaphore" + +import ( + "container/list" + "sync" + + // Use the old context because packages that depend on this one + // (e.g. cloud.google.com/go/...) must run on Go 1.6. + // TODO(jba): update to "context" when possible. + "golang.org/x/net/context" +) + +type waiter struct { + n int64 + ready chan<- struct{} // Closed when semaphore acquired. +} + +// NewWeighted creates a new weighted semaphore with the given +// maximum combined weight for concurrent access. +func NewWeighted(n int64) *Weighted { + w := &Weighted{size: n} + return w +} + +// Weighted provides a way to bound concurrent access to a resource. +// The callers can request access with a given weight. +type Weighted struct { + size int64 + cur int64 + mu sync.Mutex + waiters list.List +} + +// Acquire acquires the semaphore with a weight of n, blocking only until ctx +// is done. On success, returns nil. On failure, returns ctx.Err() and leaves +// the semaphore unchanged. +// +// If ctx is already done, Acquire may still succeed without blocking. +func (s *Weighted) Acquire(ctx context.Context, n int64) error { + s.mu.Lock() + if s.size-s.cur >= n && s.waiters.Len() == 0 { + s.cur += n + s.mu.Unlock() + return nil + } + + if n > s.size { + // Don't make other Acquire calls block on one that's doomed to fail. + s.mu.Unlock() + <-ctx.Done() + return ctx.Err() + } + + ready := make(chan struct{}) + w := waiter{n: n, ready: ready} + elem := s.waiters.PushBack(w) + s.mu.Unlock() + + select { + case <-ctx.Done(): + err := ctx.Err() + s.mu.Lock() + select { + case <-ready: + // Acquired the semaphore after we were canceled. Rather than trying to + // fix up the queue, just pretend we didn't notice the cancelation. + err = nil + default: + s.waiters.Remove(elem) + } + s.mu.Unlock() + return err + + case <-ready: + return nil + } +} + +// TryAcquire acquires the semaphore with a weight of n without blocking. +// On success, returns true. On failure, returns false and leaves the semaphore unchanged. +func (s *Weighted) TryAcquire(n int64) bool { + s.mu.Lock() + success := s.size-s.cur >= n && s.waiters.Len() == 0 + if success { + s.cur += n + } + s.mu.Unlock() + return success +} + +// Release releases the semaphore with a weight of n. +func (s *Weighted) Release(n int64) { + s.mu.Lock() + s.cur -= n + if s.cur < 0 { + s.mu.Unlock() + panic("semaphore: bad release") + } + for { + next := s.waiters.Front() + if next == nil { + break // No more waiters blocked. + } + + w := next.Value.(waiter) + if s.size-s.cur < w.n { + // Not enough tokens for the next waiter. We could keep going (to try to + // find a waiter with a smaller request), but under load that could cause + // starvation for large requests; instead, we leave all remaining waiters + // blocked. + // + // Consider a semaphore used as a read-write lock, with N tokens, N + // readers, and one writer. Each reader can Acquire(1) to obtain a read + // lock. The writer can Acquire(N) to obtain a write lock, excluding all + // of the readers. If we allow the readers to jump ahead in the queue, + // the writer will starve — there is always one token available for every + // reader. + break + } + + s.cur += w.n + s.waiters.Remove(next) + close(w.ready) + } + s.mu.Unlock() +} diff --git a/vendor/google.golang.org/api/support/bundler/bundler.go b/vendor/google.golang.org/api/support/bundler/bundler.go new file mode 100644 index 0000000..c4e4c9a --- /dev/null +++ b/vendor/google.golang.org/api/support/bundler/bundler.go @@ -0,0 +1,258 @@ +// Copyright 2016 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package bundler supports bundling (batching) of items. Bundling amortizes an +// action with fixed costs over multiple items. For example, if an API provides +// an RPC that accepts a list of items as input, but clients would prefer +// adding items one at a time, then a Bundler can accept individual items from +// the client and bundle many of them into a single RPC. +// +// This package is experimental and subject to change without notice. +package bundler + +import ( + "errors" + "reflect" + "sync" + "time" + + "golang.org/x/net/context" + "golang.org/x/sync/semaphore" +) + +const ( + DefaultDelayThreshold = time.Second + DefaultBundleCountThreshold = 10 + DefaultBundleByteThreshold = 1e6 // 1M + DefaultBufferedByteLimit = 1e9 // 1G +) + +var ( + // ErrOverflow indicates that Bundler's stored bytes exceeds its BufferedByteLimit. + ErrOverflow = errors.New("bundler reached buffered byte limit") + + // ErrOversizedItem indicates that an item's size exceeds the maximum bundle size. + ErrOversizedItem = errors.New("item size exceeds bundle byte limit") +) + +// A Bundler collects items added to it into a bundle until the bundle +// exceeds a given size, then calls a user-provided function to handle the bundle. +type Bundler struct { + // Starting from the time that the first message is added to a bundle, once + // this delay has passed, handle the bundle. The default is DefaultDelayThreshold. + DelayThreshold time.Duration + + // Once a bundle has this many items, handle the bundle. Since only one + // item at a time is added to a bundle, no bundle will exceed this + // threshold, so it also serves as a limit. The default is + // DefaultBundleCountThreshold. + BundleCountThreshold int + + // Once the number of bytes in current bundle reaches this threshold, handle + // the bundle. The default is DefaultBundleByteThreshold. This triggers handling, + // but does not cap the total size of a bundle. + BundleByteThreshold int + + // The maximum size of a bundle, in bytes. Zero means unlimited. + BundleByteLimit int + + // The maximum number of bytes that the Bundler will keep in memory before + // returning ErrOverflow. The default is DefaultBufferedByteLimit. + BufferedByteLimit int + + handler func(interface{}) // called to handle a bundle + itemSliceZero reflect.Value // nil (zero value) for slice of items + flushTimer *time.Timer // implements DelayThreshold + + mu sync.Mutex + sem *semaphore.Weighted // enforces BufferedByteLimit + semOnce sync.Once + curBundle bundle // incoming items added to this bundle + handlingc <-chan struct{} // set to non-nil while a handler is running; closed when it returns +} + +type bundle struct { + items reflect.Value // slice of item type + size int // size in bytes of all items +} + +// NewBundler creates a new Bundler. +// +// itemExample is a value of the type that will be bundled. For example, if you +// want to create bundles of *Entry, you could pass &Entry{} for itemExample. +// +// handler is a function that will be called on each bundle. If itemExample is +// of type T, the argument to handler is of type []T. handler is always called +// sequentially for each bundle, and never in parallel. +// +// Configure the Bundler by setting its thresholds and limits before calling +// any of its methods. +func NewBundler(itemExample interface{}, handler func(interface{})) *Bundler { + b := &Bundler{ + DelayThreshold: DefaultDelayThreshold, + BundleCountThreshold: DefaultBundleCountThreshold, + BundleByteThreshold: DefaultBundleByteThreshold, + BufferedByteLimit: DefaultBufferedByteLimit, + + handler: handler, + itemSliceZero: reflect.Zero(reflect.SliceOf(reflect.TypeOf(itemExample))), + } + b.curBundle.items = b.itemSliceZero + return b +} + +func (b *Bundler) sema() *semaphore.Weighted { + // Create the semaphore lazily, because the user may set BufferedByteLimit + // after NewBundler. + b.semOnce.Do(func() { + b.sem = semaphore.NewWeighted(int64(b.BufferedByteLimit)) + }) + return b.sem +} + +// Add adds item to the current bundle. It marks the bundle for handling and +// starts a new one if any of the thresholds or limits are exceeded. +// +// If the item's size exceeds the maximum bundle size (Bundler.BundleByteLimit), then +// the item can never be handled. Add returns ErrOversizedItem in this case. +// +// If adding the item would exceed the maximum memory allowed +// (Bundler.BufferedByteLimit) or an AddWait call is blocked waiting for +// memory, Add returns ErrOverflow. +// +// Add never blocks. +func (b *Bundler) Add(item interface{}, size int) error { + // If this item exceeds the maximum size of a bundle, + // we can never send it. + if b.BundleByteLimit > 0 && size > b.BundleByteLimit { + return ErrOversizedItem + } + // If adding this item would exceed our allotted memory + // footprint, we can't accept it. + // (TryAcquire also returns false if anything is waiting on the semaphore, + // so calls to Add and AddWait shouldn't be mixed.) + if !b.sema().TryAcquire(int64(size)) { + return ErrOverflow + } + b.add(item, size) + return nil +} + +// add adds item to the current bundle. It marks the bundle for handling and +// starts a new one if any of the thresholds or limits are exceeded. +func (b *Bundler) add(item interface{}, size int) { + b.mu.Lock() + defer b.mu.Unlock() + + // If adding this item to the current bundle would cause it to exceed the + // maximum bundle size, close the current bundle and start a new one. + if b.BundleByteLimit > 0 && b.curBundle.size+size > b.BundleByteLimit { + b.startFlushLocked() + } + // Add the item. + b.curBundle.items = reflect.Append(b.curBundle.items, reflect.ValueOf(item)) + b.curBundle.size += size + + // Start a timer to flush the item if one isn't already running. + // startFlushLocked clears the timer and closes the bundle at the same time, + // so we only allocate a new timer for the first item in each bundle. + // (We could try to call Reset on the timer instead, but that would add a lot + // of complexity to the code just to save one small allocation.) + if b.flushTimer == nil { + b.flushTimer = time.AfterFunc(b.DelayThreshold, b.Flush) + } + + // If the current bundle equals the count threshold, close it. + if b.curBundle.items.Len() == b.BundleCountThreshold { + b.startFlushLocked() + } + // If the current bundle equals or exceeds the byte threshold, close it. + if b.curBundle.size >= b.BundleByteThreshold { + b.startFlushLocked() + } +} + +// AddWait adds item to the current bundle. It marks the bundle for handling and +// starts a new one if any of the thresholds or limits are exceeded. +// +// If the item's size exceeds the maximum bundle size (Bundler.BundleByteLimit), then +// the item can never be handled. AddWait returns ErrOversizedItem in this case. +// +// If adding the item would exceed the maximum memory allowed (Bundler.BufferedByteLimit), +// AddWait blocks until space is available or ctx is done. +// +// Calls to Add and AddWait should not be mixed on the same Bundler. +func (b *Bundler) AddWait(ctx context.Context, item interface{}, size int) error { + // If this item exceeds the maximum size of a bundle, + // we can never send it. + if b.BundleByteLimit > 0 && size > b.BundleByteLimit { + return ErrOversizedItem + } + // If adding this item would exceed our allotted memory footprint, block + // until space is available. The semaphore is FIFO, so there will be no + // starvation. + if err := b.sema().Acquire(ctx, int64(size)); err != nil { + return err + } + // Here, we've reserved space for item. Other goroutines can call AddWait + // and even acquire space, but no one can take away our reservation + // (assuming sem.Release is used correctly). So there is no race condition + // resulting from locking the mutex after sem.Acquire returns. + b.add(item, size) + return nil +} + +// Flush invokes the handler for all remaining items in the Bundler and waits +// for it to return. +func (b *Bundler) Flush() { + b.mu.Lock() + b.startFlushLocked() + done := b.handlingc + b.mu.Unlock() + + if done != nil { + <-done + } +} + +func (b *Bundler) startFlushLocked() { + if b.flushTimer != nil { + b.flushTimer.Stop() + b.flushTimer = nil + } + + if b.curBundle.items.Len() == 0 { + return + } + bun := b.curBundle + b.curBundle = bundle{items: b.itemSliceZero} + + done := make(chan struct{}) + var running <-chan struct{} + running, b.handlingc = b.handlingc, done + + go func() { + defer func() { + b.sem.Release(int64(bun.size)) + close(done) + }() + + if running != nil { + // Wait for our turn to call the handler. + <-running + } + + b.handler(bun.items.Interface()) + }() +}