add logworker
This commit is contained in:
parent
b7e4d5c99f
commit
2d46bec18b
17
box.rb
17
box.rb
|
@ -9,25 +9,32 @@ def foldercopy dir
|
|||
end
|
||||
|
||||
def gobuild pkg
|
||||
run "mkdir -p /root/go/bin && cd /root/go/bin && go#{$gover} build #{$repo}/#{pkg} && go#{$gover} install #{$repo}/#{pkg}"
|
||||
run "mkdir -p /root/go/bin && cd /root/go/bin && go#{$gover} build -v #{$repo}/#{pkg}"
|
||||
end
|
||||
|
||||
[
|
||||
"bot",
|
||||
"cmd",
|
||||
"internal",
|
||||
"vendor",
|
||||
"vendor-log",
|
||||
].each { |x| foldercopy x }
|
||||
|
||||
[
|
||||
"bot",
|
||||
"cmd",
|
||||
"internal",
|
||||
].each { |x| foldercopy x }
|
||||
|
||||
[
|
||||
"cmd/vyvanse",
|
||||
"cmd/logworker",
|
||||
].each { |x| gobuild x }
|
||||
|
||||
cmd "/root/go/bin/vyvanse"
|
||||
|
||||
run "rm -rf $HOME/sdk /root/go/pkg ||:"
|
||||
run "apk del go#{$gover}"
|
||||
|
||||
tag "xena/vyvanse:thick"
|
||||
|
||||
flatten
|
||||
|
||||
tag "xena/vyvanse"
|
||||
tag "xena/vyvanse:latest"
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"git.xeserv.us/xena/gorqlite"
|
||||
"git.xeserv.us/xena/vyvanse/internal/dao"
|
||||
|
||||
"github.com/Xe/ln"
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/drone/mq/stomp"
|
||||
"github.com/namsral/flag"
|
||||
opentracing "github.com/opentracing/opentracing-go"
|
||||
zipkin "github.com/openzipkin/zipkin-go-opentracing"
|
||||
)
|
||||
|
||||
var (
|
||||
token = flag.String("token", "", "discord bot token")
|
||||
zipkinURL = flag.String("zipkin-url", "", "URL for Zipkin traces")
|
||||
databaseURL = flag.String("database-url", "http://", "URL for database (rqlite)")
|
||||
mqURL = flag.String("mq-url", "tcp://mq:9000", "URL for STOMP server")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if *zipkinURL != "" {
|
||||
collector, err := zipkin.NewHTTPCollector(*zipkinURL)
|
||||
if err != nil {
|
||||
ln.FatalErr(context.Background(), err)
|
||||
}
|
||||
tracer, err := zipkin.NewTracer(
|
||||
zipkin.NewRecorder(collector, false, "logworker:5000", "logworker"))
|
||||
if err != nil {
|
||||
ln.FatalErr(context.Background(), err)
|
||||
}
|
||||
|
||||
opentracing.SetGlobalTracer(tracer)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
db, err := gorqlite.Open(*databaseURL)
|
||||
if err != nil {
|
||||
ln.FatalErr(ctx, err)
|
||||
}
|
||||
|
||||
ls := dao.NewLogs(db)
|
||||
err = ls.Migrate(ctx)
|
||||
if err != nil {
|
||||
ln.FatalErr(ctx, err, ln.F{"action": "migrate logs table"})
|
||||
}
|
||||
|
||||
mq, err := stomp.Dial(*mqURL)
|
||||
if err != nil {
|
||||
ln.FatalErr(ctx, err, ln.F{"url": *mqURL})
|
||||
}
|
||||
|
||||
mq.Subscribe("/topic/message_create", stomp.HandlerFunc(func(m *stomp.Message) {
|
||||
sp, ctx := opentracing.StartSpanFromContext(m.Context(), "logworker.topic.message.create")
|
||||
defer sp.Finish()
|
||||
|
||||
msg := &discordgo.Message{}
|
||||
err := json.Unmarshal(m.Msg, msg)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{"action": "can't unmarshal message body to a discordgo message"})
|
||||
return
|
||||
}
|
||||
|
||||
f := ln.F{
|
||||
"stomp_id": string(m.ID),
|
||||
"channel_id": msg.ChannelID,
|
||||
"message_id": msg.ID,
|
||||
"message_author": msg.Author.ID,
|
||||
"message_author_name": msg.Author.Username,
|
||||
"message_author_is_bot": msg.Author.Bot,
|
||||
}
|
||||
|
||||
err = ls.Add(ctx, msg)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f, ln.F{"action": "can't add discordgo message to the database"})
|
||||
}
|
||||
}))
|
||||
|
||||
for {
|
||||
select {}
|
||||
}
|
||||
}
|
|
@ -13,6 +13,7 @@ import (
|
|||
|
||||
"github.com/Xe/ln"
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/drone/mq/stomp"
|
||||
_ "github.com/joho/godotenv/autoload"
|
||||
"github.com/namsral/flag"
|
||||
xkcd "github.com/nishanths/go-xkcd"
|
||||
|
@ -29,11 +30,26 @@ var (
|
|||
token = flag.String("token", "", "discord bot token")
|
||||
zipkinURL = flag.String("zipkin-url", "", "URL for Zipkin traces")
|
||||
databaseURL = flag.String("database-url", "http://", "URL for database (rqlite)")
|
||||
mqURL = flag.String("mq-url", "tcp://mq:9000", "URL for STOMP server")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if *zipkinURL != "" {
|
||||
collector, err := zipkin.NewHTTPCollector(*zipkinURL)
|
||||
if err != nil {
|
||||
ln.FatalErr(context.Background(), err)
|
||||
}
|
||||
tracer, err := zipkin.NewTracer(
|
||||
zipkin.NewRecorder(collector, false, "vyvanse:5000", "vyvanse"))
|
||||
if err != nil {
|
||||
ln.FatalErr(context.Background(), err)
|
||||
}
|
||||
|
||||
opentracing.SetGlobalTracer(tracer)
|
||||
}
|
||||
|
||||
xk := xkcd.NewClient()
|
||||
dg, err := discordgo.New("Bot " + *token)
|
||||
if err != nil {
|
||||
|
@ -62,6 +78,14 @@ func main() {
|
|||
}
|
||||
sp.Finish()
|
||||
|
||||
ctx = context.Background()
|
||||
|
||||
mq, err := stomp.Dial(*mqURL)
|
||||
if err != nil {
|
||||
ln.FatalErr(ctx, err, ln.F{"url": *mqURL})
|
||||
}
|
||||
_ = mq
|
||||
|
||||
c := cron.New()
|
||||
|
||||
comic, err := xk.Latest()
|
||||
|
@ -111,20 +135,6 @@ func main() {
|
|||
|
||||
c.Start()
|
||||
|
||||
if *zipkinURL != "" {
|
||||
collector, err := zipkin.NewHTTPCollector(*zipkinURL)
|
||||
if err != nil {
|
||||
ln.FatalErr(context.Background(), err)
|
||||
}
|
||||
tracer, err := zipkin.NewTracer(
|
||||
zipkin.NewRecorder(collector, false, "vyvanse:5000", "vyvanse"))
|
||||
if err != nil {
|
||||
ln.FatalErr(context.Background(), err)
|
||||
}
|
||||
|
||||
opentracing.SetGlobalTracer(tracer)
|
||||
}
|
||||
|
||||
cs := bot.NewCommandSet()
|
||||
cs.Prefix = ">"
|
||||
|
||||
|
@ -134,6 +144,61 @@ func main() {
|
|||
cs.AddCmd("splattus", "splatoon 2 map rotation status", bot.NoPermissions, spla2nMaps)
|
||||
cs.AddCmd("top10", "shows the top 10 chatters on this server", bot.NoPermissions, top10(us))
|
||||
|
||||
dg.AddHandler(func(s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||
sp, ctx := opentracing.StartSpanFromContext(context.Background(), "message.create.post.stomp")
|
||||
defer sp.Finish()
|
||||
|
||||
f := ln.F{
|
||||
"channel_id": m.ChannelID,
|
||||
"message_id": m.ID,
|
||||
"message_author": m.Author.ID,
|
||||
"message_author_name": m.Author.Username,
|
||||
"message_author_is_bot": m.Author.Bot,
|
||||
}
|
||||
|
||||
err := mq.SendJSON("/topic/message_create", m.Message)
|
||||
if err != nil {
|
||||
if err.Error() == "EOF" {
|
||||
mq, err = stomp.Dial(*mqURL)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f, ln.F{"url": *mqURL, "action": "reconnect to mq"})
|
||||
return
|
||||
}
|
||||
|
||||
err = mq.SendJSON("/topic/message_create", m.Message)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f, ln.F{"action": "retry message_create post to message queue"})
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ln.Error(ctx, err, f, ln.F{"action": "send created message to queue"})
|
||||
return
|
||||
}
|
||||
|
||||
ln.Log(ctx, f, ln.F{"action": "message_create"})
|
||||
})
|
||||
|
||||
dg.AddHandler(func(s *discordgo.Session, m *discordgo.GuildMemberAdd) {
|
||||
sp, ctx := opentracing.StartSpanFromContext(context.Background(), "member.add.post.stomp")
|
||||
defer sp.Finish()
|
||||
|
||||
f := ln.F{
|
||||
"guild_id": m.GuildID,
|
||||
"user_id": m.User.ID,
|
||||
"user_name": m.User.Username,
|
||||
}
|
||||
|
||||
err := mq.SendJSON("/topic/member_add", m.Member)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f, ln.F{"action": "send added member to queue"})
|
||||
}
|
||||
|
||||
ln.Log(ctx, f, ln.F{"action": "member_add"})
|
||||
})
|
||||
|
||||
dg.AddHandler(func(s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||
if m.Author.ID == s.State.User.ID {
|
||||
return
|
||||
|
@ -165,8 +230,9 @@ func main() {
|
|||
return
|
||||
}
|
||||
|
||||
fmt.Println("Bot is now running. Press CTRL-C to exit.")
|
||||
// Simple way to keep program running until CTRL-C is pressed.
|
||||
<-make(chan struct{})
|
||||
return
|
||||
ln.Log(ctx, ln.F{"action": "bot is running"})
|
||||
|
||||
for {
|
||||
select {}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,5 +33,15 @@ services:
|
|||
- mq
|
||||
- rqlite
|
||||
|
||||
logworker:
|
||||
restart: always
|
||||
image: xena/vyvanse
|
||||
env_file: ./.env
|
||||
depends_on:
|
||||
- zipkin
|
||||
- mq
|
||||
- rqlite
|
||||
command: /root/go/bin/logworker
|
||||
|
||||
volumes:
|
||||
rqlite:
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.xeserv.us/xena/gorqlite"
|
||||
"github.com/Xe/ln"
|
||||
"github.com/bwmarrin/discordgo"
|
||||
opentracing "github.com/opentracing/opentracing-go"
|
||||
splog "github.com/opentracing/opentracing-go/log"
|
||||
"google.golang.org/api/support/bundler"
|
||||
)
|
||||
|
||||
type Logs struct {
|
||||
conn gorqlite.Connection
|
||||
bdl *bundler.Bundler
|
||||
}
|
||||
|
||||
func NewLogs(conn gorqlite.Connection) *Logs {
|
||||
l := &Logs{conn: conn}
|
||||
l.bdl = bundler.NewBundler("string", l.writeLines)
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *Logs) writeLines(sqlLines interface{}) {
|
||||
sp, ctx := opentracing.StartSpanFromContext(context.Background(), "logs.write.lines")
|
||||
defer sp.Finish()
|
||||
|
||||
_, err := l.conn.Write(sqlLines.([]string))
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{"action": "write lines that are batched"})
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logs) Migrate(ctx context.Context) error {
|
||||
sp, ctx := opentracing.StartSpanFromContext(ctx, "logs.migrate")
|
||||
defer sp.Finish()
|
||||
|
||||
migrationDDL := []string{
|
||||
`CREATE TABLE IF NOT EXISTS logs(id INTEGER PRIMARY KEY, discord_id TEXT UNIQUE, channel_id TEXT, content TEXT, timestamp INTEGER, mention_everyone INTEGER, author_id TEXT, author_username TEXT)`,
|
||||
}
|
||||
|
||||
res, err := l.conn.Write(migrationDDL)
|
||||
if err != nil {
|
||||
sp.LogFields(splog.Error(err))
|
||||
}
|
||||
|
||||
for i, re := range res {
|
||||
if re.Err != nil {
|
||||
sp.LogFields(splog.Error(re.Err))
|
||||
return re.Err
|
||||
}
|
||||
|
||||
sp.LogFields(splog.Int("migration.step", i), splog.Float64("timing", re.Timing), splog.Int64("rows.affected", re.RowsAffected))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Logs) Add(ctx context.Context, m *discordgo.Message) error {
|
||||
sp, ctx := opentracing.StartSpanFromContext(ctx, "logs.add")
|
||||
defer sp.Finish()
|
||||
|
||||
stmt := gorqlite.NewPreparedStatement(`INSERT INTO logs (discord_id, channel_id, content, timestamp, mention_everyone, author_id, author_username) VALUES (%s, %s, %s, %d, %d, %s, %s)`)
|
||||
ts, err := m.Timestamp.Parse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var me int
|
||||
if m.MentionEveryone {
|
||||
me = 1
|
||||
}
|
||||
|
||||
bd := stmt.Bind(m.ID, m.ChannelID, m.Content, ts, me, m.Author.ID, m.Author.Username)
|
||||
|
||||
l.bdl.Add(bd, len(bd))
|
||||
|
||||
return nil
|
||||
}
|
|
@ -28,7 +28,6 @@ func (u *Users) Migrate(ctx context.Context) error {
|
|||
res, err := u.conn.Write(migrationDDL)
|
||||
if err != nil {
|
||||
sp.LogFields(splog.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
for i, re := range res {
|
||||
|
|
|
@ -42,3 +42,11 @@ ae77be60afb1dcacde03767a8c37337fad28ac14 github.com/kardianos/osext
|
|||
40a5e952d22c3ef520c6ab7bdb9b1a010ec9a524 git.xeserv.us/xena/gorqlite
|
||||
97311d9f7767e3d6f422ea06661bc2c7a19e8a5d github.com/mattn/go-runewidth
|
||||
be5337e7b39e64e5f91445ce7e721888dbab7387 github.com/olekukonko/tablewriter
|
||||
280af2a3b9c7d9ce90d625150dfff972c6c190b8 github.com/drone/mq/logger
|
||||
280af2a3b9c7d9ce90d625150dfff972c6c190b8 github.com/drone/mq/stomp
|
||||
280af2a3b9c7d9ce90d625150dfff972c6c190b8 github.com/drone/mq/stomp/dialer
|
||||
f5079bd7f6f74e23c4d65efa0f4ce14cbd6a3c0f golang.org/x/net/context
|
||||
f5079bd7f6f74e23c4d65efa0f4ce14cbd6a3c0f golang.org/x/net/websocket
|
||||
66aacef3dd8a676686c7ae3716979581e8b03c47 golang.org/x/net/context
|
||||
f52d1811a62927559de87708c8913c1650ce4f26 golang.org/x/sync/semaphore
|
||||
e0e0e6e500066ff47335c7717e2a090ad127adec google.golang.org/api/support/bundler
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
package logger
|
||||
|
||||
var std Logger = new(none)
|
||||
|
||||
// Debugf writes a debug message to the standard logger.
|
||||
func Debugf(format string, args ...interface{}) {
|
||||
std.Debugf(format, args...)
|
||||
}
|
||||
|
||||
// Verbosef writes a verbose message to the standard logger.
|
||||
func Verbosef(format string, args ...interface{}) {
|
||||
std.Verbosef(format, args...)
|
||||
}
|
||||
|
||||
// Noticef writes a notice message to the standard logger.
|
||||
func Noticef(format string, args ...interface{}) {
|
||||
std.Noticef(format, args...)
|
||||
}
|
||||
|
||||
// Warningf writes a warning message to the standard logger.
|
||||
func Warningf(format string, args ...interface{}) {
|
||||
std.Warningf(format, args...)
|
||||
}
|
||||
|
||||
// Printf writes a default message to the standard logger.
|
||||
func Printf(format string, args ...interface{}) {
|
||||
std.Printf(format, args...)
|
||||
}
|
||||
|
||||
// SetLogger sets the standard logger.
|
||||
func SetLogger(logger Logger) {
|
||||
std = logger
|
||||
}
|
||||
|
||||
// Logger represents a logger.
|
||||
type Logger interface {
|
||||
|
||||
// Debugf writes a debug message.
|
||||
Debugf(string, ...interface{})
|
||||
|
||||
// Verbosef writes a verbose message.
|
||||
Verbosef(string, ...interface{})
|
||||
|
||||
// Noticef writes a notice message.
|
||||
Noticef(string, ...interface{})
|
||||
|
||||
// Warningf writes a warning message.
|
||||
Warningf(string, ...interface{})
|
||||
|
||||
// Printf writes a default message.
|
||||
Printf(string, ...interface{})
|
||||
}
|
||||
|
||||
// none is a logger that silently ignores all writes.
|
||||
type none struct{}
|
||||
|
||||
func (*none) Debugf(string, ...interface{}) {}
|
||||
func (*none) Verbosef(string, ...interface{}) {}
|
||||
func (*none) Noticef(string, ...interface{}) {}
|
||||
func (*none) Warningf(string, ...interface{}) {}
|
||||
func (*none) Printf(string, ...interface{}) {}
|
|
@ -0,0 +1,259 @@
|
|||
package stomp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/drone/mq/logger"
|
||||
"github.com/drone/mq/stomp/dialer"
|
||||
)
|
||||
|
||||
// Client defines a client connection to a STOMP server.
|
||||
type Client struct {
|
||||
mu sync.Mutex
|
||||
|
||||
peer Peer
|
||||
subs map[string]Handler
|
||||
wait map[string]chan struct{}
|
||||
done chan error
|
||||
|
||||
seq int64
|
||||
|
||||
skipVerify bool
|
||||
readBufferSize int
|
||||
writeBufferSize int
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// New returns a new STOMP client using the given connection.
|
||||
func New(peer Peer) *Client {
|
||||
return &Client{
|
||||
peer: peer,
|
||||
subs: make(map[string]Handler),
|
||||
wait: make(map[string]chan struct{}),
|
||||
done: make(chan error, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// Dial creates a client connection to the given target.
|
||||
func Dial(target string) (*Client, error) {
|
||||
conn, err := dialer.Dial(target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return New(Conn(conn)), nil
|
||||
}
|
||||
|
||||
// Send sends the data to the given destination.
|
||||
func (c *Client) Send(dest string, data []byte, opts ...MessageOption) error {
|
||||
m := NewMessage()
|
||||
m.Method = MethodSend
|
||||
m.Dest = []byte(dest)
|
||||
m.Body = data
|
||||
m.Apply(opts...)
|
||||
return c.sendMessage(m)
|
||||
}
|
||||
|
||||
// SendJSON sends the JSON encoding of v to the given destination.
|
||||
func (c *Client) SendJSON(dest string, v interface{}, opts ...MessageOption) error {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts = append(opts,
|
||||
WithHeader("content-type", "application/json"),
|
||||
)
|
||||
return c.Send(dest, data, opts...)
|
||||
}
|
||||
|
||||
// Subscribe subscribes to the given destination.
|
||||
func (c *Client) Subscribe(dest string, handler Handler, opts ...MessageOption) (id []byte, err error) {
|
||||
id = c.incr()
|
||||
|
||||
m := NewMessage()
|
||||
m.Method = MethodSubscribe
|
||||
m.ID = id
|
||||
m.Dest = []byte(dest)
|
||||
m.Apply(opts...)
|
||||
|
||||
c.mu.Lock()
|
||||
c.subs[string(id)] = handler
|
||||
c.mu.Unlock()
|
||||
|
||||
err = c.sendMessage(m)
|
||||
if err != nil {
|
||||
c.mu.Lock()
|
||||
delete(c.subs, string(id))
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Unsubscribe unsubscribes to the destination.
|
||||
func (c *Client) Unsubscribe(id []byte, opts ...MessageOption) error {
|
||||
c.mu.Lock()
|
||||
delete(c.subs, string(id))
|
||||
c.mu.Unlock()
|
||||
|
||||
m := NewMessage()
|
||||
m.Method = MethodUnsubscribe
|
||||
m.ID = id
|
||||
m.Apply(opts...)
|
||||
|
||||
return c.sendMessage(m)
|
||||
}
|
||||
|
||||
// Ack acknowledges the messages with the given id.
|
||||
func (c *Client) Ack(id []byte, opts ...MessageOption) error {
|
||||
m := NewMessage()
|
||||
m.Method = MethodAck
|
||||
m.ID = id
|
||||
m.Apply(opts...)
|
||||
|
||||
return c.sendMessage(m)
|
||||
}
|
||||
|
||||
// Nack negative-acknowledges the messages with the given id.
|
||||
func (c *Client) Nack(id []byte, opts ...MessageOption) error {
|
||||
m := NewMessage()
|
||||
m.Method = MethodNack
|
||||
m.ID = id
|
||||
m.Apply(opts...)
|
||||
|
||||
return c.peer.Send(m)
|
||||
}
|
||||
|
||||
// Connect opens the connection and establishes the session.
|
||||
func (c *Client) Connect(opts ...MessageOption) error {
|
||||
m := NewMessage()
|
||||
m.Proto = STOMP
|
||||
m.Method = MethodStomp
|
||||
m.Apply(opts...)
|
||||
if err := c.sendMessage(m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m, ok := <-c.peer.Receive()
|
||||
if !ok {
|
||||
return io.EOF
|
||||
}
|
||||
defer m.Release()
|
||||
|
||||
if !bytes.Equal(m.Method, MethodConnected) {
|
||||
return fmt.Errorf("stomp: inbound message: unexpected method, want connected")
|
||||
}
|
||||
go c.listen()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disconnect terminates the session and closes the connection.
|
||||
func (c *Client) Disconnect() error {
|
||||
m := NewMessage()
|
||||
m.Method = MethodDisconnect
|
||||
c.sendMessage(m)
|
||||
return c.peer.Close()
|
||||
}
|
||||
|
||||
// Done returns a channel
|
||||
func (c *Client) Done() <-chan error {
|
||||
return c.done
|
||||
}
|
||||
|
||||
func (c *Client) incr() []byte {
|
||||
c.mu.Lock()
|
||||
i := c.seq
|
||||
c.seq++
|
||||
c.mu.Unlock()
|
||||
return strconv.AppendInt(nil, i, 10)
|
||||
}
|
||||
|
||||
func (c *Client) listen() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Warningf("stomp client: recover panic: %s", r)
|
||||
err, ok := r.(error)
|
||||
if !ok {
|
||||
logger.Warningf("%v: %s", r, debug.Stack())
|
||||
c.done <- fmt.Errorf("%v", r)
|
||||
} else {
|
||||
logger.Warningf("%s", err)
|
||||
c.done <- err
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
m, ok := <-c.peer.Receive()
|
||||
if !ok {
|
||||
c.done <- io.EOF
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case bytes.Equal(m.Method, MethodMessage):
|
||||
c.handleMessage(m)
|
||||
case bytes.Equal(m.Method, MethodRecipet):
|
||||
c.handleReceipt(m)
|
||||
default:
|
||||
logger.Noticef("stomp client: unknown message type: %s",
|
||||
string(m.Method),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) handleReceipt(m *Message) {
|
||||
c.mu.Lock()
|
||||
receiptc, ok := c.wait[string(m.Receipt)]
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
logger.Noticef("stomp client: unknown read receipt: %s",
|
||||
string(m.Receipt),
|
||||
)
|
||||
return
|
||||
}
|
||||
receiptc <- struct{}{}
|
||||
}
|
||||
|
||||
func (c *Client) handleMessage(m *Message) {
|
||||
c.mu.Lock()
|
||||
handler, ok := c.subs[string(m.Subs)]
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
logger.Noticef("stomp client: subscription not found: %s",
|
||||
string(m.Subs),
|
||||
)
|
||||
return
|
||||
}
|
||||
handler.Handle(m)
|
||||
}
|
||||
|
||||
func (c *Client) sendMessage(m *Message) error {
|
||||
if len(m.Receipt) == 0 {
|
||||
return c.peer.Send(m)
|
||||
}
|
||||
|
||||
receiptc := make(chan struct{}, 1)
|
||||
c.wait[string(m.Receipt)] = receiptc
|
||||
|
||||
defer func() {
|
||||
delete(c.wait, string(m.Receipt))
|
||||
}()
|
||||
|
||||
err := c.peer.Send(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-receiptc:
|
||||
return nil
|
||||
}
|
||||
}
|
|
@ -0,0 +1,156 @@
|
|||
package stomp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/drone/mq/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
bufferSize = 32 << 10 // default buffer size 32KB
|
||||
bufferLimit = 32 << 15 // default buffer limit 1MB
|
||||
)
|
||||
|
||||
var (
|
||||
never time.Time
|
||||
deadline = time.Second * 5
|
||||
|
||||
heartbeatTime = time.Second * 30
|
||||
heartbeatWait = time.Second * 60
|
||||
)
|
||||
|
||||
type connPeer struct {
|
||||
conn net.Conn
|
||||
done chan bool
|
||||
|
||||
reader *bufio.Reader
|
||||
writer *bufio.Writer
|
||||
incoming chan *Message
|
||||
outgoing chan *Message
|
||||
}
|
||||
|
||||
// Conn creates a network-connected peer that reads and writes
|
||||
// messages using net.Conn c.
|
||||
func Conn(c net.Conn) Peer {
|
||||
p := &connPeer{
|
||||
reader: bufio.NewReaderSize(c, bufferSize),
|
||||
writer: bufio.NewWriterSize(c, bufferSize),
|
||||
incoming: make(chan *Message),
|
||||
outgoing: make(chan *Message),
|
||||
done: make(chan bool),
|
||||
conn: c,
|
||||
}
|
||||
|
||||
go p.readInto(p.incoming)
|
||||
go p.writeFrom(p.outgoing)
|
||||
return p
|
||||
}
|
||||
|
||||
func (c *connPeer) Receive() <-chan *Message {
|
||||
return c.incoming
|
||||
}
|
||||
|
||||
func (c *connPeer) Send(message *Message) error {
|
||||
select {
|
||||
case <-c.done:
|
||||
return io.EOF
|
||||
default:
|
||||
c.outgoing <- message
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connPeer) Addr() string {
|
||||
return c.conn.RemoteAddr().String()
|
||||
}
|
||||
|
||||
func (c *connPeer) Close() error {
|
||||
return c.close()
|
||||
}
|
||||
|
||||
func (c *connPeer) close() error {
|
||||
select {
|
||||
case <-c.done:
|
||||
return io.EOF
|
||||
default:
|
||||
close(c.done)
|
||||
close(c.incoming)
|
||||
close(c.outgoing)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connPeer) readInto(messages chan<- *Message) {
|
||||
defer c.close()
|
||||
|
||||
for {
|
||||
// lim := io.LimitReader(c.conn, bufferLimit)
|
||||
// buf := bufio.NewReaderSize(lim, bufferSize)
|
||||
|
||||
buf, err := c.reader.ReadBytes(0)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if len(buf) == 1 {
|
||||
c.conn.SetReadDeadline(time.Now().Add(heartbeatWait))
|
||||
logger.Verbosef("stomp: received heart-beat")
|
||||
continue
|
||||
}
|
||||
|
||||
msg := NewMessage()
|
||||
msg.Parse(buf[:len(buf)-1])
|
||||
|
||||
select {
|
||||
case <-c.done:
|
||||
break
|
||||
default:
|
||||
messages <- msg
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connPeer) writeFrom(messages <-chan *Message) {
|
||||
tick := time.NewTicker(time.Millisecond * 100).C
|
||||
heartbeat := time.NewTicker(heartbeatTime).C
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
break loop
|
||||
case <-heartbeat:
|
||||
logger.Verbosef("stomp: send heart-beat.")
|
||||
c.writer.WriteByte(0)
|
||||
case <-tick:
|
||||
c.conn.SetWriteDeadline(time.Now().Add(deadline))
|
||||
if err := c.writer.Flush(); err != nil {
|
||||
break loop
|
||||
}
|
||||
c.conn.SetWriteDeadline(never)
|
||||
case msg, ok := <-messages:
|
||||
if !ok {
|
||||
break loop
|
||||
}
|
||||
writeTo(c.writer, msg)
|
||||
c.writer.WriteByte(0)
|
||||
msg.Release()
|
||||
}
|
||||
}
|
||||
|
||||
c.drain()
|
||||
}
|
||||
|
||||
func (c *connPeer) drain() error {
|
||||
c.conn.SetWriteDeadline(time.Now().Add(deadline))
|
||||
for msg := range c.outgoing {
|
||||
writeTo(c.writer, msg)
|
||||
c.writer.WriteByte(0)
|
||||
msg.Release()
|
||||
}
|
||||
c.conn.SetWriteDeadline(never)
|
||||
c.writer.Flush()
|
||||
return c.conn.Close()
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
package stomp
|
||||
|
||||
// STOMP protocol version.
|
||||
var STOMP = []byte("1.2")
|
||||
|
||||
// STOMP protocol methods.
|
||||
var (
|
||||
MethodStomp = []byte("STOMP")
|
||||
MethodConnect = []byte("CONNECT")
|
||||
MethodConnected = []byte("CONNECTED")
|
||||
MethodSend = []byte("SEND")
|
||||
MethodSubscribe = []byte("SUBSCRIBE")
|
||||
MethodUnsubscribe = []byte("UNSUBSCRIBE")
|
||||
MethodAck = []byte("ACK")
|
||||
MethodNack = []byte("NACK")
|
||||
MethodDisconnect = []byte("DISCONNECT")
|
||||
MethodMessage = []byte("MESSAGE")
|
||||
MethodRecipet = []byte("RECEIPT")
|
||||
MethodError = []byte("ERROR")
|
||||
)
|
||||
|
||||
// STOMP protocol headers.
|
||||
var (
|
||||
HeaderAccept = []byte("accept-version")
|
||||
HeaderAck = []byte("ack")
|
||||
HeaderExpires = []byte("expires")
|
||||
HeaderDest = []byte("destination")
|
||||
HeaderHost = []byte("host")
|
||||
HeaderLogin = []byte("login")
|
||||
HeaderPass = []byte("passcode")
|
||||
HeaderID = []byte("id")
|
||||
HeaderMessageID = []byte("message-id")
|
||||
HeaderPersist = []byte("persist")
|
||||
HeaderPrefetch = []byte("prefetch-count")
|
||||
HeaderReceipt = []byte("receipt")
|
||||
HeaderReceiptID = []byte("receipt-id")
|
||||
HeaderRetain = []byte("retain")
|
||||
HeaderSelector = []byte("selector")
|
||||
HeaderServer = []byte("server")
|
||||
HeaderSession = []byte("session")
|
||||
HeaderSubscription = []byte("subscription")
|
||||
HeaderVersion = []byte("version")
|
||||
)
|
||||
|
||||
// Common STOMP header values.
|
||||
var (
|
||||
AckAuto = []byte("auto")
|
||||
AckClient = []byte("client")
|
||||
PersistTrue = []byte("true")
|
||||
RetainTrue = []byte("true")
|
||||
RetainLast = []byte("last")
|
||||
RetainAll = []byte("all")
|
||||
RetainRemove = []byte("remove")
|
||||
)
|
||||
|
||||
var headerLookup = map[string]struct{}{
|
||||
"accept-version": struct{}{},
|
||||
"ack": struct{}{},
|
||||
"expires": struct{}{},
|
||||
"destination": struct{}{},
|
||||
"host": struct{}{},
|
||||
"login": struct{}{},
|
||||
"passcode": struct{}{},
|
||||
"id": struct{}{},
|
||||
"message-id": struct{}{},
|
||||
"persist": struct{}{},
|
||||
"prefetch-count": struct{}{},
|
||||
"receipt": struct{}{},
|
||||
"receipt-id": struct{}{},
|
||||
"retain": struct{}{},
|
||||
"selector": struct{}{},
|
||||
"server": struct{}{},
|
||||
"session": struct{}{},
|
||||
"subscription": struct{}{},
|
||||
"version": struct{}{},
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package stomp
|
||||
|
||||
import "golang.org/x/net/context"
|
||||
|
||||
const clientKey = "stomp.client"
|
||||
|
||||
// NewContext adds the client to the context.
|
||||
func (c *Client) NewContext(ctx context.Context, client *Client) context.Context {
|
||||
// HACK for use with gin and echo
|
||||
if s, ok := ctx.(setter); ok {
|
||||
s.Set(clientKey, clientKey)
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, clientKey, client)
|
||||
}
|
||||
|
||||
// FromContext retrieves the client from context
|
||||
func FromContext(ctx context.Context) (*Client, bool) {
|
||||
client, ok := ctx.Value(clientKey).(*Client)
|
||||
return client, ok
|
||||
}
|
||||
|
||||
// MustFromContext retrieves the client from context. Panics if not found
|
||||
func MustFromContext(ctx context.Context) *Client {
|
||||
client, ok := FromContext(ctx)
|
||||
if !ok {
|
||||
panic("stomp.Client not found in context")
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// HACK setter defines a context that enables setting values. This is a
|
||||
// temporary workaround for use with gin and echo and will eventually
|
||||
// be removed. DO NOT depend on this.
|
||||
type setter interface {
|
||||
Set(string, interface{})
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
package dialer
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/url"
|
||||
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
protoHTTP = "http"
|
||||
protoHTTPS = "https"
|
||||
protoWS = "ws"
|
||||
protoWSS = "wss"
|
||||
protoTCP = "tcp"
|
||||
)
|
||||
|
||||
// Dial creates a client connection to the given target.
|
||||
func Dial(target string) (net.Conn, error) {
|
||||
u, err := url.Parse(target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch u.Scheme {
|
||||
case protoHTTP, protoHTTPS, protoWS, protoWSS:
|
||||
return dialWebsocket(u)
|
||||
case protoTCP:
|
||||
return dialSocket(u)
|
||||
default:
|
||||
panic("stomp: invalid protocol")
|
||||
}
|
||||
}
|
||||
|
||||
func dialWebsocket(target *url.URL) (net.Conn, error) {
|
||||
origin, err := target.Parse("/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch origin.Scheme {
|
||||
case protoWS:
|
||||
origin.Scheme = protoHTTP
|
||||
case protoWSS:
|
||||
origin.Scheme = protoHTTPS
|
||||
}
|
||||
return websocket.Dial(target.String(), "", origin.String())
|
||||
}
|
||||
|
||||
func dialSocket(target *url.URL) (net.Conn, error) {
|
||||
return net.Dial(protoTCP, target.Host)
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
package stomp
|
||||
|
||||
// Handler handles a STOMP message.
|
||||
type Handler interface {
|
||||
Handle(*Message)
|
||||
}
|
||||
|
||||
// The HandlerFunc type is an adapter to allow the use of an ordinary
|
||||
// function as a STOMP message handler.
|
||||
type HandlerFunc func(*Message)
|
||||
|
||||
// Handle calls f(m).
|
||||
func (f HandlerFunc) Handle(m *Message) { f(m) }
|
|
@ -0,0 +1,109 @@
|
|||
package stomp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const defaultHeaderLen = 5
|
||||
|
||||
type item struct {
|
||||
name []byte
|
||||
data []byte
|
||||
}
|
||||
|
||||
// Header represents the header section of the STOMP message.
|
||||
type Header struct {
|
||||
items []item
|
||||
itemc int
|
||||
}
|
||||
|
||||
func newHeader() *Header {
|
||||
return &Header{
|
||||
items: make([]item, defaultHeaderLen),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the named header value.
|
||||
func (h *Header) Get(name []byte) (b []byte) {
|
||||
for i := 0; i < h.itemc; i++ {
|
||||
if v := h.items[i]; bytes.Equal(v.name, name) {
|
||||
return v.data
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetString returns the named header value.
|
||||
func (h *Header) GetString(name string) string {
|
||||
k := []byte(name)
|
||||
v := h.Get(k)
|
||||
return string(v)
|
||||
}
|
||||
|
||||
// GetBool returns the named header value.
|
||||
func (h *Header) GetBool(name string) bool {
|
||||
s := h.GetString(name)
|
||||
b, _ := strconv.ParseBool(s)
|
||||
return b
|
||||
}
|
||||
|
||||
// GetInt returns the named header value.
|
||||
func (h *Header) GetInt(name string) int {
|
||||
s := h.GetString(name)
|
||||
i, _ := strconv.Atoi(s)
|
||||
return i
|
||||
}
|
||||
|
||||
// GetInt64 returns the named header value.
|
||||
func (h *Header) GetInt64(name string) int64 {
|
||||
s := h.GetString(name)
|
||||
i, _ := strconv.ParseInt(s, 10, 64)
|
||||
return i
|
||||
}
|
||||
|
||||
// Field returns the named header value in string format. This is used to
|
||||
// provide compatibility with the SQL expression evaluation package.
|
||||
func (h *Header) Field(name []byte) []byte {
|
||||
return h.Get(name)
|
||||
}
|
||||
|
||||
// Add appens the key value pair to the header.
|
||||
func (h *Header) Add(name, data []byte) {
|
||||
h.grow()
|
||||
h.items[h.itemc].name = name
|
||||
h.items[h.itemc].data = data
|
||||
h.itemc++
|
||||
}
|
||||
|
||||
// Index returns the keypair at index i.
|
||||
func (h *Header) Index(i int) (k, v []byte) {
|
||||
if i > h.itemc {
|
||||
return
|
||||
}
|
||||
k = h.items[i].name
|
||||
v = h.items[i].data
|
||||
return
|
||||
}
|
||||
|
||||
// Len returns the header length.
|
||||
func (h *Header) Len() int {
|
||||
return h.itemc
|
||||
}
|
||||
|
||||
func (h *Header) grow() {
|
||||
if h.itemc > defaultHeaderLen-1 {
|
||||
h.items = append(h.items, item{})
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Header) reset() {
|
||||
h.itemc = 0
|
||||
h.items = h.items[:defaultHeaderLen]
|
||||
for i := range h.items {
|
||||
h.items[i].name = zeroBytes
|
||||
h.items[i].data = zeroBytes
|
||||
}
|
||||
}
|
||||
|
||||
var zeroBytes []byte
|
|
@ -0,0 +1,146 @@
|
|||
package stomp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// Message represents a parsed STOMP message.
|
||||
type Message struct {
|
||||
ID []byte // id header
|
||||
Proto []byte // stomp version
|
||||
Method []byte // stomp method
|
||||
User []byte // username header
|
||||
Pass []byte // password header
|
||||
Dest []byte // destination header
|
||||
Subs []byte // subscription id
|
||||
Ack []byte // ack id
|
||||
Msg []byte // message-id header
|
||||
Persist []byte // persist header
|
||||
Retain []byte // retain header
|
||||
Prefetch []byte // prefetch count
|
||||
Expires []byte // expires header
|
||||
Receipt []byte // receipt header
|
||||
Selector []byte // selector header
|
||||
Body []byte
|
||||
Header *Header // custom headers
|
||||
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// Copy returns a copy of the Message.
|
||||
func (m *Message) Copy() *Message {
|
||||
c := NewMessage()
|
||||
c.ID = m.ID
|
||||
c.Proto = m.Proto
|
||||
c.Method = m.Method
|
||||
c.User = m.User
|
||||
c.Pass = m.Pass
|
||||
c.Dest = m.Dest
|
||||
c.Subs = m.Subs
|
||||
c.Ack = m.Ack
|
||||
c.Prefetch = m.Prefetch
|
||||
c.Selector = m.Selector
|
||||
c.Persist = m.Persist
|
||||
c.Retain = m.Retain
|
||||
c.Receipt = m.Receipt
|
||||
c.Expires = m.Expires
|
||||
c.Body = m.Body
|
||||
c.ctx = m.ctx
|
||||
c.Header.itemc = m.Header.itemc
|
||||
copy(c.Header.items, m.Header.items)
|
||||
return c
|
||||
}
|
||||
|
||||
// Apply applies the options to the message.
|
||||
func (m *Message) Apply(opts ...MessageOption) {
|
||||
for _, opt := range opts {
|
||||
opt(m)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse parses the raw bytes into the message.
|
||||
func (m *Message) Parse(b []byte) error {
|
||||
return read(b, m)
|
||||
}
|
||||
|
||||
// Bytes returns the Message in raw byte format.
|
||||
func (m *Message) Bytes() []byte {
|
||||
var buf bytes.Buffer
|
||||
writeTo(&buf, m)
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// String returns the Message in string format.
|
||||
func (m *Message) String() string {
|
||||
return string(m.Bytes())
|
||||
}
|
||||
|
||||
// Release releases the message back to the message pool.
|
||||
func (m *Message) Release() {
|
||||
m.Reset()
|
||||
pool.Put(m)
|
||||
}
|
||||
|
||||
// Reset resets the meesage fields to their zero values.
|
||||
func (m *Message) Reset() {
|
||||
m.ID = m.ID[:0]
|
||||
m.Proto = m.Proto[:0]
|
||||
m.Method = m.Method[:0]
|
||||
m.User = m.User[:0]
|
||||
m.Pass = m.Pass[:0]
|
||||
m.Dest = m.Dest[:0]
|
||||
m.Subs = m.Subs[:0]
|
||||
m.Ack = m.Ack[:0]
|
||||
m.Prefetch = m.Prefetch[:0]
|
||||
m.Selector = m.Selector[:0]
|
||||
m.Persist = m.Persist[:0]
|
||||
m.Retain = m.Retain[:0]
|
||||
m.Receipt = m.Receipt[:0]
|
||||
m.Expires = m.Expires[:0]
|
||||
m.Body = m.Body[:0]
|
||||
m.ctx = nil
|
||||
m.Header.reset()
|
||||
}
|
||||
|
||||
// Context returns the request's context.
|
||||
func (m *Message) Context() context.Context {
|
||||
if m.ctx != nil {
|
||||
return m.ctx
|
||||
}
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
// WithContext returns a shallow copy of m with its context changed
|
||||
// to ctx. The provided ctx must be non-nil.
|
||||
func (m *Message) WithContext(ctx context.Context) *Message {
|
||||
c := m.Copy()
|
||||
c.ctx = ctx
|
||||
return c
|
||||
}
|
||||
|
||||
// Unmarshal parses the JSON-encoded body of the message and
|
||||
// stores the result in the value pointed to by v.
|
||||
func (m *Message) Unmarshal(v interface{}) error {
|
||||
return json.Unmarshal(m.Body, v)
|
||||
}
|
||||
|
||||
// NewMessage returns an empty message from the message pool.
|
||||
func NewMessage() *Message {
|
||||
return pool.Get().(*Message)
|
||||
}
|
||||
|
||||
var pool = sync.Pool{New: func() interface{} {
|
||||
return &Message{Header: newHeader()}
|
||||
}}
|
||||
|
||||
// Rand returns a random int64 number as a []byte of
|
||||
// ascii characters.
|
||||
func Rand() []byte {
|
||||
return strconv.AppendInt(nil, rand.Int63(), 10)
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
package stomp
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// MessageOption configures message options.
|
||||
type MessageOption func(*Message)
|
||||
|
||||
// WithCredentials returns a MessageOption which sets credentials.
|
||||
func WithCredentials(username, password string) MessageOption {
|
||||
return func(m *Message) {
|
||||
m.User = []byte(username)
|
||||
m.Pass = []byte(password)
|
||||
}
|
||||
}
|
||||
|
||||
// WithHeader returns a MessageOption which sets a header.
|
||||
func WithHeader(key, value string) MessageOption {
|
||||
return func(m *Message) {
|
||||
_, ok := headerLookup[strings.ToLower(key)]
|
||||
if !ok {
|
||||
m.Header.Add(
|
||||
[]byte(key),
|
||||
[]byte(value),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithHeaders returns a MessageOption which sets headers.
|
||||
func WithHeaders(headers map[string]string) MessageOption {
|
||||
return func(m *Message) {
|
||||
for key, value := range headers {
|
||||
_, ok := headerLookup[strings.ToLower(key)]
|
||||
if !ok {
|
||||
m.Header.Add(
|
||||
[]byte(key),
|
||||
[]byte(value),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithExpires returns a MessageOption configured with an expiration.
|
||||
func WithExpires(exp int64) MessageOption {
|
||||
return func(m *Message) {
|
||||
m.Expires = strconv.AppendInt(nil, exp, 10)
|
||||
}
|
||||
}
|
||||
|
||||
// WithPrefetch returns a MessageOption configured with a prefetch count.
|
||||
func WithPrefetch(prefetch int) MessageOption {
|
||||
return func(m *Message) {
|
||||
m.Prefetch = strconv.AppendInt(nil, int64(prefetch), 10)
|
||||
}
|
||||
}
|
||||
|
||||
// WithReceipt returns a MessageOption configured with a receipt request.
|
||||
func WithReceipt() MessageOption {
|
||||
return func(m *Message) {
|
||||
m.Receipt = strconv.AppendInt(nil, rand.Int63(), 10)
|
||||
}
|
||||
}
|
||||
|
||||
// WithPersistence returns a MessageOption configured to persist.
|
||||
func WithPersistence() MessageOption {
|
||||
return func(m *Message) {
|
||||
m.Persist = PersistTrue
|
||||
}
|
||||
}
|
||||
|
||||
// WithRetain returns a MessageOption configured to retain the message.
|
||||
func WithRetain(retain string) MessageOption {
|
||||
return func(m *Message) {
|
||||
m.Retain = []byte(retain)
|
||||
}
|
||||
}
|
||||
|
||||
// WithSelector returns a MessageOption configured to filter messages
|
||||
// using a sql-like evaluation string.
|
||||
func WithSelector(selector string) MessageOption {
|
||||
return func(m *Message) {
|
||||
m.Selector = []byte(selector)
|
||||
}
|
||||
}
|
||||
|
||||
// WithAck returns a MessageOption configured with an ack policy.
|
||||
func WithAck(ack string) MessageOption {
|
||||
return func(m *Message) {
|
||||
m.Ack = []byte(ack)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,86 @@
|
|||
package stomp
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Peer defines a peer-to-peer connection.
|
||||
type Peer interface {
|
||||
// Send sends a message.
|
||||
Send(*Message) error
|
||||
|
||||
// Receive returns a channel of inbound messages.
|
||||
Receive() <-chan *Message
|
||||
|
||||
// Close closes the connection.
|
||||
Close() error
|
||||
|
||||
// Addr returns the peer address.
|
||||
Addr() string
|
||||
}
|
||||
|
||||
// Pipe creates a synchronous in-memory pipe, where reads on one end are
|
||||
// matched with writes on the other. This is useful for direct, in-memory
|
||||
// client-server communication.
|
||||
func Pipe() (Peer, Peer) {
|
||||
atob := make(chan *Message, 10)
|
||||
btoa := make(chan *Message, 10)
|
||||
|
||||
a := &localPeer{
|
||||
incoming: btoa,
|
||||
outgoing: atob,
|
||||
finished: make(chan bool),
|
||||
}
|
||||
b := &localPeer{
|
||||
incoming: atob,
|
||||
outgoing: btoa,
|
||||
finished: make(chan bool),
|
||||
}
|
||||
|
||||
return a, b
|
||||
}
|
||||
|
||||
type localPeer struct {
|
||||
finished chan bool
|
||||
outgoing chan<- *Message
|
||||
incoming <-chan *Message
|
||||
}
|
||||
|
||||
func (p *localPeer) Receive() <-chan *Message {
|
||||
return p.incoming
|
||||
}
|
||||
|
||||
func (p *localPeer) Send(m *Message) error {
|
||||
select {
|
||||
case <-p.finished:
|
||||
return io.EOF
|
||||
default:
|
||||
p.outgoing <- m
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *localPeer) Close() error {
|
||||
close(p.finished)
|
||||
close(p.outgoing)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *localPeer) Addr() string {
|
||||
peerAddrOnce.Do(func() {
|
||||
// get the local address list
|
||||
addr, _ := net.InterfaceAddrs()
|
||||
if len(addr) != 0 {
|
||||
// use the last address in the list
|
||||
peerAddr = addr[len(addr)-1].String()
|
||||
}
|
||||
})
|
||||
return peerAddr
|
||||
}
|
||||
|
||||
var peerAddrOnce sync.Once
|
||||
|
||||
// default address displayed for local pipes
|
||||
var peerAddr = "127.0.0.1/8"
|
|
@ -0,0 +1,139 @@
|
|||
package stomp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func read(input []byte, m *Message) (err error) {
|
||||
var (
|
||||
pos int
|
||||
off int
|
||||
tot = len(input)
|
||||
)
|
||||
|
||||
// parse the stomp message
|
||||
for ; ; off++ {
|
||||
if off == tot {
|
||||
return fmt.Errorf("stomp: invalid method")
|
||||
}
|
||||
if input[off] == '\n' {
|
||||
m.Method = input[pos:off]
|
||||
off++
|
||||
pos = off
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// parse the stomp headers
|
||||
for {
|
||||
if off == tot {
|
||||
return fmt.Errorf("stomp: unexpected eof")
|
||||
}
|
||||
if input[off] == '\n' {
|
||||
off++
|
||||
pos = off
|
||||
break
|
||||
}
|
||||
|
||||
var (
|
||||
name []byte
|
||||
value []byte
|
||||
)
|
||||
|
||||
loop:
|
||||
// parse each individual header
|
||||
for ; ; off++ {
|
||||
if off >= tot {
|
||||
return fmt.Errorf("stomp: unexpected eof")
|
||||
}
|
||||
|
||||
switch input[off] {
|
||||
case '\n':
|
||||
value = input[pos:off]
|
||||
off++
|
||||
pos = off
|
||||
break loop
|
||||
case ':':
|
||||
name = input[pos:off]
|
||||
off++
|
||||
pos = off
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case bytes.Equal(name, HeaderAccept):
|
||||
m.Proto = value
|
||||
case bytes.Equal(name, HeaderAck):
|
||||
m.Ack = value
|
||||
case bytes.Equal(name, HeaderDest):
|
||||
m.Dest = value
|
||||
case bytes.Equal(name, HeaderExpires):
|
||||
m.Expires = value
|
||||
case bytes.Equal(name, HeaderLogin):
|
||||
m.User = value
|
||||
case bytes.Equal(name, HeaderPass):
|
||||
m.Pass = value
|
||||
case bytes.Equal(name, HeaderID):
|
||||
m.ID = value
|
||||
case bytes.Equal(name, HeaderMessageID):
|
||||
m.ID = value
|
||||
case bytes.Equal(name, HeaderPersist):
|
||||
m.Persist = value
|
||||
case bytes.Equal(name, HeaderPrefetch):
|
||||
m.Prefetch = value
|
||||
case bytes.Equal(name, HeaderReceipt):
|
||||
m.Receipt = value
|
||||
case bytes.Equal(name, HeaderReceiptID):
|
||||
m.Receipt = value
|
||||
case bytes.Equal(name, HeaderRetain):
|
||||
m.Retain = value
|
||||
case bytes.Equal(name, HeaderSelector):
|
||||
m.Selector = value
|
||||
case bytes.Equal(name, HeaderSubscription):
|
||||
m.Subs = value
|
||||
case bytes.Equal(name, HeaderVersion):
|
||||
m.Proto = value
|
||||
default:
|
||||
m.Header.Add(name, value)
|
||||
}
|
||||
}
|
||||
|
||||
if tot > pos {
|
||||
m.Body = input[pos:]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const (
|
||||
asciiZero = 48
|
||||
asciiNine = 57
|
||||
)
|
||||
|
||||
// ParseInt returns the ascii integer value.
|
||||
func ParseInt(d []byte) (n int) {
|
||||
if len(d) == 0 {
|
||||
return 0
|
||||
}
|
||||
for _, dec := range d {
|
||||
if dec < asciiZero || dec > asciiNine {
|
||||
return 0
|
||||
}
|
||||
n = n*10 + (int(dec) - asciiZero)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// ParseInt64 returns the ascii integer value.
|
||||
func ParseInt64(d []byte) (n int64) {
|
||||
if len(d) == 0 {
|
||||
return 0
|
||||
}
|
||||
for _, dec := range d {
|
||||
if dec < asciiZero || dec > asciiNine {
|
||||
return 0
|
||||
}
|
||||
n = n*10 + (int64(dec) - asciiZero)
|
||||
}
|
||||
return n
|
||||
}
|
|
@ -0,0 +1,173 @@
|
|||
package stomp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
)
|
||||
|
||||
var (
|
||||
crlf = []byte{'\r', '\n'}
|
||||
newline = []byte{'\n'}
|
||||
separator = []byte{':'}
|
||||
terminator = []byte{0}
|
||||
)
|
||||
|
||||
func writeTo(w io.Writer, m *Message) {
|
||||
w.Write(m.Method)
|
||||
w.Write(newline)
|
||||
|
||||
switch {
|
||||
case bytes.Equal(m.Method, MethodStomp):
|
||||
// version
|
||||
w.Write(HeaderAccept)
|
||||
w.Write(separator)
|
||||
w.Write(m.Proto)
|
||||
w.Write(newline)
|
||||
// login
|
||||
if len(m.User) != 0 {
|
||||
w.Write(HeaderLogin)
|
||||
w.Write(separator)
|
||||
w.Write(m.User)
|
||||
w.Write(newline)
|
||||
}
|
||||
// passcode
|
||||
if len(m.Pass) != 0 {
|
||||
w.Write(HeaderPass)
|
||||
w.Write(separator)
|
||||
w.Write(m.Pass)
|
||||
w.Write(newline)
|
||||
}
|
||||
case bytes.Equal(m.Method, MethodConnected):
|
||||
// version
|
||||
w.Write(HeaderVersion)
|
||||
w.Write(separator)
|
||||
w.Write(m.Proto)
|
||||
w.Write(newline)
|
||||
case bytes.Equal(m.Method, MethodSend):
|
||||
// dest
|
||||
w.Write(HeaderDest)
|
||||
w.Write(separator)
|
||||
w.Write(m.Dest)
|
||||
w.Write(newline)
|
||||
if len(m.Expires) != 0 {
|
||||
w.Write(HeaderExpires)
|
||||
w.Write(separator)
|
||||
w.Write(m.Expires)
|
||||
w.Write(newline)
|
||||
}
|
||||
if len(m.Retain) != 0 {
|
||||
w.Write(HeaderRetain)
|
||||
w.Write(separator)
|
||||
w.Write(m.Retain)
|
||||
w.Write(newline)
|
||||
}
|
||||
if len(m.Persist) != 0 {
|
||||
w.Write(HeaderPersist)
|
||||
w.Write(separator)
|
||||
w.Write(m.Persist)
|
||||
w.Write(newline)
|
||||
}
|
||||
case bytes.Equal(m.Method, MethodSubscribe):
|
||||
// id
|
||||
w.Write(HeaderID)
|
||||
w.Write(separator)
|
||||
w.Write(m.ID)
|
||||
w.Write(newline)
|
||||
// destination
|
||||
w.Write(HeaderDest)
|
||||
w.Write(separator)
|
||||
w.Write(m.Dest)
|
||||
w.Write(newline)
|
||||
// selector
|
||||
if len(m.Selector) != 0 {
|
||||
w.Write(HeaderSelector)
|
||||
w.Write(separator)
|
||||
w.Write(m.Selector)
|
||||
w.Write(newline)
|
||||
}
|
||||
// prefetch
|
||||
if len(m.Prefetch) != 0 {
|
||||
w.Write(HeaderPrefetch)
|
||||
w.Write(separator)
|
||||
w.Write(m.Prefetch)
|
||||
w.Write(newline)
|
||||
}
|
||||
if len(m.Ack) != 0 {
|
||||
w.Write(HeaderAck)
|
||||
w.Write(separator)
|
||||
w.Write(m.Ack)
|
||||
w.Write(newline)
|
||||
}
|
||||
case bytes.Equal(m.Method, MethodUnsubscribe):
|
||||
// id
|
||||
w.Write(HeaderID)
|
||||
w.Write(separator)
|
||||
w.Write(m.ID)
|
||||
w.Write(newline)
|
||||
case bytes.Equal(m.Method, MethodAck):
|
||||
// id
|
||||
w.Write(HeaderID)
|
||||
w.Write(separator)
|
||||
w.Write(m.ID)
|
||||
w.Write(newline)
|
||||
case bytes.Equal(m.Method, MethodNack):
|
||||
// id
|
||||
w.Write(HeaderID)
|
||||
w.Write(separator)
|
||||
w.Write(m.ID)
|
||||
w.Write(newline)
|
||||
case bytes.Equal(m.Method, MethodMessage):
|
||||
// message-id
|
||||
w.Write(HeaderMessageID)
|
||||
w.Write(separator)
|
||||
w.Write(m.ID)
|
||||
w.Write(newline)
|
||||
// destination
|
||||
w.Write(HeaderDest)
|
||||
w.Write(separator)
|
||||
w.Write(m.Dest)
|
||||
w.Write(newline)
|
||||
// subscription
|
||||
w.Write(HeaderSubscription)
|
||||
w.Write(separator)
|
||||
w.Write(m.Subs)
|
||||
w.Write(newline)
|
||||
// ack
|
||||
if len(m.Ack) != 0 {
|
||||
w.Write(HeaderAck)
|
||||
w.Write(separator)
|
||||
w.Write(m.Ack)
|
||||
w.Write(newline)
|
||||
}
|
||||
case bytes.Equal(m.Method, MethodRecipet):
|
||||
// receipt-id
|
||||
w.Write(HeaderReceiptID)
|
||||
w.Write(separator)
|
||||
w.Write(m.Receipt)
|
||||
w.Write(newline)
|
||||
}
|
||||
|
||||
// receipt header
|
||||
if includeReceiptHeader(m) {
|
||||
w.Write(HeaderReceipt)
|
||||
w.Write(separator)
|
||||
w.Write(m.Receipt)
|
||||
w.Write(newline)
|
||||
}
|
||||
|
||||
for i, item := range m.Header.items {
|
||||
if m.Header.itemc == i {
|
||||
break
|
||||
}
|
||||
w.Write(item.name)
|
||||
w.Write(separator)
|
||||
w.Write(item.data)
|
||||
w.Write(newline)
|
||||
}
|
||||
w.Write(newline)
|
||||
w.Write(m.Body)
|
||||
}
|
||||
|
||||
func includeReceiptHeader(m *Message) bool {
|
||||
return len(m.Receipt) != 0 && !bytes.Equal(m.Method, MethodRecipet)
|
||||
}
|
|
@ -0,0 +1,106 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// DialError is an error that occurs while dialling a websocket server.
|
||||
type DialError struct {
|
||||
*Config
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *DialError) Error() string {
|
||||
return "websocket.Dial " + e.Config.Location.String() + ": " + e.Err.Error()
|
||||
}
|
||||
|
||||
// NewConfig creates a new WebSocket config for client connection.
|
||||
func NewConfig(server, origin string) (config *Config, err error) {
|
||||
config = new(Config)
|
||||
config.Version = ProtocolVersionHybi13
|
||||
config.Location, err = url.ParseRequestURI(server)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
config.Origin, err = url.ParseRequestURI(origin)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
config.Header = http.Header(make(map[string][]string))
|
||||
return
|
||||
}
|
||||
|
||||
// NewClient creates a new WebSocket client connection over rwc.
|
||||
func NewClient(config *Config, rwc io.ReadWriteCloser) (ws *Conn, err error) {
|
||||
br := bufio.NewReader(rwc)
|
||||
bw := bufio.NewWriter(rwc)
|
||||
err = hybiClientHandshake(config, br, bw)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
buf := bufio.NewReadWriter(br, bw)
|
||||
ws = newHybiClientConn(config, buf, rwc)
|
||||
return
|
||||
}
|
||||
|
||||
// Dial opens a new client connection to a WebSocket.
|
||||
func Dial(url_, protocol, origin string) (ws *Conn, err error) {
|
||||
config, err := NewConfig(url_, origin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if protocol != "" {
|
||||
config.Protocol = []string{protocol}
|
||||
}
|
||||
return DialConfig(config)
|
||||
}
|
||||
|
||||
var portMap = map[string]string{
|
||||
"ws": "80",
|
||||
"wss": "443",
|
||||
}
|
||||
|
||||
func parseAuthority(location *url.URL) string {
|
||||
if _, ok := portMap[location.Scheme]; ok {
|
||||
if _, _, err := net.SplitHostPort(location.Host); err != nil {
|
||||
return net.JoinHostPort(location.Host, portMap[location.Scheme])
|
||||
}
|
||||
}
|
||||
return location.Host
|
||||
}
|
||||
|
||||
// DialConfig opens a new client connection to a WebSocket with a config.
|
||||
func DialConfig(config *Config) (ws *Conn, err error) {
|
||||
var client net.Conn
|
||||
if config.Location == nil {
|
||||
return nil, &DialError{config, ErrBadWebSocketLocation}
|
||||
}
|
||||
if config.Origin == nil {
|
||||
return nil, &DialError{config, ErrBadWebSocketOrigin}
|
||||
}
|
||||
dialer := config.Dialer
|
||||
if dialer == nil {
|
||||
dialer = &net.Dialer{}
|
||||
}
|
||||
client, err = dialWithDialer(dialer, config)
|
||||
if err != nil {
|
||||
goto Error
|
||||
}
|
||||
ws, err = NewClient(config, client)
|
||||
if err != nil {
|
||||
client.Close()
|
||||
goto Error
|
||||
}
|
||||
return
|
||||
|
||||
Error:
|
||||
return nil, &DialError{config, err}
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
)
|
||||
|
||||
func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
|
||||
switch config.Location.Scheme {
|
||||
case "ws":
|
||||
conn, err = dialer.Dial("tcp", parseAuthority(config.Location))
|
||||
|
||||
case "wss":
|
||||
conn, err = tls.DialWithDialer(dialer, "tcp", parseAuthority(config.Location), config.TlsConfig)
|
||||
|
||||
default:
|
||||
err = ErrBadScheme
|
||||
}
|
||||
return
|
||||
}
|
|
@ -0,0 +1,583 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package websocket
|
||||
|
||||
// This file implements a protocol of hybi draft.
|
||||
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
|
||||
closeStatusNormal = 1000
|
||||
closeStatusGoingAway = 1001
|
||||
closeStatusProtocolError = 1002
|
||||
closeStatusUnsupportedData = 1003
|
||||
closeStatusFrameTooLarge = 1004
|
||||
closeStatusNoStatusRcvd = 1005
|
||||
closeStatusAbnormalClosure = 1006
|
||||
closeStatusBadMessageData = 1007
|
||||
closeStatusPolicyViolation = 1008
|
||||
closeStatusTooBigData = 1009
|
||||
closeStatusExtensionMismatch = 1010
|
||||
|
||||
maxControlFramePayloadLength = 125
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBadMaskingKey = &ProtocolError{"bad masking key"}
|
||||
ErrBadPongMessage = &ProtocolError{"bad pong message"}
|
||||
ErrBadClosingStatus = &ProtocolError{"bad closing status"}
|
||||
ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"}
|
||||
ErrNotImplemented = &ProtocolError{"not implemented"}
|
||||
|
||||
handshakeHeader = map[string]bool{
|
||||
"Host": true,
|
||||
"Upgrade": true,
|
||||
"Connection": true,
|
||||
"Sec-Websocket-Key": true,
|
||||
"Sec-Websocket-Origin": true,
|
||||
"Sec-Websocket-Version": true,
|
||||
"Sec-Websocket-Protocol": true,
|
||||
"Sec-Websocket-Accept": true,
|
||||
}
|
||||
)
|
||||
|
||||
// A hybiFrameHeader is a frame header as defined in hybi draft.
|
||||
type hybiFrameHeader struct {
|
||||
Fin bool
|
||||
Rsv [3]bool
|
||||
OpCode byte
|
||||
Length int64
|
||||
MaskingKey []byte
|
||||
|
||||
data *bytes.Buffer
|
||||
}
|
||||
|
||||
// A hybiFrameReader is a reader for hybi frame.
|
||||
type hybiFrameReader struct {
|
||||
reader io.Reader
|
||||
|
||||
header hybiFrameHeader
|
||||
pos int64
|
||||
length int
|
||||
}
|
||||
|
||||
func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) {
|
||||
n, err = frame.reader.Read(msg)
|
||||
if frame.header.MaskingKey != nil {
|
||||
for i := 0; i < n; i++ {
|
||||
msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4]
|
||||
frame.pos++
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode }
|
||||
|
||||
func (frame *hybiFrameReader) HeaderReader() io.Reader {
|
||||
if frame.header.data == nil {
|
||||
return nil
|
||||
}
|
||||
if frame.header.data.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
return frame.header.data
|
||||
}
|
||||
|
||||
func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil }
|
||||
|
||||
func (frame *hybiFrameReader) Len() (n int) { return frame.length }
|
||||
|
||||
// A hybiFrameReaderFactory creates new frame reader based on its frame type.
|
||||
type hybiFrameReaderFactory struct {
|
||||
*bufio.Reader
|
||||
}
|
||||
|
||||
// NewFrameReader reads a frame header from the connection, and creates new reader for the frame.
|
||||
// See Section 5.2 Base Framing protocol for detail.
|
||||
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2
|
||||
func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) {
|
||||
hybiFrame := new(hybiFrameReader)
|
||||
frame = hybiFrame
|
||||
var header []byte
|
||||
var b byte
|
||||
// First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits)
|
||||
b, err = buf.ReadByte()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
header = append(header, b)
|
||||
hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0
|
||||
for i := 0; i < 3; i++ {
|
||||
j := uint(6 - i)
|
||||
hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0
|
||||
}
|
||||
hybiFrame.header.OpCode = header[0] & 0x0f
|
||||
|
||||
// Second byte. Mask/Payload len(7bits)
|
||||
b, err = buf.ReadByte()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
header = append(header, b)
|
||||
mask := (b & 0x80) != 0
|
||||
b &= 0x7f
|
||||
lengthFields := 0
|
||||
switch {
|
||||
case b <= 125: // Payload length 7bits.
|
||||
hybiFrame.header.Length = int64(b)
|
||||
case b == 126: // Payload length 7+16bits
|
||||
lengthFields = 2
|
||||
case b == 127: // Payload length 7+64bits
|
||||
lengthFields = 8
|
||||
}
|
||||
for i := 0; i < lengthFields; i++ {
|
||||
b, err = buf.ReadByte()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if lengthFields == 8 && i == 0 { // MSB must be zero when 7+64 bits
|
||||
b &= 0x7f
|
||||
}
|
||||
header = append(header, b)
|
||||
hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b)
|
||||
}
|
||||
if mask {
|
||||
// Masking key. 4 bytes.
|
||||
for i := 0; i < 4; i++ {
|
||||
b, err = buf.ReadByte()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
header = append(header, b)
|
||||
hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b)
|
||||
}
|
||||
}
|
||||
hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length)
|
||||
hybiFrame.header.data = bytes.NewBuffer(header)
|
||||
hybiFrame.length = len(header) + int(hybiFrame.header.Length)
|
||||
return
|
||||
}
|
||||
|
||||
// A HybiFrameWriter is a writer for hybi frame.
|
||||
type hybiFrameWriter struct {
|
||||
writer *bufio.Writer
|
||||
|
||||
header *hybiFrameHeader
|
||||
}
|
||||
|
||||
func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) {
|
||||
var header []byte
|
||||
var b byte
|
||||
if frame.header.Fin {
|
||||
b |= 0x80
|
||||
}
|
||||
for i := 0; i < 3; i++ {
|
||||
if frame.header.Rsv[i] {
|
||||
j := uint(6 - i)
|
||||
b |= 1 << j
|
||||
}
|
||||
}
|
||||
b |= frame.header.OpCode
|
||||
header = append(header, b)
|
||||
if frame.header.MaskingKey != nil {
|
||||
b = 0x80
|
||||
} else {
|
||||
b = 0
|
||||
}
|
||||
lengthFields := 0
|
||||
length := len(msg)
|
||||
switch {
|
||||
case length <= 125:
|
||||
b |= byte(length)
|
||||
case length < 65536:
|
||||
b |= 126
|
||||
lengthFields = 2
|
||||
default:
|
||||
b |= 127
|
||||
lengthFields = 8
|
||||
}
|
||||
header = append(header, b)
|
||||
for i := 0; i < lengthFields; i++ {
|
||||
j := uint((lengthFields - i - 1) * 8)
|
||||
b = byte((length >> j) & 0xff)
|
||||
header = append(header, b)
|
||||
}
|
||||
if frame.header.MaskingKey != nil {
|
||||
if len(frame.header.MaskingKey) != 4 {
|
||||
return 0, ErrBadMaskingKey
|
||||
}
|
||||
header = append(header, frame.header.MaskingKey...)
|
||||
frame.writer.Write(header)
|
||||
data := make([]byte, length)
|
||||
for i := range data {
|
||||
data[i] = msg[i] ^ frame.header.MaskingKey[i%4]
|
||||
}
|
||||
frame.writer.Write(data)
|
||||
err = frame.writer.Flush()
|
||||
return length, err
|
||||
}
|
||||
frame.writer.Write(header)
|
||||
frame.writer.Write(msg)
|
||||
err = frame.writer.Flush()
|
||||
return length, err
|
||||
}
|
||||
|
||||
func (frame *hybiFrameWriter) Close() error { return nil }
|
||||
|
||||
type hybiFrameWriterFactory struct {
|
||||
*bufio.Writer
|
||||
needMaskingKey bool
|
||||
}
|
||||
|
||||
func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) {
|
||||
frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType}
|
||||
if buf.needMaskingKey {
|
||||
frameHeader.MaskingKey, err = generateMaskingKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil
|
||||
}
|
||||
|
||||
type hybiFrameHandler struct {
|
||||
conn *Conn
|
||||
payloadType byte
|
||||
}
|
||||
|
||||
func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, error) {
|
||||
if handler.conn.IsServerConn() {
|
||||
// The client MUST mask all frames sent to the server.
|
||||
if frame.(*hybiFrameReader).header.MaskingKey == nil {
|
||||
handler.WriteClose(closeStatusProtocolError)
|
||||
return nil, io.EOF
|
||||
}
|
||||
} else {
|
||||
// The server MUST NOT mask all frames.
|
||||
if frame.(*hybiFrameReader).header.MaskingKey != nil {
|
||||
handler.WriteClose(closeStatusProtocolError)
|
||||
return nil, io.EOF
|
||||
}
|
||||
}
|
||||
if header := frame.HeaderReader(); header != nil {
|
||||
io.Copy(ioutil.Discard, header)
|
||||
}
|
||||
switch frame.PayloadType() {
|
||||
case ContinuationFrame:
|
||||
frame.(*hybiFrameReader).header.OpCode = handler.payloadType
|
||||
case TextFrame, BinaryFrame:
|
||||
handler.payloadType = frame.PayloadType()
|
||||
case CloseFrame:
|
||||
return nil, io.EOF
|
||||
case PingFrame, PongFrame:
|
||||
b := make([]byte, maxControlFramePayloadLength)
|
||||
n, err := io.ReadFull(frame, b)
|
||||
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
|
||||
return nil, err
|
||||
}
|
||||
io.Copy(ioutil.Discard, frame)
|
||||
if frame.PayloadType() == PingFrame {
|
||||
if _, err := handler.WritePong(b[:n]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
func (handler *hybiFrameHandler) WriteClose(status int) (err error) {
|
||||
handler.conn.wio.Lock()
|
||||
defer handler.conn.wio.Unlock()
|
||||
w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msg := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(msg, uint16(status))
|
||||
_, err = w.Write(msg)
|
||||
w.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) {
|
||||
handler.conn.wio.Lock()
|
||||
defer handler.conn.wio.Unlock()
|
||||
w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, err = w.Write(msg)
|
||||
w.Close()
|
||||
return n, err
|
||||
}
|
||||
|
||||
// newHybiConn creates a new WebSocket connection speaking hybi draft protocol.
|
||||
func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
|
||||
if buf == nil {
|
||||
br := bufio.NewReader(rwc)
|
||||
bw := bufio.NewWriter(rwc)
|
||||
buf = bufio.NewReadWriter(br, bw)
|
||||
}
|
||||
ws := &Conn{config: config, request: request, buf: buf, rwc: rwc,
|
||||
frameReaderFactory: hybiFrameReaderFactory{buf.Reader},
|
||||
frameWriterFactory: hybiFrameWriterFactory{
|
||||
buf.Writer, request == nil},
|
||||
PayloadType: TextFrame,
|
||||
defaultCloseStatus: closeStatusNormal}
|
||||
ws.frameHandler = &hybiFrameHandler{conn: ws}
|
||||
return ws
|
||||
}
|
||||
|
||||
// generateMaskingKey generates a masking key for a frame.
|
||||
func generateMaskingKey() (maskingKey []byte, err error) {
|
||||
maskingKey = make([]byte, 4)
|
||||
if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// generateNonce generates a nonce consisting of a randomly selected 16-byte
|
||||
// value that has been base64-encoded.
|
||||
func generateNonce() (nonce []byte) {
|
||||
key := make([]byte, 16)
|
||||
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
nonce = make([]byte, 24)
|
||||
base64.StdEncoding.Encode(nonce, key)
|
||||
return
|
||||
}
|
||||
|
||||
// removeZone removes IPv6 zone identifer from host.
|
||||
// E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080"
|
||||
func removeZone(host string) string {
|
||||
if !strings.HasPrefix(host, "[") {
|
||||
return host
|
||||
}
|
||||
i := strings.LastIndex(host, "]")
|
||||
if i < 0 {
|
||||
return host
|
||||
}
|
||||
j := strings.LastIndex(host[:i], "%")
|
||||
if j < 0 {
|
||||
return host
|
||||
}
|
||||
return host[:j] + host[i:]
|
||||
}
|
||||
|
||||
// getNonceAccept computes the base64-encoded SHA-1 of the concatenation of
|
||||
// the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string.
|
||||
func getNonceAccept(nonce []byte) (expected []byte, err error) {
|
||||
h := sha1.New()
|
||||
if _, err = h.Write(nonce); err != nil {
|
||||
return
|
||||
}
|
||||
if _, err = h.Write([]byte(websocketGUID)); err != nil {
|
||||
return
|
||||
}
|
||||
expected = make([]byte, 28)
|
||||
base64.StdEncoding.Encode(expected, h.Sum(nil))
|
||||
return
|
||||
}
|
||||
|
||||
// Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17
|
||||
func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) {
|
||||
bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n")
|
||||
|
||||
// According to RFC 6874, an HTTP client, proxy, or other
|
||||
// intermediary must remove any IPv6 zone identifier attached
|
||||
// to an outgoing URI.
|
||||
bw.WriteString("Host: " + removeZone(config.Location.Host) + "\r\n")
|
||||
bw.WriteString("Upgrade: websocket\r\n")
|
||||
bw.WriteString("Connection: Upgrade\r\n")
|
||||
nonce := generateNonce()
|
||||
if config.handshakeData != nil {
|
||||
nonce = []byte(config.handshakeData["key"])
|
||||
}
|
||||
bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n")
|
||||
bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n")
|
||||
|
||||
if config.Version != ProtocolVersionHybi13 {
|
||||
return ErrBadProtocolVersion
|
||||
}
|
||||
|
||||
bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n")
|
||||
if len(config.Protocol) > 0 {
|
||||
bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n")
|
||||
}
|
||||
// TODO(ukai): send Sec-WebSocket-Extensions.
|
||||
err = config.Header.WriteSubset(bw, handshakeHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bw.WriteString("\r\n")
|
||||
if err = bw.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(br, &http.Request{Method: "GET"})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode != 101 {
|
||||
return ErrBadStatus
|
||||
}
|
||||
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" ||
|
||||
strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
|
||||
return ErrBadUpgrade
|
||||
}
|
||||
expectedAccept, err := getNonceAccept(nonce)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) {
|
||||
return ErrChallengeResponse
|
||||
}
|
||||
if resp.Header.Get("Sec-WebSocket-Extensions") != "" {
|
||||
return ErrUnsupportedExtensions
|
||||
}
|
||||
offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol")
|
||||
if offeredProtocol != "" {
|
||||
protocolMatched := false
|
||||
for i := 0; i < len(config.Protocol); i++ {
|
||||
if config.Protocol[i] == offeredProtocol {
|
||||
protocolMatched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !protocolMatched {
|
||||
return ErrBadWebSocketProtocol
|
||||
}
|
||||
config.Protocol = []string{offeredProtocol}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// newHybiClientConn creates a client WebSocket connection after handshake.
|
||||
func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn {
|
||||
return newHybiConn(config, buf, rwc, nil)
|
||||
}
|
||||
|
||||
// A HybiServerHandshaker performs a server handshake using hybi draft protocol.
|
||||
type hybiServerHandshaker struct {
|
||||
*Config
|
||||
accept []byte
|
||||
}
|
||||
|
||||
func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) {
|
||||
c.Version = ProtocolVersionHybi13
|
||||
if req.Method != "GET" {
|
||||
return http.StatusMethodNotAllowed, ErrBadRequestMethod
|
||||
}
|
||||
// HTTP version can be safely ignored.
|
||||
|
||||
if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" ||
|
||||
!strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
|
||||
return http.StatusBadRequest, ErrNotWebSocket
|
||||
}
|
||||
|
||||
key := req.Header.Get("Sec-Websocket-Key")
|
||||
if key == "" {
|
||||
return http.StatusBadRequest, ErrChallengeResponse
|
||||
}
|
||||
version := req.Header.Get("Sec-Websocket-Version")
|
||||
switch version {
|
||||
case "13":
|
||||
c.Version = ProtocolVersionHybi13
|
||||
default:
|
||||
return http.StatusBadRequest, ErrBadWebSocketVersion
|
||||
}
|
||||
var scheme string
|
||||
if req.TLS != nil {
|
||||
scheme = "wss"
|
||||
} else {
|
||||
scheme = "ws"
|
||||
}
|
||||
c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI())
|
||||
if err != nil {
|
||||
return http.StatusBadRequest, err
|
||||
}
|
||||
protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
|
||||
if protocol != "" {
|
||||
protocols := strings.Split(protocol, ",")
|
||||
for i := 0; i < len(protocols); i++ {
|
||||
c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i]))
|
||||
}
|
||||
}
|
||||
c.accept, err = getNonceAccept([]byte(key))
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
return http.StatusSwitchingProtocols, nil
|
||||
}
|
||||
|
||||
// Origin parses the Origin header in req.
|
||||
// If the Origin header is not set, it returns nil and nil.
|
||||
func Origin(config *Config, req *http.Request) (*url.URL, error) {
|
||||
var origin string
|
||||
switch config.Version {
|
||||
case ProtocolVersionHybi13:
|
||||
origin = req.Header.Get("Origin")
|
||||
}
|
||||
if origin == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return url.ParseRequestURI(origin)
|
||||
}
|
||||
|
||||
func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
|
||||
if len(c.Protocol) > 0 {
|
||||
if len(c.Protocol) != 1 {
|
||||
// You need choose a Protocol in Handshake func in Server.
|
||||
return ErrBadWebSocketProtocol
|
||||
}
|
||||
}
|
||||
buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n")
|
||||
buf.WriteString("Upgrade: websocket\r\n")
|
||||
buf.WriteString("Connection: Upgrade\r\n")
|
||||
buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n")
|
||||
if len(c.Protocol) > 0 {
|
||||
buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n")
|
||||
}
|
||||
// TODO(ukai): send Sec-WebSocket-Extensions.
|
||||
if c.Header != nil {
|
||||
err := c.Header.WriteSubset(buf, handshakeHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
buf.WriteString("\r\n")
|
||||
return buf.Flush()
|
||||
}
|
||||
|
||||
func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
|
||||
return newHybiServerConn(c.Config, buf, rwc, request)
|
||||
}
|
||||
|
||||
// newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol.
|
||||
func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
|
||||
return newHybiConn(config, buf, rwc, request)
|
||||
}
|
|
@ -0,0 +1,113 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) {
|
||||
var hs serverHandshaker = &hybiServerHandshaker{Config: config}
|
||||
code, err := hs.ReadHandshake(buf.Reader, req)
|
||||
if err == ErrBadWebSocketVersion {
|
||||
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
|
||||
fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion)
|
||||
buf.WriteString("\r\n")
|
||||
buf.WriteString(err.Error())
|
||||
buf.Flush()
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
|
||||
buf.WriteString("\r\n")
|
||||
buf.WriteString(err.Error())
|
||||
buf.Flush()
|
||||
return
|
||||
}
|
||||
if handshake != nil {
|
||||
err = handshake(config, req)
|
||||
if err != nil {
|
||||
code = http.StatusForbidden
|
||||
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
|
||||
buf.WriteString("\r\n")
|
||||
buf.Flush()
|
||||
return
|
||||
}
|
||||
}
|
||||
err = hs.AcceptHandshake(buf.Writer)
|
||||
if err != nil {
|
||||
code = http.StatusBadRequest
|
||||
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
|
||||
buf.WriteString("\r\n")
|
||||
buf.Flush()
|
||||
return
|
||||
}
|
||||
conn = hs.NewServerConn(buf, rwc, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Server represents a server of a WebSocket.
|
||||
type Server struct {
|
||||
// Config is a WebSocket configuration for new WebSocket connection.
|
||||
Config
|
||||
|
||||
// Handshake is an optional function in WebSocket handshake.
|
||||
// For example, you can check, or don't check Origin header.
|
||||
// Another example, you can select config.Protocol.
|
||||
Handshake func(*Config, *http.Request) error
|
||||
|
||||
// Handler handles a WebSocket connection.
|
||||
Handler
|
||||
}
|
||||
|
||||
// ServeHTTP implements the http.Handler interface for a WebSocket
|
||||
func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
s.serveWebSocket(w, req)
|
||||
}
|
||||
|
||||
func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) {
|
||||
rwc, buf, err := w.(http.Hijacker).Hijack()
|
||||
if err != nil {
|
||||
panic("Hijack failed: " + err.Error())
|
||||
}
|
||||
// The server should abort the WebSocket connection if it finds
|
||||
// the client did not send a handshake that matches with protocol
|
||||
// specification.
|
||||
defer rwc.Close()
|
||||
conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if conn == nil {
|
||||
panic("unexpected nil conn")
|
||||
}
|
||||
s.Handler(conn)
|
||||
}
|
||||
|
||||
// Handler is a simple interface to a WebSocket browser client.
|
||||
// It checks if Origin header is valid URL by default.
|
||||
// You might want to verify websocket.Conn.Config().Origin in the func.
|
||||
// If you use Server instead of Handler, you could call websocket.Origin and
|
||||
// check the origin in your Handshake func. So, if you want to accept
|
||||
// non-browser clients, which do not send an Origin header, set a
|
||||
// Server.Handshake that does not check the origin.
|
||||
type Handler func(*Conn)
|
||||
|
||||
func checkOrigin(config *Config, req *http.Request) (err error) {
|
||||
config.Origin, err = Origin(config, req)
|
||||
if err == nil && config.Origin == nil {
|
||||
return fmt.Errorf("null origin")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// ServeHTTP implements the http.Handler interface for a WebSocket
|
||||
func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
s := Server{Handler: h, Handshake: checkOrigin}
|
||||
s.serveWebSocket(w, req)
|
||||
}
|
|
@ -0,0 +1,448 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package websocket implements a client and server for the WebSocket protocol
|
||||
// as specified in RFC 6455.
|
||||
//
|
||||
// This package currently lacks some features found in an alternative
|
||||
// and more actively maintained WebSocket package:
|
||||
//
|
||||
// https://godoc.org/github.com/gorilla/websocket
|
||||
//
|
||||
package websocket // import "golang.org/x/net/websocket"
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
ProtocolVersionHybi13 = 13
|
||||
ProtocolVersionHybi = ProtocolVersionHybi13
|
||||
SupportedProtocolVersion = "13"
|
||||
|
||||
ContinuationFrame = 0
|
||||
TextFrame = 1
|
||||
BinaryFrame = 2
|
||||
CloseFrame = 8
|
||||
PingFrame = 9
|
||||
PongFrame = 10
|
||||
UnknownFrame = 255
|
||||
|
||||
DefaultMaxPayloadBytes = 32 << 20 // 32MB
|
||||
)
|
||||
|
||||
// ProtocolError represents WebSocket protocol errors.
|
||||
type ProtocolError struct {
|
||||
ErrorString string
|
||||
}
|
||||
|
||||
func (err *ProtocolError) Error() string { return err.ErrorString }
|
||||
|
||||
var (
|
||||
ErrBadProtocolVersion = &ProtocolError{"bad protocol version"}
|
||||
ErrBadScheme = &ProtocolError{"bad scheme"}
|
||||
ErrBadStatus = &ProtocolError{"bad status"}
|
||||
ErrBadUpgrade = &ProtocolError{"missing or bad upgrade"}
|
||||
ErrBadWebSocketOrigin = &ProtocolError{"missing or bad WebSocket-Origin"}
|
||||
ErrBadWebSocketLocation = &ProtocolError{"missing or bad WebSocket-Location"}
|
||||
ErrBadWebSocketProtocol = &ProtocolError{"missing or bad WebSocket-Protocol"}
|
||||
ErrBadWebSocketVersion = &ProtocolError{"missing or bad WebSocket Version"}
|
||||
ErrChallengeResponse = &ProtocolError{"mismatch challenge/response"}
|
||||
ErrBadFrame = &ProtocolError{"bad frame"}
|
||||
ErrBadFrameBoundary = &ProtocolError{"not on frame boundary"}
|
||||
ErrNotWebSocket = &ProtocolError{"not websocket protocol"}
|
||||
ErrBadRequestMethod = &ProtocolError{"bad method"}
|
||||
ErrNotSupported = &ProtocolError{"not supported"}
|
||||
)
|
||||
|
||||
// ErrFrameTooLarge is returned by Codec's Receive method if payload size
|
||||
// exceeds limit set by Conn.MaxPayloadBytes
|
||||
var ErrFrameTooLarge = errors.New("websocket: frame payload size exceeds limit")
|
||||
|
||||
// Addr is an implementation of net.Addr for WebSocket.
|
||||
type Addr struct {
|
||||
*url.URL
|
||||
}
|
||||
|
||||
// Network returns the network type for a WebSocket, "websocket".
|
||||
func (addr *Addr) Network() string { return "websocket" }
|
||||
|
||||
// Config is a WebSocket configuration
|
||||
type Config struct {
|
||||
// A WebSocket server address.
|
||||
Location *url.URL
|
||||
|
||||
// A Websocket client origin.
|
||||
Origin *url.URL
|
||||
|
||||
// WebSocket subprotocols.
|
||||
Protocol []string
|
||||
|
||||
// WebSocket protocol version.
|
||||
Version int
|
||||
|
||||
// TLS config for secure WebSocket (wss).
|
||||
TlsConfig *tls.Config
|
||||
|
||||
// Additional header fields to be sent in WebSocket opening handshake.
|
||||
Header http.Header
|
||||
|
||||
// Dialer used when opening websocket connections.
|
||||
Dialer *net.Dialer
|
||||
|
||||
handshakeData map[string]string
|
||||
}
|
||||
|
||||
// serverHandshaker is an interface to handle WebSocket server side handshake.
|
||||
type serverHandshaker interface {
|
||||
// ReadHandshake reads handshake request message from client.
|
||||
// Returns http response code and error if any.
|
||||
ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error)
|
||||
|
||||
// AcceptHandshake accepts the client handshake request and sends
|
||||
// handshake response back to client.
|
||||
AcceptHandshake(buf *bufio.Writer) (err error)
|
||||
|
||||
// NewServerConn creates a new WebSocket connection.
|
||||
NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) (conn *Conn)
|
||||
}
|
||||
|
||||
// frameReader is an interface to read a WebSocket frame.
|
||||
type frameReader interface {
|
||||
// Reader is to read payload of the frame.
|
||||
io.Reader
|
||||
|
||||
// PayloadType returns payload type.
|
||||
PayloadType() byte
|
||||
|
||||
// HeaderReader returns a reader to read header of the frame.
|
||||
HeaderReader() io.Reader
|
||||
|
||||
// TrailerReader returns a reader to read trailer of the frame.
|
||||
// If it returns nil, there is no trailer in the frame.
|
||||
TrailerReader() io.Reader
|
||||
|
||||
// Len returns total length of the frame, including header and trailer.
|
||||
Len() int
|
||||
}
|
||||
|
||||
// frameReaderFactory is an interface to creates new frame reader.
|
||||
type frameReaderFactory interface {
|
||||
NewFrameReader() (r frameReader, err error)
|
||||
}
|
||||
|
||||
// frameWriter is an interface to write a WebSocket frame.
|
||||
type frameWriter interface {
|
||||
// Writer is to write payload of the frame.
|
||||
io.WriteCloser
|
||||
}
|
||||
|
||||
// frameWriterFactory is an interface to create new frame writer.
|
||||
type frameWriterFactory interface {
|
||||
NewFrameWriter(payloadType byte) (w frameWriter, err error)
|
||||
}
|
||||
|
||||
type frameHandler interface {
|
||||
HandleFrame(frame frameReader) (r frameReader, err error)
|
||||
WriteClose(status int) (err error)
|
||||
}
|
||||
|
||||
// Conn represents a WebSocket connection.
|
||||
//
|
||||
// Multiple goroutines may invoke methods on a Conn simultaneously.
|
||||
type Conn struct {
|
||||
config *Config
|
||||
request *http.Request
|
||||
|
||||
buf *bufio.ReadWriter
|
||||
rwc io.ReadWriteCloser
|
||||
|
||||
rio sync.Mutex
|
||||
frameReaderFactory
|
||||
frameReader
|
||||
|
||||
wio sync.Mutex
|
||||
frameWriterFactory
|
||||
|
||||
frameHandler
|
||||
PayloadType byte
|
||||
defaultCloseStatus int
|
||||
|
||||
// MaxPayloadBytes limits the size of frame payload received over Conn
|
||||
// by Codec's Receive method. If zero, DefaultMaxPayloadBytes is used.
|
||||
MaxPayloadBytes int
|
||||
}
|
||||
|
||||
// Read implements the io.Reader interface:
|
||||
// it reads data of a frame from the WebSocket connection.
|
||||
// if msg is not large enough for the frame data, it fills the msg and next Read
|
||||
// will read the rest of the frame data.
|
||||
// it reads Text frame or Binary frame.
|
||||
func (ws *Conn) Read(msg []byte) (n int, err error) {
|
||||
ws.rio.Lock()
|
||||
defer ws.rio.Unlock()
|
||||
again:
|
||||
if ws.frameReader == nil {
|
||||
frame, err := ws.frameReaderFactory.NewFrameReader()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
ws.frameReader, err = ws.frameHandler.HandleFrame(frame)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if ws.frameReader == nil {
|
||||
goto again
|
||||
}
|
||||
}
|
||||
n, err = ws.frameReader.Read(msg)
|
||||
if err == io.EOF {
|
||||
if trailer := ws.frameReader.TrailerReader(); trailer != nil {
|
||||
io.Copy(ioutil.Discard, trailer)
|
||||
}
|
||||
ws.frameReader = nil
|
||||
goto again
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Write implements the io.Writer interface:
|
||||
// it writes data as a frame to the WebSocket connection.
|
||||
func (ws *Conn) Write(msg []byte) (n int, err error) {
|
||||
ws.wio.Lock()
|
||||
defer ws.wio.Unlock()
|
||||
w, err := ws.frameWriterFactory.NewFrameWriter(ws.PayloadType)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, err = w.Write(msg)
|
||||
w.Close()
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Close implements the io.Closer interface.
|
||||
func (ws *Conn) Close() error {
|
||||
err := ws.frameHandler.WriteClose(ws.defaultCloseStatus)
|
||||
err1 := ws.rwc.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err1
|
||||
}
|
||||
|
||||
func (ws *Conn) IsClientConn() bool { return ws.request == nil }
|
||||
func (ws *Conn) IsServerConn() bool { return ws.request != nil }
|
||||
|
||||
// LocalAddr returns the WebSocket Origin for the connection for client, or
|
||||
// the WebSocket location for server.
|
||||
func (ws *Conn) LocalAddr() net.Addr {
|
||||
if ws.IsClientConn() {
|
||||
return &Addr{ws.config.Origin}
|
||||
}
|
||||
return &Addr{ws.config.Location}
|
||||
}
|
||||
|
||||
// RemoteAddr returns the WebSocket location for the connection for client, or
|
||||
// the Websocket Origin for server.
|
||||
func (ws *Conn) RemoteAddr() net.Addr {
|
||||
if ws.IsClientConn() {
|
||||
return &Addr{ws.config.Location}
|
||||
}
|
||||
return &Addr{ws.config.Origin}
|
||||
}
|
||||
|
||||
var errSetDeadline = errors.New("websocket: cannot set deadline: not using a net.Conn")
|
||||
|
||||
// SetDeadline sets the connection's network read & write deadlines.
|
||||
func (ws *Conn) SetDeadline(t time.Time) error {
|
||||
if conn, ok := ws.rwc.(net.Conn); ok {
|
||||
return conn.SetDeadline(t)
|
||||
}
|
||||
return errSetDeadline
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the connection's network read deadline.
|
||||
func (ws *Conn) SetReadDeadline(t time.Time) error {
|
||||
if conn, ok := ws.rwc.(net.Conn); ok {
|
||||
return conn.SetReadDeadline(t)
|
||||
}
|
||||
return errSetDeadline
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the connection's network write deadline.
|
||||
func (ws *Conn) SetWriteDeadline(t time.Time) error {
|
||||
if conn, ok := ws.rwc.(net.Conn); ok {
|
||||
return conn.SetWriteDeadline(t)
|
||||
}
|
||||
return errSetDeadline
|
||||
}
|
||||
|
||||
// Config returns the WebSocket config.
|
||||
func (ws *Conn) Config() *Config { return ws.config }
|
||||
|
||||
// Request returns the http request upgraded to the WebSocket.
|
||||
// It is nil for client side.
|
||||
func (ws *Conn) Request() *http.Request { return ws.request }
|
||||
|
||||
// Codec represents a symmetric pair of functions that implement a codec.
|
||||
type Codec struct {
|
||||
Marshal func(v interface{}) (data []byte, payloadType byte, err error)
|
||||
Unmarshal func(data []byte, payloadType byte, v interface{}) (err error)
|
||||
}
|
||||
|
||||
// Send sends v marshaled by cd.Marshal as single frame to ws.
|
||||
func (cd Codec) Send(ws *Conn, v interface{}) (err error) {
|
||||
data, payloadType, err := cd.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ws.wio.Lock()
|
||||
defer ws.wio.Unlock()
|
||||
w, err := ws.frameWriterFactory.NewFrameWriter(payloadType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.Write(data)
|
||||
w.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores
|
||||
// in v. The whole frame payload is read to an in-memory buffer; max size of
|
||||
// payload is defined by ws.MaxPayloadBytes. If frame payload size exceeds
|
||||
// limit, ErrFrameTooLarge is returned; in this case frame is not read off wire
|
||||
// completely. The next call to Receive would read and discard leftover data of
|
||||
// previous oversized frame before processing next frame.
|
||||
func (cd Codec) Receive(ws *Conn, v interface{}) (err error) {
|
||||
ws.rio.Lock()
|
||||
defer ws.rio.Unlock()
|
||||
if ws.frameReader != nil {
|
||||
_, err = io.Copy(ioutil.Discard, ws.frameReader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ws.frameReader = nil
|
||||
}
|
||||
again:
|
||||
frame, err := ws.frameReaderFactory.NewFrameReader()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
frame, err = ws.frameHandler.HandleFrame(frame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if frame == nil {
|
||||
goto again
|
||||
}
|
||||
maxPayloadBytes := ws.MaxPayloadBytes
|
||||
if maxPayloadBytes == 0 {
|
||||
maxPayloadBytes = DefaultMaxPayloadBytes
|
||||
}
|
||||
if hf, ok := frame.(*hybiFrameReader); ok && hf.header.Length > int64(maxPayloadBytes) {
|
||||
// payload size exceeds limit, no need to call Unmarshal
|
||||
//
|
||||
// set frameReader to current oversized frame so that
|
||||
// the next call to this function can drain leftover
|
||||
// data before processing the next frame
|
||||
ws.frameReader = frame
|
||||
return ErrFrameTooLarge
|
||||
}
|
||||
payloadType := frame.PayloadType()
|
||||
data, err := ioutil.ReadAll(frame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return cd.Unmarshal(data, payloadType, v)
|
||||
}
|
||||
|
||||
func marshal(v interface{}) (msg []byte, payloadType byte, err error) {
|
||||
switch data := v.(type) {
|
||||
case string:
|
||||
return []byte(data), TextFrame, nil
|
||||
case []byte:
|
||||
return data, BinaryFrame, nil
|
||||
}
|
||||
return nil, UnknownFrame, ErrNotSupported
|
||||
}
|
||||
|
||||
func unmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
|
||||
switch data := v.(type) {
|
||||
case *string:
|
||||
*data = string(msg)
|
||||
return nil
|
||||
case *[]byte:
|
||||
*data = msg
|
||||
return nil
|
||||
}
|
||||
return ErrNotSupported
|
||||
}
|
||||
|
||||
/*
|
||||
Message is a codec to send/receive text/binary data in a frame on WebSocket connection.
|
||||
To send/receive text frame, use string type.
|
||||
To send/receive binary frame, use []byte type.
|
||||
|
||||
Trivial usage:
|
||||
|
||||
import "websocket"
|
||||
|
||||
// receive text frame
|
||||
var message string
|
||||
websocket.Message.Receive(ws, &message)
|
||||
|
||||
// send text frame
|
||||
message = "hello"
|
||||
websocket.Message.Send(ws, message)
|
||||
|
||||
// receive binary frame
|
||||
var data []byte
|
||||
websocket.Message.Receive(ws, &data)
|
||||
|
||||
// send binary frame
|
||||
data = []byte{0, 1, 2}
|
||||
websocket.Message.Send(ws, data)
|
||||
|
||||
*/
|
||||
var Message = Codec{marshal, unmarshal}
|
||||
|
||||
func jsonMarshal(v interface{}) (msg []byte, payloadType byte, err error) {
|
||||
msg, err = json.Marshal(v)
|
||||
return msg, TextFrame, err
|
||||
}
|
||||
|
||||
func jsonUnmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
|
||||
return json.Unmarshal(msg, v)
|
||||
}
|
||||
|
||||
/*
|
||||
JSON is a codec to send/receive JSON data in a frame from a WebSocket connection.
|
||||
|
||||
Trivial usage:
|
||||
|
||||
import "websocket"
|
||||
|
||||
type T struct {
|
||||
Msg string
|
||||
Count int
|
||||
}
|
||||
|
||||
// receive JSON type T
|
||||
var data T
|
||||
websocket.JSON.Receive(ws, &data)
|
||||
|
||||
// send JSON type T
|
||||
websocket.JSON.Send(ws, data)
|
||||
*/
|
||||
var JSON = Codec{jsonMarshal, jsonUnmarshal}
|
|
@ -0,0 +1,131 @@
|
|||
// Copyright 2017 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package semaphore provides a weighted semaphore implementation.
|
||||
package semaphore // import "golang.org/x/sync/semaphore"
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
|
||||
// Use the old context because packages that depend on this one
|
||||
// (e.g. cloud.google.com/go/...) must run on Go 1.6.
|
||||
// TODO(jba): update to "context" when possible.
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
type waiter struct {
|
||||
n int64
|
||||
ready chan<- struct{} // Closed when semaphore acquired.
|
||||
}
|
||||
|
||||
// NewWeighted creates a new weighted semaphore with the given
|
||||
// maximum combined weight for concurrent access.
|
||||
func NewWeighted(n int64) *Weighted {
|
||||
w := &Weighted{size: n}
|
||||
return w
|
||||
}
|
||||
|
||||
// Weighted provides a way to bound concurrent access to a resource.
|
||||
// The callers can request access with a given weight.
|
||||
type Weighted struct {
|
||||
size int64
|
||||
cur int64
|
||||
mu sync.Mutex
|
||||
waiters list.List
|
||||
}
|
||||
|
||||
// Acquire acquires the semaphore with a weight of n, blocking only until ctx
|
||||
// is done. On success, returns nil. On failure, returns ctx.Err() and leaves
|
||||
// the semaphore unchanged.
|
||||
//
|
||||
// If ctx is already done, Acquire may still succeed without blocking.
|
||||
func (s *Weighted) Acquire(ctx context.Context, n int64) error {
|
||||
s.mu.Lock()
|
||||
if s.size-s.cur >= n && s.waiters.Len() == 0 {
|
||||
s.cur += n
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
if n > s.size {
|
||||
// Don't make other Acquire calls block on one that's doomed to fail.
|
||||
s.mu.Unlock()
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
ready := make(chan struct{})
|
||||
w := waiter{n: n, ready: ready}
|
||||
elem := s.waiters.PushBack(w)
|
||||
s.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err := ctx.Err()
|
||||
s.mu.Lock()
|
||||
select {
|
||||
case <-ready:
|
||||
// Acquired the semaphore after we were canceled. Rather than trying to
|
||||
// fix up the queue, just pretend we didn't notice the cancelation.
|
||||
err = nil
|
||||
default:
|
||||
s.waiters.Remove(elem)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return err
|
||||
|
||||
case <-ready:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// TryAcquire acquires the semaphore with a weight of n without blocking.
|
||||
// On success, returns true. On failure, returns false and leaves the semaphore unchanged.
|
||||
func (s *Weighted) TryAcquire(n int64) bool {
|
||||
s.mu.Lock()
|
||||
success := s.size-s.cur >= n && s.waiters.Len() == 0
|
||||
if success {
|
||||
s.cur += n
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return success
|
||||
}
|
||||
|
||||
// Release releases the semaphore with a weight of n.
|
||||
func (s *Weighted) Release(n int64) {
|
||||
s.mu.Lock()
|
||||
s.cur -= n
|
||||
if s.cur < 0 {
|
||||
s.mu.Unlock()
|
||||
panic("semaphore: bad release")
|
||||
}
|
||||
for {
|
||||
next := s.waiters.Front()
|
||||
if next == nil {
|
||||
break // No more waiters blocked.
|
||||
}
|
||||
|
||||
w := next.Value.(waiter)
|
||||
if s.size-s.cur < w.n {
|
||||
// Not enough tokens for the next waiter. We could keep going (to try to
|
||||
// find a waiter with a smaller request), but under load that could cause
|
||||
// starvation for large requests; instead, we leave all remaining waiters
|
||||
// blocked.
|
||||
//
|
||||
// Consider a semaphore used as a read-write lock, with N tokens, N
|
||||
// readers, and one writer. Each reader can Acquire(1) to obtain a read
|
||||
// lock. The writer can Acquire(N) to obtain a write lock, excluding all
|
||||
// of the readers. If we allow the readers to jump ahead in the queue,
|
||||
// the writer will starve — there is always one token available for every
|
||||
// reader.
|
||||
break
|
||||
}
|
||||
|
||||
s.cur += w.n
|
||||
s.waiters.Remove(next)
|
||||
close(w.ready)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
|
@ -0,0 +1,258 @@
|
|||
// Copyright 2016 Google Inc. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package bundler supports bundling (batching) of items. Bundling amortizes an
|
||||
// action with fixed costs over multiple items. For example, if an API provides
|
||||
// an RPC that accepts a list of items as input, but clients would prefer
|
||||
// adding items one at a time, then a Bundler can accept individual items from
|
||||
// the client and bundle many of them into a single RPC.
|
||||
//
|
||||
// This package is experimental and subject to change without notice.
|
||||
package bundler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/sync/semaphore"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultDelayThreshold = time.Second
|
||||
DefaultBundleCountThreshold = 10
|
||||
DefaultBundleByteThreshold = 1e6 // 1M
|
||||
DefaultBufferedByteLimit = 1e9 // 1G
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrOverflow indicates that Bundler's stored bytes exceeds its BufferedByteLimit.
|
||||
ErrOverflow = errors.New("bundler reached buffered byte limit")
|
||||
|
||||
// ErrOversizedItem indicates that an item's size exceeds the maximum bundle size.
|
||||
ErrOversizedItem = errors.New("item size exceeds bundle byte limit")
|
||||
)
|
||||
|
||||
// A Bundler collects items added to it into a bundle until the bundle
|
||||
// exceeds a given size, then calls a user-provided function to handle the bundle.
|
||||
type Bundler struct {
|
||||
// Starting from the time that the first message is added to a bundle, once
|
||||
// this delay has passed, handle the bundle. The default is DefaultDelayThreshold.
|
||||
DelayThreshold time.Duration
|
||||
|
||||
// Once a bundle has this many items, handle the bundle. Since only one
|
||||
// item at a time is added to a bundle, no bundle will exceed this
|
||||
// threshold, so it also serves as a limit. The default is
|
||||
// DefaultBundleCountThreshold.
|
||||
BundleCountThreshold int
|
||||
|
||||
// Once the number of bytes in current bundle reaches this threshold, handle
|
||||
// the bundle. The default is DefaultBundleByteThreshold. This triggers handling,
|
||||
// but does not cap the total size of a bundle.
|
||||
BundleByteThreshold int
|
||||
|
||||
// The maximum size of a bundle, in bytes. Zero means unlimited.
|
||||
BundleByteLimit int
|
||||
|
||||
// The maximum number of bytes that the Bundler will keep in memory before
|
||||
// returning ErrOverflow. The default is DefaultBufferedByteLimit.
|
||||
BufferedByteLimit int
|
||||
|
||||
handler func(interface{}) // called to handle a bundle
|
||||
itemSliceZero reflect.Value // nil (zero value) for slice of items
|
||||
flushTimer *time.Timer // implements DelayThreshold
|
||||
|
||||
mu sync.Mutex
|
||||
sem *semaphore.Weighted // enforces BufferedByteLimit
|
||||
semOnce sync.Once
|
||||
curBundle bundle // incoming items added to this bundle
|
||||
handlingc <-chan struct{} // set to non-nil while a handler is running; closed when it returns
|
||||
}
|
||||
|
||||
type bundle struct {
|
||||
items reflect.Value // slice of item type
|
||||
size int // size in bytes of all items
|
||||
}
|
||||
|
||||
// NewBundler creates a new Bundler.
|
||||
//
|
||||
// itemExample is a value of the type that will be bundled. For example, if you
|
||||
// want to create bundles of *Entry, you could pass &Entry{} for itemExample.
|
||||
//
|
||||
// handler is a function that will be called on each bundle. If itemExample is
|
||||
// of type T, the argument to handler is of type []T. handler is always called
|
||||
// sequentially for each bundle, and never in parallel.
|
||||
//
|
||||
// Configure the Bundler by setting its thresholds and limits before calling
|
||||
// any of its methods.
|
||||
func NewBundler(itemExample interface{}, handler func(interface{})) *Bundler {
|
||||
b := &Bundler{
|
||||
DelayThreshold: DefaultDelayThreshold,
|
||||
BundleCountThreshold: DefaultBundleCountThreshold,
|
||||
BundleByteThreshold: DefaultBundleByteThreshold,
|
||||
BufferedByteLimit: DefaultBufferedByteLimit,
|
||||
|
||||
handler: handler,
|
||||
itemSliceZero: reflect.Zero(reflect.SliceOf(reflect.TypeOf(itemExample))),
|
||||
}
|
||||
b.curBundle.items = b.itemSliceZero
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Bundler) sema() *semaphore.Weighted {
|
||||
// Create the semaphore lazily, because the user may set BufferedByteLimit
|
||||
// after NewBundler.
|
||||
b.semOnce.Do(func() {
|
||||
b.sem = semaphore.NewWeighted(int64(b.BufferedByteLimit))
|
||||
})
|
||||
return b.sem
|
||||
}
|
||||
|
||||
// Add adds item to the current bundle. It marks the bundle for handling and
|
||||
// starts a new one if any of the thresholds or limits are exceeded.
|
||||
//
|
||||
// If the item's size exceeds the maximum bundle size (Bundler.BundleByteLimit), then
|
||||
// the item can never be handled. Add returns ErrOversizedItem in this case.
|
||||
//
|
||||
// If adding the item would exceed the maximum memory allowed
|
||||
// (Bundler.BufferedByteLimit) or an AddWait call is blocked waiting for
|
||||
// memory, Add returns ErrOverflow.
|
||||
//
|
||||
// Add never blocks.
|
||||
func (b *Bundler) Add(item interface{}, size int) error {
|
||||
// If this item exceeds the maximum size of a bundle,
|
||||
// we can never send it.
|
||||
if b.BundleByteLimit > 0 && size > b.BundleByteLimit {
|
||||
return ErrOversizedItem
|
||||
}
|
||||
// If adding this item would exceed our allotted memory
|
||||
// footprint, we can't accept it.
|
||||
// (TryAcquire also returns false if anything is waiting on the semaphore,
|
||||
// so calls to Add and AddWait shouldn't be mixed.)
|
||||
if !b.sema().TryAcquire(int64(size)) {
|
||||
return ErrOverflow
|
||||
}
|
||||
b.add(item, size)
|
||||
return nil
|
||||
}
|
||||
|
||||
// add adds item to the current bundle. It marks the bundle for handling and
|
||||
// starts a new one if any of the thresholds or limits are exceeded.
|
||||
func (b *Bundler) add(item interface{}, size int) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
// If adding this item to the current bundle would cause it to exceed the
|
||||
// maximum bundle size, close the current bundle and start a new one.
|
||||
if b.BundleByteLimit > 0 && b.curBundle.size+size > b.BundleByteLimit {
|
||||
b.startFlushLocked()
|
||||
}
|
||||
// Add the item.
|
||||
b.curBundle.items = reflect.Append(b.curBundle.items, reflect.ValueOf(item))
|
||||
b.curBundle.size += size
|
||||
|
||||
// Start a timer to flush the item if one isn't already running.
|
||||
// startFlushLocked clears the timer and closes the bundle at the same time,
|
||||
// so we only allocate a new timer for the first item in each bundle.
|
||||
// (We could try to call Reset on the timer instead, but that would add a lot
|
||||
// of complexity to the code just to save one small allocation.)
|
||||
if b.flushTimer == nil {
|
||||
b.flushTimer = time.AfterFunc(b.DelayThreshold, b.Flush)
|
||||
}
|
||||
|
||||
// If the current bundle equals the count threshold, close it.
|
||||
if b.curBundle.items.Len() == b.BundleCountThreshold {
|
||||
b.startFlushLocked()
|
||||
}
|
||||
// If the current bundle equals or exceeds the byte threshold, close it.
|
||||
if b.curBundle.size >= b.BundleByteThreshold {
|
||||
b.startFlushLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// AddWait adds item to the current bundle. It marks the bundle for handling and
|
||||
// starts a new one if any of the thresholds or limits are exceeded.
|
||||
//
|
||||
// If the item's size exceeds the maximum bundle size (Bundler.BundleByteLimit), then
|
||||
// the item can never be handled. AddWait returns ErrOversizedItem in this case.
|
||||
//
|
||||
// If adding the item would exceed the maximum memory allowed (Bundler.BufferedByteLimit),
|
||||
// AddWait blocks until space is available or ctx is done.
|
||||
//
|
||||
// Calls to Add and AddWait should not be mixed on the same Bundler.
|
||||
func (b *Bundler) AddWait(ctx context.Context, item interface{}, size int) error {
|
||||
// If this item exceeds the maximum size of a bundle,
|
||||
// we can never send it.
|
||||
if b.BundleByteLimit > 0 && size > b.BundleByteLimit {
|
||||
return ErrOversizedItem
|
||||
}
|
||||
// If adding this item would exceed our allotted memory footprint, block
|
||||
// until space is available. The semaphore is FIFO, so there will be no
|
||||
// starvation.
|
||||
if err := b.sema().Acquire(ctx, int64(size)); err != nil {
|
||||
return err
|
||||
}
|
||||
// Here, we've reserved space for item. Other goroutines can call AddWait
|
||||
// and even acquire space, but no one can take away our reservation
|
||||
// (assuming sem.Release is used correctly). So there is no race condition
|
||||
// resulting from locking the mutex after sem.Acquire returns.
|
||||
b.add(item, size)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush invokes the handler for all remaining items in the Bundler and waits
|
||||
// for it to return.
|
||||
func (b *Bundler) Flush() {
|
||||
b.mu.Lock()
|
||||
b.startFlushLocked()
|
||||
done := b.handlingc
|
||||
b.mu.Unlock()
|
||||
|
||||
if done != nil {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Bundler) startFlushLocked() {
|
||||
if b.flushTimer != nil {
|
||||
b.flushTimer.Stop()
|
||||
b.flushTimer = nil
|
||||
}
|
||||
|
||||
if b.curBundle.items.Len() == 0 {
|
||||
return
|
||||
}
|
||||
bun := b.curBundle
|
||||
b.curBundle = bundle{items: b.itemSliceZero}
|
||||
|
||||
done := make(chan struct{})
|
||||
var running <-chan struct{}
|
||||
running, b.handlingc = b.handlingc, done
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
b.sem.Release(int64(bun.size))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
if running != nil {
|
||||
// Wait for our turn to call the handler.
|
||||
<-running
|
||||
}
|
||||
|
||||
b.handler(bun.items.Interface())
|
||||
}()
|
||||
}
|
Loading…
Reference in New Issue