package quic import ( "context" "fmt" "io" "net" "sync" "time" "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" ) type streamI interface { Stream AddStreamFrame(*wire.StreamFrame) error RegisterRemoteError(error, protocol.ByteCount) error HasDataForWriting() bool GetDataForWriting(maxBytes protocol.ByteCount) (data []byte, shouldSendFin bool) GetWriteOffset() protocol.ByteCount Finished() bool Cancel(error) // methods needed for flow control GetWindowUpdate() protocol.ByteCount UpdateSendWindow(protocol.ByteCount) IsFlowControlBlocked() bool } type cryptoStream interface { streamI SetReadOffset(protocol.ByteCount) } // A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface // // Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually. type stream struct { mutex sync.Mutex ctx context.Context ctxCancel context.CancelFunc streamID protocol.StreamID onData func() // onReset is a callback that should send a RST_STREAM onReset func(protocol.StreamID, protocol.ByteCount) readPosInFrame int writeOffset protocol.ByteCount readOffset protocol.ByteCount // Once set, the errors must not be changed! err error // cancelled is set when Cancel() is called cancelled utils.AtomicBool // finishedReading is set once we read a frame with a FinBit finishedReading utils.AtomicBool // finisedWriting is set once Close() is called finishedWriting utils.AtomicBool // resetLocally is set if Reset() is called resetLocally utils.AtomicBool // resetRemotely is set if RegisterRemoteError() is called resetRemotely utils.AtomicBool frameQueue *streamFrameSorter readChan chan struct{} readDeadline time.Time dataForWriting []byte finSent utils.AtomicBool rstSent utils.AtomicBool writeChan chan struct{} writeDeadline time.Time flowController flowcontrol.StreamFlowController version protocol.VersionNumber } var _ Stream = &stream{} var _ streamI = &stream{} type deadlineError struct{} func (deadlineError) Error() string { return "deadline exceeded" } func (deadlineError) Temporary() bool { return true } func (deadlineError) Timeout() bool { return true } var errDeadline net.Error = &deadlineError{} // newStream creates a new Stream func newStream(StreamID protocol.StreamID, onData func(), onReset func(protocol.StreamID, protocol.ByteCount), flowController flowcontrol.StreamFlowController, version protocol.VersionNumber, ) *stream { s := &stream{ onData: onData, onReset: onReset, streamID: StreamID, flowController: flowController, frameQueue: newStreamFrameSorter(), readChan: make(chan struct{}, 1), writeChan: make(chan struct{}, 1), version: version, } s.ctx, s.ctxCancel = context.WithCancel(context.Background()) return s } // Read implements io.Reader. It is not thread safe! func (s *stream) Read(p []byte) (int, error) { s.mutex.Lock() err := s.err s.mutex.Unlock() if s.cancelled.Get() || s.resetLocally.Get() { return 0, err } if s.finishedReading.Get() { return 0, io.EOF } bytesRead := 0 for bytesRead < len(p) { s.mutex.Lock() frame := s.frameQueue.Head() if frame == nil && bytesRead > 0 { err = s.err s.mutex.Unlock() return bytesRead, err } var err error for { // Stop waiting on errors if s.resetLocally.Get() || s.cancelled.Get() { err = s.err break } deadline := s.readDeadline if !deadline.IsZero() && !time.Now().Before(deadline) { err = errDeadline break } if frame != nil { s.readPosInFrame = int(s.readOffset - frame.Offset) break } s.mutex.Unlock() if deadline.IsZero() { <-s.readChan } else { select { case <-s.readChan: case <-time.After(deadline.Sub(time.Now())): } } s.mutex.Lock() frame = s.frameQueue.Head() } s.mutex.Unlock() if err != nil { return bytesRead, err } m := utils.Min(len(p)-bytesRead, int(frame.DataLen())-s.readPosInFrame) if bytesRead > len(p) { return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) } if s.readPosInFrame > int(frame.DataLen()) { return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, frame.DataLen()) } copy(p[bytesRead:], frame.Data[s.readPosInFrame:]) s.readPosInFrame += m bytesRead += m s.readOffset += protocol.ByteCount(m) // when a RST_STREAM was received, the was already informed about the final byteOffset for this stream if !s.resetRemotely.Get() { s.flowController.AddBytesRead(protocol.ByteCount(m)) } s.onData() // so that a possible WINDOW_UPDATE is sent if s.readPosInFrame >= int(frame.DataLen()) { fin := frame.FinBit s.mutex.Lock() s.frameQueue.Pop() s.mutex.Unlock() if fin { s.finishedReading.Set(true) return bytesRead, io.EOF } } } return bytesRead, nil } func (s *stream) Write(p []byte) (int, error) { s.mutex.Lock() defer s.mutex.Unlock() if s.resetLocally.Get() || s.err != nil { return 0, s.err } if s.finishedWriting.Get() { return 0, fmt.Errorf("write on closed stream %d", s.streamID) } if len(p) == 0 { return 0, nil } s.dataForWriting = make([]byte, len(p)) copy(s.dataForWriting, p) s.onData() var err error for { deadline := s.writeDeadline if !deadline.IsZero() && !time.Now().Before(deadline) { err = errDeadline break } if s.dataForWriting == nil || s.err != nil { break } s.mutex.Unlock() if deadline.IsZero() { <-s.writeChan } else { select { case <-s.writeChan: case <-time.After(deadline.Sub(time.Now())): } } s.mutex.Lock() } if err != nil { return 0, err } if s.err != nil { return len(p) - len(s.dataForWriting), s.err } return len(p), nil } func (s *stream) GetWriteOffset() protocol.ByteCount { return s.writeOffset } // HasDataForWriting says if there's stream available to be dequeued for writing func (s *stream) HasDataForWriting() bool { s.mutex.Lock() hasData := s.err == nil && // nothing should be sent if an error occurred (len(s.dataForWriting) > 0 || // there is data queued for sending s.finishedWriting.Get() && !s.finSent.Get()) // if there is no data, but writing finished and the FIN hasn't been sent yet s.mutex.Unlock() return hasData } func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) { data, shouldSendFin := s.getDataForWritingImpl(maxBytes) if shouldSendFin { s.finSent.Set(true) } return data, shouldSendFin } func (s *stream) getDataForWritingImpl(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) { s.mutex.Lock() defer s.mutex.Unlock() if s.err != nil || s.dataForWriting == nil { return nil, s.finishedWriting.Get() && !s.finSent.Get() } // TODO(#657): Flow control for the crypto stream if s.streamID != s.version.CryptoStreamID() { maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize()) } if maxBytes == 0 { return nil, false } var ret []byte if protocol.ByteCount(len(s.dataForWriting)) > maxBytes { ret = s.dataForWriting[:maxBytes] s.dataForWriting = s.dataForWriting[maxBytes:] } else { ret = s.dataForWriting s.dataForWriting = nil s.signalWrite() } s.writeOffset += protocol.ByteCount(len(ret)) s.flowController.AddBytesSent(protocol.ByteCount(len(ret))) return ret, s.finishedWriting.Get() && s.dataForWriting == nil && !s.finSent.Get() } // Close implements io.Closer func (s *stream) Close() error { s.finishedWriting.Set(true) s.ctxCancel() s.onData() return nil } func (s *stream) shouldSendReset() bool { if s.rstSent.Get() { return false } return (s.resetLocally.Get() || s.resetRemotely.Get()) && !s.finishedWriteAndSentFin() } // AddStreamFrame adds a new stream frame func (s *stream) AddStreamFrame(frame *wire.StreamFrame) error { maxOffset := frame.Offset + frame.DataLen() if err := s.flowController.UpdateHighestReceived(maxOffset, frame.FinBit); err != nil { return err } s.mutex.Lock() defer s.mutex.Unlock() if err := s.frameQueue.Push(frame); err != nil && err != errDuplicateStreamData { return err } s.signalRead() return nil } // signalRead performs a non-blocking send on the readChan func (s *stream) signalRead() { select { case s.readChan <- struct{}{}: default: } } // signalRead performs a non-blocking send on the writeChan func (s *stream) signalWrite() { select { case s.writeChan <- struct{}{}: default: } } func (s *stream) SetReadDeadline(t time.Time) error { s.mutex.Lock() oldDeadline := s.readDeadline s.readDeadline = t s.mutex.Unlock() // if the new deadline is before the currently set deadline, wake up Read() if t.Before(oldDeadline) { s.signalRead() } return nil } func (s *stream) SetWriteDeadline(t time.Time) error { s.mutex.Lock() oldDeadline := s.writeDeadline s.writeDeadline = t s.mutex.Unlock() if t.Before(oldDeadline) { s.signalWrite() } return nil } func (s *stream) SetDeadline(t time.Time) error { _ = s.SetReadDeadline(t) // SetReadDeadline never errors _ = s.SetWriteDeadline(t) // SetWriteDeadline never errors return nil } // CloseRemote makes the stream receive a "virtual" FIN stream frame at a given offset func (s *stream) CloseRemote(offset protocol.ByteCount) { s.AddStreamFrame(&wire.StreamFrame{FinBit: true, Offset: offset}) } // Cancel is called by session to indicate that an error occurred // The stream should will be closed immediately func (s *stream) Cancel(err error) { s.mutex.Lock() s.cancelled.Set(true) s.ctxCancel() // errors must not be changed! if s.err == nil { s.err = err s.signalRead() s.signalWrite() } s.mutex.Unlock() } // resets the stream locally func (s *stream) Reset(err error) { if s.resetLocally.Get() { return } s.mutex.Lock() s.resetLocally.Set(true) s.ctxCancel() // errors must not be changed! if s.err == nil { s.err = err s.signalRead() s.signalWrite() } if s.shouldSendReset() { s.onReset(s.streamID, s.writeOffset) s.rstSent.Set(true) } s.mutex.Unlock() } // resets the stream remotely func (s *stream) RegisterRemoteError(err error, offset protocol.ByteCount) error { if s.resetRemotely.Get() { return nil } s.mutex.Lock() s.resetRemotely.Set(true) s.ctxCancel() // errors must not be changed! if s.err == nil { s.err = err s.signalWrite() } if err := s.flowController.UpdateHighestReceived(offset, true); err != nil { return err } if s.shouldSendReset() { s.onReset(s.streamID, s.writeOffset) s.rstSent.Set(true) } s.mutex.Unlock() return nil } func (s *stream) finishedWriteAndSentFin() bool { return s.finishedWriting.Get() && s.finSent.Get() } func (s *stream) Finished() bool { return s.cancelled.Get() || (s.finishedReading.Get() && s.finishedWriteAndSentFin()) || (s.resetRemotely.Get() && s.rstSent.Get()) || (s.finishedReading.Get() && s.rstSent.Get()) || (s.finishedWriteAndSentFin() && s.resetRemotely.Get()) } func (s *stream) Context() context.Context { return s.ctx } func (s *stream) StreamID() protocol.StreamID { return s.streamID } func (s *stream) UpdateSendWindow(n protocol.ByteCount) { s.flowController.UpdateSendWindow(n) } func (s *stream) IsFlowControlBlocked() bool { return s.flowController.IsBlocked() } func (s *stream) GetWindowUpdate() protocol.ByteCount { return s.flowController.GetWindowUpdate() } // SetReadOffset sets the read offset. // It is only needed for the crypto stream. // It must not be called concurrently with any other stream methods, especially Read and Write. func (s *stream) SetReadOffset(offset protocol.ByteCount) { s.readOffset = offset s.frameQueue.readPosition = offset }