add logworker
This commit is contained in:
parent
b7e4d5c99f
commit
2d46bec18b
17
box.rb
17
box.rb
|
@ -9,25 +9,32 @@ def foldercopy dir
|
||||||
end
|
end
|
||||||
|
|
||||||
def gobuild pkg
|
def gobuild pkg
|
||||||
run "mkdir -p /root/go/bin && cd /root/go/bin && go#{$gover} build #{$repo}/#{pkg} && go#{$gover} install #{$repo}/#{pkg}"
|
run "mkdir -p /root/go/bin && cd /root/go/bin && go#{$gover} build -v #{$repo}/#{pkg}"
|
||||||
end
|
end
|
||||||
|
|
||||||
[
|
[
|
||||||
"bot",
|
|
||||||
"cmd",
|
|
||||||
"internal",
|
|
||||||
"vendor",
|
"vendor",
|
||||||
"vendor-log",
|
"vendor-log",
|
||||||
].each { |x| foldercopy x }
|
].each { |x| foldercopy x }
|
||||||
|
|
||||||
|
[
|
||||||
|
"bot",
|
||||||
|
"cmd",
|
||||||
|
"internal",
|
||||||
|
].each { |x| foldercopy x }
|
||||||
|
|
||||||
[
|
[
|
||||||
"cmd/vyvanse",
|
"cmd/vyvanse",
|
||||||
|
"cmd/logworker",
|
||||||
].each { |x| gobuild x }
|
].each { |x| gobuild x }
|
||||||
|
|
||||||
cmd "/root/go/bin/vyvanse"
|
cmd "/root/go/bin/vyvanse"
|
||||||
|
|
||||||
run "rm -rf $HOME/sdk /root/go/pkg ||:"
|
run "rm -rf $HOME/sdk /root/go/pkg ||:"
|
||||||
run "apk del go#{$gover}"
|
run "apk del go#{$gover}"
|
||||||
|
|
||||||
|
tag "xena/vyvanse:thick"
|
||||||
|
|
||||||
flatten
|
flatten
|
||||||
|
|
||||||
tag "xena/vyvanse"
|
tag "xena/vyvanse:latest"
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"git.xeserv.us/xena/gorqlite"
|
||||||
|
"git.xeserv.us/xena/vyvanse/internal/dao"
|
||||||
|
|
||||||
|
"github.com/Xe/ln"
|
||||||
|
"github.com/bwmarrin/discordgo"
|
||||||
|
"github.com/drone/mq/stomp"
|
||||||
|
"github.com/namsral/flag"
|
||||||
|
opentracing "github.com/opentracing/opentracing-go"
|
||||||
|
zipkin "github.com/openzipkin/zipkin-go-opentracing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
token = flag.String("token", "", "discord bot token")
|
||||||
|
zipkinURL = flag.String("zipkin-url", "", "URL for Zipkin traces")
|
||||||
|
databaseURL = flag.String("database-url", "http://", "URL for database (rqlite)")
|
||||||
|
mqURL = flag.String("mq-url", "tcp://mq:9000", "URL for STOMP server")
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if *zipkinURL != "" {
|
||||||
|
collector, err := zipkin.NewHTTPCollector(*zipkinURL)
|
||||||
|
if err != nil {
|
||||||
|
ln.FatalErr(context.Background(), err)
|
||||||
|
}
|
||||||
|
tracer, err := zipkin.NewTracer(
|
||||||
|
zipkin.NewRecorder(collector, false, "logworker:5000", "logworker"))
|
||||||
|
if err != nil {
|
||||||
|
ln.FatalErr(context.Background(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
opentracing.SetGlobalTracer(tracer)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
db, err := gorqlite.Open(*databaseURL)
|
||||||
|
if err != nil {
|
||||||
|
ln.FatalErr(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ls := dao.NewLogs(db)
|
||||||
|
err = ls.Migrate(ctx)
|
||||||
|
if err != nil {
|
||||||
|
ln.FatalErr(ctx, err, ln.F{"action": "migrate logs table"})
|
||||||
|
}
|
||||||
|
|
||||||
|
mq, err := stomp.Dial(*mqURL)
|
||||||
|
if err != nil {
|
||||||
|
ln.FatalErr(ctx, err, ln.F{"url": *mqURL})
|
||||||
|
}
|
||||||
|
|
||||||
|
mq.Subscribe("/topic/message_create", stomp.HandlerFunc(func(m *stomp.Message) {
|
||||||
|
sp, ctx := opentracing.StartSpanFromContext(m.Context(), "logworker.topic.message.create")
|
||||||
|
defer sp.Finish()
|
||||||
|
|
||||||
|
msg := &discordgo.Message{}
|
||||||
|
err := json.Unmarshal(m.Msg, msg)
|
||||||
|
if err != nil {
|
||||||
|
ln.Error(ctx, err, ln.F{"action": "can't unmarshal message body to a discordgo message"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
f := ln.F{
|
||||||
|
"stomp_id": string(m.ID),
|
||||||
|
"channel_id": msg.ChannelID,
|
||||||
|
"message_id": msg.ID,
|
||||||
|
"message_author": msg.Author.ID,
|
||||||
|
"message_author_name": msg.Author.Username,
|
||||||
|
"message_author_is_bot": msg.Author.Bot,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ls.Add(ctx, msg)
|
||||||
|
if err != nil {
|
||||||
|
ln.Error(ctx, err, f, ln.F{"action": "can't add discordgo message to the database"})
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {}
|
||||||
|
}
|
||||||
|
}
|
|
@ -13,6 +13,7 @@ import (
|
||||||
|
|
||||||
"github.com/Xe/ln"
|
"github.com/Xe/ln"
|
||||||
"github.com/bwmarrin/discordgo"
|
"github.com/bwmarrin/discordgo"
|
||||||
|
"github.com/drone/mq/stomp"
|
||||||
_ "github.com/joho/godotenv/autoload"
|
_ "github.com/joho/godotenv/autoload"
|
||||||
"github.com/namsral/flag"
|
"github.com/namsral/flag"
|
||||||
xkcd "github.com/nishanths/go-xkcd"
|
xkcd "github.com/nishanths/go-xkcd"
|
||||||
|
@ -29,11 +30,26 @@ var (
|
||||||
token = flag.String("token", "", "discord bot token")
|
token = flag.String("token", "", "discord bot token")
|
||||||
zipkinURL = flag.String("zipkin-url", "", "URL for Zipkin traces")
|
zipkinURL = flag.String("zipkin-url", "", "URL for Zipkin traces")
|
||||||
databaseURL = flag.String("database-url", "http://", "URL for database (rqlite)")
|
databaseURL = flag.String("database-url", "http://", "URL for database (rqlite)")
|
||||||
|
mqURL = flag.String("mq-url", "tcp://mq:9000", "URL for STOMP server")
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
if *zipkinURL != "" {
|
||||||
|
collector, err := zipkin.NewHTTPCollector(*zipkinURL)
|
||||||
|
if err != nil {
|
||||||
|
ln.FatalErr(context.Background(), err)
|
||||||
|
}
|
||||||
|
tracer, err := zipkin.NewTracer(
|
||||||
|
zipkin.NewRecorder(collector, false, "vyvanse:5000", "vyvanse"))
|
||||||
|
if err != nil {
|
||||||
|
ln.FatalErr(context.Background(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
opentracing.SetGlobalTracer(tracer)
|
||||||
|
}
|
||||||
|
|
||||||
xk := xkcd.NewClient()
|
xk := xkcd.NewClient()
|
||||||
dg, err := discordgo.New("Bot " + *token)
|
dg, err := discordgo.New("Bot " + *token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -62,6 +78,14 @@ func main() {
|
||||||
}
|
}
|
||||||
sp.Finish()
|
sp.Finish()
|
||||||
|
|
||||||
|
ctx = context.Background()
|
||||||
|
|
||||||
|
mq, err := stomp.Dial(*mqURL)
|
||||||
|
if err != nil {
|
||||||
|
ln.FatalErr(ctx, err, ln.F{"url": *mqURL})
|
||||||
|
}
|
||||||
|
_ = mq
|
||||||
|
|
||||||
c := cron.New()
|
c := cron.New()
|
||||||
|
|
||||||
comic, err := xk.Latest()
|
comic, err := xk.Latest()
|
||||||
|
@ -111,20 +135,6 @@ func main() {
|
||||||
|
|
||||||
c.Start()
|
c.Start()
|
||||||
|
|
||||||
if *zipkinURL != "" {
|
|
||||||
collector, err := zipkin.NewHTTPCollector(*zipkinURL)
|
|
||||||
if err != nil {
|
|
||||||
ln.FatalErr(context.Background(), err)
|
|
||||||
}
|
|
||||||
tracer, err := zipkin.NewTracer(
|
|
||||||
zipkin.NewRecorder(collector, false, "vyvanse:5000", "vyvanse"))
|
|
||||||
if err != nil {
|
|
||||||
ln.FatalErr(context.Background(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
opentracing.SetGlobalTracer(tracer)
|
|
||||||
}
|
|
||||||
|
|
||||||
cs := bot.NewCommandSet()
|
cs := bot.NewCommandSet()
|
||||||
cs.Prefix = ">"
|
cs.Prefix = ">"
|
||||||
|
|
||||||
|
@ -134,6 +144,61 @@ func main() {
|
||||||
cs.AddCmd("splattus", "splatoon 2 map rotation status", bot.NoPermissions, spla2nMaps)
|
cs.AddCmd("splattus", "splatoon 2 map rotation status", bot.NoPermissions, spla2nMaps)
|
||||||
cs.AddCmd("top10", "shows the top 10 chatters on this server", bot.NoPermissions, top10(us))
|
cs.AddCmd("top10", "shows the top 10 chatters on this server", bot.NoPermissions, top10(us))
|
||||||
|
|
||||||
|
dg.AddHandler(func(s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||||
|
sp, ctx := opentracing.StartSpanFromContext(context.Background(), "message.create.post.stomp")
|
||||||
|
defer sp.Finish()
|
||||||
|
|
||||||
|
f := ln.F{
|
||||||
|
"channel_id": m.ChannelID,
|
||||||
|
"message_id": m.ID,
|
||||||
|
"message_author": m.Author.ID,
|
||||||
|
"message_author_name": m.Author.Username,
|
||||||
|
"message_author_is_bot": m.Author.Bot,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := mq.SendJSON("/topic/message_create", m.Message)
|
||||||
|
if err != nil {
|
||||||
|
if err.Error() == "EOF" {
|
||||||
|
mq, err = stomp.Dial(*mqURL)
|
||||||
|
if err != nil {
|
||||||
|
ln.Error(ctx, err, f, ln.F{"url": *mqURL, "action": "reconnect to mq"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = mq.SendJSON("/topic/message_create", m.Message)
|
||||||
|
if err != nil {
|
||||||
|
ln.Error(ctx, err, f, ln.F{"action": "retry message_create post to message queue"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ln.Error(ctx, err, f, ln.F{"action": "send created message to queue"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ln.Log(ctx, f, ln.F{"action": "message_create"})
|
||||||
|
})
|
||||||
|
|
||||||
|
dg.AddHandler(func(s *discordgo.Session, m *discordgo.GuildMemberAdd) {
|
||||||
|
sp, ctx := opentracing.StartSpanFromContext(context.Background(), "member.add.post.stomp")
|
||||||
|
defer sp.Finish()
|
||||||
|
|
||||||
|
f := ln.F{
|
||||||
|
"guild_id": m.GuildID,
|
||||||
|
"user_id": m.User.ID,
|
||||||
|
"user_name": m.User.Username,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := mq.SendJSON("/topic/member_add", m.Member)
|
||||||
|
if err != nil {
|
||||||
|
ln.Error(ctx, err, f, ln.F{"action": "send added member to queue"})
|
||||||
|
}
|
||||||
|
|
||||||
|
ln.Log(ctx, f, ln.F{"action": "member_add"})
|
||||||
|
})
|
||||||
|
|
||||||
dg.AddHandler(func(s *discordgo.Session, m *discordgo.MessageCreate) {
|
dg.AddHandler(func(s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||||
if m.Author.ID == s.State.User.ID {
|
if m.Author.ID == s.State.User.ID {
|
||||||
return
|
return
|
||||||
|
@ -165,8 +230,9 @@ func main() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("Bot is now running. Press CTRL-C to exit.")
|
ln.Log(ctx, ln.F{"action": "bot is running"})
|
||||||
// Simple way to keep program running until CTRL-C is pressed.
|
|
||||||
<-make(chan struct{})
|
for {
|
||||||
return
|
select {}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,5 +33,15 @@ services:
|
||||||
- mq
|
- mq
|
||||||
- rqlite
|
- rqlite
|
||||||
|
|
||||||
|
logworker:
|
||||||
|
restart: always
|
||||||
|
image: xena/vyvanse
|
||||||
|
env_file: ./.env
|
||||||
|
depends_on:
|
||||||
|
- zipkin
|
||||||
|
- mq
|
||||||
|
- rqlite
|
||||||
|
command: /root/go/bin/logworker
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
rqlite:
|
rqlite:
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
package dao
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"git.xeserv.us/xena/gorqlite"
|
||||||
|
"github.com/Xe/ln"
|
||||||
|
"github.com/bwmarrin/discordgo"
|
||||||
|
opentracing "github.com/opentracing/opentracing-go"
|
||||||
|
splog "github.com/opentracing/opentracing-go/log"
|
||||||
|
"google.golang.org/api/support/bundler"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Logs struct {
|
||||||
|
conn gorqlite.Connection
|
||||||
|
bdl *bundler.Bundler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLogs(conn gorqlite.Connection) *Logs {
|
||||||
|
l := &Logs{conn: conn}
|
||||||
|
l.bdl = bundler.NewBundler("string", l.writeLines)
|
||||||
|
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logs) writeLines(sqlLines interface{}) {
|
||||||
|
sp, ctx := opentracing.StartSpanFromContext(context.Background(), "logs.write.lines")
|
||||||
|
defer sp.Finish()
|
||||||
|
|
||||||
|
_, err := l.conn.Write(sqlLines.([]string))
|
||||||
|
if err != nil {
|
||||||
|
ln.Error(ctx, err, ln.F{"action": "write lines that are batched"})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logs) Migrate(ctx context.Context) error {
|
||||||
|
sp, ctx := opentracing.StartSpanFromContext(ctx, "logs.migrate")
|
||||||
|
defer sp.Finish()
|
||||||
|
|
||||||
|
migrationDDL := []string{
|
||||||
|
`CREATE TABLE IF NOT EXISTS logs(id INTEGER PRIMARY KEY, discord_id TEXT UNIQUE, channel_id TEXT, content TEXT, timestamp INTEGER, mention_everyone INTEGER, author_id TEXT, author_username TEXT)`,
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := l.conn.Write(migrationDDL)
|
||||||
|
if err != nil {
|
||||||
|
sp.LogFields(splog.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, re := range res {
|
||||||
|
if re.Err != nil {
|
||||||
|
sp.LogFields(splog.Error(re.Err))
|
||||||
|
return re.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
sp.LogFields(splog.Int("migration.step", i), splog.Float64("timing", re.Timing), splog.Int64("rows.affected", re.RowsAffected))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logs) Add(ctx context.Context, m *discordgo.Message) error {
|
||||||
|
sp, ctx := opentracing.StartSpanFromContext(ctx, "logs.add")
|
||||||
|
defer sp.Finish()
|
||||||
|
|
||||||
|
stmt := gorqlite.NewPreparedStatement(`INSERT INTO logs (discord_id, channel_id, content, timestamp, mention_everyone, author_id, author_username) VALUES (%s, %s, %s, %d, %d, %s, %s)`)
|
||||||
|
ts, err := m.Timestamp.Parse()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var me int
|
||||||
|
if m.MentionEveryone {
|
||||||
|
me = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
bd := stmt.Bind(m.ID, m.ChannelID, m.Content, ts, me, m.Author.ID, m.Author.Username)
|
||||||
|
|
||||||
|
l.bdl.Add(bd, len(bd))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -28,7 +28,6 @@ func (u *Users) Migrate(ctx context.Context) error {
|
||||||
res, err := u.conn.Write(migrationDDL)
|
res, err := u.conn.Write(migrationDDL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sp.LogFields(splog.Error(err))
|
sp.LogFields(splog.Error(err))
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, re := range res {
|
for i, re := range res {
|
||||||
|
|
|
@ -42,3 +42,11 @@ ae77be60afb1dcacde03767a8c37337fad28ac14 github.com/kardianos/osext
|
||||||
40a5e952d22c3ef520c6ab7bdb9b1a010ec9a524 git.xeserv.us/xena/gorqlite
|
40a5e952d22c3ef520c6ab7bdb9b1a010ec9a524 git.xeserv.us/xena/gorqlite
|
||||||
97311d9f7767e3d6f422ea06661bc2c7a19e8a5d github.com/mattn/go-runewidth
|
97311d9f7767e3d6f422ea06661bc2c7a19e8a5d github.com/mattn/go-runewidth
|
||||||
be5337e7b39e64e5f91445ce7e721888dbab7387 github.com/olekukonko/tablewriter
|
be5337e7b39e64e5f91445ce7e721888dbab7387 github.com/olekukonko/tablewriter
|
||||||
|
280af2a3b9c7d9ce90d625150dfff972c6c190b8 github.com/drone/mq/logger
|
||||||
|
280af2a3b9c7d9ce90d625150dfff972c6c190b8 github.com/drone/mq/stomp
|
||||||
|
280af2a3b9c7d9ce90d625150dfff972c6c190b8 github.com/drone/mq/stomp/dialer
|
||||||
|
f5079bd7f6f74e23c4d65efa0f4ce14cbd6a3c0f golang.org/x/net/context
|
||||||
|
f5079bd7f6f74e23c4d65efa0f4ce14cbd6a3c0f golang.org/x/net/websocket
|
||||||
|
66aacef3dd8a676686c7ae3716979581e8b03c47 golang.org/x/net/context
|
||||||
|
f52d1811a62927559de87708c8913c1650ce4f26 golang.org/x/sync/semaphore
|
||||||
|
e0e0e6e500066ff47335c7717e2a090ad127adec google.golang.org/api/support/bundler
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
package logger
|
||||||
|
|
||||||
|
var std Logger = new(none)
|
||||||
|
|
||||||
|
// Debugf writes a debug message to the standard logger.
|
||||||
|
func Debugf(format string, args ...interface{}) {
|
||||||
|
std.Debugf(format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verbosef writes a verbose message to the standard logger.
|
||||||
|
func Verbosef(format string, args ...interface{}) {
|
||||||
|
std.Verbosef(format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Noticef writes a notice message to the standard logger.
|
||||||
|
func Noticef(format string, args ...interface{}) {
|
||||||
|
std.Noticef(format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warningf writes a warning message to the standard logger.
|
||||||
|
func Warningf(format string, args ...interface{}) {
|
||||||
|
std.Warningf(format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Printf writes a default message to the standard logger.
|
||||||
|
func Printf(format string, args ...interface{}) {
|
||||||
|
std.Printf(format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLogger sets the standard logger.
|
||||||
|
func SetLogger(logger Logger) {
|
||||||
|
std = logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logger represents a logger.
|
||||||
|
type Logger interface {
|
||||||
|
|
||||||
|
// Debugf writes a debug message.
|
||||||
|
Debugf(string, ...interface{})
|
||||||
|
|
||||||
|
// Verbosef writes a verbose message.
|
||||||
|
Verbosef(string, ...interface{})
|
||||||
|
|
||||||
|
// Noticef writes a notice message.
|
||||||
|
Noticef(string, ...interface{})
|
||||||
|
|
||||||
|
// Warningf writes a warning message.
|
||||||
|
Warningf(string, ...interface{})
|
||||||
|
|
||||||
|
// Printf writes a default message.
|
||||||
|
Printf(string, ...interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// none is a logger that silently ignores all writes.
|
||||||
|
type none struct{}
|
||||||
|
|
||||||
|
func (*none) Debugf(string, ...interface{}) {}
|
||||||
|
func (*none) Verbosef(string, ...interface{}) {}
|
||||||
|
func (*none) Noticef(string, ...interface{}) {}
|
||||||
|
func (*none) Warningf(string, ...interface{}) {}
|
||||||
|
func (*none) Printf(string, ...interface{}) {}
|
|
@ -0,0 +1,259 @@
|
||||||
|
package stomp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"runtime/debug"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/drone/mq/logger"
|
||||||
|
"github.com/drone/mq/stomp/dialer"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client defines a client connection to a STOMP server.
|
||||||
|
type Client struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
peer Peer
|
||||||
|
subs map[string]Handler
|
||||||
|
wait map[string]chan struct{}
|
||||||
|
done chan error
|
||||||
|
|
||||||
|
seq int64
|
||||||
|
|
||||||
|
skipVerify bool
|
||||||
|
readBufferSize int
|
||||||
|
writeBufferSize int
|
||||||
|
timeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a new STOMP client using the given connection.
|
||||||
|
func New(peer Peer) *Client {
|
||||||
|
return &Client{
|
||||||
|
peer: peer,
|
||||||
|
subs: make(map[string]Handler),
|
||||||
|
wait: make(map[string]chan struct{}),
|
||||||
|
done: make(chan error, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial creates a client connection to the given target.
|
||||||
|
func Dial(target string) (*Client, error) {
|
||||||
|
conn, err := dialer.Dial(target)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return New(Conn(conn)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send sends the data to the given destination.
|
||||||
|
func (c *Client) Send(dest string, data []byte, opts ...MessageOption) error {
|
||||||
|
m := NewMessage()
|
||||||
|
m.Method = MethodSend
|
||||||
|
m.Dest = []byte(dest)
|
||||||
|
m.Body = data
|
||||||
|
m.Apply(opts...)
|
||||||
|
return c.sendMessage(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendJSON sends the JSON encoding of v to the given destination.
|
||||||
|
func (c *Client) SendJSON(dest string, v interface{}, opts ...MessageOption) error {
|
||||||
|
data, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
opts = append(opts,
|
||||||
|
WithHeader("content-type", "application/json"),
|
||||||
|
)
|
||||||
|
return c.Send(dest, data, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe subscribes to the given destination.
|
||||||
|
func (c *Client) Subscribe(dest string, handler Handler, opts ...MessageOption) (id []byte, err error) {
|
||||||
|
id = c.incr()
|
||||||
|
|
||||||
|
m := NewMessage()
|
||||||
|
m.Method = MethodSubscribe
|
||||||
|
m.ID = id
|
||||||
|
m.Dest = []byte(dest)
|
||||||
|
m.Apply(opts...)
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
c.subs[string(id)] = handler
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
err = c.sendMessage(m)
|
||||||
|
if err != nil {
|
||||||
|
c.mu.Lock()
|
||||||
|
delete(c.subs, string(id))
|
||||||
|
c.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribe unsubscribes to the destination.
|
||||||
|
func (c *Client) Unsubscribe(id []byte, opts ...MessageOption) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
delete(c.subs, string(id))
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
m := NewMessage()
|
||||||
|
m.Method = MethodUnsubscribe
|
||||||
|
m.ID = id
|
||||||
|
m.Apply(opts...)
|
||||||
|
|
||||||
|
return c.sendMessage(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ack acknowledges the messages with the given id.
|
||||||
|
func (c *Client) Ack(id []byte, opts ...MessageOption) error {
|
||||||
|
m := NewMessage()
|
||||||
|
m.Method = MethodAck
|
||||||
|
m.ID = id
|
||||||
|
m.Apply(opts...)
|
||||||
|
|
||||||
|
return c.sendMessage(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Nack negative-acknowledges the messages with the given id.
|
||||||
|
func (c *Client) Nack(id []byte, opts ...MessageOption) error {
|
||||||
|
m := NewMessage()
|
||||||
|
m.Method = MethodNack
|
||||||
|
m.ID = id
|
||||||
|
m.Apply(opts...)
|
||||||
|
|
||||||
|
return c.peer.Send(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect opens the connection and establishes the session.
|
||||||
|
func (c *Client) Connect(opts ...MessageOption) error {
|
||||||
|
m := NewMessage()
|
||||||
|
m.Proto = STOMP
|
||||||
|
m.Method = MethodStomp
|
||||||
|
m.Apply(opts...)
|
||||||
|
if err := c.sendMessage(m); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m, ok := <-c.peer.Receive()
|
||||||
|
if !ok {
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
defer m.Release()
|
||||||
|
|
||||||
|
if !bytes.Equal(m.Method, MethodConnected) {
|
||||||
|
return fmt.Errorf("stomp: inbound message: unexpected method, want connected")
|
||||||
|
}
|
||||||
|
go c.listen()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disconnect terminates the session and closes the connection.
|
||||||
|
func (c *Client) Disconnect() error {
|
||||||
|
m := NewMessage()
|
||||||
|
m.Method = MethodDisconnect
|
||||||
|
c.sendMessage(m)
|
||||||
|
return c.peer.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Done returns a channel
|
||||||
|
func (c *Client) Done() <-chan error {
|
||||||
|
return c.done
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) incr() []byte {
|
||||||
|
c.mu.Lock()
|
||||||
|
i := c.seq
|
||||||
|
c.seq++
|
||||||
|
c.mu.Unlock()
|
||||||
|
return strconv.AppendInt(nil, i, 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) listen() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
logger.Warningf("stomp client: recover panic: %s", r)
|
||||||
|
err, ok := r.(error)
|
||||||
|
if !ok {
|
||||||
|
logger.Warningf("%v: %s", r, debug.Stack())
|
||||||
|
c.done <- fmt.Errorf("%v", r)
|
||||||
|
} else {
|
||||||
|
logger.Warningf("%s", err)
|
||||||
|
c.done <- err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
m, ok := <-c.peer.Receive()
|
||||||
|
if !ok {
|
||||||
|
c.done <- io.EOF
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case bytes.Equal(m.Method, MethodMessage):
|
||||||
|
c.handleMessage(m)
|
||||||
|
case bytes.Equal(m.Method, MethodRecipet):
|
||||||
|
c.handleReceipt(m)
|
||||||
|
default:
|
||||||
|
logger.Noticef("stomp client: unknown message type: %s",
|
||||||
|
string(m.Method),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) handleReceipt(m *Message) {
|
||||||
|
c.mu.Lock()
|
||||||
|
receiptc, ok := c.wait[string(m.Receipt)]
|
||||||
|
c.mu.Unlock()
|
||||||
|
if !ok {
|
||||||
|
logger.Noticef("stomp client: unknown read receipt: %s",
|
||||||
|
string(m.Receipt),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
receiptc <- struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) handleMessage(m *Message) {
|
||||||
|
c.mu.Lock()
|
||||||
|
handler, ok := c.subs[string(m.Subs)]
|
||||||
|
c.mu.Unlock()
|
||||||
|
if !ok {
|
||||||
|
logger.Noticef("stomp client: subscription not found: %s",
|
||||||
|
string(m.Subs),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
handler.Handle(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) sendMessage(m *Message) error {
|
||||||
|
if len(m.Receipt) == 0 {
|
||||||
|
return c.peer.Send(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
receiptc := make(chan struct{}, 1)
|
||||||
|
c.wait[string(m.Receipt)] = receiptc
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
delete(c.wait, string(m.Receipt))
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := c.peer.Send(m)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-receiptc:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,156 @@
|
||||||
|
package stomp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/drone/mq/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
bufferSize = 32 << 10 // default buffer size 32KB
|
||||||
|
bufferLimit = 32 << 15 // default buffer limit 1MB
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
never time.Time
|
||||||
|
deadline = time.Second * 5
|
||||||
|
|
||||||
|
heartbeatTime = time.Second * 30
|
||||||
|
heartbeatWait = time.Second * 60
|
||||||
|
)
|
||||||
|
|
||||||
|
type connPeer struct {
|
||||||
|
conn net.Conn
|
||||||
|
done chan bool
|
||||||
|
|
||||||
|
reader *bufio.Reader
|
||||||
|
writer *bufio.Writer
|
||||||
|
incoming chan *Message
|
||||||
|
outgoing chan *Message
|
||||||
|
}
|
||||||
|
|
||||||
|
// Conn creates a network-connected peer that reads and writes
|
||||||
|
// messages using net.Conn c.
|
||||||
|
func Conn(c net.Conn) Peer {
|
||||||
|
p := &connPeer{
|
||||||
|
reader: bufio.NewReaderSize(c, bufferSize),
|
||||||
|
writer: bufio.NewWriterSize(c, bufferSize),
|
||||||
|
incoming: make(chan *Message),
|
||||||
|
outgoing: make(chan *Message),
|
||||||
|
done: make(chan bool),
|
||||||
|
conn: c,
|
||||||
|
}
|
||||||
|
|
||||||
|
go p.readInto(p.incoming)
|
||||||
|
go p.writeFrom(p.outgoing)
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connPeer) Receive() <-chan *Message {
|
||||||
|
return c.incoming
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connPeer) Send(message *Message) error {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return io.EOF
|
||||||
|
default:
|
||||||
|
c.outgoing <- message
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connPeer) Addr() string {
|
||||||
|
return c.conn.RemoteAddr().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connPeer) Close() error {
|
||||||
|
return c.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connPeer) close() error {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return io.EOF
|
||||||
|
default:
|
||||||
|
close(c.done)
|
||||||
|
close(c.incoming)
|
||||||
|
close(c.outgoing)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connPeer) readInto(messages chan<- *Message) {
|
||||||
|
defer c.close()
|
||||||
|
|
||||||
|
for {
|
||||||
|
// lim := io.LimitReader(c.conn, bufferLimit)
|
||||||
|
// buf := bufio.NewReaderSize(lim, bufferSize)
|
||||||
|
|
||||||
|
buf, err := c.reader.ReadBytes(0)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if len(buf) == 1 {
|
||||||
|
c.conn.SetReadDeadline(time.Now().Add(heartbeatWait))
|
||||||
|
logger.Verbosef("stomp: received heart-beat")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := NewMessage()
|
||||||
|
msg.Parse(buf[:len(buf)-1])
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
messages <- msg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connPeer) writeFrom(messages <-chan *Message) {
|
||||||
|
tick := time.NewTicker(time.Millisecond * 100).C
|
||||||
|
heartbeat := time.NewTicker(heartbeatTime).C
|
||||||
|
|
||||||
|
loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
break loop
|
||||||
|
case <-heartbeat:
|
||||||
|
logger.Verbosef("stomp: send heart-beat.")
|
||||||
|
c.writer.WriteByte(0)
|
||||||
|
case <-tick:
|
||||||
|
c.conn.SetWriteDeadline(time.Now().Add(deadline))
|
||||||
|
if err := c.writer.Flush(); err != nil {
|
||||||
|
break loop
|
||||||
|
}
|
||||||
|
c.conn.SetWriteDeadline(never)
|
||||||
|
case msg, ok := <-messages:
|
||||||
|
if !ok {
|
||||||
|
break loop
|
||||||
|
}
|
||||||
|
writeTo(c.writer, msg)
|
||||||
|
c.writer.WriteByte(0)
|
||||||
|
msg.Release()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.drain()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connPeer) drain() error {
|
||||||
|
c.conn.SetWriteDeadline(time.Now().Add(deadline))
|
||||||
|
for msg := range c.outgoing {
|
||||||
|
writeTo(c.writer, msg)
|
||||||
|
c.writer.WriteByte(0)
|
||||||
|
msg.Release()
|
||||||
|
}
|
||||||
|
c.conn.SetWriteDeadline(never)
|
||||||
|
c.writer.Flush()
|
||||||
|
return c.conn.Close()
|
||||||
|
}
|
|
@ -0,0 +1,76 @@
|
||||||
|
package stomp
|
||||||
|
|
||||||
|
// STOMP protocol version.
|
||||||
|
var STOMP = []byte("1.2")
|
||||||
|
|
||||||
|
// STOMP protocol methods.
|
||||||
|
var (
|
||||||
|
MethodStomp = []byte("STOMP")
|
||||||
|
MethodConnect = []byte("CONNECT")
|
||||||
|
MethodConnected = []byte("CONNECTED")
|
||||||
|
MethodSend = []byte("SEND")
|
||||||
|
MethodSubscribe = []byte("SUBSCRIBE")
|
||||||
|
MethodUnsubscribe = []byte("UNSUBSCRIBE")
|
||||||
|
MethodAck = []byte("ACK")
|
||||||
|
MethodNack = []byte("NACK")
|
||||||
|
MethodDisconnect = []byte("DISCONNECT")
|
||||||
|
MethodMessage = []byte("MESSAGE")
|
||||||
|
MethodRecipet = []byte("RECEIPT")
|
||||||
|
MethodError = []byte("ERROR")
|
||||||
|
)
|
||||||
|
|
||||||
|
// STOMP protocol headers.
|
||||||
|
var (
|
||||||
|
HeaderAccept = []byte("accept-version")
|
||||||
|
HeaderAck = []byte("ack")
|
||||||
|
HeaderExpires = []byte("expires")
|
||||||
|
HeaderDest = []byte("destination")
|
||||||
|
HeaderHost = []byte("host")
|
||||||
|
HeaderLogin = []byte("login")
|
||||||
|
HeaderPass = []byte("passcode")
|
||||||
|
HeaderID = []byte("id")
|
||||||
|
HeaderMessageID = []byte("message-id")
|
||||||
|
HeaderPersist = []byte("persist")
|
||||||
|
HeaderPrefetch = []byte("prefetch-count")
|
||||||
|
HeaderReceipt = []byte("receipt")
|
||||||
|
HeaderReceiptID = []byte("receipt-id")
|
||||||
|
HeaderRetain = []byte("retain")
|
||||||
|
HeaderSelector = []byte("selector")
|
||||||
|
HeaderServer = []byte("server")
|
||||||
|
HeaderSession = []byte("session")
|
||||||
|
HeaderSubscription = []byte("subscription")
|
||||||
|
HeaderVersion = []byte("version")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Common STOMP header values.
|
||||||
|
var (
|
||||||
|
AckAuto = []byte("auto")
|
||||||
|
AckClient = []byte("client")
|
||||||
|
PersistTrue = []byte("true")
|
||||||
|
RetainTrue = []byte("true")
|
||||||
|
RetainLast = []byte("last")
|
||||||
|
RetainAll = []byte("all")
|
||||||
|
RetainRemove = []byte("remove")
|
||||||
|
)
|
||||||
|
|
||||||
|
var headerLookup = map[string]struct{}{
|
||||||
|
"accept-version": struct{}{},
|
||||||
|
"ack": struct{}{},
|
||||||
|
"expires": struct{}{},
|
||||||
|
"destination": struct{}{},
|
||||||
|
"host": struct{}{},
|
||||||
|
"login": struct{}{},
|
||||||
|
"passcode": struct{}{},
|
||||||
|
"id": struct{}{},
|
||||||
|
"message-id": struct{}{},
|
||||||
|
"persist": struct{}{},
|
||||||
|
"prefetch-count": struct{}{},
|
||||||
|
"receipt": struct{}{},
|
||||||
|
"receipt-id": struct{}{},
|
||||||
|
"retain": struct{}{},
|
||||||
|
"selector": struct{}{},
|
||||||
|
"server": struct{}{},
|
||||||
|
"session": struct{}{},
|
||||||
|
"subscription": struct{}{},
|
||||||
|
"version": struct{}{},
|
||||||
|
}
|
|
@ -0,0 +1,37 @@
|
||||||
|
package stomp
|
||||||
|
|
||||||
|
import "golang.org/x/net/context"
|
||||||
|
|
||||||
|
const clientKey = "stomp.client"
|
||||||
|
|
||||||
|
// NewContext adds the client to the context.
|
||||||
|
func (c *Client) NewContext(ctx context.Context, client *Client) context.Context {
|
||||||
|
// HACK for use with gin and echo
|
||||||
|
if s, ok := ctx.(setter); ok {
|
||||||
|
s.Set(clientKey, clientKey)
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, clientKey, client)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromContext retrieves the client from context
|
||||||
|
func FromContext(ctx context.Context) (*Client, bool) {
|
||||||
|
client, ok := ctx.Value(clientKey).(*Client)
|
||||||
|
return client, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustFromContext retrieves the client from context. Panics if not found
|
||||||
|
func MustFromContext(ctx context.Context) *Client {
|
||||||
|
client, ok := FromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
panic("stomp.Client not found in context")
|
||||||
|
}
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
// HACK setter defines a context that enables setting values. This is a
|
||||||
|
// temporary workaround for use with gin and echo and will eventually
|
||||||
|
// be removed. DO NOT depend on this.
|
||||||
|
type setter interface {
|
||||||
|
Set(string, interface{})
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
package dialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"golang.org/x/net/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
protoHTTP = "http"
|
||||||
|
protoHTTPS = "https"
|
||||||
|
protoWS = "ws"
|
||||||
|
protoWSS = "wss"
|
||||||
|
protoTCP = "tcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dial creates a client connection to the given target.
|
||||||
|
func Dial(target string) (net.Conn, error) {
|
||||||
|
u, err := url.Parse(target)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch u.Scheme {
|
||||||
|
case protoHTTP, protoHTTPS, protoWS, protoWSS:
|
||||||
|
return dialWebsocket(u)
|
||||||
|
case protoTCP:
|
||||||
|
return dialSocket(u)
|
||||||
|
default:
|
||||||
|
panic("stomp: invalid protocol")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialWebsocket(target *url.URL) (net.Conn, error) {
|
||||||
|
origin, err := target.Parse("/")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch origin.Scheme {
|
||||||
|
case protoWS:
|
||||||
|
origin.Scheme = protoHTTP
|
||||||
|
case protoWSS:
|
||||||
|
origin.Scheme = protoHTTPS
|
||||||
|
}
|
||||||
|
return websocket.Dial(target.String(), "", origin.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialSocket(target *url.URL) (net.Conn, error) {
|
||||||
|
return net.Dial(protoTCP, target.Host)
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
package stomp
|
||||||
|
|
||||||
|
// Handler handles a STOMP message.
|
||||||
|
type Handler interface {
|
||||||
|
Handle(*Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The HandlerFunc type is an adapter to allow the use of an ordinary
|
||||||
|
// function as a STOMP message handler.
|
||||||
|
type HandlerFunc func(*Message)
|
||||||
|
|
||||||
|
// Handle calls f(m).
|
||||||
|
func (f HandlerFunc) Handle(m *Message) { f(m) }
|
|
@ -0,0 +1,109 @@
|
||||||
|
package stomp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultHeaderLen = 5
|
||||||
|
|
||||||
|
type item struct {
|
||||||
|
name []byte
|
||||||
|
data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Header represents the header section of the STOMP message.
|
||||||
|
type Header struct {
|
||||||
|
items []item
|
||||||
|
itemc int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHeader() *Header {
|
||||||
|
return &Header{
|
||||||
|
items: make([]item, defaultHeaderLen),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the named header value.
|
||||||
|
func (h *Header) Get(name []byte) (b []byte) {
|
||||||
|
for i := 0; i < h.itemc; i++ {
|
||||||
|
if v := h.items[i]; bytes.Equal(v.name, name) {
|
||||||
|
return v.data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetString returns the named header value.
|
||||||
|
func (h *Header) GetString(name string) string {
|
||||||
|
k := []byte(name)
|
||||||
|
v := h.Get(k)
|
||||||
|
return string(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBool returns the named header value.
|
||||||
|
func (h *Header) GetBool(name string) bool {
|
||||||
|
s := h.GetString(name)
|
||||||
|
b, _ := strconv.ParseBool(s)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetInt returns the named header value.
|
||||||
|
func (h *Header) GetInt(name string) int {
|
||||||
|
s := h.GetString(name)
|
||||||
|
i, _ := strconv.Atoi(s)
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetInt64 returns the named header value.
|
||||||
|
func (h *Header) GetInt64(name string) int64 {
|
||||||
|
s := h.GetString(name)
|
||||||
|
i, _ := strconv.ParseInt(s, 10, 64)
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
// Field returns the named header value in string format. This is used to
|
||||||
|
// provide compatibility with the SQL expression evaluation package.
|
||||||
|
func (h *Header) Field(name []byte) []byte {
|
||||||
|
return h.Get(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add appens the key value pair to the header.
|
||||||
|
func (h *Header) Add(name, data []byte) {
|
||||||
|
h.grow()
|
||||||
|
h.items[h.itemc].name = name
|
||||||
|
h.items[h.itemc].data = data
|
||||||
|
h.itemc++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Index returns the keypair at index i.
|
||||||
|
func (h *Header) Index(i int) (k, v []byte) {
|
||||||
|
if i > h.itemc {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
k = h.items[i].name
|
||||||
|
v = h.items[i].data
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the header length.
|
||||||
|
func (h *Header) Len() int {
|
||||||
|
return h.itemc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Header) grow() {
|
||||||
|
if h.itemc > defaultHeaderLen-1 {
|
||||||
|
h.items = append(h.items, item{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Header) reset() {
|
||||||
|
h.itemc = 0
|
||||||
|
h.items = h.items[:defaultHeaderLen]
|
||||||
|
for i := range h.items {
|
||||||
|
h.items[i].name = zeroBytes
|
||||||
|
h.items[i].data = zeroBytes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var zeroBytes []byte
|
|
@ -0,0 +1,146 @@
|
||||||
|
package stomp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"math/rand"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Message represents a parsed STOMP message.
|
||||||
|
type Message struct {
|
||||||
|
ID []byte // id header
|
||||||
|
Proto []byte // stomp version
|
||||||
|
Method []byte // stomp method
|
||||||
|
User []byte // username header
|
||||||
|
Pass []byte // password header
|
||||||
|
Dest []byte // destination header
|
||||||
|
Subs []byte // subscription id
|
||||||
|
Ack []byte // ack id
|
||||||
|
Msg []byte // message-id header
|
||||||
|
Persist []byte // persist header
|
||||||
|
Retain []byte // retain header
|
||||||
|
Prefetch []byte // prefetch count
|
||||||
|
Expires []byte // expires header
|
||||||
|
Receipt []byte // receipt header
|
||||||
|
Selector []byte // selector header
|
||||||
|
Body []byte
|
||||||
|
Header *Header // custom headers
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy returns a copy of the Message.
|
||||||
|
func (m *Message) Copy() *Message {
|
||||||
|
c := NewMessage()
|
||||||
|
c.ID = m.ID
|
||||||
|
c.Proto = m.Proto
|
||||||
|
c.Method = m.Method
|
||||||
|
c.User = m.User
|
||||||
|
c.Pass = m.Pass
|
||||||
|
c.Dest = m.Dest
|
||||||
|
c.Subs = m.Subs
|
||||||
|
c.Ack = m.Ack
|
||||||
|
c.Prefetch = m.Prefetch
|
||||||
|
c.Selector = m.Selector
|
||||||
|
c.Persist = m.Persist
|
||||||
|
c.Retain = m.Retain
|
||||||
|
c.Receipt = m.Receipt
|
||||||
|
c.Expires = m.Expires
|
||||||
|
c.Body = m.Body
|
||||||
|
c.ctx = m.ctx
|
||||||
|
c.Header.itemc = m.Header.itemc
|
||||||
|
copy(c.Header.items, m.Header.items)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply applies the options to the message.
|
||||||
|
func (m *Message) Apply(opts ...MessageOption) {
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse parses the raw bytes into the message.
|
||||||
|
func (m *Message) Parse(b []byte) error {
|
||||||
|
return read(b, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bytes returns the Message in raw byte format.
|
||||||
|
func (m *Message) Bytes() []byte {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writeTo(&buf, m)
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the Message in string format.
|
||||||
|
func (m *Message) String() string {
|
||||||
|
return string(m.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release releases the message back to the message pool.
|
||||||
|
func (m *Message) Release() {
|
||||||
|
m.Reset()
|
||||||
|
pool.Put(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset resets the meesage fields to their zero values.
|
||||||
|
func (m *Message) Reset() {
|
||||||
|
m.ID = m.ID[:0]
|
||||||
|
m.Proto = m.Proto[:0]
|
||||||
|
m.Method = m.Method[:0]
|
||||||
|
m.User = m.User[:0]
|
||||||
|
m.Pass = m.Pass[:0]
|
||||||
|
m.Dest = m.Dest[:0]
|
||||||
|
m.Subs = m.Subs[:0]
|
||||||
|
m.Ack = m.Ack[:0]
|
||||||
|
m.Prefetch = m.Prefetch[:0]
|
||||||
|
m.Selector = m.Selector[:0]
|
||||||
|
m.Persist = m.Persist[:0]
|
||||||
|
m.Retain = m.Retain[:0]
|
||||||
|
m.Receipt = m.Receipt[:0]
|
||||||
|
m.Expires = m.Expires[:0]
|
||||||
|
m.Body = m.Body[:0]
|
||||||
|
m.ctx = nil
|
||||||
|
m.Header.reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context returns the request's context.
|
||||||
|
func (m *Message) Context() context.Context {
|
||||||
|
if m.ctx != nil {
|
||||||
|
return m.ctx
|
||||||
|
}
|
||||||
|
return context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithContext returns a shallow copy of m with its context changed
|
||||||
|
// to ctx. The provided ctx must be non-nil.
|
||||||
|
func (m *Message) WithContext(ctx context.Context) *Message {
|
||||||
|
c := m.Copy()
|
||||||
|
c.ctx = ctx
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal parses the JSON-encoded body of the message and
|
||||||
|
// stores the result in the value pointed to by v.
|
||||||
|
func (m *Message) Unmarshal(v interface{}) error {
|
||||||
|
return json.Unmarshal(m.Body, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMessage returns an empty message from the message pool.
|
||||||
|
func NewMessage() *Message {
|
||||||
|
return pool.Get().(*Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
var pool = sync.Pool{New: func() interface{} {
|
||||||
|
return &Message{Header: newHeader()}
|
||||||
|
}}
|
||||||
|
|
||||||
|
// Rand returns a random int64 number as a []byte of
|
||||||
|
// ascii characters.
|
||||||
|
func Rand() []byte {
|
||||||
|
return strconv.AppendInt(nil, rand.Int63(), 10)
|
||||||
|
}
|
|
@ -0,0 +1,96 @@
|
||||||
|
package stomp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MessageOption configures message options.
|
||||||
|
type MessageOption func(*Message)
|
||||||
|
|
||||||
|
// WithCredentials returns a MessageOption which sets credentials.
|
||||||
|
func WithCredentials(username, password string) MessageOption {
|
||||||
|
return func(m *Message) {
|
||||||
|
m.User = []byte(username)
|
||||||
|
m.Pass = []byte(password)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHeader returns a MessageOption which sets a header.
|
||||||
|
func WithHeader(key, value string) MessageOption {
|
||||||
|
return func(m *Message) {
|
||||||
|
_, ok := headerLookup[strings.ToLower(key)]
|
||||||
|
if !ok {
|
||||||
|
m.Header.Add(
|
||||||
|
[]byte(key),
|
||||||
|
[]byte(value),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHeaders returns a MessageOption which sets headers.
|
||||||
|
func WithHeaders(headers map[string]string) MessageOption {
|
||||||
|
return func(m *Message) {
|
||||||
|
for key, value := range headers {
|
||||||
|
_, ok := headerLookup[strings.ToLower(key)]
|
||||||
|
if !ok {
|
||||||
|
m.Header.Add(
|
||||||
|
[]byte(key),
|
||||||
|
[]byte(value),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithExpires returns a MessageOption configured with an expiration.
|
||||||
|
func WithExpires(exp int64) MessageOption {
|
||||||
|
return func(m *Message) {
|
||||||
|
m.Expires = strconv.AppendInt(nil, exp, 10)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithPrefetch returns a MessageOption configured with a prefetch count.
|
||||||
|
func WithPrefetch(prefetch int) MessageOption {
|
||||||
|
return func(m *Message) {
|
||||||
|
m.Prefetch = strconv.AppendInt(nil, int64(prefetch), 10)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithReceipt returns a MessageOption configured with a receipt request.
|
||||||
|
func WithReceipt() MessageOption {
|
||||||
|
return func(m *Message) {
|
||||||
|
m.Receipt = strconv.AppendInt(nil, rand.Int63(), 10)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithPersistence returns a MessageOption configured to persist.
|
||||||
|
func WithPersistence() MessageOption {
|
||||||
|
return func(m *Message) {
|
||||||
|
m.Persist = PersistTrue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRetain returns a MessageOption configured to retain the message.
|
||||||
|
func WithRetain(retain string) MessageOption {
|
||||||
|
return func(m *Message) {
|
||||||
|
m.Retain = []byte(retain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSelector returns a MessageOption configured to filter messages
|
||||||
|
// using a sql-like evaluation string.
|
||||||
|
func WithSelector(selector string) MessageOption {
|
||||||
|
return func(m *Message) {
|
||||||
|
m.Selector = []byte(selector)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAck returns a MessageOption configured with an ack policy.
|
||||||
|
func WithAck(ack string) MessageOption {
|
||||||
|
return func(m *Message) {
|
||||||
|
m.Ack = []byte(ack)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,86 @@
|
||||||
|
package stomp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Peer defines a peer-to-peer connection.
|
||||||
|
type Peer interface {
|
||||||
|
// Send sends a message.
|
||||||
|
Send(*Message) error
|
||||||
|
|
||||||
|
// Receive returns a channel of inbound messages.
|
||||||
|
Receive() <-chan *Message
|
||||||
|
|
||||||
|
// Close closes the connection.
|
||||||
|
Close() error
|
||||||
|
|
||||||
|
// Addr returns the peer address.
|
||||||
|
Addr() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pipe creates a synchronous in-memory pipe, where reads on one end are
|
||||||
|
// matched with writes on the other. This is useful for direct, in-memory
|
||||||
|
// client-server communication.
|
||||||
|
func Pipe() (Peer, Peer) {
|
||||||
|
atob := make(chan *Message, 10)
|
||||||
|
btoa := make(chan *Message, 10)
|
||||||
|
|
||||||
|
a := &localPeer{
|
||||||
|
incoming: btoa,
|
||||||
|
outgoing: atob,
|
||||||
|
finished: make(chan bool),
|
||||||
|
}
|
||||||
|
b := &localPeer{
|
||||||
|
incoming: atob,
|
||||||
|
outgoing: btoa,
|
||||||
|
finished: make(chan bool),
|
||||||
|
}
|
||||||
|
|
||||||
|
return a, b
|
||||||
|
}
|
||||||
|
|
||||||
|
type localPeer struct {
|
||||||
|
finished chan bool
|
||||||
|
outgoing chan<- *Message
|
||||||
|
incoming <-chan *Message
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *localPeer) Receive() <-chan *Message {
|
||||||
|
return p.incoming
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *localPeer) Send(m *Message) error {
|
||||||
|
select {
|
||||||
|
case <-p.finished:
|
||||||
|
return io.EOF
|
||||||
|
default:
|
||||||
|
p.outgoing <- m
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *localPeer) Close() error {
|
||||||
|
close(p.finished)
|
||||||
|
close(p.outgoing)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *localPeer) Addr() string {
|
||||||
|
peerAddrOnce.Do(func() {
|
||||||
|
// get the local address list
|
||||||
|
addr, _ := net.InterfaceAddrs()
|
||||||
|
if len(addr) != 0 {
|
||||||
|
// use the last address in the list
|
||||||
|
peerAddr = addr[len(addr)-1].String()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return peerAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
var peerAddrOnce sync.Once
|
||||||
|
|
||||||
|
// default address displayed for local pipes
|
||||||
|
var peerAddr = "127.0.0.1/8"
|
|
@ -0,0 +1,139 @@
|
||||||
|
package stomp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func read(input []byte, m *Message) (err error) {
|
||||||
|
var (
|
||||||
|
pos int
|
||||||
|
off int
|
||||||
|
tot = len(input)
|
||||||
|
)
|
||||||
|
|
||||||
|
// parse the stomp message
|
||||||
|
for ; ; off++ {
|
||||||
|
if off == tot {
|
||||||
|
return fmt.Errorf("stomp: invalid method")
|
||||||
|
}
|
||||||
|
if input[off] == '\n' {
|
||||||
|
m.Method = input[pos:off]
|
||||||
|
off++
|
||||||
|
pos = off
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse the stomp headers
|
||||||
|
for {
|
||||||
|
if off == tot {
|
||||||
|
return fmt.Errorf("stomp: unexpected eof")
|
||||||
|
}
|
||||||
|
if input[off] == '\n' {
|
||||||
|
off++
|
||||||
|
pos = off
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
name []byte
|
||||||
|
value []byte
|
||||||
|
)
|
||||||
|
|
||||||
|
loop:
|
||||||
|
// parse each individual header
|
||||||
|
for ; ; off++ {
|
||||||
|
if off >= tot {
|
||||||
|
return fmt.Errorf("stomp: unexpected eof")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch input[off] {
|
||||||
|
case '\n':
|
||||||
|
value = input[pos:off]
|
||||||
|
off++
|
||||||
|
pos = off
|
||||||
|
break loop
|
||||||
|
case ':':
|
||||||
|
name = input[pos:off]
|
||||||
|
off++
|
||||||
|
pos = off
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case bytes.Equal(name, HeaderAccept):
|
||||||
|
m.Proto = value
|
||||||
|
case bytes.Equal(name, HeaderAck):
|
||||||
|
m.Ack = value
|
||||||
|
case bytes.Equal(name, HeaderDest):
|
||||||
|
m.Dest = value
|
||||||
|
case bytes.Equal(name, HeaderExpires):
|
||||||
|
m.Expires = value
|
||||||
|
case bytes.Equal(name, HeaderLogin):
|
||||||
|
m.User = value
|
||||||
|
case bytes.Equal(name, HeaderPass):
|
||||||
|
m.Pass = value
|
||||||
|
case bytes.Equal(name, HeaderID):
|
||||||
|
m.ID = value
|
||||||
|
case bytes.Equal(name, HeaderMessageID):
|
||||||
|
m.ID = value
|
||||||
|
case bytes.Equal(name, HeaderPersist):
|
||||||
|
m.Persist = value
|
||||||
|
case bytes.Equal(name, HeaderPrefetch):
|
||||||
|
m.Prefetch = value
|
||||||
|
case bytes.Equal(name, HeaderReceipt):
|
||||||
|
m.Receipt = value
|
||||||
|
case bytes.Equal(name, HeaderReceiptID):
|
||||||
|
m.Receipt = value
|
||||||
|
case bytes.Equal(name, HeaderRetain):
|
||||||
|
m.Retain = value
|
||||||
|
case bytes.Equal(name, HeaderSelector):
|
||||||
|
m.Selector = value
|
||||||
|
case bytes.Equal(name, HeaderSubscription):
|
||||||
|
m.Subs = value
|
||||||
|
case bytes.Equal(name, HeaderVersion):
|
||||||
|
m.Proto = value
|
||||||
|
default:
|
||||||
|
m.Header.Add(name, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tot > pos {
|
||||||
|
m.Body = input[pos:]
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
asciiZero = 48
|
||||||
|
asciiNine = 57
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseInt returns the ascii integer value.
|
||||||
|
func ParseInt(d []byte) (n int) {
|
||||||
|
if len(d) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
for _, dec := range d {
|
||||||
|
if dec < asciiZero || dec > asciiNine {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
n = n*10 + (int(dec) - asciiZero)
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseInt64 returns the ascii integer value.
|
||||||
|
func ParseInt64(d []byte) (n int64) {
|
||||||
|
if len(d) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
for _, dec := range d {
|
||||||
|
if dec < asciiZero || dec > asciiNine {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
n = n*10 + (int64(dec) - asciiZero)
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
|
@ -0,0 +1,173 @@
|
||||||
|
package stomp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
crlf = []byte{'\r', '\n'}
|
||||||
|
newline = []byte{'\n'}
|
||||||
|
separator = []byte{':'}
|
||||||
|
terminator = []byte{0}
|
||||||
|
)
|
||||||
|
|
||||||
|
func writeTo(w io.Writer, m *Message) {
|
||||||
|
w.Write(m.Method)
|
||||||
|
w.Write(newline)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case bytes.Equal(m.Method, MethodStomp):
|
||||||
|
// version
|
||||||
|
w.Write(HeaderAccept)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.Proto)
|
||||||
|
w.Write(newline)
|
||||||
|
// login
|
||||||
|
if len(m.User) != 0 {
|
||||||
|
w.Write(HeaderLogin)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.User)
|
||||||
|
w.Write(newline)
|
||||||
|
}
|
||||||
|
// passcode
|
||||||
|
if len(m.Pass) != 0 {
|
||||||
|
w.Write(HeaderPass)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.Pass)
|
||||||
|
w.Write(newline)
|
||||||
|
}
|
||||||
|
case bytes.Equal(m.Method, MethodConnected):
|
||||||
|
// version
|
||||||
|
w.Write(HeaderVersion)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.Proto)
|
||||||
|
w.Write(newline)
|
||||||
|
case bytes.Equal(m.Method, MethodSend):
|
||||||
|
// dest
|
||||||
|
w.Write(HeaderDest)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.Dest)
|
||||||
|
w.Write(newline)
|
||||||
|
if len(m.Expires) != 0 {
|
||||||
|
w.Write(HeaderExpires)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.Expires)
|
||||||
|
w.Write(newline)
|
||||||
|
}
|
||||||
|
if len(m.Retain) != 0 {
|
||||||
|
w.Write(HeaderRetain)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.Retain)
|
||||||
|
w.Write(newline)
|
||||||
|
}
|
||||||
|
if len(m.Persist) != 0 {
|
||||||
|
w.Write(HeaderPersist)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.Persist)
|
||||||
|
w.Write(newline)
|
||||||
|
}
|
||||||
|
case bytes.Equal(m.Method, MethodSubscribe):
|
||||||
|
// id
|
||||||
|
w.Write(HeaderID)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.ID)
|
||||||
|
w.Write(newline)
|
||||||
|
// destination
|
||||||
|
w.Write(HeaderDest)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.Dest)
|
||||||
|
w.Write(newline)
|
||||||
|
// selector
|
||||||
|
if len(m.Selector) != 0 {
|
||||||
|
w.Write(HeaderSelector)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.Selector)
|
||||||
|
w.Write(newline)
|
||||||
|
}
|
||||||
|
// prefetch
|
||||||
|
if len(m.Prefetch) != 0 {
|
||||||
|
w.Write(HeaderPrefetch)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.Prefetch)
|
||||||
|
w.Write(newline)
|
||||||
|
}
|
||||||
|
if len(m.Ack) != 0 {
|
||||||
|
w.Write(HeaderAck)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.Ack)
|
||||||
|
w.Write(newline)
|
||||||
|
}
|
||||||
|
case bytes.Equal(m.Method, MethodUnsubscribe):
|
||||||
|
// id
|
||||||
|
w.Write(HeaderID)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.ID)
|
||||||
|
w.Write(newline)
|
||||||
|
case bytes.Equal(m.Method, MethodAck):
|
||||||
|
// id
|
||||||
|
w.Write(HeaderID)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.ID)
|
||||||
|
w.Write(newline)
|
||||||
|
case bytes.Equal(m.Method, MethodNack):
|
||||||
|
// id
|
||||||
|
w.Write(HeaderID)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.ID)
|
||||||
|
w.Write(newline)
|
||||||
|
case bytes.Equal(m.Method, MethodMessage):
|
||||||
|
// message-id
|
||||||
|
w.Write(HeaderMessageID)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.ID)
|
||||||
|
w.Write(newline)
|
||||||
|
// destination
|
||||||
|
w.Write(HeaderDest)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.Dest)
|
||||||
|
w.Write(newline)
|
||||||
|
// subscription
|
||||||
|
w.Write(HeaderSubscription)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.Subs)
|
||||||
|
w.Write(newline)
|
||||||
|
// ack
|
||||||
|
if len(m.Ack) != 0 {
|
||||||
|
w.Write(HeaderAck)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.Ack)
|
||||||
|
w.Write(newline)
|
||||||
|
}
|
||||||
|
case bytes.Equal(m.Method, MethodRecipet):
|
||||||
|
// receipt-id
|
||||||
|
w.Write(HeaderReceiptID)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.Receipt)
|
||||||
|
w.Write(newline)
|
||||||
|
}
|
||||||
|
|
||||||
|
// receipt header
|
||||||
|
if includeReceiptHeader(m) {
|
||||||
|
w.Write(HeaderReceipt)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(m.Receipt)
|
||||||
|
w.Write(newline)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, item := range m.Header.items {
|
||||||
|
if m.Header.itemc == i {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
w.Write(item.name)
|
||||||
|
w.Write(separator)
|
||||||
|
w.Write(item.data)
|
||||||
|
w.Write(newline)
|
||||||
|
}
|
||||||
|
w.Write(newline)
|
||||||
|
w.Write(m.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func includeReceiptHeader(m *Message) bool {
|
||||||
|
return len(m.Receipt) != 0 && !bytes.Equal(m.Method, MethodRecipet)
|
||||||
|
}
|
|
@ -0,0 +1,106 @@
|
||||||
|
// Copyright 2009 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DialError is an error that occurs while dialling a websocket server.
|
||||||
|
type DialError struct {
|
||||||
|
*Config
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *DialError) Error() string {
|
||||||
|
return "websocket.Dial " + e.Config.Location.String() + ": " + e.Err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConfig creates a new WebSocket config for client connection.
|
||||||
|
func NewConfig(server, origin string) (config *Config, err error) {
|
||||||
|
config = new(Config)
|
||||||
|
config.Version = ProtocolVersionHybi13
|
||||||
|
config.Location, err = url.ParseRequestURI(server)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
config.Origin, err = url.ParseRequestURI(origin)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
config.Header = http.Header(make(map[string][]string))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient creates a new WebSocket client connection over rwc.
|
||||||
|
func NewClient(config *Config, rwc io.ReadWriteCloser) (ws *Conn, err error) {
|
||||||
|
br := bufio.NewReader(rwc)
|
||||||
|
bw := bufio.NewWriter(rwc)
|
||||||
|
err = hybiClientHandshake(config, br, bw)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
buf := bufio.NewReadWriter(br, bw)
|
||||||
|
ws = newHybiClientConn(config, buf, rwc)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial opens a new client connection to a WebSocket.
|
||||||
|
func Dial(url_, protocol, origin string) (ws *Conn, err error) {
|
||||||
|
config, err := NewConfig(url_, origin)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if protocol != "" {
|
||||||
|
config.Protocol = []string{protocol}
|
||||||
|
}
|
||||||
|
return DialConfig(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
var portMap = map[string]string{
|
||||||
|
"ws": "80",
|
||||||
|
"wss": "443",
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAuthority(location *url.URL) string {
|
||||||
|
if _, ok := portMap[location.Scheme]; ok {
|
||||||
|
if _, _, err := net.SplitHostPort(location.Host); err != nil {
|
||||||
|
return net.JoinHostPort(location.Host, portMap[location.Scheme])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return location.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialConfig opens a new client connection to a WebSocket with a config.
|
||||||
|
func DialConfig(config *Config) (ws *Conn, err error) {
|
||||||
|
var client net.Conn
|
||||||
|
if config.Location == nil {
|
||||||
|
return nil, &DialError{config, ErrBadWebSocketLocation}
|
||||||
|
}
|
||||||
|
if config.Origin == nil {
|
||||||
|
return nil, &DialError{config, ErrBadWebSocketOrigin}
|
||||||
|
}
|
||||||
|
dialer := config.Dialer
|
||||||
|
if dialer == nil {
|
||||||
|
dialer = &net.Dialer{}
|
||||||
|
}
|
||||||
|
client, err = dialWithDialer(dialer, config)
|
||||||
|
if err != nil {
|
||||||
|
goto Error
|
||||||
|
}
|
||||||
|
ws, err = NewClient(config, client)
|
||||||
|
if err != nil {
|
||||||
|
client.Close()
|
||||||
|
goto Error
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
Error:
|
||||||
|
return nil, &DialError{config, err}
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
// Copyright 2015 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
|
||||||
|
switch config.Location.Scheme {
|
||||||
|
case "ws":
|
||||||
|
conn, err = dialer.Dial("tcp", parseAuthority(config.Location))
|
||||||
|
|
||||||
|
case "wss":
|
||||||
|
conn, err = tls.DialWithDialer(dialer, "tcp", parseAuthority(config.Location), config.TlsConfig)
|
||||||
|
|
||||||
|
default:
|
||||||
|
err = ErrBadScheme
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,583 @@
|
||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
// This file implements a protocol of hybi draft.
|
||||||
|
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||||
|
|
||||||
|
closeStatusNormal = 1000
|
||||||
|
closeStatusGoingAway = 1001
|
||||||
|
closeStatusProtocolError = 1002
|
||||||
|
closeStatusUnsupportedData = 1003
|
||||||
|
closeStatusFrameTooLarge = 1004
|
||||||
|
closeStatusNoStatusRcvd = 1005
|
||||||
|
closeStatusAbnormalClosure = 1006
|
||||||
|
closeStatusBadMessageData = 1007
|
||||||
|
closeStatusPolicyViolation = 1008
|
||||||
|
closeStatusTooBigData = 1009
|
||||||
|
closeStatusExtensionMismatch = 1010
|
||||||
|
|
||||||
|
maxControlFramePayloadLength = 125
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrBadMaskingKey = &ProtocolError{"bad masking key"}
|
||||||
|
ErrBadPongMessage = &ProtocolError{"bad pong message"}
|
||||||
|
ErrBadClosingStatus = &ProtocolError{"bad closing status"}
|
||||||
|
ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"}
|
||||||
|
ErrNotImplemented = &ProtocolError{"not implemented"}
|
||||||
|
|
||||||
|
handshakeHeader = map[string]bool{
|
||||||
|
"Host": true,
|
||||||
|
"Upgrade": true,
|
||||||
|
"Connection": true,
|
||||||
|
"Sec-Websocket-Key": true,
|
||||||
|
"Sec-Websocket-Origin": true,
|
||||||
|
"Sec-Websocket-Version": true,
|
||||||
|
"Sec-Websocket-Protocol": true,
|
||||||
|
"Sec-Websocket-Accept": true,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// A hybiFrameHeader is a frame header as defined in hybi draft.
|
||||||
|
type hybiFrameHeader struct {
|
||||||
|
Fin bool
|
||||||
|
Rsv [3]bool
|
||||||
|
OpCode byte
|
||||||
|
Length int64
|
||||||
|
MaskingKey []byte
|
||||||
|
|
||||||
|
data *bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
// A hybiFrameReader is a reader for hybi frame.
|
||||||
|
type hybiFrameReader struct {
|
||||||
|
reader io.Reader
|
||||||
|
|
||||||
|
header hybiFrameHeader
|
||||||
|
pos int64
|
||||||
|
length int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) {
|
||||||
|
n, err = frame.reader.Read(msg)
|
||||||
|
if frame.header.MaskingKey != nil {
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4]
|
||||||
|
frame.pos++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode }
|
||||||
|
|
||||||
|
func (frame *hybiFrameReader) HeaderReader() io.Reader {
|
||||||
|
if frame.header.data == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if frame.header.data.Len() == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return frame.header.data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil }
|
||||||
|
|
||||||
|
func (frame *hybiFrameReader) Len() (n int) { return frame.length }
|
||||||
|
|
||||||
|
// A hybiFrameReaderFactory creates new frame reader based on its frame type.
|
||||||
|
type hybiFrameReaderFactory struct {
|
||||||
|
*bufio.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFrameReader reads a frame header from the connection, and creates new reader for the frame.
|
||||||
|
// See Section 5.2 Base Framing protocol for detail.
|
||||||
|
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2
|
||||||
|
func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) {
|
||||||
|
hybiFrame := new(hybiFrameReader)
|
||||||
|
frame = hybiFrame
|
||||||
|
var header []byte
|
||||||
|
var b byte
|
||||||
|
// First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits)
|
||||||
|
b, err = buf.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
header = append(header, b)
|
||||||
|
hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
j := uint(6 - i)
|
||||||
|
hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0
|
||||||
|
}
|
||||||
|
hybiFrame.header.OpCode = header[0] & 0x0f
|
||||||
|
|
||||||
|
// Second byte. Mask/Payload len(7bits)
|
||||||
|
b, err = buf.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
header = append(header, b)
|
||||||
|
mask := (b & 0x80) != 0
|
||||||
|
b &= 0x7f
|
||||||
|
lengthFields := 0
|
||||||
|
switch {
|
||||||
|
case b <= 125: // Payload length 7bits.
|
||||||
|
hybiFrame.header.Length = int64(b)
|
||||||
|
case b == 126: // Payload length 7+16bits
|
||||||
|
lengthFields = 2
|
||||||
|
case b == 127: // Payload length 7+64bits
|
||||||
|
lengthFields = 8
|
||||||
|
}
|
||||||
|
for i := 0; i < lengthFields; i++ {
|
||||||
|
b, err = buf.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if lengthFields == 8 && i == 0 { // MSB must be zero when 7+64 bits
|
||||||
|
b &= 0x7f
|
||||||
|
}
|
||||||
|
header = append(header, b)
|
||||||
|
hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b)
|
||||||
|
}
|
||||||
|
if mask {
|
||||||
|
// Masking key. 4 bytes.
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
b, err = buf.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
header = append(header, b)
|
||||||
|
hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length)
|
||||||
|
hybiFrame.header.data = bytes.NewBuffer(header)
|
||||||
|
hybiFrame.length = len(header) + int(hybiFrame.header.Length)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// A HybiFrameWriter is a writer for hybi frame.
|
||||||
|
type hybiFrameWriter struct {
|
||||||
|
writer *bufio.Writer
|
||||||
|
|
||||||
|
header *hybiFrameHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) {
|
||||||
|
var header []byte
|
||||||
|
var b byte
|
||||||
|
if frame.header.Fin {
|
||||||
|
b |= 0x80
|
||||||
|
}
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
if frame.header.Rsv[i] {
|
||||||
|
j := uint(6 - i)
|
||||||
|
b |= 1 << j
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b |= frame.header.OpCode
|
||||||
|
header = append(header, b)
|
||||||
|
if frame.header.MaskingKey != nil {
|
||||||
|
b = 0x80
|
||||||
|
} else {
|
||||||
|
b = 0
|
||||||
|
}
|
||||||
|
lengthFields := 0
|
||||||
|
length := len(msg)
|
||||||
|
switch {
|
||||||
|
case length <= 125:
|
||||||
|
b |= byte(length)
|
||||||
|
case length < 65536:
|
||||||
|
b |= 126
|
||||||
|
lengthFields = 2
|
||||||
|
default:
|
||||||
|
b |= 127
|
||||||
|
lengthFields = 8
|
||||||
|
}
|
||||||
|
header = append(header, b)
|
||||||
|
for i := 0; i < lengthFields; i++ {
|
||||||
|
j := uint((lengthFields - i - 1) * 8)
|
||||||
|
b = byte((length >> j) & 0xff)
|
||||||
|
header = append(header, b)
|
||||||
|
}
|
||||||
|
if frame.header.MaskingKey != nil {
|
||||||
|
if len(frame.header.MaskingKey) != 4 {
|
||||||
|
return 0, ErrBadMaskingKey
|
||||||
|
}
|
||||||
|
header = append(header, frame.header.MaskingKey...)
|
||||||
|
frame.writer.Write(header)
|
||||||
|
data := make([]byte, length)
|
||||||
|
for i := range data {
|
||||||
|
data[i] = msg[i] ^ frame.header.MaskingKey[i%4]
|
||||||
|
}
|
||||||
|
frame.writer.Write(data)
|
||||||
|
err = frame.writer.Flush()
|
||||||
|
return length, err
|
||||||
|
}
|
||||||
|
frame.writer.Write(header)
|
||||||
|
frame.writer.Write(msg)
|
||||||
|
err = frame.writer.Flush()
|
||||||
|
return length, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (frame *hybiFrameWriter) Close() error { return nil }
|
||||||
|
|
||||||
|
type hybiFrameWriterFactory struct {
|
||||||
|
*bufio.Writer
|
||||||
|
needMaskingKey bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) {
|
||||||
|
frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType}
|
||||||
|
if buf.needMaskingKey {
|
||||||
|
frameHeader.MaskingKey, err = generateMaskingKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type hybiFrameHandler struct {
|
||||||
|
conn *Conn
|
||||||
|
payloadType byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, error) {
|
||||||
|
if handler.conn.IsServerConn() {
|
||||||
|
// The client MUST mask all frames sent to the server.
|
||||||
|
if frame.(*hybiFrameReader).header.MaskingKey == nil {
|
||||||
|
handler.WriteClose(closeStatusProtocolError)
|
||||||
|
return nil, io.EOF
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// The server MUST NOT mask all frames.
|
||||||
|
if frame.(*hybiFrameReader).header.MaskingKey != nil {
|
||||||
|
handler.WriteClose(closeStatusProtocolError)
|
||||||
|
return nil, io.EOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if header := frame.HeaderReader(); header != nil {
|
||||||
|
io.Copy(ioutil.Discard, header)
|
||||||
|
}
|
||||||
|
switch frame.PayloadType() {
|
||||||
|
case ContinuationFrame:
|
||||||
|
frame.(*hybiFrameReader).header.OpCode = handler.payloadType
|
||||||
|
case TextFrame, BinaryFrame:
|
||||||
|
handler.payloadType = frame.PayloadType()
|
||||||
|
case CloseFrame:
|
||||||
|
return nil, io.EOF
|
||||||
|
case PingFrame, PongFrame:
|
||||||
|
b := make([]byte, maxControlFramePayloadLength)
|
||||||
|
n, err := io.ReadFull(frame, b)
|
||||||
|
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
io.Copy(ioutil.Discard, frame)
|
||||||
|
if frame.PayloadType() == PingFrame {
|
||||||
|
if _, err := handler.WritePong(b[:n]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return frame, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (handler *hybiFrameHandler) WriteClose(status int) (err error) {
|
||||||
|
handler.conn.wio.Lock()
|
||||||
|
defer handler.conn.wio.Unlock()
|
||||||
|
w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
msg := make([]byte, 2)
|
||||||
|
binary.BigEndian.PutUint16(msg, uint16(status))
|
||||||
|
_, err = w.Write(msg)
|
||||||
|
w.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) {
|
||||||
|
handler.conn.wio.Lock()
|
||||||
|
defer handler.conn.wio.Unlock()
|
||||||
|
w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
n, err = w.Write(msg)
|
||||||
|
w.Close()
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// newHybiConn creates a new WebSocket connection speaking hybi draft protocol.
|
||||||
|
func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
|
||||||
|
if buf == nil {
|
||||||
|
br := bufio.NewReader(rwc)
|
||||||
|
bw := bufio.NewWriter(rwc)
|
||||||
|
buf = bufio.NewReadWriter(br, bw)
|
||||||
|
}
|
||||||
|
ws := &Conn{config: config, request: request, buf: buf, rwc: rwc,
|
||||||
|
frameReaderFactory: hybiFrameReaderFactory{buf.Reader},
|
||||||
|
frameWriterFactory: hybiFrameWriterFactory{
|
||||||
|
buf.Writer, request == nil},
|
||||||
|
PayloadType: TextFrame,
|
||||||
|
defaultCloseStatus: closeStatusNormal}
|
||||||
|
ws.frameHandler = &hybiFrameHandler{conn: ws}
|
||||||
|
return ws
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateMaskingKey generates a masking key for a frame.
|
||||||
|
func generateMaskingKey() (maskingKey []byte, err error) {
|
||||||
|
maskingKey = make([]byte, 4)
|
||||||
|
if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateNonce generates a nonce consisting of a randomly selected 16-byte
|
||||||
|
// value that has been base64-encoded.
|
||||||
|
func generateNonce() (nonce []byte) {
|
||||||
|
key := make([]byte, 16)
|
||||||
|
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
nonce = make([]byte, 24)
|
||||||
|
base64.StdEncoding.Encode(nonce, key)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeZone removes IPv6 zone identifer from host.
|
||||||
|
// E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080"
|
||||||
|
func removeZone(host string) string {
|
||||||
|
if !strings.HasPrefix(host, "[") {
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
i := strings.LastIndex(host, "]")
|
||||||
|
if i < 0 {
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
j := strings.LastIndex(host[:i], "%")
|
||||||
|
if j < 0 {
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
return host[:j] + host[i:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNonceAccept computes the base64-encoded SHA-1 of the concatenation of
|
||||||
|
// the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string.
|
||||||
|
func getNonceAccept(nonce []byte) (expected []byte, err error) {
|
||||||
|
h := sha1.New()
|
||||||
|
if _, err = h.Write(nonce); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err = h.Write([]byte(websocketGUID)); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
expected = make([]byte, 28)
|
||||||
|
base64.StdEncoding.Encode(expected, h.Sum(nil))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17
|
||||||
|
func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) {
|
||||||
|
bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n")
|
||||||
|
|
||||||
|
// According to RFC 6874, an HTTP client, proxy, or other
|
||||||
|
// intermediary must remove any IPv6 zone identifier attached
|
||||||
|
// to an outgoing URI.
|
||||||
|
bw.WriteString("Host: " + removeZone(config.Location.Host) + "\r\n")
|
||||||
|
bw.WriteString("Upgrade: websocket\r\n")
|
||||||
|
bw.WriteString("Connection: Upgrade\r\n")
|
||||||
|
nonce := generateNonce()
|
||||||
|
if config.handshakeData != nil {
|
||||||
|
nonce = []byte(config.handshakeData["key"])
|
||||||
|
}
|
||||||
|
bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n")
|
||||||
|
bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n")
|
||||||
|
|
||||||
|
if config.Version != ProtocolVersionHybi13 {
|
||||||
|
return ErrBadProtocolVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n")
|
||||||
|
if len(config.Protocol) > 0 {
|
||||||
|
bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n")
|
||||||
|
}
|
||||||
|
// TODO(ukai): send Sec-WebSocket-Extensions.
|
||||||
|
err = config.Header.WriteSubset(bw, handshakeHeader)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
bw.WriteString("\r\n")
|
||||||
|
if err = bw.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.ReadResponse(br, &http.Request{Method: "GET"})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if resp.StatusCode != 101 {
|
||||||
|
return ErrBadStatus
|
||||||
|
}
|
||||||
|
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" ||
|
||||||
|
strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
|
||||||
|
return ErrBadUpgrade
|
||||||
|
}
|
||||||
|
expectedAccept, err := getNonceAccept(nonce)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) {
|
||||||
|
return ErrChallengeResponse
|
||||||
|
}
|
||||||
|
if resp.Header.Get("Sec-WebSocket-Extensions") != "" {
|
||||||
|
return ErrUnsupportedExtensions
|
||||||
|
}
|
||||||
|
offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol")
|
||||||
|
if offeredProtocol != "" {
|
||||||
|
protocolMatched := false
|
||||||
|
for i := 0; i < len(config.Protocol); i++ {
|
||||||
|
if config.Protocol[i] == offeredProtocol {
|
||||||
|
protocolMatched = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !protocolMatched {
|
||||||
|
return ErrBadWebSocketProtocol
|
||||||
|
}
|
||||||
|
config.Protocol = []string{offeredProtocol}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// newHybiClientConn creates a client WebSocket connection after handshake.
|
||||||
|
func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn {
|
||||||
|
return newHybiConn(config, buf, rwc, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A HybiServerHandshaker performs a server handshake using hybi draft protocol.
|
||||||
|
type hybiServerHandshaker struct {
|
||||||
|
*Config
|
||||||
|
accept []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) {
|
||||||
|
c.Version = ProtocolVersionHybi13
|
||||||
|
if req.Method != "GET" {
|
||||||
|
return http.StatusMethodNotAllowed, ErrBadRequestMethod
|
||||||
|
}
|
||||||
|
// HTTP version can be safely ignored.
|
||||||
|
|
||||||
|
if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" ||
|
||||||
|
!strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
|
||||||
|
return http.StatusBadRequest, ErrNotWebSocket
|
||||||
|
}
|
||||||
|
|
||||||
|
key := req.Header.Get("Sec-Websocket-Key")
|
||||||
|
if key == "" {
|
||||||
|
return http.StatusBadRequest, ErrChallengeResponse
|
||||||
|
}
|
||||||
|
version := req.Header.Get("Sec-Websocket-Version")
|
||||||
|
switch version {
|
||||||
|
case "13":
|
||||||
|
c.Version = ProtocolVersionHybi13
|
||||||
|
default:
|
||||||
|
return http.StatusBadRequest, ErrBadWebSocketVersion
|
||||||
|
}
|
||||||
|
var scheme string
|
||||||
|
if req.TLS != nil {
|
||||||
|
scheme = "wss"
|
||||||
|
} else {
|
||||||
|
scheme = "ws"
|
||||||
|
}
|
||||||
|
c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI())
|
||||||
|
if err != nil {
|
||||||
|
return http.StatusBadRequest, err
|
||||||
|
}
|
||||||
|
protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
|
||||||
|
if protocol != "" {
|
||||||
|
protocols := strings.Split(protocol, ",")
|
||||||
|
for i := 0; i < len(protocols); i++ {
|
||||||
|
c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.accept, err = getNonceAccept([]byte(key))
|
||||||
|
if err != nil {
|
||||||
|
return http.StatusInternalServerError, err
|
||||||
|
}
|
||||||
|
return http.StatusSwitchingProtocols, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Origin parses the Origin header in req.
|
||||||
|
// If the Origin header is not set, it returns nil and nil.
|
||||||
|
func Origin(config *Config, req *http.Request) (*url.URL, error) {
|
||||||
|
var origin string
|
||||||
|
switch config.Version {
|
||||||
|
case ProtocolVersionHybi13:
|
||||||
|
origin = req.Header.Get("Origin")
|
||||||
|
}
|
||||||
|
if origin == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return url.ParseRequestURI(origin)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
|
||||||
|
if len(c.Protocol) > 0 {
|
||||||
|
if len(c.Protocol) != 1 {
|
||||||
|
// You need choose a Protocol in Handshake func in Server.
|
||||||
|
return ErrBadWebSocketProtocol
|
||||||
|
}
|
||||||
|
}
|
||||||
|
buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n")
|
||||||
|
buf.WriteString("Upgrade: websocket\r\n")
|
||||||
|
buf.WriteString("Connection: Upgrade\r\n")
|
||||||
|
buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n")
|
||||||
|
if len(c.Protocol) > 0 {
|
||||||
|
buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n")
|
||||||
|
}
|
||||||
|
// TODO(ukai): send Sec-WebSocket-Extensions.
|
||||||
|
if c.Header != nil {
|
||||||
|
err := c.Header.WriteSubset(buf, handshakeHeader)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
buf.WriteString("\r\n")
|
||||||
|
return buf.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
|
||||||
|
return newHybiServerConn(c.Config, buf, rwc, request)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol.
|
||||||
|
func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
|
||||||
|
return newHybiConn(config, buf, rwc, request)
|
||||||
|
}
|
|
@ -0,0 +1,113 @@
|
||||||
|
// Copyright 2009 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) {
|
||||||
|
var hs serverHandshaker = &hybiServerHandshaker{Config: config}
|
||||||
|
code, err := hs.ReadHandshake(buf.Reader, req)
|
||||||
|
if err == ErrBadWebSocketVersion {
|
||||||
|
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
|
||||||
|
fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion)
|
||||||
|
buf.WriteString("\r\n")
|
||||||
|
buf.WriteString(err.Error())
|
||||||
|
buf.Flush()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
|
||||||
|
buf.WriteString("\r\n")
|
||||||
|
buf.WriteString(err.Error())
|
||||||
|
buf.Flush()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if handshake != nil {
|
||||||
|
err = handshake(config, req)
|
||||||
|
if err != nil {
|
||||||
|
code = http.StatusForbidden
|
||||||
|
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
|
||||||
|
buf.WriteString("\r\n")
|
||||||
|
buf.Flush()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = hs.AcceptHandshake(buf.Writer)
|
||||||
|
if err != nil {
|
||||||
|
code = http.StatusBadRequest
|
||||||
|
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
|
||||||
|
buf.WriteString("\r\n")
|
||||||
|
buf.Flush()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn = hs.NewServerConn(buf, rwc, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server represents a server of a WebSocket.
|
||||||
|
type Server struct {
|
||||||
|
// Config is a WebSocket configuration for new WebSocket connection.
|
||||||
|
Config
|
||||||
|
|
||||||
|
// Handshake is an optional function in WebSocket handshake.
|
||||||
|
// For example, you can check, or don't check Origin header.
|
||||||
|
// Another example, you can select config.Protocol.
|
||||||
|
Handshake func(*Config, *http.Request) error
|
||||||
|
|
||||||
|
// Handler handles a WebSocket connection.
|
||||||
|
Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP implements the http.Handler interface for a WebSocket
|
||||||
|
func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
|
s.serveWebSocket(w, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) {
|
||||||
|
rwc, buf, err := w.(http.Hijacker).Hijack()
|
||||||
|
if err != nil {
|
||||||
|
panic("Hijack failed: " + err.Error())
|
||||||
|
}
|
||||||
|
// The server should abort the WebSocket connection if it finds
|
||||||
|
// the client did not send a handshake that matches with protocol
|
||||||
|
// specification.
|
||||||
|
defer rwc.Close()
|
||||||
|
conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
panic("unexpected nil conn")
|
||||||
|
}
|
||||||
|
s.Handler(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler is a simple interface to a WebSocket browser client.
|
||||||
|
// It checks if Origin header is valid URL by default.
|
||||||
|
// You might want to verify websocket.Conn.Config().Origin in the func.
|
||||||
|
// If you use Server instead of Handler, you could call websocket.Origin and
|
||||||
|
// check the origin in your Handshake func. So, if you want to accept
|
||||||
|
// non-browser clients, which do not send an Origin header, set a
|
||||||
|
// Server.Handshake that does not check the origin.
|
||||||
|
type Handler func(*Conn)
|
||||||
|
|
||||||
|
func checkOrigin(config *Config, req *http.Request) (err error) {
|
||||||
|
config.Origin, err = Origin(config, req)
|
||||||
|
if err == nil && config.Origin == nil {
|
||||||
|
return fmt.Errorf("null origin")
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP implements the http.Handler interface for a WebSocket
|
||||||
|
func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
|
s := Server{Handler: h, Handshake: checkOrigin}
|
||||||
|
s.serveWebSocket(w, req)
|
||||||
|
}
|
|
@ -0,0 +1,448 @@
|
||||||
|
// Copyright 2009 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Package websocket implements a client and server for the WebSocket protocol
|
||||||
|
// as specified in RFC 6455.
|
||||||
|
//
|
||||||
|
// This package currently lacks some features found in an alternative
|
||||||
|
// and more actively maintained WebSocket package:
|
||||||
|
//
|
||||||
|
// https://godoc.org/github.com/gorilla/websocket
|
||||||
|
//
|
||||||
|
package websocket // import "golang.org/x/net/websocket"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ProtocolVersionHybi13 = 13
|
||||||
|
ProtocolVersionHybi = ProtocolVersionHybi13
|
||||||
|
SupportedProtocolVersion = "13"
|
||||||
|
|
||||||
|
ContinuationFrame = 0
|
||||||
|
TextFrame = 1
|
||||||
|
BinaryFrame = 2
|
||||||
|
CloseFrame = 8
|
||||||
|
PingFrame = 9
|
||||||
|
PongFrame = 10
|
||||||
|
UnknownFrame = 255
|
||||||
|
|
||||||
|
DefaultMaxPayloadBytes = 32 << 20 // 32MB
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProtocolError represents WebSocket protocol errors.
|
||||||
|
type ProtocolError struct {
|
||||||
|
ErrorString string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err *ProtocolError) Error() string { return err.ErrorString }
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrBadProtocolVersion = &ProtocolError{"bad protocol version"}
|
||||||
|
ErrBadScheme = &ProtocolError{"bad scheme"}
|
||||||
|
ErrBadStatus = &ProtocolError{"bad status"}
|
||||||
|
ErrBadUpgrade = &ProtocolError{"missing or bad upgrade"}
|
||||||
|
ErrBadWebSocketOrigin = &ProtocolError{"missing or bad WebSocket-Origin"}
|
||||||
|
ErrBadWebSocketLocation = &ProtocolError{"missing or bad WebSocket-Location"}
|
||||||
|
ErrBadWebSocketProtocol = &ProtocolError{"missing or bad WebSocket-Protocol"}
|
||||||
|
ErrBadWebSocketVersion = &ProtocolError{"missing or bad WebSocket Version"}
|
||||||
|
ErrChallengeResponse = &ProtocolError{"mismatch challenge/response"}
|
||||||
|
ErrBadFrame = &ProtocolError{"bad frame"}
|
||||||
|
ErrBadFrameBoundary = &ProtocolError{"not on frame boundary"}
|
||||||
|
ErrNotWebSocket = &ProtocolError{"not websocket protocol"}
|
||||||
|
ErrBadRequestMethod = &ProtocolError{"bad method"}
|
||||||
|
ErrNotSupported = &ProtocolError{"not supported"}
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrFrameTooLarge is returned by Codec's Receive method if payload size
|
||||||
|
// exceeds limit set by Conn.MaxPayloadBytes
|
||||||
|
var ErrFrameTooLarge = errors.New("websocket: frame payload size exceeds limit")
|
||||||
|
|
||||||
|
// Addr is an implementation of net.Addr for WebSocket.
|
||||||
|
type Addr struct {
|
||||||
|
*url.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
// Network returns the network type for a WebSocket, "websocket".
|
||||||
|
func (addr *Addr) Network() string { return "websocket" }
|
||||||
|
|
||||||
|
// Config is a WebSocket configuration
|
||||||
|
type Config struct {
|
||||||
|
// A WebSocket server address.
|
||||||
|
Location *url.URL
|
||||||
|
|
||||||
|
// A Websocket client origin.
|
||||||
|
Origin *url.URL
|
||||||
|
|
||||||
|
// WebSocket subprotocols.
|
||||||
|
Protocol []string
|
||||||
|
|
||||||
|
// WebSocket protocol version.
|
||||||
|
Version int
|
||||||
|
|
||||||
|
// TLS config for secure WebSocket (wss).
|
||||||
|
TlsConfig *tls.Config
|
||||||
|
|
||||||
|
// Additional header fields to be sent in WebSocket opening handshake.
|
||||||
|
Header http.Header
|
||||||
|
|
||||||
|
// Dialer used when opening websocket connections.
|
||||||
|
Dialer *net.Dialer
|
||||||
|
|
||||||
|
handshakeData map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// serverHandshaker is an interface to handle WebSocket server side handshake.
|
||||||
|
type serverHandshaker interface {
|
||||||
|
// ReadHandshake reads handshake request message from client.
|
||||||
|
// Returns http response code and error if any.
|
||||||
|
ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error)
|
||||||
|
|
||||||
|
// AcceptHandshake accepts the client handshake request and sends
|
||||||
|
// handshake response back to client.
|
||||||
|
AcceptHandshake(buf *bufio.Writer) (err error)
|
||||||
|
|
||||||
|
// NewServerConn creates a new WebSocket connection.
|
||||||
|
NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) (conn *Conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// frameReader is an interface to read a WebSocket frame.
|
||||||
|
type frameReader interface {
|
||||||
|
// Reader is to read payload of the frame.
|
||||||
|
io.Reader
|
||||||
|
|
||||||
|
// PayloadType returns payload type.
|
||||||
|
PayloadType() byte
|
||||||
|
|
||||||
|
// HeaderReader returns a reader to read header of the frame.
|
||||||
|
HeaderReader() io.Reader
|
||||||
|
|
||||||
|
// TrailerReader returns a reader to read trailer of the frame.
|
||||||
|
// If it returns nil, there is no trailer in the frame.
|
||||||
|
TrailerReader() io.Reader
|
||||||
|
|
||||||
|
// Len returns total length of the frame, including header and trailer.
|
||||||
|
Len() int
|
||||||
|
}
|
||||||
|
|
||||||
|
// frameReaderFactory is an interface to creates new frame reader.
|
||||||
|
type frameReaderFactory interface {
|
||||||
|
NewFrameReader() (r frameReader, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// frameWriter is an interface to write a WebSocket frame.
|
||||||
|
type frameWriter interface {
|
||||||
|
// Writer is to write payload of the frame.
|
||||||
|
io.WriteCloser
|
||||||
|
}
|
||||||
|
|
||||||
|
// frameWriterFactory is an interface to create new frame writer.
|
||||||
|
type frameWriterFactory interface {
|
||||||
|
NewFrameWriter(payloadType byte) (w frameWriter, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type frameHandler interface {
|
||||||
|
HandleFrame(frame frameReader) (r frameReader, err error)
|
||||||
|
WriteClose(status int) (err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Conn represents a WebSocket connection.
|
||||||
|
//
|
||||||
|
// Multiple goroutines may invoke methods on a Conn simultaneously.
|
||||||
|
type Conn struct {
|
||||||
|
config *Config
|
||||||
|
request *http.Request
|
||||||
|
|
||||||
|
buf *bufio.ReadWriter
|
||||||
|
rwc io.ReadWriteCloser
|
||||||
|
|
||||||
|
rio sync.Mutex
|
||||||
|
frameReaderFactory
|
||||||
|
frameReader
|
||||||
|
|
||||||
|
wio sync.Mutex
|
||||||
|
frameWriterFactory
|
||||||
|
|
||||||
|
frameHandler
|
||||||
|
PayloadType byte
|
||||||
|
defaultCloseStatus int
|
||||||
|
|
||||||
|
// MaxPayloadBytes limits the size of frame payload received over Conn
|
||||||
|
// by Codec's Receive method. If zero, DefaultMaxPayloadBytes is used.
|
||||||
|
MaxPayloadBytes int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read implements the io.Reader interface:
|
||||||
|
// it reads data of a frame from the WebSocket connection.
|
||||||
|
// if msg is not large enough for the frame data, it fills the msg and next Read
|
||||||
|
// will read the rest of the frame data.
|
||||||
|
// it reads Text frame or Binary frame.
|
||||||
|
func (ws *Conn) Read(msg []byte) (n int, err error) {
|
||||||
|
ws.rio.Lock()
|
||||||
|
defer ws.rio.Unlock()
|
||||||
|
again:
|
||||||
|
if ws.frameReader == nil {
|
||||||
|
frame, err := ws.frameReaderFactory.NewFrameReader()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
ws.frameReader, err = ws.frameHandler.HandleFrame(frame)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if ws.frameReader == nil {
|
||||||
|
goto again
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n, err = ws.frameReader.Read(msg)
|
||||||
|
if err == io.EOF {
|
||||||
|
if trailer := ws.frameReader.TrailerReader(); trailer != nil {
|
||||||
|
io.Copy(ioutil.Discard, trailer)
|
||||||
|
}
|
||||||
|
ws.frameReader = nil
|
||||||
|
goto again
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write implements the io.Writer interface:
|
||||||
|
// it writes data as a frame to the WebSocket connection.
|
||||||
|
func (ws *Conn) Write(msg []byte) (n int, err error) {
|
||||||
|
ws.wio.Lock()
|
||||||
|
defer ws.wio.Unlock()
|
||||||
|
w, err := ws.frameWriterFactory.NewFrameWriter(ws.PayloadType)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
n, err = w.Write(msg)
|
||||||
|
w.Close()
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements the io.Closer interface.
|
||||||
|
func (ws *Conn) Close() error {
|
||||||
|
err := ws.frameHandler.WriteClose(ws.defaultCloseStatus)
|
||||||
|
err1 := ws.rwc.Close()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return err1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *Conn) IsClientConn() bool { return ws.request == nil }
|
||||||
|
func (ws *Conn) IsServerConn() bool { return ws.request != nil }
|
||||||
|
|
||||||
|
// LocalAddr returns the WebSocket Origin for the connection for client, or
|
||||||
|
// the WebSocket location for server.
|
||||||
|
func (ws *Conn) LocalAddr() net.Addr {
|
||||||
|
if ws.IsClientConn() {
|
||||||
|
return &Addr{ws.config.Origin}
|
||||||
|
}
|
||||||
|
return &Addr{ws.config.Location}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoteAddr returns the WebSocket location for the connection for client, or
|
||||||
|
// the Websocket Origin for server.
|
||||||
|
func (ws *Conn) RemoteAddr() net.Addr {
|
||||||
|
if ws.IsClientConn() {
|
||||||
|
return &Addr{ws.config.Location}
|
||||||
|
}
|
||||||
|
return &Addr{ws.config.Origin}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errSetDeadline = errors.New("websocket: cannot set deadline: not using a net.Conn")
|
||||||
|
|
||||||
|
// SetDeadline sets the connection's network read & write deadlines.
|
||||||
|
func (ws *Conn) SetDeadline(t time.Time) error {
|
||||||
|
if conn, ok := ws.rwc.(net.Conn); ok {
|
||||||
|
return conn.SetDeadline(t)
|
||||||
|
}
|
||||||
|
return errSetDeadline
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadDeadline sets the connection's network read deadline.
|
||||||
|
func (ws *Conn) SetReadDeadline(t time.Time) error {
|
||||||
|
if conn, ok := ws.rwc.(net.Conn); ok {
|
||||||
|
return conn.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
return errSetDeadline
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWriteDeadline sets the connection's network write deadline.
|
||||||
|
func (ws *Conn) SetWriteDeadline(t time.Time) error {
|
||||||
|
if conn, ok := ws.rwc.(net.Conn); ok {
|
||||||
|
return conn.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
return errSetDeadline
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config returns the WebSocket config.
|
||||||
|
func (ws *Conn) Config() *Config { return ws.config }
|
||||||
|
|
||||||
|
// Request returns the http request upgraded to the WebSocket.
|
||||||
|
// It is nil for client side.
|
||||||
|
func (ws *Conn) Request() *http.Request { return ws.request }
|
||||||
|
|
||||||
|
// Codec represents a symmetric pair of functions that implement a codec.
|
||||||
|
type Codec struct {
|
||||||
|
Marshal func(v interface{}) (data []byte, payloadType byte, err error)
|
||||||
|
Unmarshal func(data []byte, payloadType byte, v interface{}) (err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send sends v marshaled by cd.Marshal as single frame to ws.
|
||||||
|
func (cd Codec) Send(ws *Conn, v interface{}) (err error) {
|
||||||
|
data, payloadType, err := cd.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ws.wio.Lock()
|
||||||
|
defer ws.wio.Unlock()
|
||||||
|
w, err := ws.frameWriterFactory.NewFrameWriter(payloadType)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = w.Write(data)
|
||||||
|
w.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores
|
||||||
|
// in v. The whole frame payload is read to an in-memory buffer; max size of
|
||||||
|
// payload is defined by ws.MaxPayloadBytes. If frame payload size exceeds
|
||||||
|
// limit, ErrFrameTooLarge is returned; in this case frame is not read off wire
|
||||||
|
// completely. The next call to Receive would read and discard leftover data of
|
||||||
|
// previous oversized frame before processing next frame.
|
||||||
|
func (cd Codec) Receive(ws *Conn, v interface{}) (err error) {
|
||||||
|
ws.rio.Lock()
|
||||||
|
defer ws.rio.Unlock()
|
||||||
|
if ws.frameReader != nil {
|
||||||
|
_, err = io.Copy(ioutil.Discard, ws.frameReader)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ws.frameReader = nil
|
||||||
|
}
|
||||||
|
again:
|
||||||
|
frame, err := ws.frameReaderFactory.NewFrameReader()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
frame, err = ws.frameHandler.HandleFrame(frame)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if frame == nil {
|
||||||
|
goto again
|
||||||
|
}
|
||||||
|
maxPayloadBytes := ws.MaxPayloadBytes
|
||||||
|
if maxPayloadBytes == 0 {
|
||||||
|
maxPayloadBytes = DefaultMaxPayloadBytes
|
||||||
|
}
|
||||||
|
if hf, ok := frame.(*hybiFrameReader); ok && hf.header.Length > int64(maxPayloadBytes) {
|
||||||
|
// payload size exceeds limit, no need to call Unmarshal
|
||||||
|
//
|
||||||
|
// set frameReader to current oversized frame so that
|
||||||
|
// the next call to this function can drain leftover
|
||||||
|
// data before processing the next frame
|
||||||
|
ws.frameReader = frame
|
||||||
|
return ErrFrameTooLarge
|
||||||
|
}
|
||||||
|
payloadType := frame.PayloadType()
|
||||||
|
data, err := ioutil.ReadAll(frame)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return cd.Unmarshal(data, payloadType, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func marshal(v interface{}) (msg []byte, payloadType byte, err error) {
|
||||||
|
switch data := v.(type) {
|
||||||
|
case string:
|
||||||
|
return []byte(data), TextFrame, nil
|
||||||
|
case []byte:
|
||||||
|
return data, BinaryFrame, nil
|
||||||
|
}
|
||||||
|
return nil, UnknownFrame, ErrNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func unmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
|
||||||
|
switch data := v.(type) {
|
||||||
|
case *string:
|
||||||
|
*data = string(msg)
|
||||||
|
return nil
|
||||||
|
case *[]byte:
|
||||||
|
*data = msg
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return ErrNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Message is a codec to send/receive text/binary data in a frame on WebSocket connection.
|
||||||
|
To send/receive text frame, use string type.
|
||||||
|
To send/receive binary frame, use []byte type.
|
||||||
|
|
||||||
|
Trivial usage:
|
||||||
|
|
||||||
|
import "websocket"
|
||||||
|
|
||||||
|
// receive text frame
|
||||||
|
var message string
|
||||||
|
websocket.Message.Receive(ws, &message)
|
||||||
|
|
||||||
|
// send text frame
|
||||||
|
message = "hello"
|
||||||
|
websocket.Message.Send(ws, message)
|
||||||
|
|
||||||
|
// receive binary frame
|
||||||
|
var data []byte
|
||||||
|
websocket.Message.Receive(ws, &data)
|
||||||
|
|
||||||
|
// send binary frame
|
||||||
|
data = []byte{0, 1, 2}
|
||||||
|
websocket.Message.Send(ws, data)
|
||||||
|
|
||||||
|
*/
|
||||||
|
var Message = Codec{marshal, unmarshal}
|
||||||
|
|
||||||
|
func jsonMarshal(v interface{}) (msg []byte, payloadType byte, err error) {
|
||||||
|
msg, err = json.Marshal(v)
|
||||||
|
return msg, TextFrame, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func jsonUnmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
|
||||||
|
return json.Unmarshal(msg, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
JSON is a codec to send/receive JSON data in a frame from a WebSocket connection.
|
||||||
|
|
||||||
|
Trivial usage:
|
||||||
|
|
||||||
|
import "websocket"
|
||||||
|
|
||||||
|
type T struct {
|
||||||
|
Msg string
|
||||||
|
Count int
|
||||||
|
}
|
||||||
|
|
||||||
|
// receive JSON type T
|
||||||
|
var data T
|
||||||
|
websocket.JSON.Receive(ws, &data)
|
||||||
|
|
||||||
|
// send JSON type T
|
||||||
|
websocket.JSON.Send(ws, data)
|
||||||
|
*/
|
||||||
|
var JSON = Codec{jsonMarshal, jsonUnmarshal}
|
|
@ -0,0 +1,131 @@
|
||||||
|
// Copyright 2017 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Package semaphore provides a weighted semaphore implementation.
|
||||||
|
package semaphore // import "golang.org/x/sync/semaphore"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/list"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
// Use the old context because packages that depend on this one
|
||||||
|
// (e.g. cloud.google.com/go/...) must run on Go 1.6.
|
||||||
|
// TODO(jba): update to "context" when possible.
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type waiter struct {
|
||||||
|
n int64
|
||||||
|
ready chan<- struct{} // Closed when semaphore acquired.
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWeighted creates a new weighted semaphore with the given
|
||||||
|
// maximum combined weight for concurrent access.
|
||||||
|
func NewWeighted(n int64) *Weighted {
|
||||||
|
w := &Weighted{size: n}
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
// Weighted provides a way to bound concurrent access to a resource.
|
||||||
|
// The callers can request access with a given weight.
|
||||||
|
type Weighted struct {
|
||||||
|
size int64
|
||||||
|
cur int64
|
||||||
|
mu sync.Mutex
|
||||||
|
waiters list.List
|
||||||
|
}
|
||||||
|
|
||||||
|
// Acquire acquires the semaphore with a weight of n, blocking only until ctx
|
||||||
|
// is done. On success, returns nil. On failure, returns ctx.Err() and leaves
|
||||||
|
// the semaphore unchanged.
|
||||||
|
//
|
||||||
|
// If ctx is already done, Acquire may still succeed without blocking.
|
||||||
|
func (s *Weighted) Acquire(ctx context.Context, n int64) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
if s.size-s.cur >= n && s.waiters.Len() == 0 {
|
||||||
|
s.cur += n
|
||||||
|
s.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if n > s.size {
|
||||||
|
// Don't make other Acquire calls block on one that's doomed to fail.
|
||||||
|
s.mu.Unlock()
|
||||||
|
<-ctx.Done()
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
ready := make(chan struct{})
|
||||||
|
w := waiter{n: n, ready: ready}
|
||||||
|
elem := s.waiters.PushBack(w)
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
err := ctx.Err()
|
||||||
|
s.mu.Lock()
|
||||||
|
select {
|
||||||
|
case <-ready:
|
||||||
|
// Acquired the semaphore after we were canceled. Rather than trying to
|
||||||
|
// fix up the queue, just pretend we didn't notice the cancelation.
|
||||||
|
err = nil
|
||||||
|
default:
|
||||||
|
s.waiters.Remove(elem)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
return err
|
||||||
|
|
||||||
|
case <-ready:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TryAcquire acquires the semaphore with a weight of n without blocking.
|
||||||
|
// On success, returns true. On failure, returns false and leaves the semaphore unchanged.
|
||||||
|
func (s *Weighted) TryAcquire(n int64) bool {
|
||||||
|
s.mu.Lock()
|
||||||
|
success := s.size-s.cur >= n && s.waiters.Len() == 0
|
||||||
|
if success {
|
||||||
|
s.cur += n
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
return success
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release releases the semaphore with a weight of n.
|
||||||
|
func (s *Weighted) Release(n int64) {
|
||||||
|
s.mu.Lock()
|
||||||
|
s.cur -= n
|
||||||
|
if s.cur < 0 {
|
||||||
|
s.mu.Unlock()
|
||||||
|
panic("semaphore: bad release")
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
next := s.waiters.Front()
|
||||||
|
if next == nil {
|
||||||
|
break // No more waiters blocked.
|
||||||
|
}
|
||||||
|
|
||||||
|
w := next.Value.(waiter)
|
||||||
|
if s.size-s.cur < w.n {
|
||||||
|
// Not enough tokens for the next waiter. We could keep going (to try to
|
||||||
|
// find a waiter with a smaller request), but under load that could cause
|
||||||
|
// starvation for large requests; instead, we leave all remaining waiters
|
||||||
|
// blocked.
|
||||||
|
//
|
||||||
|
// Consider a semaphore used as a read-write lock, with N tokens, N
|
||||||
|
// readers, and one writer. Each reader can Acquire(1) to obtain a read
|
||||||
|
// lock. The writer can Acquire(N) to obtain a write lock, excluding all
|
||||||
|
// of the readers. If we allow the readers to jump ahead in the queue,
|
||||||
|
// the writer will starve — there is always one token available for every
|
||||||
|
// reader.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
s.cur += w.n
|
||||||
|
s.waiters.Remove(next)
|
||||||
|
close(w.ready)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
|
@ -0,0 +1,258 @@
|
||||||
|
// Copyright 2016 Google Inc. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Package bundler supports bundling (batching) of items. Bundling amortizes an
|
||||||
|
// action with fixed costs over multiple items. For example, if an API provides
|
||||||
|
// an RPC that accepts a list of items as input, but clients would prefer
|
||||||
|
// adding items one at a time, then a Bundler can accept individual items from
|
||||||
|
// the client and bundle many of them into a single RPC.
|
||||||
|
//
|
||||||
|
// This package is experimental and subject to change without notice.
|
||||||
|
package bundler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
"golang.org/x/sync/semaphore"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultDelayThreshold = time.Second
|
||||||
|
DefaultBundleCountThreshold = 10
|
||||||
|
DefaultBundleByteThreshold = 1e6 // 1M
|
||||||
|
DefaultBufferedByteLimit = 1e9 // 1G
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrOverflow indicates that Bundler's stored bytes exceeds its BufferedByteLimit.
|
||||||
|
ErrOverflow = errors.New("bundler reached buffered byte limit")
|
||||||
|
|
||||||
|
// ErrOversizedItem indicates that an item's size exceeds the maximum bundle size.
|
||||||
|
ErrOversizedItem = errors.New("item size exceeds bundle byte limit")
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Bundler collects items added to it into a bundle until the bundle
|
||||||
|
// exceeds a given size, then calls a user-provided function to handle the bundle.
|
||||||
|
type Bundler struct {
|
||||||
|
// Starting from the time that the first message is added to a bundle, once
|
||||||
|
// this delay has passed, handle the bundle. The default is DefaultDelayThreshold.
|
||||||
|
DelayThreshold time.Duration
|
||||||
|
|
||||||
|
// Once a bundle has this many items, handle the bundle. Since only one
|
||||||
|
// item at a time is added to a bundle, no bundle will exceed this
|
||||||
|
// threshold, so it also serves as a limit. The default is
|
||||||
|
// DefaultBundleCountThreshold.
|
||||||
|
BundleCountThreshold int
|
||||||
|
|
||||||
|
// Once the number of bytes in current bundle reaches this threshold, handle
|
||||||
|
// the bundle. The default is DefaultBundleByteThreshold. This triggers handling,
|
||||||
|
// but does not cap the total size of a bundle.
|
||||||
|
BundleByteThreshold int
|
||||||
|
|
||||||
|
// The maximum size of a bundle, in bytes. Zero means unlimited.
|
||||||
|
BundleByteLimit int
|
||||||
|
|
||||||
|
// The maximum number of bytes that the Bundler will keep in memory before
|
||||||
|
// returning ErrOverflow. The default is DefaultBufferedByteLimit.
|
||||||
|
BufferedByteLimit int
|
||||||
|
|
||||||
|
handler func(interface{}) // called to handle a bundle
|
||||||
|
itemSliceZero reflect.Value // nil (zero value) for slice of items
|
||||||
|
flushTimer *time.Timer // implements DelayThreshold
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
sem *semaphore.Weighted // enforces BufferedByteLimit
|
||||||
|
semOnce sync.Once
|
||||||
|
curBundle bundle // incoming items added to this bundle
|
||||||
|
handlingc <-chan struct{} // set to non-nil while a handler is running; closed when it returns
|
||||||
|
}
|
||||||
|
|
||||||
|
type bundle struct {
|
||||||
|
items reflect.Value // slice of item type
|
||||||
|
size int // size in bytes of all items
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBundler creates a new Bundler.
|
||||||
|
//
|
||||||
|
// itemExample is a value of the type that will be bundled. For example, if you
|
||||||
|
// want to create bundles of *Entry, you could pass &Entry{} for itemExample.
|
||||||
|
//
|
||||||
|
// handler is a function that will be called on each bundle. If itemExample is
|
||||||
|
// of type T, the argument to handler is of type []T. handler is always called
|
||||||
|
// sequentially for each bundle, and never in parallel.
|
||||||
|
//
|
||||||
|
// Configure the Bundler by setting its thresholds and limits before calling
|
||||||
|
// any of its methods.
|
||||||
|
func NewBundler(itemExample interface{}, handler func(interface{})) *Bundler {
|
||||||
|
b := &Bundler{
|
||||||
|
DelayThreshold: DefaultDelayThreshold,
|
||||||
|
BundleCountThreshold: DefaultBundleCountThreshold,
|
||||||
|
BundleByteThreshold: DefaultBundleByteThreshold,
|
||||||
|
BufferedByteLimit: DefaultBufferedByteLimit,
|
||||||
|
|
||||||
|
handler: handler,
|
||||||
|
itemSliceZero: reflect.Zero(reflect.SliceOf(reflect.TypeOf(itemExample))),
|
||||||
|
}
|
||||||
|
b.curBundle.items = b.itemSliceZero
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bundler) sema() *semaphore.Weighted {
|
||||||
|
// Create the semaphore lazily, because the user may set BufferedByteLimit
|
||||||
|
// after NewBundler.
|
||||||
|
b.semOnce.Do(func() {
|
||||||
|
b.sem = semaphore.NewWeighted(int64(b.BufferedByteLimit))
|
||||||
|
})
|
||||||
|
return b.sem
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds item to the current bundle. It marks the bundle for handling and
|
||||||
|
// starts a new one if any of the thresholds or limits are exceeded.
|
||||||
|
//
|
||||||
|
// If the item's size exceeds the maximum bundle size (Bundler.BundleByteLimit), then
|
||||||
|
// the item can never be handled. Add returns ErrOversizedItem in this case.
|
||||||
|
//
|
||||||
|
// If adding the item would exceed the maximum memory allowed
|
||||||
|
// (Bundler.BufferedByteLimit) or an AddWait call is blocked waiting for
|
||||||
|
// memory, Add returns ErrOverflow.
|
||||||
|
//
|
||||||
|
// Add never blocks.
|
||||||
|
func (b *Bundler) Add(item interface{}, size int) error {
|
||||||
|
// If this item exceeds the maximum size of a bundle,
|
||||||
|
// we can never send it.
|
||||||
|
if b.BundleByteLimit > 0 && size > b.BundleByteLimit {
|
||||||
|
return ErrOversizedItem
|
||||||
|
}
|
||||||
|
// If adding this item would exceed our allotted memory
|
||||||
|
// footprint, we can't accept it.
|
||||||
|
// (TryAcquire also returns false if anything is waiting on the semaphore,
|
||||||
|
// so calls to Add and AddWait shouldn't be mixed.)
|
||||||
|
if !b.sema().TryAcquire(int64(size)) {
|
||||||
|
return ErrOverflow
|
||||||
|
}
|
||||||
|
b.add(item, size)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// add adds item to the current bundle. It marks the bundle for handling and
|
||||||
|
// starts a new one if any of the thresholds or limits are exceeded.
|
||||||
|
func (b *Bundler) add(item interface{}, size int) {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
// If adding this item to the current bundle would cause it to exceed the
|
||||||
|
// maximum bundle size, close the current bundle and start a new one.
|
||||||
|
if b.BundleByteLimit > 0 && b.curBundle.size+size > b.BundleByteLimit {
|
||||||
|
b.startFlushLocked()
|
||||||
|
}
|
||||||
|
// Add the item.
|
||||||
|
b.curBundle.items = reflect.Append(b.curBundle.items, reflect.ValueOf(item))
|
||||||
|
b.curBundle.size += size
|
||||||
|
|
||||||
|
// Start a timer to flush the item if one isn't already running.
|
||||||
|
// startFlushLocked clears the timer and closes the bundle at the same time,
|
||||||
|
// so we only allocate a new timer for the first item in each bundle.
|
||||||
|
// (We could try to call Reset on the timer instead, but that would add a lot
|
||||||
|
// of complexity to the code just to save one small allocation.)
|
||||||
|
if b.flushTimer == nil {
|
||||||
|
b.flushTimer = time.AfterFunc(b.DelayThreshold, b.Flush)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the current bundle equals the count threshold, close it.
|
||||||
|
if b.curBundle.items.Len() == b.BundleCountThreshold {
|
||||||
|
b.startFlushLocked()
|
||||||
|
}
|
||||||
|
// If the current bundle equals or exceeds the byte threshold, close it.
|
||||||
|
if b.curBundle.size >= b.BundleByteThreshold {
|
||||||
|
b.startFlushLocked()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddWait adds item to the current bundle. It marks the bundle for handling and
|
||||||
|
// starts a new one if any of the thresholds or limits are exceeded.
|
||||||
|
//
|
||||||
|
// If the item's size exceeds the maximum bundle size (Bundler.BundleByteLimit), then
|
||||||
|
// the item can never be handled. AddWait returns ErrOversizedItem in this case.
|
||||||
|
//
|
||||||
|
// If adding the item would exceed the maximum memory allowed (Bundler.BufferedByteLimit),
|
||||||
|
// AddWait blocks until space is available or ctx is done.
|
||||||
|
//
|
||||||
|
// Calls to Add and AddWait should not be mixed on the same Bundler.
|
||||||
|
func (b *Bundler) AddWait(ctx context.Context, item interface{}, size int) error {
|
||||||
|
// If this item exceeds the maximum size of a bundle,
|
||||||
|
// we can never send it.
|
||||||
|
if b.BundleByteLimit > 0 && size > b.BundleByteLimit {
|
||||||
|
return ErrOversizedItem
|
||||||
|
}
|
||||||
|
// If adding this item would exceed our allotted memory footprint, block
|
||||||
|
// until space is available. The semaphore is FIFO, so there will be no
|
||||||
|
// starvation.
|
||||||
|
if err := b.sema().Acquire(ctx, int64(size)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Here, we've reserved space for item. Other goroutines can call AddWait
|
||||||
|
// and even acquire space, but no one can take away our reservation
|
||||||
|
// (assuming sem.Release is used correctly). So there is no race condition
|
||||||
|
// resulting from locking the mutex after sem.Acquire returns.
|
||||||
|
b.add(item, size)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush invokes the handler for all remaining items in the Bundler and waits
|
||||||
|
// for it to return.
|
||||||
|
func (b *Bundler) Flush() {
|
||||||
|
b.mu.Lock()
|
||||||
|
b.startFlushLocked()
|
||||||
|
done := b.handlingc
|
||||||
|
b.mu.Unlock()
|
||||||
|
|
||||||
|
if done != nil {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bundler) startFlushLocked() {
|
||||||
|
if b.flushTimer != nil {
|
||||||
|
b.flushTimer.Stop()
|
||||||
|
b.flushTimer = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.curBundle.items.Len() == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bun := b.curBundle
|
||||||
|
b.curBundle = bundle{items: b.itemSliceZero}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
var running <-chan struct{}
|
||||||
|
running, b.handlingc = b.handlingc, done
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
b.sem.Release(int64(bun.size))
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if running != nil {
|
||||||
|
// Wait for our turn to call the handler.
|
||||||
|
<-running
|
||||||
|
}
|
||||||
|
|
||||||
|
b.handler(bun.items.Interface())
|
||||||
|
}()
|
||||||
|
}
|
Loading…
Reference in New Issue