package smux import ( "encoding/binary" "io" "sync" "sync/atomic" "time" "github.com/pkg/errors" ) const ( defaultAcceptBacklog = 1024 ) const ( errBrokenPipe = "broken pipe" errInvalidProtocol = "invalid protocol version" errGoAway = "stream id overflows, should start a new connection" ) type writeRequest struct { frame Frame result chan writeResult } type writeResult struct { n int err error } // Session defines a multiplexed connection for streams type Session struct { conn io.ReadWriteCloser config *Config nextStreamID uint32 // next stream identifier nextStreamIDLock sync.Mutex bucket int32 // token bucket bucketNotify chan struct{} // used for waiting for tokens streams map[uint32]*Stream // all streams in this session streamLock sync.Mutex // locks streams die chan struct{} // flag session has died dieLock sync.Mutex chAccepts chan *Stream dataReady int32 // flag data has arrived goAway int32 // flag id exhausted deadline atomic.Value writes chan writeRequest } func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { s := new(Session) s.die = make(chan struct{}) s.conn = conn s.config = config s.streams = make(map[uint32]*Stream) s.chAccepts = make(chan *Stream, defaultAcceptBacklog) s.bucket = int32(config.MaxReceiveBuffer) s.bucketNotify = make(chan struct{}, 1) s.writes = make(chan writeRequest) if client { s.nextStreamID = 1 } else { s.nextStreamID = 0 } go s.recvLoop() go s.sendLoop() go s.keepalive() return s } // OpenStream is used to create a new stream func (s *Session) OpenStream() (*Stream, error) { if s.IsClosed() { return nil, errors.New(errBrokenPipe) } // generate stream id s.nextStreamIDLock.Lock() if s.goAway > 0 { s.nextStreamIDLock.Unlock() return nil, errors.New(errGoAway) } s.nextStreamID += 2 sid := s.nextStreamID if sid == sid%2 { // stream-id overflows s.goAway = 1 s.nextStreamIDLock.Unlock() return nil, errors.New(errGoAway) } s.nextStreamIDLock.Unlock() stream := newStream(sid, s.config.MaxFrameSize, s) if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil { return nil, errors.Wrap(err, "writeFrame") } s.streamLock.Lock() s.streams[sid] = stream s.streamLock.Unlock() return stream, nil } // AcceptStream is used to block until the next available stream // is ready to be accepted. func (s *Session) AcceptStream() (*Stream, error) { var deadline <-chan time.Time if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() { timer := time.NewTimer(d.Sub(time.Now())) defer timer.Stop() deadline = timer.C } select { case stream := <-s.chAccepts: return stream, nil case <-deadline: return nil, errTimeout case <-s.die: return nil, errors.New(errBrokenPipe) } } // Close is used to close the session and all streams. func (s *Session) Close() (err error) { s.dieLock.Lock() select { case <-s.die: s.dieLock.Unlock() return errors.New(errBrokenPipe) default: close(s.die) s.dieLock.Unlock() s.streamLock.Lock() for k := range s.streams { s.streams[k].sessionClose() } s.streamLock.Unlock() s.notifyBucket() return s.conn.Close() } } // notifyBucket notifies recvLoop that bucket is available func (s *Session) notifyBucket() { select { case s.bucketNotify <- struct{}{}: default: } } // IsClosed does a safe check to see if we have shutdown func (s *Session) IsClosed() bool { select { case <-s.die: return true default: return false } } // NumStreams returns the number of currently open streams func (s *Session) NumStreams() int { if s.IsClosed() { return 0 } s.streamLock.Lock() defer s.streamLock.Unlock() return len(s.streams) } // SetDeadline sets a deadline used by Accept* calls. // A zero time value disables the deadline. func (s *Session) SetDeadline(t time.Time) error { s.deadline.Store(t) return nil } // notify the session that a stream has closed func (s *Session) streamClosed(sid uint32) { s.streamLock.Lock() if n := s.streams[sid].recycleTokens(); n > 0 { // return remaining tokens to the bucket if atomic.AddInt32(&s.bucket, int32(n)) > 0 { s.notifyBucket() } } delete(s.streams, sid) s.streamLock.Unlock() } // returnTokens is called by stream to return token after read func (s *Session) returnTokens(n int) { if atomic.AddInt32(&s.bucket, int32(n)) > 0 { s.notifyBucket() } } // session read a frame from underlying connection // it's data is pointed to the input buffer func (s *Session) readFrame(buffer []byte) (f Frame, err error) { if _, err := io.ReadFull(s.conn, buffer[:headerSize]); err != nil { return f, errors.Wrap(err, "readFrame") } dec := rawHeader(buffer) if dec.Version() != version { return f, errors.New(errInvalidProtocol) } f.ver = dec.Version() f.cmd = dec.Cmd() f.sid = dec.StreamID() if length := dec.Length(); length > 0 { if _, err := io.ReadFull(s.conn, buffer[headerSize:headerSize+length]); err != nil { return f, errors.Wrap(err, "readFrame") } f.data = buffer[headerSize : headerSize+length] } return f, nil } // recvLoop keeps on reading from underlying connection if tokens are available func (s *Session) recvLoop() { buffer := make([]byte, (1<<16)+headerSize) for { for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() { <-s.bucketNotify } if f, err := s.readFrame(buffer); err == nil { atomic.StoreInt32(&s.dataReady, 1) switch f.cmd { case cmdNOP: case cmdSYN: s.streamLock.Lock() if _, ok := s.streams[f.sid]; !ok { stream := newStream(f.sid, s.config.MaxFrameSize, s) s.streams[f.sid] = stream select { case s.chAccepts <- stream: case <-s.die: } } s.streamLock.Unlock() case cmdFIN: s.streamLock.Lock() if stream, ok := s.streams[f.sid]; ok { stream.markRST() stream.notifyReadEvent() } s.streamLock.Unlock() case cmdPSH: s.streamLock.Lock() if stream, ok := s.streams[f.sid]; ok { atomic.AddInt32(&s.bucket, -int32(len(f.data))) stream.pushBytes(f.data) stream.notifyReadEvent() } s.streamLock.Unlock() default: s.Close() return } } else { s.Close() return } } } func (s *Session) keepalive() { tickerPing := time.NewTicker(s.config.KeepAliveInterval) tickerTimeout := time.NewTicker(s.config.KeepAliveTimeout) defer tickerPing.Stop() defer tickerTimeout.Stop() for { select { case <-tickerPing.C: s.writeFrame(newFrame(cmdNOP, 0)) s.notifyBucket() // force a signal to the recvLoop case <-tickerTimeout.C: if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) { s.Close() return } case <-s.die: return } } } func (s *Session) sendLoop() { buf := make([]byte, (1<<16)+headerSize) for { select { case <-s.die: return case request, ok := <-s.writes: if !ok { continue } buf[0] = request.frame.ver buf[1] = request.frame.cmd binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data))) binary.LittleEndian.PutUint32(buf[4:], request.frame.sid) copy(buf[headerSize:], request.frame.data) n, err := s.conn.Write(buf[:headerSize+len(request.frame.data)]) n -= headerSize if n < 0 { n = 0 } result := writeResult{ n: n, err: err, } request.result <- result close(request.result) } } } // writeFrame writes the frame to the underlying connection // and returns the number of bytes written if successful func (s *Session) writeFrame(f Frame) (n int, err error) { req := writeRequest{ frame: f, result: make(chan writeResult, 1), } select { case <-s.die: return 0, errors.New(errBrokenPipe) case s.writes <- req: } result := <-req.result return result.n, result.err }