route/vendor/github.com/lucas-clemente/quic-go/streams_map.go

330 lines
8.3 KiB
Go
Raw Normal View History

2017-12-12 02:51:45 +00:00
package quic
import (
"errors"
"fmt"
"sync"
2018-01-03 19:19:49 +00:00
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
2017-12-12 02:51:45 +00:00
"github.com/lucas-clemente/quic-go/qerr"
)
type streamsMap struct {
mutex sync.RWMutex
2018-01-03 19:19:49 +00:00
perspective protocol.Perspective
2017-12-12 02:51:45 +00:00
2018-01-03 19:19:49 +00:00
streams map[protocol.StreamID]streamI
2017-12-12 02:51:45 +00:00
// needed for round-robin scheduling
openStreams []protocol.StreamID
2018-01-03 19:19:49 +00:00
roundRobinIndex int
2017-12-12 02:51:45 +00:00
nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
highestStreamOpenedByPeer protocol.StreamID
nextStreamOrErrCond sync.Cond
openStreamOrErrCond sync.Cond
closeErr error
nextStreamToAccept protocol.StreamID
newStream newStreamLambda
numOutgoingStreams uint32
numIncomingStreams uint32
2018-01-03 19:19:49 +00:00
maxIncomingStreams uint32
maxOutgoingStreams uint32
2017-12-12 02:51:45 +00:00
}
2018-01-03 19:19:49 +00:00
type streamLambda func(streamI) (bool, error)
type newStreamLambda func(protocol.StreamID) streamI
2017-12-12 02:51:45 +00:00
2018-01-03 19:19:49 +00:00
var errMapAccess = errors.New("streamsMap: Error accessing the streams map")
2017-12-12 02:51:45 +00:00
2018-01-03 19:19:49 +00:00
func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver protocol.VersionNumber) *streamsMap {
// add some tolerance to the maximum incoming streams value
maxStreams := uint32(protocol.MaxIncomingStreams)
maxIncomingStreams := utils.MaxUint32(
maxStreams+protocol.MaxStreamsMinimumIncrement,
uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)),
)
2017-12-12 02:51:45 +00:00
sm := streamsMap{
2018-01-03 19:19:49 +00:00
perspective: pers,
streams: make(map[protocol.StreamID]streamI),
openStreams: make([]protocol.StreamID, 0),
newStream: newStream,
maxIncomingStreams: maxIncomingStreams,
2017-12-12 02:51:45 +00:00
}
sm.nextStreamOrErrCond.L = &sm.mutex
sm.openStreamOrErrCond.L = &sm.mutex
2018-01-03 19:19:49 +00:00
nextOddStream := protocol.StreamID(1)
if ver.CryptoStreamID() == protocol.StreamID(1) {
nextOddStream = 3
}
2017-12-12 02:51:45 +00:00
if pers == protocol.PerspectiveClient {
2018-01-03 19:19:49 +00:00
sm.nextStream = nextOddStream
2017-12-12 02:51:45 +00:00
sm.nextStreamToAccept = 2
} else {
sm.nextStream = 2
2018-01-03 19:19:49 +00:00
sm.nextStreamToAccept = nextOddStream
2017-12-12 02:51:45 +00:00
}
return &sm
}
2018-01-03 19:19:49 +00:00
// getStreamPerspective says which side should initiate a stream
func (m *streamsMap) streamInitiatedBy(id protocol.StreamID) protocol.Perspective {
if id%2 == 0 {
return protocol.PerspectiveServer
}
return protocol.PerspectiveClient
}
2017-12-12 02:51:45 +00:00
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used.
2018-01-03 19:19:49 +00:00
func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
2017-12-12 02:51:45 +00:00
m.mutex.RLock()
s, ok := m.streams[id]
m.mutex.RUnlock()
if ok {
return s, nil // s may be nil
}
// ... we don't have an existing stream
m.mutex.Lock()
defer m.mutex.Unlock()
// We need to check whether another invocation has already created a stream (between RUnlock() and Lock()).
s, ok = m.streams[id]
if ok {
return s, nil
}
2018-01-03 19:19:49 +00:00
if m.perspective == m.streamInitiatedBy(id) {
if id <= m.nextStream { // this is a stream opened by us. Must have been closed already
return nil, nil
}
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id))
2017-12-12 02:51:45 +00:00
}
2018-01-03 19:19:49 +00:00
if id <= m.highestStreamOpenedByPeer { // this is a peer-initiated stream that doesn't exist anymore. Must have been closed already
return nil, nil
2017-12-12 02:51:45 +00:00
}
// sid is the next stream that will be opened
sid := m.highestStreamOpenedByPeer + 2
// if there is no stream opened yet, and this is the server, stream 1 should be openend
if sid == 2 && m.perspective == protocol.PerspectiveServer {
sid = 1
}
for ; sid <= id; sid += 2 {
2018-01-03 19:19:49 +00:00
if _, err := m.openRemoteStream(sid); err != nil {
2017-12-12 02:51:45 +00:00
return nil, err
}
}
m.nextStreamOrErrCond.Broadcast()
return m.streams[id], nil
}
2018-01-03 19:19:49 +00:00
func (m *streamsMap) openRemoteStream(id protocol.StreamID) (streamI, error) {
if m.numIncomingStreams >= m.maxIncomingStreams {
2017-12-12 02:51:45 +00:00
return nil, qerr.TooManyOpenStreams
}
if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer))
}
2018-01-03 19:19:49 +00:00
m.numIncomingStreams++
2017-12-12 02:51:45 +00:00
if id > m.highestStreamOpenedByPeer {
m.highestStreamOpenedByPeer = id
}
2018-01-03 19:19:49 +00:00
s := m.newStream(id)
2017-12-12 02:51:45 +00:00
m.putStream(s)
return s, nil
}
2018-01-03 19:19:49 +00:00
func (m *streamsMap) openStreamImpl() (streamI, error) {
2017-12-12 02:51:45 +00:00
id := m.nextStream
2018-01-03 19:19:49 +00:00
if m.numOutgoingStreams >= m.maxOutgoingStreams {
2017-12-12 02:51:45 +00:00
return nil, qerr.TooManyOpenStreams
}
2018-01-03 19:19:49 +00:00
m.numOutgoingStreams++
2017-12-12 02:51:45 +00:00
m.nextStream += 2
2018-01-03 19:19:49 +00:00
s := m.newStream(id)
2017-12-12 02:51:45 +00:00
m.putStream(s)
return s, nil
}
// OpenStream opens the next available stream
2018-01-03 19:19:49 +00:00
func (m *streamsMap) OpenStream() (streamI, error) {
2017-12-12 02:51:45 +00:00
m.mutex.Lock()
defer m.mutex.Unlock()
2018-01-03 19:19:49 +00:00
if m.closeErr != nil {
return nil, m.closeErr
}
2017-12-12 02:51:45 +00:00
return m.openStreamImpl()
}
2018-01-03 19:19:49 +00:00
func (m *streamsMap) OpenStreamSync() (streamI, error) {
2017-12-12 02:51:45 +00:00
m.mutex.Lock()
defer m.mutex.Unlock()
for {
if m.closeErr != nil {
return nil, m.closeErr
}
str, err := m.openStreamImpl()
if err == nil {
return str, err
}
if err != nil && err != qerr.TooManyOpenStreams {
return nil, err
}
m.openStreamOrErrCond.Wait()
}
}
// AcceptStream returns the next stream opened by the peer
// it blocks until a new stream is opened
2018-01-03 19:19:49 +00:00
func (m *streamsMap) AcceptStream() (streamI, error) {
2017-12-12 02:51:45 +00:00
m.mutex.Lock()
defer m.mutex.Unlock()
2018-01-03 19:19:49 +00:00
var str streamI
2017-12-12 02:51:45 +00:00
for {
var ok bool
if m.closeErr != nil {
return nil, m.closeErr
}
str, ok = m.streams[m.nextStreamToAccept]
if ok {
break
}
m.nextStreamOrErrCond.Wait()
}
m.nextStreamToAccept += 2
return str, nil
}
2018-01-03 19:19:49 +00:00
func (m *streamsMap) DeleteClosedStreams() error {
2017-12-12 02:51:45 +00:00
m.mutex.Lock()
defer m.mutex.Unlock()
2018-01-03 19:19:49 +00:00
var numDeletedStreams int
// for every closed stream, the streamID is replaced by 0 in the openStreams slice
for i, streamID := range m.openStreams {
str, ok := m.streams[streamID]
if !ok {
return errMapAccess
}
if !str.Finished() {
continue
}
numDeletedStreams++
m.openStreams[i] = 0
if m.streamInitiatedBy(streamID) == m.perspective {
m.numOutgoingStreams--
} else {
m.numIncomingStreams--
}
delete(m.streams, streamID)
}
if numDeletedStreams == 0 {
return nil
2017-12-12 02:51:45 +00:00
}
2018-01-03 19:19:49 +00:00
// remove all 0s (representing closed streams) from the openStreams slice
// and adjust the roundRobinIndex
var j int
for i, id := range m.openStreams {
if i != j {
m.openStreams[j] = m.openStreams[i]
2017-12-12 02:51:45 +00:00
}
2018-01-03 19:19:49 +00:00
if id != 0 {
j++
} else if j < m.roundRobinIndex {
m.roundRobinIndex--
2017-12-12 02:51:45 +00:00
}
}
2018-01-03 19:19:49 +00:00
m.openStreams = m.openStreams[:len(m.openStreams)-numDeletedStreams]
m.openStreamOrErrCond.Signal()
2017-12-12 02:51:45 +00:00
return nil
}
// RoundRobinIterate executes the streamLambda for every open stream, until the streamLambda returns false
// It uses a round-robin-like scheduling to ensure that every stream is considered fairly
2018-01-03 19:19:49 +00:00
// It prioritizes the the header-stream (StreamID 3)
2017-12-12 02:51:45 +00:00
func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
m.mutex.Lock()
defer m.mutex.Unlock()
2018-01-03 19:19:49 +00:00
numStreams := len(m.streams)
2017-12-12 02:51:45 +00:00
startIndex := m.roundRobinIndex
2018-01-03 19:19:49 +00:00
for i := 0; i < numStreams; i++ {
2017-12-12 02:51:45 +00:00
streamID := m.openStreams[(i+startIndex)%numStreams]
cont, err := m.iterateFunc(streamID, fn)
if err != nil {
return err
}
m.roundRobinIndex = (m.roundRobinIndex + 1) % numStreams
if !cont {
break
}
}
return nil
}
2018-01-03 19:19:49 +00:00
// Range executes a callback for all streams, in pseudo-random order
func (m *streamsMap) Range(cb func(s streamI)) {
m.mutex.RLock()
defer m.mutex.RUnlock()
for _, s := range m.streams {
if s != nil {
cb(s)
}
}
}
2017-12-12 02:51:45 +00:00
func (m *streamsMap) iterateFunc(streamID protocol.StreamID, fn streamLambda) (bool, error) {
str, ok := m.streams[streamID]
if !ok {
return true, errMapAccess
}
return fn(str)
}
2018-01-03 19:19:49 +00:00
func (m *streamsMap) putStream(s streamI) error {
2017-12-12 02:51:45 +00:00
id := s.StreamID()
if _, ok := m.streams[id]; ok {
return fmt.Errorf("a stream with ID %d already exists", id)
}
m.streams[id] = s
m.openStreams = append(m.openStreams, id)
return nil
}
func (m *streamsMap) CloseWithError(err error) {
m.mutex.Lock()
2018-01-03 19:19:49 +00:00
defer m.mutex.Unlock()
2017-12-12 02:51:45 +00:00
m.closeErr = err
m.nextStreamOrErrCond.Broadcast()
m.openStreamOrErrCond.Broadcast()
2018-01-03 19:19:49 +00:00
for _, s := range m.openStreams {
m.streams[s].Cancel(err)
}
}
func (m *streamsMap) UpdateMaxStreamLimit(limit uint32) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.maxOutgoingStreams = limit
m.openStreamOrErrCond.Broadcast()
2017-12-12 02:51:45 +00:00
}