354 lines
7.6 KiB
Go
354 lines
7.6 KiB
Go
package smux
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"io"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
const (
|
|
defaultAcceptBacklog = 1024
|
|
)
|
|
|
|
const (
|
|
errBrokenPipe = "broken pipe"
|
|
errInvalidProtocol = "invalid protocol version"
|
|
errGoAway = "stream id overflows, should start a new connection"
|
|
)
|
|
|
|
type writeRequest struct {
|
|
frame Frame
|
|
result chan writeResult
|
|
}
|
|
|
|
type writeResult struct {
|
|
n int
|
|
err error
|
|
}
|
|
|
|
// Session defines a multiplexed connection for streams
|
|
type Session struct {
|
|
conn io.ReadWriteCloser
|
|
|
|
config *Config
|
|
nextStreamID uint32 // next stream identifier
|
|
nextStreamIDLock sync.Mutex
|
|
|
|
bucket int32 // token bucket
|
|
bucketNotify chan struct{} // used for waiting for tokens
|
|
|
|
streams map[uint32]*Stream // all streams in this session
|
|
streamLock sync.Mutex // locks streams
|
|
|
|
die chan struct{} // flag session has died
|
|
dieLock sync.Mutex
|
|
chAccepts chan *Stream
|
|
|
|
dataReady int32 // flag data has arrived
|
|
|
|
goAway int32 // flag id exhausted
|
|
|
|
deadline atomic.Value
|
|
|
|
writes chan writeRequest
|
|
}
|
|
|
|
func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
|
|
s := new(Session)
|
|
s.die = make(chan struct{})
|
|
s.conn = conn
|
|
s.config = config
|
|
s.streams = make(map[uint32]*Stream)
|
|
s.chAccepts = make(chan *Stream, defaultAcceptBacklog)
|
|
s.bucket = int32(config.MaxReceiveBuffer)
|
|
s.bucketNotify = make(chan struct{}, 1)
|
|
s.writes = make(chan writeRequest)
|
|
|
|
if client {
|
|
s.nextStreamID = 1
|
|
} else {
|
|
s.nextStreamID = 0
|
|
}
|
|
go s.recvLoop()
|
|
go s.sendLoop()
|
|
go s.keepalive()
|
|
return s
|
|
}
|
|
|
|
// OpenStream is used to create a new stream
|
|
func (s *Session) OpenStream() (*Stream, error) {
|
|
if s.IsClosed() {
|
|
return nil, errors.New(errBrokenPipe)
|
|
}
|
|
|
|
// generate stream id
|
|
s.nextStreamIDLock.Lock()
|
|
if s.goAway > 0 {
|
|
s.nextStreamIDLock.Unlock()
|
|
return nil, errors.New(errGoAway)
|
|
}
|
|
|
|
s.nextStreamID += 2
|
|
sid := s.nextStreamID
|
|
if sid == sid%2 { // stream-id overflows
|
|
s.goAway = 1
|
|
s.nextStreamIDLock.Unlock()
|
|
return nil, errors.New(errGoAway)
|
|
}
|
|
s.nextStreamIDLock.Unlock()
|
|
|
|
stream := newStream(sid, s.config.MaxFrameSize, s)
|
|
|
|
if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil {
|
|
return nil, errors.Wrap(err, "writeFrame")
|
|
}
|
|
|
|
s.streamLock.Lock()
|
|
s.streams[sid] = stream
|
|
s.streamLock.Unlock()
|
|
return stream, nil
|
|
}
|
|
|
|
// AcceptStream is used to block until the next available stream
|
|
// is ready to be accepted.
|
|
func (s *Session) AcceptStream() (*Stream, error) {
|
|
var deadline <-chan time.Time
|
|
if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() {
|
|
timer := time.NewTimer(d.Sub(time.Now()))
|
|
defer timer.Stop()
|
|
deadline = timer.C
|
|
}
|
|
select {
|
|
case stream := <-s.chAccepts:
|
|
return stream, nil
|
|
case <-deadline:
|
|
return nil, errTimeout
|
|
case <-s.die:
|
|
return nil, errors.New(errBrokenPipe)
|
|
}
|
|
}
|
|
|
|
// Close is used to close the session and all streams.
|
|
func (s *Session) Close() (err error) {
|
|
s.dieLock.Lock()
|
|
|
|
select {
|
|
case <-s.die:
|
|
s.dieLock.Unlock()
|
|
return errors.New(errBrokenPipe)
|
|
default:
|
|
close(s.die)
|
|
s.dieLock.Unlock()
|
|
s.streamLock.Lock()
|
|
for k := range s.streams {
|
|
s.streams[k].sessionClose()
|
|
}
|
|
s.streamLock.Unlock()
|
|
s.notifyBucket()
|
|
return s.conn.Close()
|
|
}
|
|
}
|
|
|
|
// notifyBucket notifies recvLoop that bucket is available
|
|
func (s *Session) notifyBucket() {
|
|
select {
|
|
case s.bucketNotify <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
|
|
// IsClosed does a safe check to see if we have shutdown
|
|
func (s *Session) IsClosed() bool {
|
|
select {
|
|
case <-s.die:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// NumStreams returns the number of currently open streams
|
|
func (s *Session) NumStreams() int {
|
|
if s.IsClosed() {
|
|
return 0
|
|
}
|
|
s.streamLock.Lock()
|
|
defer s.streamLock.Unlock()
|
|
return len(s.streams)
|
|
}
|
|
|
|
// SetDeadline sets a deadline used by Accept* calls.
|
|
// A zero time value disables the deadline.
|
|
func (s *Session) SetDeadline(t time.Time) error {
|
|
s.deadline.Store(t)
|
|
return nil
|
|
}
|
|
|
|
// notify the session that a stream has closed
|
|
func (s *Session) streamClosed(sid uint32) {
|
|
s.streamLock.Lock()
|
|
if n := s.streams[sid].recycleTokens(); n > 0 { // return remaining tokens to the bucket
|
|
if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
|
|
s.notifyBucket()
|
|
}
|
|
}
|
|
delete(s.streams, sid)
|
|
s.streamLock.Unlock()
|
|
}
|
|
|
|
// returnTokens is called by stream to return token after read
|
|
func (s *Session) returnTokens(n int) {
|
|
if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
|
|
s.notifyBucket()
|
|
}
|
|
}
|
|
|
|
// session read a frame from underlying connection
|
|
// it's data is pointed to the input buffer
|
|
func (s *Session) readFrame(buffer []byte) (f Frame, err error) {
|
|
if _, err := io.ReadFull(s.conn, buffer[:headerSize]); err != nil {
|
|
return f, errors.Wrap(err, "readFrame")
|
|
}
|
|
|
|
dec := rawHeader(buffer)
|
|
if dec.Version() != version {
|
|
return f, errors.New(errInvalidProtocol)
|
|
}
|
|
|
|
f.ver = dec.Version()
|
|
f.cmd = dec.Cmd()
|
|
f.sid = dec.StreamID()
|
|
if length := dec.Length(); length > 0 {
|
|
if _, err := io.ReadFull(s.conn, buffer[headerSize:headerSize+length]); err != nil {
|
|
return f, errors.Wrap(err, "readFrame")
|
|
}
|
|
f.data = buffer[headerSize : headerSize+length]
|
|
}
|
|
return f, nil
|
|
}
|
|
|
|
// recvLoop keeps on reading from underlying connection if tokens are available
|
|
func (s *Session) recvLoop() {
|
|
buffer := make([]byte, (1<<16)+headerSize)
|
|
for {
|
|
for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() {
|
|
<-s.bucketNotify
|
|
}
|
|
|
|
if f, err := s.readFrame(buffer); err == nil {
|
|
atomic.StoreInt32(&s.dataReady, 1)
|
|
|
|
switch f.cmd {
|
|
case cmdNOP:
|
|
case cmdSYN:
|
|
s.streamLock.Lock()
|
|
if _, ok := s.streams[f.sid]; !ok {
|
|
stream := newStream(f.sid, s.config.MaxFrameSize, s)
|
|
s.streams[f.sid] = stream
|
|
select {
|
|
case s.chAccepts <- stream:
|
|
case <-s.die:
|
|
}
|
|
}
|
|
s.streamLock.Unlock()
|
|
case cmdFIN:
|
|
s.streamLock.Lock()
|
|
if stream, ok := s.streams[f.sid]; ok {
|
|
stream.markRST()
|
|
stream.notifyReadEvent()
|
|
}
|
|
s.streamLock.Unlock()
|
|
case cmdPSH:
|
|
s.streamLock.Lock()
|
|
if stream, ok := s.streams[f.sid]; ok {
|
|
atomic.AddInt32(&s.bucket, -int32(len(f.data)))
|
|
stream.pushBytes(f.data)
|
|
stream.notifyReadEvent()
|
|
}
|
|
s.streamLock.Unlock()
|
|
default:
|
|
s.Close()
|
|
return
|
|
}
|
|
} else {
|
|
s.Close()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Session) keepalive() {
|
|
tickerPing := time.NewTicker(s.config.KeepAliveInterval)
|
|
tickerTimeout := time.NewTicker(s.config.KeepAliveTimeout)
|
|
defer tickerPing.Stop()
|
|
defer tickerTimeout.Stop()
|
|
for {
|
|
select {
|
|
case <-tickerPing.C:
|
|
s.writeFrame(newFrame(cmdNOP, 0))
|
|
s.notifyBucket() // force a signal to the recvLoop
|
|
case <-tickerTimeout.C:
|
|
if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) {
|
|
s.Close()
|
|
return
|
|
}
|
|
case <-s.die:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Session) sendLoop() {
|
|
buf := make([]byte, (1<<16)+headerSize)
|
|
for {
|
|
select {
|
|
case <-s.die:
|
|
return
|
|
case request, ok := <-s.writes:
|
|
if !ok {
|
|
continue
|
|
}
|
|
buf[0] = request.frame.ver
|
|
buf[1] = request.frame.cmd
|
|
binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data)))
|
|
binary.LittleEndian.PutUint32(buf[4:], request.frame.sid)
|
|
copy(buf[headerSize:], request.frame.data)
|
|
n, err := s.conn.Write(buf[:headerSize+len(request.frame.data)])
|
|
|
|
n -= headerSize
|
|
if n < 0 {
|
|
n = 0
|
|
}
|
|
|
|
result := writeResult{
|
|
n: n,
|
|
err: err,
|
|
}
|
|
|
|
request.result <- result
|
|
close(request.result)
|
|
}
|
|
}
|
|
}
|
|
|
|
// writeFrame writes the frame to the underlying connection
|
|
// and returns the number of bytes written if successful
|
|
func (s *Session) writeFrame(f Frame) (n int, err error) {
|
|
req := writeRequest{
|
|
frame: f,
|
|
result: make(chan writeResult, 1),
|
|
}
|
|
select {
|
|
case <-s.die:
|
|
return 0, errors.New(errBrokenPipe)
|
|
case s.writes <- req:
|
|
}
|
|
|
|
result := <-req.result
|
|
return result.n, result.err
|
|
}
|