add logworker
This commit is contained in:
parent
b7e4d5c99f
commit
2d46bec18b
17
box.rb
17
box.rb
|
@ -9,25 +9,32 @@ def foldercopy dir
|
|||
end
|
||||
|
||||
def gobuild pkg
|
||||
run "mkdir -p /root/go/bin && cd /root/go/bin && go#{$gover} build #{$repo}/#{pkg} && go#{$gover} install #{$repo}/#{pkg}"
|
||||
run "mkdir -p /root/go/bin && cd /root/go/bin && go#{$gover} build -v #{$repo}/#{pkg}"
|
||||
end
|
||||
|
||||
[
|
||||
"bot",
|
||||
"cmd",
|
||||
"internal",
|
||||
"vendor",
|
||||
"vendor-log",
|
||||
].each { |x| foldercopy x }
|
||||
|
||||
[
|
||||
"bot",
|
||||
"cmd",
|
||||
"internal",
|
||||
].each { |x| foldercopy x }
|
||||
|
||||
[
|
||||
"cmd/vyvanse",
|
||||
"cmd/logworker",
|
||||
].each { |x| gobuild x }
|
||||
|
||||
cmd "/root/go/bin/vyvanse"
|
||||
|
||||
run "rm -rf $HOME/sdk /root/go/pkg ||:"
|
||||
run "apk del go#{$gover}"
|
||||
|
||||
tag "xena/vyvanse:thick"
|
||||
|
||||
flatten
|
||||
|
||||
tag "xena/vyvanse"
|
||||
tag "xena/vyvanse:latest"
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"git.xeserv.us/xena/gorqlite"
|
||||
"git.xeserv.us/xena/vyvanse/internal/dao"
|
||||
|
||||
"github.com/Xe/ln"
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/drone/mq/stomp"
|
||||
"github.com/namsral/flag"
|
||||
opentracing "github.com/opentracing/opentracing-go"
|
||||
zipkin "github.com/openzipkin/zipkin-go-opentracing"
|
||||
)
|
||||
|
||||
var (
|
||||
token = flag.String("token", "", "discord bot token")
|
||||
zipkinURL = flag.String("zipkin-url", "", "URL for Zipkin traces")
|
||||
databaseURL = flag.String("database-url", "http://", "URL for database (rqlite)")
|
||||
mqURL = flag.String("mq-url", "tcp://mq:9000", "URL for STOMP server")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if *zipkinURL != "" {
|
||||
collector, err := zipkin.NewHTTPCollector(*zipkinURL)
|
||||
if err != nil {
|
||||
ln.FatalErr(context.Background(), err)
|
||||
}
|
||||
tracer, err := zipkin.NewTracer(
|
||||
zipkin.NewRecorder(collector, false, "logworker:5000", "logworker"))
|
||||
if err != nil {
|
||||
ln.FatalErr(context.Background(), err)
|
||||
}
|
||||
|
||||
opentracing.SetGlobalTracer(tracer)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
db, err := gorqlite.Open(*databaseURL)
|
||||
if err != nil {
|
||||
ln.FatalErr(ctx, err)
|
||||
}
|
||||
|
||||
ls := dao.NewLogs(db)
|
||||
err = ls.Migrate(ctx)
|
||||
if err != nil {
|
||||
ln.FatalErr(ctx, err, ln.F{"action": "migrate logs table"})
|
||||
}
|
||||
|
||||
mq, err := stomp.Dial(*mqURL)
|
||||
if err != nil {
|
||||
ln.FatalErr(ctx, err, ln.F{"url": *mqURL})
|
||||
}
|
||||
|
||||
mq.Subscribe("/topic/message_create", stomp.HandlerFunc(func(m *stomp.Message) {
|
||||
sp, ctx := opentracing.StartSpanFromContext(m.Context(), "logworker.topic.message.create")
|
||||
defer sp.Finish()
|
||||
|
||||
msg := &discordgo.Message{}
|
||||
err := json.Unmarshal(m.Msg, msg)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{"action": "can't unmarshal message body to a discordgo message"})
|
||||
return
|
||||
}
|
||||
|
||||
f := ln.F{
|
||||
"stomp_id": string(m.ID),
|
||||
"channel_id": msg.ChannelID,
|
||||
"message_id": msg.ID,
|
||||
"message_author": msg.Author.ID,
|
||||
"message_author_name": msg.Author.Username,
|
||||
"message_author_is_bot": msg.Author.Bot,
|
||||
}
|
||||
|
||||
err = ls.Add(ctx, msg)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f, ln.F{"action": "can't add discordgo message to the database"})
|
||||
}
|
||||
}))
|
||||
|
||||
for {
|
||||
select {}
|
||||
}
|
||||
}
|
|
@ -13,6 +13,7 @@ import (
|
|||
|
||||
"github.com/Xe/ln"
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/drone/mq/stomp"
|
||||
_ "github.com/joho/godotenv/autoload"
|
||||
"github.com/namsral/flag"
|
||||
xkcd "github.com/nishanths/go-xkcd"
|
||||
|
@ -29,11 +30,26 @@ var (
|
|||
token = flag.String("token", "", "discord bot token")
|
||||
zipkinURL = flag.String("zipkin-url", "", "URL for Zipkin traces")
|
||||
databaseURL = flag.String("database-url", "http://", "URL for database (rqlite)")
|
||||
mqURL = flag.String("mq-url", "tcp://mq:9000", "URL for STOMP server")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if *zipkinURL != "" {
|
||||
collector, err := zipkin.NewHTTPCollector(*zipkinURL)
|
||||
if err != nil {
|
||||
ln.FatalErr(context.Background(), err)
|
||||
}
|
||||
tracer, err := zipkin.NewTracer(
|
||||
zipkin.NewRecorder(collector, false, "vyvanse:5000", "vyvanse"))
|
||||
if err != nil {
|
||||
ln.FatalErr(context.Background(), err)
|
||||
}
|
||||
|
||||
opentracing.SetGlobalTracer(tracer)
|
||||
}
|
||||
|
||||
xk := xkcd.NewClient()
|
||||
dg, err := discordgo.New("Bot " + *token)
|
||||
if err != nil {
|
||||
|
@ -62,6 +78,14 @@ func main() {
|
|||
}
|
||||
sp.Finish()
|
||||
|
||||
ctx = context.Background()
|
||||
|
||||
mq, err := stomp.Dial(*mqURL)
|
||||
if err != nil {
|
||||
ln.FatalErr(ctx, err, ln.F{"url": *mqURL})
|
||||
}
|
||||
_ = mq
|
||||
|
||||
c := cron.New()
|
||||
|
||||
comic, err := xk.Latest()
|
||||
|
@ -111,20 +135,6 @@ func main() {
|
|||
|
||||
c.Start()
|
||||
|
||||
if *zipkinURL != "" {
|
||||
collector, err := zipkin.NewHTTPCollector(*zipkinURL)
|
||||
if err != nil {
|
||||
ln.FatalErr(context.Background(), err)
|
||||
}
|
||||
tracer, err := zipkin.NewTracer(
|
||||
zipkin.NewRecorder(collector, false, "vyvanse:5000", "vyvanse"))
|
||||
if err != nil {
|
||||
ln.FatalErr(context.Background(), err)
|
||||
}
|
||||
|
||||
opentracing.SetGlobalTracer(tracer)
|
||||
}
|
||||
|
||||
cs := bot.NewCommandSet()
|
||||
cs.Prefix = ">"
|
||||
|
||||
|
@ -134,6 +144,61 @@ func main() {
|
|||
cs.AddCmd("splattus", "splatoon 2 map rotation status", bot.NoPermissions, spla2nMaps)
|
||||
cs.AddCmd("top10", "shows the top 10 chatters on this server", bot.NoPermissions, top10(us))
|
||||
|
||||
dg.AddHandler(func(s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||
sp, ctx := opentracing.StartSpanFromContext(context.Background(), "message.create.post.stomp")
|
||||
defer sp.Finish()
|
||||
|
||||
f := ln.F{
|
||||
"channel_id": m.ChannelID,
|
||||
"message_id": m.ID,
|
||||
"message_author": m.Author.ID,
|
||||
"message_author_name": m.Author.Username,
|
||||
"message_author_is_bot": m.Author.Bot,
|
||||
}
|
||||
|
||||
err := mq.SendJSON("/topic/message_create", m.Message)
|
||||
if err != nil {
|
||||
if err.Error() == "EOF" {
|
||||
mq, err = stomp.Dial(*mqURL)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f, ln.F{"url": *mqURL, "action": "reconnect to mq"})
|
||||
return
|
||||
}
|
||||
|
||||
err = mq.SendJSON("/topic/message_create", m.Message)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f, ln.F{"action": "retry message_create post to message queue"})
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ln.Error(ctx, err, f, ln.F{"action": "send created message to queue"})
|
||||
return
|
||||
}
|
||||
|
||||
ln.Log(ctx, f, ln.F{"action": "message_create"})
|
||||
})
|
||||
|
||||
dg.AddHandler(func(s *discordgo.Session, m *discordgo.GuildMemberAdd) {
|
||||
sp, ctx := opentracing.StartSpanFromContext(context.Background(), "member.add.post.stomp")
|
||||
defer sp.Finish()
|
||||
|
||||
f := ln.F{
|
||||
"guild_id": m.GuildID,
|
||||
"user_id": m.User.ID,
|
||||
"user_name": m.User.Username,
|
||||
}
|
||||
|
||||
err := mq.SendJSON("/topic/member_add", m.Member)
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, f, ln.F{"action": "send added member to queue"})
|
||||
}
|
||||
|
||||
ln.Log(ctx, f, ln.F{"action": "member_add"})
|
||||
})
|
||||
|
||||
dg.AddHandler(func(s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||
if m.Author.ID == s.State.User.ID {
|
||||
return
|
||||
|
@ -165,8 +230,9 @@ func main() {
|
|||
return
|
||||
}
|
||||
|
||||
fmt.Println("Bot is now running. Press CTRL-C to exit.")
|
||||
// Simple way to keep program running until CTRL-C is pressed.
|
||||
<-make(chan struct{})
|
||||
return
|
||||
ln.Log(ctx, ln.F{"action": "bot is running"})
|
||||
|
||||
for {
|
||||
select {}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,5 +33,15 @@ services:
|
|||
- mq
|
||||
- rqlite
|
||||
|
||||
logworker:
|
||||
restart: always
|
||||
image: xena/vyvanse
|
||||
env_file: ./.env
|
||||
depends_on:
|
||||
- zipkin
|
||||
- mq
|
||||
- rqlite
|
||||
command: /root/go/bin/logworker
|
||||
|
||||
volumes:
|
||||
rqlite:
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.xeserv.us/xena/gorqlite"
|
||||
"github.com/Xe/ln"
|
||||
"github.com/bwmarrin/discordgo"
|
||||
opentracing "github.com/opentracing/opentracing-go"
|
||||
splog "github.com/opentracing/opentracing-go/log"
|
||||
"google.golang.org/api/support/bundler"
|
||||
)
|
||||
|
||||
type Logs struct {
|
||||
conn gorqlite.Connection
|
||||
bdl *bundler.Bundler
|
||||
}
|
||||
|
||||
func NewLogs(conn gorqlite.Connection) *Logs {
|
||||
l := &Logs{conn: conn}
|
||||
l.bdl = bundler.NewBundler("string", l.writeLines)
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *Logs) writeLines(sqlLines interface{}) {
|
||||
sp, ctx := opentracing.StartSpanFromContext(context.Background(), "logs.write.lines")
|
||||
defer sp.Finish()
|
||||
|
||||
_, err := l.conn.Write(sqlLines.([]string))
|
||||
if err != nil {
|
||||
ln.Error(ctx, err, ln.F{"action": "write lines that are batched"})
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logs) Migrate(ctx context.Context) error {
|
||||
sp, ctx := opentracing.StartSpanFromContext(ctx, "logs.migrate")
|
||||
defer sp.Finish()
|
||||
|
||||
migrationDDL := []string{
|
||||
`CREATE TABLE IF NOT EXISTS logs(id INTEGER PRIMARY KEY, discord_id TEXT UNIQUE, channel_id TEXT, content TEXT, timestamp INTEGER, mention_everyone INTEGER, author_id TEXT, author_username TEXT)`,
|
||||
}
|
||||
|
||||
res, err := l.conn.Write(migrationDDL)
|
||||
if err != nil {
|
||||
sp.LogFields(splog.Error(err))
|
||||
}
|
||||
|
||||
for i, re := range res {
|
||||
if re.Err != nil {
|
||||
sp.LogFields(splog.Error(re.Err))
|
||||
return re.Err
|
||||
}
|
||||
|
||||
sp.LogFields(splog.Int("migration.step", i), splog.Float64("timing", re.Timing), splog.Int64("rows.affected", re.RowsAffected))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Logs) Add(ctx context.Context, m *discordgo.Message) error {
|
||||
sp, ctx := opentracing.StartSpanFromContext(ctx, "logs.add")
|
||||
defer sp.Finish()
|
||||
|
||||
stmt := gorqlite.NewPreparedStatement(`INSERT INTO logs (discord_id, channel_id, content, timestamp, mention_everyone, author_id, author_username) VALUES (%s, %s, %s, %d, %d, %s, %s)`)
|
||||
ts, err := m.Timestamp.Parse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var me int
|
||||
if m.MentionEveryone {
|
||||
me = 1
|
||||
}
|
||||
|
||||
bd := stmt.Bind(m.ID, m.ChannelID, m.Content, ts, me, m.Author.ID, m.Author.Username)
|
||||
|
||||
l.bdl.Add(bd, len(bd))
|
||||
|
||||
return nil
|
||||
}
|
|
@ -28,7 +28,6 @@ func (u *Users) Migrate(ctx context.Context) error {
|
|||
res, err := u.conn.Write(migrationDDL)
|
||||
if err != nil {
|
||||
sp.LogFields(splog.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
for i, re := range res {
|
||||
|
|
|
@ -42,3 +42,11 @@ ae77be60afb1dcacde03767a8c37337fad28ac14 github.com/kardianos/osext
|
|||
40a5e952d22c3ef520c6ab7bdb9b1a010ec9a524 git.xeserv.us/xena/gorqlite
|
||||
97311d9f7767e3d6f422ea06661bc2c7a19e8a5d github.com/mattn/go-runewidth
|
||||
be5337e7b39e64e5f91445ce7e721888dbab7387 github.com/olekukonko/tablewriter
|
||||
280af2a3b9c7d9ce90d625150dfff972c6c190b8 github.com/drone/mq/logger
|
||||
280af2a3b9c7d9ce90d625150dfff972c6c190b8 github.com/drone/mq/stomp
|
||||
280af2a3b9c7d9ce90d625150dfff972c6c190b8 github.com/drone/mq/stomp/dialer
|
||||
f5079bd7f6f74e23c4d65efa0f4ce14cbd6a3c0f golang.org/x/net/context
|
||||
f5079bd7f6f74e23c4d65efa0f4ce14cbd6a3c0f golang.org/x/net/websocket
|
||||
66aacef3dd8a676686c7ae3716979581e8b03c47 golang.org/x/net/context
|
||||
f52d1811a62927559de87708c8913c1650ce4f26 golang.org/x/sync/semaphore
|
||||
e0e0e6e500066ff47335c7717e2a090ad127adec google.golang.org/api/support/bundler
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
package logger
|
||||
|
||||
var std Logger = new(none)
|
||||
|
||||
// Debugf writes a debug message to the standard logger.
|
||||
func Debugf(format string, args ...interface{}) {
|
||||
std.Debugf(format, args...)
|
||||
}
|
||||
|
||||
// Verbosef writes a verbose message to the standard logger.
|
||||
func Verbosef(format string, args ...interface{}) {
|
||||
std.Verbosef(format, args...)
|
||||
}
|
||||
|
||||
// Noticef writes a notice message to the standard logger.
|
||||
func Noticef(format string, args ...interface{}) {
|
||||
std.Noticef(format, args...)
|
||||
}
|
||||
|
||||
// Warningf writes a warning message to the standard logger.
|
||||
func Warningf(format string, args ...interface{}) {
|
||||
std.Warningf(format, args...)
|
||||
}
|
||||
|
||||
// Printf writes a default message to the standard logger.
|
||||
func Printf(format string, args ...interface{}) {
|
||||
std.Printf(format, args...)
|
||||
}
|
||||
|
||||
// SetLogger sets the standard logger.
|
||||
func SetLogger(logger Logger) {
|
||||
std = logger
|
||||
}
|
||||
|
||||
// Logger represents a logger.
|
||||
type Logger interface {
|
||||
|
||||
// Debugf writes a debug message.
|
||||
Debugf(string, ...interface{})
|
||||
|
||||
// Verbosef writes a verbose message.
|
||||
Verbosef(string, ...interface{})
|
||||
|
||||
// Noticef writes a notice message.
|
||||
Noticef(string, ...interface{})
|
||||
|
||||
// Warningf writes a warning message.
|
||||
Warningf(string, ...interface{})
|
||||
|
||||
// Printf writes a default message.
|
||||
Printf(string, ...interface{})
|
||||
}
|
||||
|
||||
// none is a logger that silently ignores all writes.
|
||||
type none struct{}
|
||||
|
||||
func (*none) Debugf(string, ...interface{}) {}
|
||||
func (*none) Verbosef(string, ...interface{}) {}
|
||||
func (*none) Noticef(string, ...interface{}) {}
|
||||
func (*none) Warningf(string, ...interface{}) {}
|
||||
func (*none) Printf(string, ...interface{}) {}
|
|
@ -0,0 +1,259 @@
|
|||
package stomp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/drone/mq/logger"
|
||||
"github.com/drone/mq/stomp/dialer"
|
||||
)
|
||||
|
||||
// Client defines a client connection to a STOMP server.
|
||||
type Client struct {
|
||||
mu sync.Mutex
|
||||
|
||||
peer Peer
|
||||
subs map[string]Handler
|
||||
wait map[string]chan struct{}
|
||||
done chan error
|
||||
|
||||
seq int64
|
||||
|
||||
skipVerify bool
|
||||
readBufferSize int
|
||||
writeBufferSize int
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// New returns a new STOMP client using the given connection.
|
||||
func New(peer Peer) *Client {
|
||||
return &Client{
|
||||
peer: peer,
|
||||
subs: make(map[string]Handler),
|
||||
wait: make(map[string]chan struct{}),
|
||||
done: make(chan error, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// Dial creates a client connection to the given target.
|
||||
func Dial(target string) (*Client, error) {
|
||||
conn, err := dialer.Dial(target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return New(Conn(conn)), nil
|
||||
}
|
||||
|
||||
// Send sends the data to the given destination.
|
||||
func (c *Client) Send(dest string, data []byte, opts ...MessageOption) error {
|
||||
m := NewMessage()
|
||||
m.Method = MethodSend
|
||||
m.Dest = []byte(dest)
|
||||
m.Body = data
|
||||
m.Apply(opts...)
|
||||
return c.sendMessage(m)
|
||||
}
|
||||
|
||||
// SendJSON sends the JSON encoding of v to the given destination.
|
||||
func (c *Client) SendJSON(dest string, v interface{}, opts ...MessageOption) error {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts = append(opts,
|
||||
WithHeader("content-type", "application/json"),
|
||||
)
|
||||
return c.Send(dest, data, opts...)
|
||||
}
|
||||
|
||||
// Subscribe subscribes to the given destination.
|
||||
func (c *Client) Subscribe(dest string, handler Handler, opts ...MessageOption) (id []byte, err error) {
|
||||
id = c.incr()
|
||||
|
||||
m := NewMessage()
|
||||
m.Method = MethodSubscribe
|
||||
m.ID = id
|
||||
m.Dest = []byte(dest)
|
||||
m.Apply(opts...)
|
||||
|
||||
c.mu.Lock()
|
||||
c.subs[string(id)] = handler
|
||||
c.mu.Unlock()
|
||||
|
||||
err = c.sendMessage(m)
|
||||
if err != nil {
|
||||
c.mu.Lock()
|
||||
delete(c.subs, string(id))
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Unsubscribe unsubscribes to the destination.
|
||||
func (c *Client) Unsubscribe(id []byte, opts ...MessageOption) error {
|
||||
c.mu.Lock()
|
||||
delete(c.subs, string(id))
|
||||
c.mu.Unlock()
|
||||
|
||||
m := NewMessage()
|
||||
m.Method = MethodUnsubscribe
|
||||
m.ID = id
|
||||
m.Apply(opts...)
|
||||
|
||||
return c.sendMessage(m)
|
||||
}
|
||||
|
||||
// Ack acknowledges the messages with the given id.
|
||||
func (c *Client) Ack(id []byte, opts ...MessageOption) error {
|
||||
m := NewMessage()
|
||||
m.Method = MethodAck
|
||||
m.ID = id
|
||||
m.Apply(opts...)
|
||||
|
||||
return c.sendMessage(m)
|
||||
}
|
||||
|
||||
// Nack negative-acknowledges the messages with the given id.
|
||||
func (c *Client) Nack(id []byte, opts ...MessageOption) error {
|
||||
m := NewMessage()
|
||||
m.Method = MethodNack
|
||||
m.ID = id
|
||||
m.Apply(opts...)
|
||||
|
||||
return c.peer.Send(m)
|
||||
}
|
||||
|
||||
// Connect opens the connection and establishes the session.
|
||||
func (c *Client) Connect(opts ...MessageOption) error {
|
||||
m := NewMessage()
|
||||
m.Proto = STOMP
|
||||
m.Method = MethodStomp
|
||||
m.Apply(opts...)
|
||||
if err := c.sendMessage(m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m, ok := <-c.peer.Receive()
|
||||
if !ok {
|
||||
return io.EOF
|
||||
}
|
||||
defer m.Release()
|
||||
|
||||
if !bytes.Equal(m.Method, MethodConnected) {
|
||||
return fmt.Errorf("stomp: inbound message: unexpected method, want connected")
|
||||
}
|
||||
go c.listen()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disconnect terminates the session and closes the connection.
|
||||
func (c *Client) Disconnect() error {
|
||||
m := NewMessage()
|
||||
m.Method = MethodDisconnect
|
||||
c.sendMessage(m)
|
||||
return c.peer.Close()
|
||||
}
|
||||
|
||||
// Done returns a channel
|
||||
func (c *Client) Done() <-chan error {
|
||||
return c.done
|
||||
}
|
||||
|
||||
func (c *Client) incr() []byte {
|
||||
c.mu.Lock()
|
||||
i := c.seq
|
||||
c.seq++
|
||||
c.mu.Unlock()
|
||||
return strconv.AppendInt(nil, i, 10)
|
||||
}
|
||||
|
||||
func (c *Client) listen() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Warningf("stomp client: recover panic: %s", r)
|
||||
err, ok := r.(error)
|
||||
if !ok {
|
||||
logger.Warningf("%v: %s", r, debug.Stack())
|
||||
c.done <- fmt.Errorf("%v", r)
|
||||
} else {
|
||||
logger.Warningf("%s", err)
|
||||
c.done <- err
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
m, ok := <-c.peer.Receive()
|
||||
if !ok {
|
||||
c.done <- io.EOF
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case bytes.Equal(m.Method, MethodMessage):
|
||||
c.handleMessage(m)
|
||||
case bytes.Equal(m.Method, MethodRecipet):
|
||||
c.handleReceipt(m)
|
||||
default:
|
||||
logger.Noticef("stomp client: unknown message type: %s",
|
||||
string(m.Method),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) handleReceipt(m *Message) {
|
||||
c.mu.Lock()
|
||||
receiptc, ok := c.wait[string(m.Receipt)]
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
logger.Noticef("stomp client: unknown read receipt: %s",
|
||||
string(m.Receipt),
|
||||
)
|
||||
return
|
||||
}
|
||||
receiptc <- struct{}{}
|
||||
}
|
||||
|
||||
func (c *Client) handleMessage(m *Message) {
|
||||
c.mu.Lock()
|
||||
handler, ok := c.subs[string(m.Subs)]
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
logger.Noticef("stomp client: subscription not found: %s",
|
||||
string(m.Subs),
|
||||
)
|
||||
return
|
||||
}
|
||||
handler.Handle(m)
|
||||
}
|
||||
|
||||
func (c *Client) sendMessage(m *Message) error {
|
||||
if len(m.Receipt) == 0 {
|
||||
return c.peer.Send(m)
|
||||
}
|
||||
|
||||
receiptc := make(chan struct{}, 1)
|
||||
c.wait[string(m.Receipt)] = receiptc
|
||||
|
||||
defer func() {
|
||||
delete(c.wait, string(m.Receipt))
|
||||
}()
|
||||
|
||||
err := c.peer.Send(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-receiptc:
|
||||
return nil
|
||||
}
|
||||
}
|
|
@ -0,0 +1,156 @@
|
|||
package stomp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/drone/mq/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
bufferSize = 32 << 10 // default buffer size 32KB
|
||||
bufferLimit = 32 << 15 // default buffer limit 1MB
|
||||
)
|
||||
|
||||
var (
|
||||
never time.Time
|
||||
deadline = time.Second * 5
|
||||
|
||||
heartbeatTime = time.Second * 30
|
||||
heartbeatWait = time.Second * 60
|
||||
)
|
||||
|
||||
type connPeer struct {
|
||||
conn net.Conn
|
||||
done chan bool
|
||||
|
||||
reader *bufio.Reader
|
||||
writer *bufio.Writer
|
||||
incoming chan *Message
|
||||
outgoing chan *Message
|
||||
}
|
||||
|
||||
// Conn creates a network-connected peer that reads and writes
|
||||
// messages using net.Conn c.
|
||||
func Conn(c net.Conn) Peer {
|
||||
p := &connPeer{
|
||||
reader: bufio.NewReaderSize(c, bufferSize),
|
||||
writer: bufio.NewWriterSize(c, bufferSize),
|
||||
incoming: make(chan *Message),
|
||||
outgoing: make(chan *Message),
|
||||
done: make(chan bool),
|
||||
conn: c,
|
||||
}
|
||||
|
||||
go p.readInto(p.incoming)
|
||||
go p.writeFrom(p.outgoing)
|
||||
return p
|
||||
}
|
||||
|
||||
func (c *connPeer) Receive() <-chan *Message {
|
||||
return c.incoming
|
||||
}
|
||||
|
||||
func (c *connPeer) Send(message *Message) error {
|
||||
select {
|
||||
case <-c.done:
|
||||
return io.EOF
|
||||
default:
|
||||
c.outgoing <- message
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connPeer) Addr() string {
|
||||
return c.conn.RemoteAddr().String()
|
||||
}
|
||||
|
||||
func (c *connPeer) Close() error {
|
||||
return c.close()
|
||||
}
|
||||
|
||||
func (c *connPeer) close() error {
|
||||
select {
|
||||
case <-c.done:
|
||||
return io.EOF
|
||||
default:
|
||||
close(c.done)
|
||||
close(c.incoming)
|
||||
close(c.outgoing)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connPeer) readInto(messages chan<- *Message) {
|
||||
defer c.close()
|
||||
|
||||
for {
|
||||
// lim := io.LimitReader(c.conn, bufferLimit)
|
||||
// buf := bufio.NewReaderSize(lim, bufferSize)
|
||||
|
||||
buf, err := c.reader.ReadBytes(0)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if len(buf) == 1 {
|
||||
c.conn.SetReadDeadline(time.Now().Add(heartbeatWait))
|
||||
logger.Verbosef("stomp: received heart-beat")
|
||||
continue
|
||||
}
|
||||
|
||||
msg := NewMessage()
|
||||
msg.Parse(buf[:len(buf)-1])
|
||||
|
||||
select {
|
||||
case <-c.done:
|
||||
break
|
||||
default:
|
||||
messages <- msg
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connPeer) writeFrom(messages <-chan *Message) {
|
||||
tick := time.NewTicker(time.Millisecond * 100).C
|
||||
heartbeat := time.NewTicker(heartbeatTime).C
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
break loop
|
||||
case <-heartbeat:
|
||||
logger.Verbosef("stomp: send heart-beat.")
|
||||
c.writer.WriteByte(0)
|
||||
case <-tick:
|
||||
c.conn.SetWriteDeadline(time.Now().Add(deadline))
|
||||
if err := c.writer.Flush(); err != nil {
|
||||
break loop
|
||||
}
|
||||
c.conn.SetWriteDeadline(never)
|
||||
case msg, ok := <-messages:
|
||||
if !ok {
|
||||
break loop
|
||||
}
|
||||
writeTo(c.writer, msg)
|
||||
c.writer.WriteByte(0)
|
||||
msg.Release()
|
||||
}
|
||||
}
|
||||
|
||||
c.drain()
|
||||
}
|
||||
|
||||
func (c *connPeer) drain() error {
|
||||
c.conn.SetWriteDeadline(time.Now().Add(deadline))
|
||||
for msg := range c.outgoing {
|
||||
writeTo(c.writer, msg)
|
||||
c.writer.WriteByte(0)
|
||||
msg.Release()
|
||||
}
|
||||
c.conn.SetWriteDeadline(never)
|
||||
c.writer.Flush()
|
||||
return c.conn.Close()
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
package stomp
|
||||
|
||||
// STOMP protocol version.
|
||||
var STOMP = []byte("1.2")
|
||||
|
||||
// STOMP protocol methods.
|
||||
var (
|
||||
MethodStomp = []byte("STOMP")
|
||||
MethodConnect = []byte("CONNECT")
|
||||
MethodConnected = []byte("CONNECTED")
|
||||
MethodSend = []byte("SEND")
|
||||
MethodSubscribe = []byte("SUBSCRIBE")
|
||||
MethodUnsubscribe = []byte("UNSUBSCRIBE")
|
||||
MethodAck = []byte("ACK")
|
||||
MethodNack = []byte("NACK")
|
||||
MethodDisconnect = []byte("DISCONNECT")
|
||||
MethodMessage = []byte("MESSAGE")
|
||||
MethodRecipet = []byte("RECEIPT")
|
||||
MethodError = []byte("ERROR")
|
||||
)
|
||||
|
||||
// STOMP protocol headers.
|
||||
var (
|
||||
HeaderAccept = []byte("accept-version")
|
||||
HeaderAck = []byte("ack")
|
||||
HeaderExpires = []byte("expires")
|
||||
HeaderDest = []byte("destination")
|
||||
HeaderHost = []byte("host")
|
||||
HeaderLogin = []byte("login")
|
||||
HeaderPass = []byte("passcode")
|
||||
HeaderID = []byte("id")
|
||||
HeaderMessageID = []byte("message-id")
|
||||
HeaderPersist = []byte("persist")
|
||||
HeaderPrefetch = []byte("prefetch-count")
|
||||
HeaderReceipt = []byte("receipt")
|
||||
HeaderReceiptID = []byte("receipt-id")
|
||||
HeaderRetain = []byte("retain")
|
||||
HeaderSelector = []byte("selector")
|
||||
HeaderServer = []byte("server")
|
||||
HeaderSession = []byte("session")
|
||||
HeaderSubscription = []byte("subscription")
|
||||
HeaderVersion = []byte("version")
|
||||
)
|
||||
|
||||
// Common STOMP header values.
|
||||
var (
|
||||
AckAuto = []byte("auto")
|
||||
AckClient = []byte("client")
|
||||
PersistTrue = []byte("true")
|
||||
RetainTrue = []byte("true")
|
||||
RetainLast = []byte("last")
|
||||
RetainAll = []byte("all")
|
||||
RetainRemove = []byte("remove")
|
||||
)
|
||||
|
||||
var headerLookup = map[string]struct{}{
|
||||
"accept-version": struct{}{},
|
||||
"ack": struct{}{},
|
||||
"expires": struct{}{},
|
||||
"destination": struct{}{},
|
||||
"host": struct{}{},
|
||||
"login": struct{}{},
|
||||
"passcode": struct{}{},
|
||||
"id": struct{}{},
|
||||
"message-id": struct{}{},
|
||||
"persist": struct{}{},
|
||||
"prefetch-count": struct{}{},
|
||||
"receipt": struct{}{},
|
||||
"receipt-id": struct{}{},
|
||||
"retain": struct{}{},
|
||||
"selector": struct{}{},
|
||||
"server": struct{}{},
|
||||
"session": struct{}{},
|
||||
"subscription": struct{}{},
|
||||
"version": struct{}{},
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package stomp
|
||||
|
||||
import "golang.org/x/net/context"
|
||||
|
||||
const clientKey = "stomp.client"
|
||||
|
||||
// NewContext adds the client to the context.
|
||||
func (c *Client) NewContext(ctx context.Context, client *Client) context.Context {
|
||||
// HACK for use with gin and echo
|
||||
if s, ok := ctx.(setter); ok {
|
||||
s.Set(clientKey, clientKey)
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, clientKey, client)
|
||||
}
|
||||
|
||||
// FromContext retrieves the client from context
|
||||
func FromContext(ctx context.Context) (*Client, bool) {
|
||||
client, ok := ctx.Value(clientKey).(*Client)
|
||||
return client, ok
|
||||
}
|
||||
|
||||
// MustFromContext retrieves the client from context. Panics if not found
|
||||
func MustFromContext(ctx context.Context) *Client {
|
||||
client, ok := FromContext(ctx)
|
||||
if !ok {
|
||||
panic("stomp.Client not found in context")
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// HACK setter defines a context that enables setting values. This is a
|
||||
// temporary workaround for use with gin and echo and will eventually
|
||||
// be removed. DO NOT depend on this.
|
||||
type setter interface {
|
||||
Set(string, interface{})
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
package dialer
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/url"
|
||||
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
protoHTTP = "http"
|
||||
protoHTTPS = "https"
|
||||
protoWS = "ws"
|
||||
protoWSS = "wss"
|
||||
protoTCP = "tcp"
|
||||
)
|
||||
|
||||
// Dial creates a client connection to the given target.
|
||||
func Dial(target string) (net.Conn, error) {
|
||||
u, err := url.Parse(target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch u.Scheme {
|
||||
case protoHTTP, protoHTTPS, protoWS, protoWSS:
|
||||
return dialWebsocket(u)
|
||||
case protoTCP:
|
||||
return dialSocket(u)
|
||||
default:
|
||||
panic("stomp: invalid protocol")
|
||||
}
|
||||
}
|
||||
|
||||
func dialWebsocket(target *url.URL) (net.Conn, error) {
|
||||
origin, err := target.Parse("/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch origin.Scheme {
|
||||
case protoWS:
|
||||
origin.Scheme = protoHTTP
|
||||
case protoWSS:
|
||||
origin.Scheme = protoHTTPS
|
||||
}
|
||||
return websocket.Dial(target.String(), "", origin.String())
|
||||
}
|
||||
|
||||
func dialSocket(target *url.URL) (net.Conn, error) {
|
||||
return net.Dial(protoTCP, target.Host)
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
package stomp
|
||||
|
||||
// Handler handles a STOMP message.
|
||||
type Handler interface {
|
||||
Handle(*Message)
|
||||
}
|
||||
|
||||
// The HandlerFunc type is an adapter to allow the use of an ordinary
|
||||
// function as a STOMP message handler.
|
||||
type HandlerFunc func(*Message)
|
||||
|
||||
// Handle calls f(m).
|
||||
func (f HandlerFunc) Handle(m *Message) { f(m) }
|
|
@ -0,0 +1,109 @@
|
|||
package stomp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const defaultHeaderLen = 5
|
||||
|
||||
type item struct {
|
||||
name []byte
|
||||
data []byte
|
||||
}
|
||||
|
||||
// Header represents the header section of the STOMP message.
|
||||
type Header struct {
|
||||
items []item
|
||||
itemc int
|
||||
}
|
||||
|
||||
func newHeader() *Header {
|
||||
return &Header{
|
||||
items: make([]item, defaultHeaderLen),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the named header value.
|
||||
func (h *Header) Get(name []byte) (b []byte) {
|
||||
for i := 0; i < h.itemc; i++ {
|
||||
if v := h.items[i]; bytes.Equal(v.name, name) {
|
||||
return v.data
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetString returns the named header value.
|
||||
func (h *Header) GetString(name string) string {
|
||||
k := []byte(name)
|
||||
v := h.Get(k)
|
||||
return string(v)
|
||||
}
|
||||
|
||||
// GetBool returns the named header value.
|
||||
func (h *Header) GetBool(name string) bool {
|
||||
s := h.GetString(name)
|
||||
b, _ := strconv.ParseBool(s)
|
||||
return b
|
||||
}
|
||||
|
||||
// GetInt returns the named header value.
|
||||
func (h *Header) GetInt(name string) int {
|
||||
s := h.GetString(name)
|
||||
i, _ := strconv.Atoi(s)
|
||||
return i
|
||||
}
|
||||
|
||||
// GetInt64 returns the named header value.
|
||||
func (h *Header) GetInt64(name string) int64 {
|
||||
s := h.GetString(name)
|
||||
i, _ := strconv.ParseInt(s, 10, 64)
|
||||
return i
|
||||
}
|
||||
|
||||
// Field returns the named header value in string format. This is used to
|
||||
// provide compatibility with the SQL expression evaluation package.
|
||||
func (h *Header) Field(name []byte) []byte {
|
||||
return h.Get(name)
|
||||
}
|
||||
|
||||
// Add appens the key value pair to the header.
|
||||
func (h *Header) Add(name, data []byte) {
|
||||
h.grow()
|
||||
h.items[h.itemc].name = name
|
||||
h.items[h.itemc].data = data
|
||||
h.itemc++
|
||||
}
|
||||
|
||||
// Index returns the keypair at index i.
|
||||
func (h *Header) Index(i int) (k, v []byte) {
|
||||
if i > h.itemc {
|
||||
return
|
||||
}
|
||||
k = h.items[i].name
|
||||
v = h.items[i].data
|
||||
return
|
||||
}
|
||||
|
||||
// Len returns the header length.
|
||||
func (h *Header) Len() int {
|
||||
return h.itemc
|
||||
}
|
||||
|
||||
func (h *Header) grow() {
|
||||
if h.itemc > defaultHeaderLen-1 {
|
||||
h.items = append(h.items, item{})
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Header) reset() {
|
||||
h.itemc = 0
|
||||
h.items = h.items[:defaultHeaderLen]
|
||||
for i := range h.items {
|
||||
h.items[i].name = zeroBytes
|
||||
h.items[i].data = zeroBytes
|
||||
}
|
||||
}
|
||||
|
||||
var zeroBytes []byte
|
|
@ -0,0 +1,146 @@
|
|||
package stomp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// Message represents a parsed STOMP message.
|
||||
type Message struct {
|
||||
ID []byte // id header
|
||||
Proto []byte // stomp version
|
||||
Method []byte // stomp method
|
||||
User []byte // username header
|
||||
Pass []byte // password header
|
||||
Dest []byte // destination header
|
||||
Subs []byte // subscription id
|
||||
Ack []byte // ack id
|
||||
Msg []byte // message-id header
|
||||
Persist []byte // persist header
|
||||
Retain []byte // retain header
|
||||
Prefetch []byte // prefetch count
|
||||
Expires []byte // expires header
|
||||
Receipt []byte // receipt header
|
||||
Selector []byte // selector header
|
||||
Body []byte
|
||||
Header *Header // custom headers
|
||||
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// Copy returns a copy of the Message.
|
||||
func (m *Message) Copy() *Message {
|
||||
c := NewMessage()
|
||||
c.ID = m.ID
|
||||
c.Proto = m.Proto
|
||||
c.Method = m.Method
|
||||
c.User = m.User
|
||||
c.Pass = m.Pass
|
||||
c.Dest = m.Dest
|
||||
c.Subs = m.Subs
|
||||
c.Ack = m.Ack
|
||||
c.Prefetch = m.Prefetch
|
||||
c.Selector = m.Selector
|
||||
c.Persist = m.Persist
|
||||
c.Retain = m.Retain
|
||||
c.Receipt = m.Receipt
|
||||
c.Expires = m.Expires
|
||||
c.Body = m.Body
|
||||
c.ctx = m.ctx
|
||||
c.Header.itemc = m.Header.itemc
|
||||
copy(c.Header.items, m.Header.items)
|
||||
return c
|
||||
}
|
||||
|
||||
// Apply applies the options to the message.
|
||||
func (m *Message) Apply(opts ...MessageOption) {
|
||||
for _, opt := range opts {
|
||||
opt(m)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse parses the raw bytes into the message.
|
||||
func (m *Message) Parse(b []byte) error {
|
||||
return read(b, m)
|
||||
}
|
||||
|
||||
// Bytes returns the Message in raw byte format.
|
||||
func (m *Message) Bytes() []byte {
|
||||
var buf bytes.Buffer
|
||||
writeTo(&buf, m)
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// String returns the Message in string format.
|
||||
func (m *Message) String() string {
|
||||
return string(m.Bytes())
|
||||
}
|
||||
|
||||
// Release releases the message back to the message pool.
|
||||
func (m *Message) Release() {
|
||||
m.Reset()
|
||||
pool.Put(m)
|
||||
}
|
||||
|
||||
// Reset resets the meesage fields to their zero values.
|
||||
func (m *Message) Reset() {
|
||||
m.ID = m.ID[:0]
|
||||
m.Proto = m.Proto[:0]
|
||||
m.Method = m.Method[:0]
|
||||
m.User = m.User[:0]
|
||||
m.Pass = m.Pass[:0]
|
||||
m.Dest = m.Dest[:0]
|
||||
m.Subs = m.Subs[:0]
|
||||
m.Ack = m.Ack[:0]
|
||||
m.Prefetch = m.Prefetch[:0]
|
||||
m.Selector = m.Selector[:0]
|
||||
m.Persist = m.Persist[:0]
|
||||
m.Retain = m.Retain[:0]
|
||||
m.Receipt = m.Receipt[:0]
|
||||
m.Expires = m.Expires[:0]
|
||||
m.Body = m.Body[:0]
|
||||
m.ctx = nil
|
||||
m.Header.reset()
|
||||
}
|
||||
|
||||
// Context returns the request's context.
|
||||
func (m *Message) Context() context.Context {
|
||||
if m.ctx != nil {
|
||||
return m.ctx
|
||||
}
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
// WithContext returns a shallow copy of m with its context changed
|
||||
// to ctx. The provided ctx must be non-nil.
|
||||
func (m *Message) WithContext(ctx context.Context) *Message {
|
||||
c := m.Copy()
|
||||
c.ctx = ctx
|
||||
return c
|
||||
}
|
||||
|
||||
// Unmarshal parses the JSON-encoded body of the message and
|
||||
// stores the result in the value pointed to by v.
|
||||
func (m *Message) Unmarshal(v interface{}) error {
|
||||
return json.Unmarshal(m.Body, v)
|
||||
}
|
||||
|
||||
// NewMessage returns an empty message from the message pool.
|
||||
func NewMessage() *Message {
|
||||
return pool.Get().(*Message)
|
||||
}
|
||||
|
||||
var pool = sync.Pool{New: func() interface{} {
|
||||
return &Message{Header: newHeader()}
|
||||
}}
|
||||
|
||||
// Rand returns a random int64 number as a []byte of
|
||||
// ascii characters.
|
||||
func Rand() []byte {
|
||||
return strconv.AppendInt(nil, rand.Int63(), 10)
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
package stomp
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// MessageOption configures message options.
|
||||
type MessageOption func(*Message)
|
||||
|
||||
// WithCredentials returns a MessageOption which sets credentials.
|
||||
func WithCredentials(username, password string) MessageOption {
|
||||
return func(m *Message) {
|
||||
m.User = []byte(username)
|
||||
m.Pass = []byte(password)
|
||||
}
|
||||
}
|
||||
|
||||
// WithHeader returns a MessageOption which sets a header.
|
||||
func WithHeader(key, value string) MessageOption {
|
||||
return func(m *Message) {
|
||||
_, ok := headerLookup[strings.ToLower(key)]
|
||||
if !ok {
|
||||
m.Header.Add(
|
||||
[]byte(key),
|
||||
[]byte(value),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithHeaders returns a MessageOption which sets headers.
|
||||
func WithHeaders(headers map[string]string) MessageOption {
|
||||
return func(m *Message) {
|
||||
for key, value := range headers {
|
||||
_, ok := headerLookup[strings.ToLower(key)]
|
||||
if !ok {
|
||||
m.Header.Add(
|
||||
[]byte(key),
|
||||
[]byte(value),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithExpires returns a MessageOption configured with an expiration.
|
||||
func WithExpires(exp int64) MessageOption {
|
||||
return func(m *Message) {
|
||||
m.Expires = strconv.AppendInt(nil, exp, 10)
|
||||
}
|
||||
}
|
||||
|
||||
// WithPrefetch returns a MessageOption configured with a prefetch count.
|
||||
func WithPrefetch(prefetch int) MessageOption {
|
||||
return func(m *Message) {
|
||||
m.Prefetch = strconv.AppendInt(nil, int64(prefetch), 10)
|
||||
}
|
||||
}
|
||||
|
||||
// WithReceipt returns a MessageOption configured with a receipt request.
|
||||
func WithReceipt() MessageOption {
|
||||
return func(m *Message) {
|
||||
m.Receipt = strconv.AppendInt(nil, rand.Int63(), 10)
|
||||
}
|
||||
}
|
||||
|
||||
// WithPersistence returns a MessageOption configured to persist.
|
||||
func WithPersistence() MessageOption {
|
||||
return func(m *Message) {
|
||||
m.Persist = PersistTrue
|
||||
}
|
||||
}
|
||||
|
||||
// WithRetain returns a MessageOption configured to retain the message.
|
||||
func WithRetain(retain string) MessageOption {
|
||||
return func(m *Message) {
|
||||
m.Retain = []byte(retain)
|
||||
}
|
||||
}
|
||||
|
||||
// WithSelector returns a MessageOption configured to filter messages
|
||||
// using a sql-like evaluation string.
|
||||
func WithSelector(selector string) MessageOption {
|
||||
return func(m *Message) {
|
||||
m.Selector = []byte(selector)
|
||||
}
|
||||
}
|
||||
|
||||
// WithAck returns a MessageOption configured with an ack policy.
|
||||
func WithAck(ack string) MessageOption {
|
||||
return func(m *Message) {
|
||||
m.Ack = []byte(ack)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,86 @@
|
|||
package stomp
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Peer defines a peer-to-peer connection.
|
||||
type Peer interface {
|
||||
// Send sends a message.
|
||||
Send(*Message) error
|
||||
|
||||
// Receive returns a channel of inbound messages.
|
||||
Receive() <-chan *Message
|
||||
|
||||
// Close closes the connection.
|
||||
Close() error
|
||||
|
||||
// Addr returns the peer address.
|
||||
Addr() string
|
||||
}
|
||||
|
||||
// Pipe creates a synchronous in-memory pipe, where reads on one end are
|
||||
// matched with writes on the other. This is useful for direct, in-memory
|
||||
// client-server communication.
|
||||
func Pipe() (Peer, Peer) {
|
||||
atob := make(chan *Message, 10)
|
||||
btoa := make(chan *Message, 10)
|
||||
|
||||
a := &localPeer{
|
||||
incoming: btoa,
|
||||
outgoing: atob,
|
||||
finished: make(chan bool),
|
||||
}
|
||||
b := &localPeer{
|
||||
incoming: atob,
|
||||
outgoing: btoa,
|
||||
finished: make(chan bool),
|
||||
}
|
||||
|
||||
return a, b
|
||||
}
|
||||
|
||||
type localPeer struct {
|
||||
finished chan bool
|
||||
outgoing chan<- *Message
|
||||
incoming <-chan *Message
|
||||
}
|
||||
|
||||
func (p *localPeer) Receive() <-chan *Message {
|
||||
return p.incoming
|
||||
}
|
||||
|
||||
func (p *localPeer) Send(m *Message) error {
|
||||
select {
|
||||
case <-p.finished:
|
||||
return io.EOF
|
||||
default:
|
||||
p.outgoing <- m
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *localPeer) Close() error {
|
||||
close(p.finished)
|
||||
close(p.outgoing)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *localPeer) Addr() string {
|
||||
peerAddrOnce.Do(func() {
|
||||
// get the local address list
|
||||
addr, _ := net.InterfaceAddrs()
|
||||
if len(addr) != 0 {
|
||||
// use the last address in the list
|
||||