update quic

This commit is contained in:
Cadey Ratio 2018-01-20 10:07:01 -08:00
parent 3e711e63dc
commit d0695adfb6
104 changed files with 7622 additions and 4515 deletions

6
Gopkg.lock generated
View File

@ -421,6 +421,7 @@
revision = "393af48d391698c6ae4219566bfbdfef67269997" revision = "393af48d391698c6ae4219566bfbdfef67269997"
[[projects]] [[projects]]
branch = "master"
name = "github.com/lucas-clemente/quic-go" name = "github.com/lucas-clemente/quic-go"
packages = [ packages = [
".", ".",
@ -435,8 +436,7 @@
"internal/wire", "internal/wire",
"qerr" "qerr"
] ]
revision = "ded0eb4f6f30a8049bd334a26ff7ff728648fe13" revision = "15bcc2579f7dab14c84f438741f2b535cf474df9"
version = "v0.6.0"
[[projects]] [[projects]]
branch = "master" branch = "master"
@ -753,6 +753,6 @@
[solve-meta] [solve-meta]
analyzer-name = "dep" analyzer-name = "dep"
analyzer-version = 1 analyzer-version = 1
inputs-digest = "a11e1692755a705514dbd401ba4795821d1ac221d6f9100124c38a29db98c568" inputs-digest = "97c8282ef9b3abed71907d17ccf38379134714596610880b02d5ca03be634678"
solver-name = "gps-cdcl" solver-name = "gps-cdcl"
solver-version = 1 solver-version = 1

View File

@ -0,0 +1,137 @@
# Gopkg.toml example
#
# Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md
# for detailed Gopkg.toml documentation.
#
# required = ["github.com/user/thing/cmd/thing"]
# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"]
#
# [[constraint]]
# name = "github.com/user/project"
# version = "1.0.0"
#
# [[constraint]]
# name = "github.com/user/project2"
# branch = "dev"
# source = "github.com/myfork/project2"
#
# [[override]]
# name = "github.com/x/y"
# version = "2.4.0"
[[constraint]]
branch = "master"
name = "github.com/Xe/gopreload"
[[constraint]]
name = "github.com/Xe/ln"
version = "0.1.0"
[[constraint]]
branch = "master"
name = "github.com/Xe/uuid"
[[constraint]]
branch = "master"
name = "github.com/Xe/x"
[[constraint]]
name = "github.com/asdine/storm"
version = "2.0.2"
[[constraint]]
branch = "master"
name = "github.com/brandur/simplebox"
[[constraint]]
name = "github.com/caarlos0/env"
version = "3.2.0"
[[constraint]]
branch = "master"
name = "github.com/dgryski/go-failure"
[[constraint]]
branch = "master"
name = "github.com/dickeyxxx/netrc"
[[constraint]]
branch = "master"
name = "github.com/facebookgo/flagenv"
[[constraint]]
branch = "master"
name = "github.com/golang/protobuf"
[[constraint]]
name = "github.com/google/gops"
version = "0.3.2"
[[constraint]]
name = "github.com/hashicorp/terraform"
version = "0.11.2"
[[constraint]]
name = "github.com/joho/godotenv"
version = "1.2.0"
[[constraint]]
branch = "master"
name = "github.com/jtolds/qod"
[[constraint]]
branch = "master"
name = "github.com/kr/pretty"
[[constraint]]
name = "github.com/lucas-clemente/quic-go"
branch = "master"
[[constraint]]
name = "github.com/magefile/mage"
version = "2.0.1"
[[constraint]]
branch = "master"
name = "github.com/mtneug/pkg"
[[constraint]]
branch = "master"
name = "github.com/olekukonko/tablewriter"
[[constraint]]
name = "github.com/pkg/errors"
version = "0.8.0"
[[constraint]]
branch = "master"
name = "github.com/streamrail/concurrent-map"
[[constraint]]
name = "github.com/xtaci/kcp-go"
version = "3.23.0"
[[constraint]]
name = "github.com/xtaci/smux"
version = "1.0.6"
[[constraint]]
name = "go.uber.org/atomic"
version = "1.3.1"
[[constraint]]
branch = "master"
name = "golang.org/x/crypto"
[[constraint]]
branch = "master"
name = "golang.org/x/net"
[[constraint]]
name = "google.golang.org/grpc"
version = "1.9.2"
[[constraint]]
name = "gopkg.in/alecthomas/kingpin.v2"
version = "2.2.6"

View File

@ -1,4 +1,5 @@
dist: trusty dist: trusty
group: travis_latest
addons: addons:
hosts: hosts:
@ -8,6 +9,7 @@ language: go
go: go:
- 1.9.2 - 1.9.2
- 1.10beta1
# first part of the GOARCH workaround # first part of the GOARCH workaround
# setting the GOARCH directly doesn't work, since the value will be overwritten later # setting the GOARCH directly doesn't work, since the value will be overwritten later
@ -30,6 +32,7 @@ before_install:
- export GOARCH=$TRAVIS_GOARCH - export GOARCH=$TRAVIS_GOARCH
- go env # for debugging - go env # for debugging
- google-chrome --version - google-chrome --version
- "printf \"quic.clemente.io certificate valid until: \" && openssl x509 -in example/fullchain.pem -enddate -noout | cut -d = -f 2"
- "export DISPLAY=:99.0" - "export DISPLAY=:99.0"
- "Xvfb $DISPLAY &> /dev/null &" - "Xvfb $DISPLAY &> /dev/null &"

View File

@ -1,6 +1,10 @@
# Changelog # Changelog
## v0.6.1 (unreleased) ## v0.7 (unreleased)
- The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored.
- Remove `DialNonFWSecure` and `DialAddrNonFWSecure`.
- Expose the `ConnectionState` in the `Session` (experimental API).
## v0.6.0 (2017-12-12) ## v0.6.0 (2017-12-12)

View File

@ -1,6 +1,7 @@
package ackhandler package ackhandler
import ( import (
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -11,3 +12,13 @@ func TestCrypto(t *testing.T) {
RegisterFailHandler(Fail) RegisterFailHandler(Fail)
RunSpecs(t, "AckHandler Suite") RunSpecs(t, "AckHandler Suite")
} }
var mockCtrl *gomock.Controller
var _ = BeforeEach(func() {
mockCtrl = gomock.NewController(GinkgoT())
})
var _ = AfterEach(func() {
mockCtrl.Finish()
})

View File

@ -16,6 +16,7 @@ type SentPacketHandler interface {
SendingAllowed() bool SendingAllowed() bool
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
ShouldSendRetransmittablePacket() bool ShouldSendRetransmittablePacket() bool
DequeuePacketForRetransmission() (packet *Packet) DequeuePacketForRetransmission() (packet *Packet)
GetLeastUnacked() protocol.PacketNumber GetLeastUnacked() protocol.PacketNumber
@ -26,7 +27,7 @@ type SentPacketHandler interface {
// ReceivedPacketHandler handles ACKs needed to send for incoming packets // ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface { type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error
IgnoreBelow(protocol.PacketNumber) IgnoreBelow(protocol.PacketNumber)
GetAlarmTimeout() time.Time GetAlarmTimeout() time.Time

View File

@ -15,7 +15,8 @@ type Packet struct {
Length protocol.ByteCount Length protocol.ByteCount
EncryptionLevel protocol.EncryptionLevel EncryptionLevel protocol.EncryptionLevel
SendTime time.Time largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK
sendTime time.Time
} }
// GetFramesForRetransmission gets all the frames for retransmission // GetFramesForRetransmission gets all the frames for retransmission

View File

@ -34,10 +34,10 @@ func NewReceivedPacketHandler(version protocol.VersionNumber) ReceivedPacketHand
} }
} }
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error { func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
if packetNumber > h.largestObserved { if packetNumber > h.largestObserved {
h.largestObserved = packetNumber h.largestObserved = packetNumber
h.largestObservedReceivedTime = time.Now() h.largestObservedReceivedTime = rcvTime
} }
if packetNumber < h.ignoreBelow { if packetNumber < h.ignoreBelow {
@ -47,7 +47,7 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil { if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
return err return err
} }
h.maybeQueueAck(packetNumber, shouldInstigateAck) h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck)
return nil return nil
} }
@ -58,7 +58,7 @@ func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
h.packetHistory.DeleteBelow(p) h.packetHistory.DeleteBelow(p)
} }
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) { func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) {
h.packetsReceivedSinceLastAck++ h.packetsReceivedSinceLastAck++
if shouldInstigateAck { if shouldInstigateAck {
@ -86,7 +86,7 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
h.ackQueued = true h.ackQueued = true
} else { } else {
if h.ackAlarm.IsZero() { if h.ackAlarm.IsZero() {
h.ackAlarm = time.Now().Add(h.ackSendDelay) h.ackAlarm = rcvTime.Add(h.ackSendDelay)
} }
} }
} }

View File

@ -21,34 +21,36 @@ var _ = Describe("receivedPacketHandler", func() {
Context("accepting packets", func() { Context("accepting packets", func() {
It("handles a packet that arrives late", func() { It("handles a packet that arrives late", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(1), true) err := handler.ReceivedPacket(protocol.PacketNumber(1), time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(protocol.PacketNumber(3), true) err = handler.ReceivedPacket(protocol.PacketNumber(3), time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(protocol.PacketNumber(2), true) err = handler.ReceivedPacket(protocol.PacketNumber(2), time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
It("saves the time when each packet arrived", func() { It("saves the time when each packet arrived", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(3), true) err := handler.ReceivedPacket(protocol.PacketNumber(3), time.Now(), true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond)) Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond))
}) })
It("updates the largestObserved and the largestObservedReceivedTime", func() { It("updates the largestObserved and the largestObservedReceivedTime", func() {
now := time.Now()
handler.largestObserved = 3 handler.largestObserved = 3
handler.largestObservedReceivedTime = time.Now().Add(-1 * time.Second) handler.largestObservedReceivedTime = now.Add(-1 * time.Second)
err := handler.ReceivedPacket(5, true) err := handler.ReceivedPacket(5, now, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5))) Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5)))
Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond)) Expect(handler.largestObservedReceivedTime).To(Equal(now))
}) })
It("doesn't update the largestObserved and the largestObservedReceivedTime for a belated packet", func() { It("doesn't update the largestObserved and the largestObservedReceivedTime for a belated packet", func() {
timestamp := time.Now().Add(-1 * time.Second) now := time.Now()
timestamp := now.Add(-1 * time.Second)
handler.largestObserved = 5 handler.largestObserved = 5
handler.largestObservedReceivedTime = timestamp handler.largestObservedReceivedTime = timestamp
err := handler.ReceivedPacket(4, true) err := handler.ReceivedPacket(4, now, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5))) Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5)))
Expect(handler.largestObservedReceivedTime).To(Equal(timestamp)) Expect(handler.largestObservedReceivedTime).To(Equal(timestamp))
@ -57,7 +59,7 @@ var _ = Describe("receivedPacketHandler", func() {
It("passes on errors from receivedPacketHistory", func() { It("passes on errors from receivedPacketHistory", func() {
var err error var err error
for i := protocol.PacketNumber(0); i < 5*protocol.MaxTrackedReceivedAckRanges; i++ { for i := protocol.PacketNumber(0); i < 5*protocol.MaxTrackedReceivedAckRanges; i++ {
err = handler.ReceivedPacket(2*i+1, true) err = handler.ReceivedPacket(2*i+1, time.Time{}, true)
// this will eventually return an error // this will eventually return an error
// details about when exactly the receivedPacketHistory errors are tested there // details about when exactly the receivedPacketHistory errors are tested there
if err != nil { if err != nil {
@ -72,7 +74,7 @@ var _ = Describe("receivedPacketHandler", func() {
Context("queueing ACKs", func() { Context("queueing ACKs", func() {
receiveAndAck10Packets := func() { receiveAndAck10Packets := func() {
for i := 1; i <= 10; i++ { for i := 1; i <= 10; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i), true) err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }
Expect(handler.GetAckFrame()).ToNot(BeNil()) Expect(handler.GetAckFrame()).ToNot(BeNil())
@ -80,14 +82,14 @@ var _ = Describe("receivedPacketHandler", func() {
} }
It("always queues an ACK for the first packet", func() { It("always queues an ACK for the first packet", func() {
err := handler.ReceivedPacket(1, false) err := handler.ReceivedPacket(1, time.Time{}, false)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue()) Expect(handler.ackQueued).To(BeTrue())
Expect(handler.GetAlarmTimeout()).To(BeZero()) Expect(handler.GetAlarmTimeout()).To(BeZero())
}) })
It("works with packet number 0", func() { It("works with packet number 0", func() {
err := handler.ReceivedPacket(0, false) err := handler.ReceivedPacket(0, time.Time{}, false)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue()) Expect(handler.ackQueued).To(BeTrue())
Expect(handler.GetAlarmTimeout()).To(BeZero()) Expect(handler.GetAlarmTimeout()).To(BeZero())
@ -95,11 +97,11 @@ var _ = Describe("receivedPacketHandler", func() {
It("queues an ACK for every second retransmittable packet, if they are arriving fast", func() { It("queues an ACK for every second retransmittable packet, if they are arriving fast", func() {
receiveAndAck10Packets() receiveAndAck10Packets()
err := handler.ReceivedPacket(11, true) err := handler.ReceivedPacket(11, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse()) Expect(handler.ackQueued).To(BeFalse())
Expect(handler.GetAlarmTimeout()).NotTo(BeZero()) Expect(handler.GetAlarmTimeout()).NotTo(BeZero())
err = handler.ReceivedPacket(12, true) err = handler.ReceivedPacket(12, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue()) Expect(handler.ackQueued).To(BeTrue())
Expect(handler.GetAlarmTimeout()).To(BeZero()) Expect(handler.GetAlarmTimeout()).To(BeZero())
@ -107,11 +109,11 @@ var _ = Describe("receivedPacketHandler", func() {
It("only sets the timer when receiving a retransmittable packets", func() { It("only sets the timer when receiving a retransmittable packets", func() {
receiveAndAck10Packets() receiveAndAck10Packets()
err := handler.ReceivedPacket(11, false) err := handler.ReceivedPacket(11, time.Time{}, false)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse()) Expect(handler.ackQueued).To(BeFalse())
Expect(handler.ackAlarm).To(BeZero()) Expect(handler.ackAlarm).To(BeZero())
err = handler.ReceivedPacket(12, true) err = handler.ReceivedPacket(12, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse()) Expect(handler.ackQueued).To(BeFalse())
Expect(handler.ackAlarm).ToNot(BeZero()) Expect(handler.ackAlarm).ToNot(BeZero())
@ -120,15 +122,15 @@ var _ = Describe("receivedPacketHandler", func() {
It("queues an ACK if it was reported missing before", func() { It("queues an ACK if it was reported missing before", func() {
receiveAndAck10Packets() receiveAndAck10Packets()
err := handler.ReceivedPacket(11, true) err := handler.ReceivedPacket(11, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(13, true) err = handler.ReceivedPacket(13, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame() // ACK: 1 and 3, missing: 2 ack := handler.GetAckFrame() // ACK: 1 and 3, missing: 2
Expect(ack).ToNot(BeNil()) Expect(ack).ToNot(BeNil())
Expect(ack.HasMissingRanges()).To(BeTrue()) Expect(ack.HasMissingRanges()).To(BeTrue())
Expect(handler.ackQueued).To(BeFalse()) Expect(handler.ackQueued).To(BeFalse())
err = handler.ReceivedPacket(12, false) err = handler.ReceivedPacket(12, time.Time{}, false)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue()) Expect(handler.ackQueued).To(BeTrue())
}) })
@ -136,10 +138,10 @@ var _ = Describe("receivedPacketHandler", func() {
It("queues an ACK if it creates a new missing range", func() { It("queues an ACK if it creates a new missing range", func() {
receiveAndAck10Packets() receiveAndAck10Packets()
for i := 11; i < 16; i++ { for i := 11; i < 16; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i), true) err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }
err := handler.ReceivedPacket(20, true) // we now know that packets 16 to 19 are missing err := handler.ReceivedPacket(20, time.Time{}, true) // we now know that packets 16 to 19 are missing
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue()) Expect(handler.ackQueued).To(BeTrue())
ack := handler.GetAckFrame() ack := handler.GetAckFrame()
@ -154,9 +156,9 @@ var _ = Describe("receivedPacketHandler", func() {
}) })
It("generates a simple ACK frame", func() { It("generates a simple ACK frame", func() {
err := handler.ReceivedPacket(1, true) err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(2, true) err = handler.ReceivedPacket(2, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame() ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil()) Expect(ack).ToNot(BeNil())
@ -166,7 +168,7 @@ var _ = Describe("receivedPacketHandler", func() {
}) })
It("generates an ACK for packet number 0", func() { It("generates an ACK for packet number 0", func() {
err := handler.ReceivedPacket(0, true) err := handler.ReceivedPacket(0, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame() ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil()) Expect(ack).ToNot(BeNil())
@ -176,12 +178,12 @@ var _ = Describe("receivedPacketHandler", func() {
}) })
It("saves the last sent ACK", func() { It("saves the last sent ACK", func() {
err := handler.ReceivedPacket(1, true) err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame() ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil()) Expect(ack).ToNot(BeNil())
Expect(handler.lastAck).To(Equal(ack)) Expect(handler.lastAck).To(Equal(ack))
err = handler.ReceivedPacket(2, true) err = handler.ReceivedPacket(2, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
handler.ackQueued = true handler.ackQueued = true
ack = handler.GetAckFrame() ack = handler.GetAckFrame()
@ -190,9 +192,9 @@ var _ = Describe("receivedPacketHandler", func() {
}) })
It("generates an ACK frame with missing packets", func() { It("generates an ACK frame with missing packets", func() {
err := handler.ReceivedPacket(1, true) err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(4, true) err = handler.ReceivedPacket(4, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame() ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil()) Expect(ack).ToNot(BeNil())
@ -205,11 +207,11 @@ var _ = Describe("receivedPacketHandler", func() {
}) })
It("generates an ACK for packet number 0 and other packets", func() { It("generates an ACK for packet number 0 and other packets", func() {
err := handler.ReceivedPacket(0, true) err := handler.ReceivedPacket(0, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(1, true) err = handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(3, true) err = handler.ReceivedPacket(3, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame() ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil()) Expect(ack).ToNot(BeNil())
@ -223,15 +225,15 @@ var _ = Describe("receivedPacketHandler", func() {
It("accepts packets below the lower limit", func() { It("accepts packets below the lower limit", func() {
handler.IgnoreBelow(6) handler.IgnoreBelow(6)
err := handler.ReceivedPacket(2, true) err := handler.ReceivedPacket(2, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
It("doesn't add delayed packets to the packetHistory", func() { It("doesn't add delayed packets to the packetHistory", func() {
handler.IgnoreBelow(7) handler.IgnoreBelow(7)
err := handler.ReceivedPacket(4, true) err := handler.ReceivedPacket(4, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(10, true) err = handler.ReceivedPacket(10, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame() ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil()) Expect(ack).ToNot(BeNil())
@ -241,7 +243,7 @@ var _ = Describe("receivedPacketHandler", func() {
It("deletes packets from the packetHistory when a lower limit is set", func() { It("deletes packets from the packetHistory when a lower limit is set", func() {
for i := 1; i <= 12; i++ { for i := 1; i <= 12; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i), true) err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }
handler.IgnoreBelow(7) handler.IgnoreBelow(7)
@ -256,7 +258,7 @@ var _ = Describe("receivedPacketHandler", func() {
// TODO: remove this test when dropping support for STOP_WAITINGs // TODO: remove this test when dropping support for STOP_WAITINGs
It("handles a lower limit of 0", func() { It("handles a lower limit of 0", func() {
handler.IgnoreBelow(0) handler.IgnoreBelow(0)
err := handler.ReceivedPacket(1337, true) err := handler.ReceivedPacket(1337, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame() ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil()) Expect(ack).ToNot(BeNil())
@ -264,7 +266,7 @@ var _ = Describe("receivedPacketHandler", func() {
}) })
It("resets all counters needed for the ACK queueing decision when sending an ACK", func() { It("resets all counters needed for the ACK queueing decision when sending an ACK", func() {
err := handler.ReceivedPacket(1, true) err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
handler.ackAlarm = time.Now().Add(-time.Minute) handler.ackAlarm = time.Now().Add(-time.Minute)
Expect(handler.GetAckFrame()).ToNot(BeNil()) Expect(handler.GetAckFrame()).ToNot(BeNil())
@ -275,7 +277,7 @@ var _ = Describe("receivedPacketHandler", func() {
}) })
It("doesn't generate an ACK when none is queued and the timer is not set", func() { It("doesn't generate an ACK when none is queued and the timer is not set", func() {
err := handler.ReceivedPacket(1, true) err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
handler.ackQueued = false handler.ackQueued = false
handler.ackAlarm = time.Time{} handler.ackAlarm = time.Time{}
@ -283,7 +285,7 @@ var _ = Describe("receivedPacketHandler", func() {
}) })
It("doesn't generate an ACK when none is queued and the timer has not yet expired", func() { It("doesn't generate an ACK when none is queued and the timer has not yet expired", func() {
err := handler.ReceivedPacket(1, true) err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
handler.ackQueued = false handler.ackQueued = false
handler.ackAlarm = time.Now().Add(time.Minute) handler.ackAlarm = time.Now().Add(time.Minute)
@ -291,7 +293,7 @@ var _ = Describe("receivedPacketHandler", func() {
}) })
It("generates an ACK when the timer has expired", func() { It("generates an ACK when the timer has expired", func() {
err := handler.ReceivedPacket(1, true) err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
handler.ackQueued = false handler.ackQueued = false
handler.ackAlarm = time.Now().Add(-time.Minute) handler.ackAlarm = time.Now().Add(-time.Minute)

View File

@ -40,6 +40,10 @@ type sentPacketHandler struct {
largestAcked protocol.PacketNumber largestAcked protocol.PacketNumber
largestReceivedPacketWithAck protocol.PacketNumber largestReceivedPacketWithAck protocol.PacketNumber
// lowestPacketNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
// example: we send an ACK for packets 90-100 with packet number 20
// once we receive an ACK from the peer for packet 20, the lowestPacketNotConfirmedAcked is 101
lowestPacketNotConfirmedAcked protocol.PacketNumber
packetHistory *PacketList packetHistory *PacketList
stopWaitingManager stopWaitingManager stopWaitingManager stopWaitingManager
@ -95,6 +99,13 @@ func (h *sentPacketHandler) ShouldSendRetransmittablePacket() bool {
} }
func (h *sentPacketHandler) SetHandshakeComplete() { func (h *sentPacketHandler) SetHandshakeComplete() {
var queue []*Packet
for _, packet := range h.retransmissionQueue {
if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
queue = append(queue, packet)
}
}
h.retransmissionQueue = queue
h.handshakeComplete = true h.handshakeComplete = true
} }
@ -114,11 +125,19 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
h.lastSentPacketNumber = packet.PacketNumber h.lastSentPacketNumber = packet.PacketNumber
now := time.Now() now := time.Now()
var largestAcked protocol.PacketNumber
if len(packet.Frames) > 0 {
if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok {
largestAcked = ackFrame.LargestAcked
}
}
packet.Frames = stripNonRetransmittableFrames(packet.Frames) packet.Frames = stripNonRetransmittableFrames(packet.Frames)
isRetransmittable := len(packet.Frames) != 0 isRetransmittable := len(packet.Frames) != 0
if isRetransmittable { if isRetransmittable {
packet.SendTime = now packet.sendTime = now
packet.largestAcked = largestAcked
h.bytesInFlight += packet.Length h.bytesInFlight += packet.Length
h.packetHistory.PushBack(*packet) h.packetHistory.PushBack(*packet)
h.numNonRetransmittablePackets = 0 h.numNonRetransmittablePackets = 0
@ -134,7 +153,7 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
isRetransmittable, isRetransmittable,
) )
h.updateLossDetectionAlarm() h.updateLossDetectionAlarm(now)
return nil return nil
} }
@ -146,14 +165,12 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
// duplicate or out-of-order ACK // duplicate or out-of-order ACK
// if withPacketNumber <= h.largestReceivedPacketWithAck && withPacketNumber != 0 { // if withPacketNumber <= h.largestReceivedPacketWithAck && withPacketNumber != 0 {
if withPacketNumber <= h.largestReceivedPacketWithAck { if withPacketNumber <= h.largestReceivedPacketWithAck {
utils.Debugf("ignoring ack because duplicate")
return ErrDuplicateOrOutOfOrderAck return ErrDuplicateOrOutOfOrderAck
} }
h.largestReceivedPacketWithAck = withPacketNumber h.largestReceivedPacketWithAck = withPacketNumber
// ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK) // ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK)
if ackFrame.LargestAcked < h.lowestUnacked() { if ackFrame.LargestAcked < h.lowestUnacked() {
utils.Debugf("ignoring ack because repeated")
return nil return nil
} }
h.largestAcked = ackFrame.LargestAcked h.largestAcked = ackFrame.LargestAcked
@ -178,13 +195,19 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
if encLevel < p.Value.EncryptionLevel { if encLevel < p.Value.EncryptionLevel {
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.Value.PacketNumber, p.Value.EncryptionLevel) return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.Value.PacketNumber, p.Value.EncryptionLevel)
} }
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
// It is safe to ignore the corner case of packets that just acked packet 0, because
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
if p.Value.largestAcked != 0 {
h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.Value.largestAcked+1)
}
h.onPacketAcked(p) h.onPacketAcked(p)
h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight) h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
} }
} }
h.detectLostPackets() h.detectLostPackets(rcvTime)
h.updateLossDetectionAlarm() h.updateLossDetectionAlarm(rcvTime)
h.garbageCollectSkippedPackets() h.garbageCollectSkippedPackets()
h.stopWaitingManager.ReceivedAck(ackFrame) h.stopWaitingManager.ReceivedAck(ackFrame)
@ -192,6 +215,10 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
return nil return nil
} }
func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber {
return h.lowestPacketNotConfirmedAcked
}
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*PacketElement, error) { func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*PacketElement, error) {
var ackedPackets []*PacketElement var ackedPackets []*PacketElement
ackRangeIndex := 0 ackRangeIndex := 0
@ -233,7 +260,7 @@ func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, a
for el := h.packetHistory.Front(); el != nil; el = el.Next() { for el := h.packetHistory.Front(); el != nil; el = el.Next() {
packet := el.Value packet := el.Value
if packet.PacketNumber == largestAcked { if packet.PacketNumber == largestAcked {
h.rttStats.UpdateRTT(rcvTime.Sub(packet.SendTime), ackDelay, time.Now()) h.rttStats.UpdateRTT(rcvTime.Sub(packet.sendTime), ackDelay, rcvTime)
return true return true
} }
// Packets are sorted by number, so we can stop searching // Packets are sorted by number, so we can stop searching
@ -244,7 +271,7 @@ func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, a
return false return false
} }
func (h *sentPacketHandler) updateLossDetectionAlarm() { func (h *sentPacketHandler) updateLossDetectionAlarm(now time.Time) {
// Cancel the alarm if no packets are outstanding // Cancel the alarm if no packets are outstanding
if h.packetHistory.Len() == 0 { if h.packetHistory.Len() == 0 {
h.alarm = time.Time{} h.alarm = time.Time{}
@ -253,19 +280,18 @@ func (h *sentPacketHandler) updateLossDetectionAlarm() {
// TODO(#497): TLP // TODO(#497): TLP
if !h.handshakeComplete { if !h.handshakeComplete {
h.alarm = time.Now().Add(h.computeHandshakeTimeout()) h.alarm = now.Add(h.computeHandshakeTimeout())
} else if !h.lossTime.IsZero() { } else if !h.lossTime.IsZero() {
// Early retransmit timer or time loss detection. // Early retransmit timer or time loss detection.
h.alarm = h.lossTime h.alarm = h.lossTime
} else { } else {
// RTO // RTO
h.alarm = time.Now().Add(h.computeRTOTimeout()) h.alarm = now.Add(h.computeRTOTimeout())
} }
} }
func (h *sentPacketHandler) detectLostPackets() { func (h *sentPacketHandler) detectLostPackets(now time.Time) {
h.lossTime = time.Time{} h.lossTime = time.Time{}
now := time.Now()
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT())) maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT) delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
@ -278,7 +304,7 @@ func (h *sentPacketHandler) detectLostPackets() {
break break
} }
timeSinceSent := now.Sub(packet.SendTime) timeSinceSent := now.Sub(packet.sendTime)
if timeSinceSent > delayUntilLost { if timeSinceSent > delayUntilLost {
lostPackets = append(lostPackets, el) lostPackets = append(lostPackets, el)
} else if h.lossTime.IsZero() { } else if h.lossTime.IsZero() {
@ -296,20 +322,22 @@ func (h *sentPacketHandler) detectLostPackets() {
} }
func (h *sentPacketHandler) OnAlarm() { func (h *sentPacketHandler) OnAlarm() {
now := time.Now()
// TODO(#497): TLP // TODO(#497): TLP
if !h.handshakeComplete { if !h.handshakeComplete {
h.queueHandshakePacketsForRetransmission() h.queueHandshakePacketsForRetransmission()
h.handshakeCount++ h.handshakeCount++
} else if !h.lossTime.IsZero() { } else if !h.lossTime.IsZero() {
// Early retransmit or time loss detection // Early retransmit or time loss detection
h.detectLostPackets() h.detectLostPackets(now)
} else { } else {
// RTO // RTO
h.retransmitOldestTwoPackets() h.retransmitOldestTwoPackets()
h.rtoCount++ h.rtoCount++
} }
h.updateLossDetectionAlarm() h.updateLossDetectionAlarm(now)
} }
func (h *sentPacketHandler) GetAlarmTimeout() time.Time { func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
@ -345,12 +373,11 @@ func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFra
} }
func (h *sentPacketHandler) SendingAllowed() bool { func (h *sentPacketHandler) SendingAllowed() bool {
congestionLimited := h.bytesInFlight > h.congestion.GetCongestionWindow() cwnd := h.congestion.GetCongestionWindow()
congestionLimited := h.bytesInFlight > cwnd
maxTrackedLimited := protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()) >= protocol.MaxTrackedSentPackets maxTrackedLimited := protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()) >= protocol.MaxTrackedSentPackets
if congestionLimited { if congestionLimited {
utils.Debugf("Congestion limited: bytes in flight %d, window %d", utils.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
h.bytesInFlight,
h.congestion.GetCongestionWindow())
} }
// Workaround for #555: // Workaround for #555:
// Always allow sending of retransmissions. This should probably be limited // Always allow sending of retransmissions. This should probably be limited

View File

@ -3,60 +3,15 @@ package ackhandler
import ( import (
"time" "time"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
type mockCongestion struct {
argsOnPacketSent []interface{}
maybeExitSlowStart bool
onRetransmissionTimeout bool
getCongestionWindow bool
packetsAcked [][]interface{}
packetsLost [][]interface{}
}
func (m *mockCongestion) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration {
panic("not implemented")
}
func (m *mockCongestion) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool {
m.argsOnPacketSent = []interface{}{sentTime, bytesInFlight, packetNumber, bytes, isRetransmittable}
return false
}
func (m *mockCongestion) GetCongestionWindow() protocol.ByteCount {
m.getCongestionWindow = true
return protocol.DefaultTCPMSS
}
func (m *mockCongestion) MaybeExitSlowStart() {
m.maybeExitSlowStart = true
}
func (m *mockCongestion) OnRetransmissionTimeout(packetsRetransmitted bool) {
m.onRetransmissionTimeout = true
}
func (m *mockCongestion) RetransmissionDelay() time.Duration {
return defaultRTOTimeout
}
func (m *mockCongestion) SetNumEmulatedConnections(n int) { panic("not implemented") }
func (m *mockCongestion) OnConnectionMigration() { panic("not implemented") }
func (m *mockCongestion) SetSlowStartLargeReduction(enabled bool) { panic("not implemented") }
func (m *mockCongestion) OnPacketAcked(n protocol.PacketNumber, l protocol.ByteCount, bif protocol.ByteCount) {
m.packetsAcked = append(m.packetsAcked, []interface{}{n, l, bif})
}
func (m *mockCongestion) OnPacketLost(n protocol.PacketNumber, l protocol.ByteCount, bif protocol.ByteCount) {
m.packetsLost = append(m.packetsLost, []interface{}{n, l, bif})
}
func retransmittablePacket(num protocol.PacketNumber) *Packet { func retransmittablePacket(num protocol.PacketNumber) *Packet {
return &Packet{ return &Packet{
PacketNumber: num, PacketNumber: num,
@ -143,7 +98,7 @@ var _ = Describe("SentPacketHandler", func() {
packet := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1} packet := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1}
err := handler.SentPacket(&packet) err := handler.SentPacket(&packet)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(handler.packetHistory.Front().Value.SendTime.Unix()).To(BeNumerically("~", time.Now().Unix(), 1)) Expect(handler.packetHistory.Front().Value.sendTime.Unix()).To(BeNumerically("~", time.Now().Unix(), 1))
}) })
It("does not store non-retransmittable packets", func() { It("does not store non-retransmittable packets", func() {
@ -553,9 +508,9 @@ var _ = Describe("SentPacketHandler", func() {
It("computes the RTT", func() { It("computes the RTT", func() {
now := time.Now() now := time.Now()
// First, fake the sent times of the first, second and last packet // First, fake the sent times of the first, second and last packet
getPacketElement(1).Value.SendTime = now.Add(-10 * time.Minute) getPacketElement(1).Value.sendTime = now.Add(-10 * time.Minute)
getPacketElement(2).Value.SendTime = now.Add(-5 * time.Minute) getPacketElement(2).Value.sendTime = now.Add(-5 * time.Minute)
getPacketElement(6).Value.SendTime = now.Add(-1 * time.Minute) getPacketElement(6).Value.sendTime = now.Add(-1 * time.Minute)
// Now, check that the proper times are used when calculating the deltas // Now, check that the proper times are used when calculating the deltas
err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1}, 1, protocol.EncryptionUnencrypted, time.Now()) err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1}, 1, protocol.EncryptionUnencrypted, time.Now())
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -570,12 +525,50 @@ var _ = Describe("SentPacketHandler", func() {
It("uses the DelayTime in the ack frame", func() { It("uses the DelayTime in the ack frame", func() {
now := time.Now() now := time.Now()
getPacketElement(1).Value.SendTime = now.Add(-10 * time.Minute) getPacketElement(1).Value.sendTime = now.Add(-10 * time.Minute)
err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, DelayTime: 5 * time.Minute}, 1, protocol.EncryptionUnencrypted, time.Now()) err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, DelayTime: 5 * time.Minute}, 1, protocol.EncryptionUnencrypted, time.Now())
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second)) Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second))
}) })
}) })
Context("determinining, which ACKs we have received an ACK for", func() {
BeforeEach(func() {
morePackets := []*Packet{
&Packet{PacketNumber: 13, Frames: []wire.Frame{&wire.AckFrame{LowestAcked: 80, LargestAcked: 100}, &streamFrame}, Length: 1},
&Packet{PacketNumber: 14, Frames: []wire.Frame{&wire.AckFrame{LowestAcked: 50, LargestAcked: 200}, &streamFrame}, Length: 1},
&Packet{PacketNumber: 15, Frames: []wire.Frame{&streamFrame}, Length: 1},
}
for _, packet := range morePackets {
err := handler.SentPacket(packet)
Expect(err).NotTo(HaveOccurred())
}
})
It("determines which ACK we have received an ACK for", func() {
err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 15, LowestAcked: 12}, 1, protocol.EncryptionUnencrypted, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201)))
})
It("doesn't do anything when the acked packet didn't contain an ACK", func() {
err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 13, LowestAcked: 13}, 1, protocol.EncryptionUnencrypted, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101)))
err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 15, LowestAcked: 15}, 2, protocol.EncryptionUnencrypted, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101)))
})
It("doesn't decrease the value", func() {
err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 14, LowestAcked: 14}, 1, protocol.EncryptionUnencrypted, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201)))
err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 13, LowestAcked: 13}, 2, protocol.EncryptionUnencrypted, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201)))
})
})
}) })
Context("Retransmission handling", func() { Context("Retransmission handling", func() {
@ -583,13 +576,13 @@ var _ = Describe("SentPacketHandler", func() {
BeforeEach(func() { BeforeEach(func() {
packets = []*Packet{ packets = []*Packet{
{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1}, {PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionUnencrypted},
{PacketNumber: 2, Frames: []wire.Frame{&streamFrame}, Length: 1}, {PacketNumber: 2, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionUnencrypted},
{PacketNumber: 3, Frames: []wire.Frame{&streamFrame}, Length: 1}, {PacketNumber: 3, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionUnencrypted},
{PacketNumber: 4, Frames: []wire.Frame{&streamFrame}, Length: 1}, {PacketNumber: 4, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionSecure},
{PacketNumber: 5, Frames: []wire.Frame{&streamFrame}, Length: 1}, {PacketNumber: 5, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionSecure},
{PacketNumber: 6, Frames: []wire.Frame{&streamFrame}, Length: 1}, {PacketNumber: 6, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionForwardSecure},
{PacketNumber: 7, Frames: []wire.Frame{&streamFrame}, Length: 1}, {PacketNumber: 7, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionForwardSecure},
} }
for _, packet := range packets { for _, packet := range packets {
handler.SentPacket(packet) handler.SentPacket(packet)
@ -597,7 +590,7 @@ var _ = Describe("SentPacketHandler", func() {
// Increase RTT, because the tests would be flaky otherwise // Increase RTT, because the tests would be flaky otherwise
handler.rttStats.UpdateRTT(time.Minute, 0, time.Now()) handler.rttStats.UpdateRTT(time.Minute, 0, time.Now())
// Ack a single packet so that we have non-RTO timings // Ack a single packet so that we have non-RTO timings
handler.ReceivedAck(&wire.AckFrame{LargestAcked: 2, LowestAcked: 2}, 1, protocol.EncryptionUnencrypted, time.Now()) handler.ReceivedAck(&wire.AckFrame{LargestAcked: 2, LowestAcked: 2}, 1, protocol.EncryptionForwardSecure, time.Now())
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6))) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6)))
}) })
@ -606,7 +599,7 @@ var _ = Describe("SentPacketHandler", func() {
}) })
It("dequeues a packet for retransmission", func() { It("dequeues a packet for retransmission", func() {
getPacketElement(1).Value.SendTime = time.Now().Add(-time.Hour) getPacketElement(1).Value.sendTime = time.Now().Add(-time.Hour)
handler.OnAlarm() handler.OnAlarm()
Expect(getPacketElement(1)).To(BeNil()) Expect(getPacketElement(1)).To(BeNil())
Expect(handler.retransmissionQueue).To(HaveLen(1)) Expect(handler.retransmissionQueue).To(HaveLen(1))
@ -617,15 +610,33 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.DequeuePacketForRetransmission()).To(BeNil()) Expect(handler.DequeuePacketForRetransmission()).To(BeNil())
}) })
Context("StopWaitings", func() { It("deletes non forward-secure packets when the handshake completes", func() {
It("gets a StopWaitingFrame", func() { for i := protocol.PacketNumber(1); i <= 7; i++ {
if i == 2 { // packet 2 was already acked in BeforeEach
continue
}
handler.queuePacketForRetransmission(getPacketElement(i))
}
Expect(handler.retransmissionQueue).To(HaveLen(6))
handler.SetHandshakeComplete()
packet := handler.DequeuePacketForRetransmission()
Expect(packet).ToNot(BeNil())
Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(6)))
packet = handler.DequeuePacketForRetransmission()
Expect(packet).ToNot(BeNil())
Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(7)))
Expect(handler.DequeuePacketForRetransmission()).To(BeNil())
})
Context("STOP_WAITINGs", func() {
It("gets a STOP_WAITING frame", func() {
ack := wire.AckFrame{LargestAcked: 5, LowestAcked: 5} ack := wire.AckFrame{LargestAcked: 5, LowestAcked: 5}
err := handler.ReceivedAck(&ack, 2, protocol.EncryptionUnencrypted, time.Now()) err := handler.ReceivedAck(&ack, 2, protocol.EncryptionForwardSecure, time.Now())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(handler.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 6})) Expect(handler.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 6}))
}) })
It("gets a StopWaitingFrame after queueing a retransmission", func() { It("gets a STOP_WAITING frame after queueing a retransmission", func() {
handler.queuePacketForRetransmission(getPacketElement(5)) handler.queuePacketForRetransmission(getPacketElement(5))
Expect(handler.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 6})) Expect(handler.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 6}))
}) })
@ -662,7 +673,7 @@ var _ = Describe("SentPacketHandler", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2))) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2)))
handler.packetHistory.Front().Value.SendTime = time.Now().Add(-time.Hour) handler.packetHistory.Front().Value.sendTime = time.Now().Add(-time.Hour)
handler.OnAlarm() handler.OnAlarm()
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(0))) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(0)))
@ -670,15 +681,23 @@ var _ = Describe("SentPacketHandler", func() {
Context("congestion", func() { Context("congestion", func() {
var ( var (
cong *mockCongestion cong *mocks.MockSendAlgorithm
) )
BeforeEach(func() { BeforeEach(func() {
cong = &mockCongestion{} cong = mocks.NewMockSendAlgorithm(mockCtrl)
cong.EXPECT().RetransmissionDelay().AnyTimes()
handler.congestion = cong handler.congestion = cong
}) })
It("should call OnSent", func() { It("should call OnSent", func() {
cong.EXPECT().OnPacketSent(
gomock.Any(),
protocol.ByteCount(42),
protocol.PacketNumber(1),
protocol.ByteCount(42),
true,
)
p := &Packet{ p := &Packet{
PacketNumber: 1, PacketNumber: 1,
Length: 42, Length: 42,
@ -686,62 +705,60 @@ var _ = Describe("SentPacketHandler", func() {
} }
err := handler.SentPacket(p) err := handler.SentPacket(p)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cong.argsOnPacketSent[1]).To(Equal(protocol.ByteCount(42)))
Expect(cong.argsOnPacketSent[2]).To(Equal(protocol.PacketNumber(1)))
Expect(cong.argsOnPacketSent[3]).To(Equal(protocol.ByteCount(42)))
Expect(cong.argsOnPacketSent[4]).To(BeTrue())
}) })
It("should call MaybeExitSlowStart and OnPacketAcked", func() { It("should call MaybeExitSlowStart and OnPacketAcked", func() {
cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
cong.EXPECT().MaybeExitSlowStart()
cong.EXPECT().OnPacketAcked(
protocol.PacketNumber(1),
protocol.ByteCount(1),
protocol.ByteCount(1),
)
handler.SentPacket(retransmittablePacket(1)) handler.SentPacket(retransmittablePacket(1))
handler.SentPacket(retransmittablePacket(2)) handler.SentPacket(retransmittablePacket(2))
err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionForwardSecure, time.Now()) err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionForwardSecure, time.Now())
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cong.maybeExitSlowStart).To(BeTrue())
Expect(cong.packetsAcked).To(BeEquivalentTo([][]interface{}{
{protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(1)},
}))
Expect(cong.packetsLost).To(BeEmpty())
}) })
It("should call MaybeExitSlowStart and OnPacketLost", func() { It("should call MaybeExitSlowStart and OnPacketLost", func() {
cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3)
cong.EXPECT().OnRetransmissionTimeout(true).Times(2)
cong.EXPECT().OnPacketLost(
protocol.PacketNumber(1),
protocol.ByteCount(1),
protocol.ByteCount(2),
)
cong.EXPECT().OnPacketLost(
protocol.PacketNumber(2),
protocol.ByteCount(1),
protocol.ByteCount(1),
)
handler.SentPacket(retransmittablePacket(1)) handler.SentPacket(retransmittablePacket(1))
handler.SentPacket(retransmittablePacket(2)) handler.SentPacket(retransmittablePacket(2))
handler.SentPacket(retransmittablePacket(3)) handler.SentPacket(retransmittablePacket(3))
handler.OnAlarm() // RTO, meaning 2 lost packets handler.OnAlarm() // RTO, meaning 2 lost packets
Expect(cong.maybeExitSlowStart).To(BeFalse())
Expect(cong.onRetransmissionTimeout).To(BeTrue())
Expect(cong.packetsAcked).To(BeEmpty())
Expect(cong.packetsLost).To(BeEquivalentTo([][]interface{}{
{protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)},
{protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(1)},
}))
}) })
It("allows or denies sending based on congestion", func() { It("allows or denies sending based on congestion", func() {
Expect(handler.retransmissionQueue).To(BeEmpty())
handler.bytesInFlight = 100
cong.EXPECT().GetCongestionWindow().Return(protocol.MaxByteCount)
Expect(handler.SendingAllowed()).To(BeTrue()) Expect(handler.SendingAllowed()).To(BeTrue())
err := handler.SentPacket(&Packet{ cong.EXPECT().GetCongestionWindow().Return(protocol.ByteCount(0))
PacketNumber: 1,
Frames: []wire.Frame{&wire.PingFrame{}},
Length: protocol.DefaultTCPMSS + 1,
})
Expect(err).NotTo(HaveOccurred())
Expect(handler.SendingAllowed()).To(BeFalse()) Expect(handler.SendingAllowed()).To(BeFalse())
}) })
It("allows or denies sending based on the number of tracked packets", func() { It("allows or denies sending based on the number of tracked packets", func() {
cong.EXPECT().GetCongestionWindow().Return(protocol.MaxByteCount).AnyTimes()
Expect(handler.SendingAllowed()).To(BeTrue()) Expect(handler.SendingAllowed()).To(BeTrue())
handler.retransmissionQueue = make([]*Packet, protocol.MaxTrackedSentPackets) handler.retransmissionQueue = make([]*Packet, protocol.MaxTrackedSentPackets)
Expect(handler.SendingAllowed()).To(BeFalse()) Expect(handler.SendingAllowed()).To(BeFalse())
}) })
It("allows sending if there are retransmisisons outstanding", func() { It("allows sending if there are retransmisisons outstanding", func() {
err := handler.SentPacket(&Packet{ handler.bytesInFlight = 100
PacketNumber: 1, cong.EXPECT().GetCongestionWindow().Return(protocol.ByteCount(0)).AnyTimes()
Frames: []wire.Frame{&wire.PingFrame{}},
Length: protocol.DefaultTCPMSS + 1,
})
Expect(err).NotTo(HaveOccurred())
Expect(handler.SendingAllowed()).To(BeFalse()) Expect(handler.SendingAllowed()).To(BeFalse())
handler.retransmissionQueue = []*Packet{nil} handler.retransmissionQueue = []*Packet{nil}
Expect(handler.SendingAllowed()).To(BeTrue()) Expect(handler.SendingAllowed()).To(BeTrue())
@ -799,7 +816,7 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.lossTime.Sub(time.Now())).To(BeNumerically("~", time.Hour*9/8, time.Minute)) Expect(handler.lossTime.Sub(time.Now())).To(BeNumerically("~", time.Hour*9/8, time.Minute))
Expect(handler.GetAlarmTimeout().Sub(time.Now())).To(BeNumerically("~", time.Hour*9/8, time.Minute)) Expect(handler.GetAlarmTimeout().Sub(time.Now())).To(BeNumerically("~", time.Hour*9/8, time.Minute))
handler.packetHistory.Front().Value.SendTime = time.Now().Add(-2 * time.Hour) handler.packetHistory.Front().Value.sendTime = time.Now().Add(-2 * time.Hour)
handler.OnAlarm() handler.OnAlarm()
Expect(handler.DequeuePacketForRetransmission()).NotTo(BeNil()) Expect(handler.DequeuePacketForRetransmission()).NotTo(BeNil())
}) })
@ -843,7 +860,7 @@ var _ = Describe("SentPacketHandler", func() {
err = handler.SentPacket(handshakePacket(4)) err = handler.SentPacket(handshakePacket(4))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionSecure, time.Now().Add(time.Hour)) err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionSecure, time.Now())
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(handler.lossTime.IsZero()).To(BeTrue()) Expect(handler.lossTime.IsZero()).To(BeTrue())
handshakeTimeout := handler.computeHandshakeTimeout() handshakeTimeout := handler.computeHandshakeTimeout()

View File

@ -60,33 +60,15 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error)
return Dial(udpConn, udpAddr, addr, tlsConf, config) return Dial(udpConn, udpAddr, addr, tlsConf, config)
} }
// DialAddrNonFWSecure establishes a new QUIC connection to a server. // Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The hostname for SNI is taken from the given address.
func DialAddrNonFWSecure(
addr string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
}
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI. // The host parameter is used for SNI.
func DialNonFWSecure( func Dial(
pconn net.PacketConn, pconn net.PacketConn,
remoteAddr net.Addr, remoteAddr net.Addr,
host string, host string,
tlsConf *tls.Config, tlsConf *tls.Config,
config *Config, config *Config,
) (NonFWSession, error) { ) (Session, error) {
connID, err := generateConnectionID() connID, err := generateConnectionID()
if err != nil { if err != nil {
return nil, err return nil, err
@ -115,31 +97,11 @@ func DialNonFWSecure(
} }
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version) utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
go c.listen()
if err := c.dial(); err != nil { if err := c.dial(); err != nil {
return nil, err return nil, err
} }
return c.session.(NonFWSession), nil return c.session, nil
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI.
func Dial(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
if err != nil {
return nil, err
}
if err := sess.WaitUntilHandshakeComplete(); err != nil {
return nil, err
}
return sess, nil
} }
// populateClientConfig populates fields in the quic.Config with their default values, if none are set // populateClientConfig populates fields in the quic.Config with their default values, if none are set
@ -199,6 +161,7 @@ func (c *client) dialGQUIC() error {
if err := c.createNewGQUICSession(); err != nil { if err := c.createNewGQUICSession(); err != nil {
return err return err
} }
go c.listen()
return c.establishSecureConnection() return c.establishSecureConnection()
} }
@ -224,6 +187,7 @@ func (c *client) dialTLS() error {
if err := c.createNewTLSSession(eh.GetPeerParams(), c.version); err != nil { if err := c.createNewTLSSession(eh.GetPeerParams(), c.version); err != nil {
return err return err
} }
go c.listen()
if err := c.establishSecureConnection(); err != nil { if err := c.establishSecureConnection(); err != nil {
if err != handshake.ErrCloseSessionForRetry { if err != handshake.ErrCloseSessionForRetry {
return err return err
@ -267,14 +231,8 @@ func (c *client) establishSecureConnection() error {
select { select {
case <-errorChan: case <-errorChan:
return runErr return runErr
case ev := <-c.session.handshakeStatus(): case err := <-c.session.handshakeStatus():
if ev.err != nil { return err
return ev.err
}
if !c.version.UsesTLS() && ev.encLevel != protocol.EncryptionSecure {
return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel)
}
return nil
} }
} }

View File

@ -5,6 +5,7 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"net" "net"
"os"
"sync/atomic" "sync/atomic"
"time" "time"
@ -100,57 +101,7 @@ var _ = Describe("Client", func() {
generateConnectionID = origGenerateConnectionID generateConnectionID = origGenerateConnectionID
}) })
It("dials non-forward-secure", func() { It("returns after the handshake is complete", func() {
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
s, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
close(dialed)
}()
Consistently(dialed).ShouldNot(BeClosed())
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
Eventually(dialed).Should(BeClosed())
})
It("dials a non-forward-secure address", func() {
serverAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
Expect(err).ToNot(HaveOccurred())
server, err := net.ListenUDP("udp", serverAddr)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
for {
_, clientAddr, err := server.ReadFromUDP(make([]byte, 200))
if err != nil {
return
}
_, err = server.WriteToUDP(acceptClientVersionPacket(cl.connectionID), clientAddr)
Expect(err).ToNot(HaveOccurred())
}
}()
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
s, err := DialAddrNonFWSecure(server.LocalAddr().String(), nil, config)
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
close(dialed)
}()
Consistently(dialed).ShouldNot(BeClosed())
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
Eventually(dialed).Should(BeClosed())
server.Close()
Eventually(done).Should(BeClosed())
})
It("Dial only returns after the handshake is complete", func() {
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID) packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
dialed := make(chan struct{}) dialed := make(chan struct{})
go func() { go func() {
@ -160,13 +111,14 @@ var _ = Describe("Client", func() {
Expect(s).ToNot(BeNil()) Expect(s).ToNot(BeNil())
close(dialed) close(dialed)
}() }()
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} close(sess.handshakeChan)
Consistently(dialed).ShouldNot(BeClosed())
close(sess.handshakeComplete)
Eventually(dialed).Should(BeClosed()) Eventually(dialed).Should(BeClosed())
}) })
It("resolves the address", func() { It("resolves the address", func() {
if os.Getenv("APPVEYOR") == "True" {
Skip("This test is flaky on AppVeyor.")
}
closeErr := errors.New("peer doesn't reply") closeErr := errors.New("peer doesn't reply")
remoteAddrChan := make(chan string) remoteAddrChan := make(chan string)
newClientSession = func( newClientSession = func(
@ -245,22 +197,7 @@ var _ = Describe("Client", func() {
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
close(done) close(done)
}() }()
sess.handshakeChan <- handshakeEvent{err: testErr} sess.handshakeChan <- testErr
Eventually(done).Should(BeClosed())
})
It("returns an error that occurs while waiting for the handshake to complete", func() {
testErr := errors.New("late handshake error")
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).To(MatchError(testErr))
close(done)
}()
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
sess.handshakeComplete <- testErr
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
@ -305,7 +242,7 @@ var _ = Describe("Client", func() {
) (packetHandler, error) { ) (packetHandler, error) {
return nil, testErr return nil, testErr
} }
_, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config) _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
}) })
@ -331,7 +268,7 @@ var _ = Describe("Client", func() {
Expect(newVersion).ToNot(Equal(cl.version)) Expect(newVersion).ToNot(Equal(cl.version))
Expect(config.Versions).To(ContainElement(newVersion)) Expect(config.Versions).To(ContainElement(newVersion))
sessionChan := make(chan *mockSession) sessionChan := make(chan *mockSession)
handshakeChan := make(chan handshakeEvent) handshakeChan := make(chan error)
newClientSession = func( newClientSession = func(
_ connection, _ connection,
_ string, _ string,
@ -382,7 +319,7 @@ var _ = Describe("Client", func() {
Expect(negotiatedVersions).To(ContainElement(newVersion)) Expect(negotiatedVersions).To(ContainElement(newVersion))
Expect(initialVersion).To(Equal(actualInitialVersion)) Expect(initialVersion).To(Equal(actualInitialVersion))
handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} close(handshakeChan)
Eventually(established).Should(BeClosed()) Eventually(established).Should(BeClosed())
}) })

View File

@ -0,0 +1,41 @@
package quic
import (
"io"
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type cryptoStreamI interface {
StreamID() protocol.StreamID
io.Reader
io.Writer
handleStreamFrame(*wire.StreamFrame) error
popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
closeForShutdown(error)
setReadOffset(protocol.ByteCount)
// methods needed for flow control
getWindowUpdate() protocol.ByteCount
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
}
type cryptoStream struct {
*stream
}
var _ cryptoStreamI = &cryptoStream{}
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI {
str := newStream(version.CryptoStreamID(), sender, flowController, version)
return &cryptoStream{str}
}
// SetReadOffset sets the read offset.
// It is only needed for the crypto stream.
// It must not be called concurrently with any other stream methods, especially Read and Write.
func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) {
s.receiveStream.readOffset = offset
s.receiveStream.frameQueue.readPosition = offset
}

View File

@ -0,0 +1,26 @@
package quic
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Crypto Stream", func() {
var (
str *cryptoStream
mockSender *MockStreamSender
)
BeforeEach(func() {
mockSender = NewMockStreamSender(mockCtrl)
str = newCryptoStream(mockSender, nil, protocol.VersionWhatever).(*cryptoStream)
})
It("sets the read offset", func() {
str.setReadOffset(0x42)
Expect(str.receiveStream.readOffset).To(Equal(protocol.ByteCount(0x42)))
Expect(str.receiveStream.frameQueue.readPosition).To(Equal(protocol.ByteCount(0x42)))
})
})

View File

@ -97,45 +97,44 @@ func (c *client) handleHeaderStream() {
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, c.headerStream) h2framer := http2.NewFramer(nil, c.headerStream)
var lastStream protocol.StreamID var err error
for err == nil {
err = c.readResponse(h2framer, decoder)
}
utils.Debugf("Error handling header stream: %s", err)
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
// stop all running request
close(c.headerErrored)
}
for { func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
frame, err := h2framer.ReadFrame() frame, err := h2framer.ReadFrame()
if err != nil { if err != nil {
c.headerErr = qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame") return err
break
} }
lastStream = protocol.StreamID(frame.Header().StreamID)
hframe, ok := frame.(*http2.HeadersFrame) hframe, ok := frame.(*http2.HeadersFrame)
if !ok { if !ok {
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame") return errors.New("not a headers frame")
break
} }
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe} mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment()) mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
if err != nil { if err != nil {
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields") return fmt.Errorf("cannot read header fields: %s", err.Error())
break
} }
c.mutex.RLock() c.mutex.RLock()
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)] responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
c.mutex.RUnlock() c.mutex.RUnlock()
if !ok { if !ok {
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream)) return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
break
} }
rsp, err := responseFromHeaders(mhframe) rsp, err := responseFromHeaders(mhframe)
if err != nil { if err != nil {
c.headerErr = qerr.Error(qerr.InternalError, err.Error()) return err
} }
responseChan <- rsp responseChan <- rsp
} return nil
// stop all running request
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error())
close(c.headerErrored)
} }
// Roundtrip executes a request and returns a response // Roundtrip executes a request and returns a response

View File

@ -188,41 +188,31 @@ var _ = Describe("Client", func() {
close(done) close(done)
}) })
It("closes the quic client when encountering an error on the header stream", func(done Done) { It("closes the quic client when encountering an error on the header stream", func() {
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100)) headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
var doReturned bool done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
var err error
rsp, err := client.RoundTrip(request) rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(client.headerErr)) Expect(err).To(MatchError(client.headerErr))
Expect(rsp).To(BeNil()) Expect(rsp).To(BeNil())
doReturned = true close(done)
}() }()
Eventually(func() bool { return doReturned }).Should(BeTrue()) Eventually(done).Should(BeClosed())
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame"))) Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr)) Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr))
close(done) })
}, 2)
It("returns subsequent request if there was an error on the header stream before", func(done Done) { It("returns subsequent request if there was an error on the header stream before", func() {
expectedErr := qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
session.streamsToOpen = []quic.Stream{headerStream, dataStream, newMockStream(7)} session.streamsToOpen = []quic.Stream{headerStream, dataStream, newMockStream(7)}
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100)) headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
var firstReqReturned bool
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(request) _, err := client.RoundTrip(request)
Expect(err).To(MatchError(expectedErr)) Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
firstReqReturned = true Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
}()
Eventually(func() bool { return firstReqReturned }).Should(BeTrue())
// now that the first request failed due to an error on the header stream, try another request // now that the first request failed due to an error on the header stream, try another request
_, err := client.RoundTrip(request) _, nextErr := client.RoundTrip(request)
Expect(err).To(MatchError(expectedErr)) Expect(nextErr).To(MatchError(err))
close(done)
}) })
It("blocks if no stream is available", func() { It("blocks if no stream is available", func() {
@ -479,16 +469,9 @@ var _ = Describe("Client", func() {
It("errors if the H2 frame is not a HeadersFrame", func() { It("errors if the H2 frame is not a HeadersFrame", func() {
h2framer.WritePing(true, [8]byte{0, 0, 0, 0, 0, 0, 0, 0}) h2framer.WritePing(true, [8]byte{0, 0, 0, 0, 0, 0, 0, 0})
var handlerReturned bool
go func() {
client.handleHeaderStream() client.handleHeaderStream()
handlerReturned = true
}()
Eventually(client.headerErrored).Should(BeClosed()) Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame"))) Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")))
Eventually(func() bool { return handlerReturned }).Should(BeTrue())
}) })
It("errors if it can't read the HPACK encoded header fields", func() { It("errors if it can't read the HPACK encoded header fields", func() {
@ -497,16 +480,26 @@ var _ = Describe("Client", func() {
EndHeaders: true, EndHeaders: true,
BlockFragment: []byte("invalid HPACK data"), BlockFragment: []byte("invalid HPACK data"),
}) })
var handlerReturned bool
go func() {
client.handleHeaderStream() client.handleHeaderStream()
handlerReturned = true
}()
Eventually(client.headerErrored).Should(BeClosed()) Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields"))) Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
Eventually(func() bool { return handlerReturned }).Should(BeTrue()) Expect(client.headerErr.ErrorMessage).To(ContainSubstring("cannot read header fields"))
})
It("errors if the stream cannot be found", func() {
var headers bytes.Buffer
enc := hpack.NewEncoder(&headers)
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
err := h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: 1337,
EndHeaders: true,
BlockFragment: headers.Bytes(),
})
Expect(err).ToNot(HaveOccurred())
client.handleHeaderStream()
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("response channel for stream 1337 not found"))
}) })
}) })
}) })

View File

@ -11,6 +11,7 @@ import (
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -29,6 +30,8 @@ type mockStream struct {
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
} }
var _ quic.Stream = &mockStream{}
func newMockStream(id protocol.StreamID) *mockStream { func newMockStream(id protocol.StreamID) *mockStream {
s := &mockStream{ s := &mockStream{
id: id, id: id,
@ -39,7 +42,8 @@ func newMockStream(id protocol.StreamID) *mockStream {
} }
func (s *mockStream) Close() error { s.closed = true; s.ctxCancel(); return nil } func (s *mockStream) Close() error { s.closed = true; s.ctxCancel(); return nil }
func (s *mockStream) Reset(error) { s.reset = true } func (s *mockStream) CancelRead(quic.ErrorCode) error { s.reset = true; return nil }
func (s *mockStream) CancelWrite(quic.ErrorCode) error { panic("not implemented") }
func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true; s.ctxCancel() } func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true; s.ctxCancel() }
func (s mockStream) StreamID() protocol.StreamID { return s.id } func (s mockStream) StreamID() protocol.StreamID { return s.id }
func (s *mockStream) Context() context.Context { return s.ctx } func (s *mockStream) Context() context.Context { return s.ctx }

View File

@ -50,6 +50,7 @@ type Server struct {
listenerMutex sync.Mutex listenerMutex sync.Mutex
listener quic.Listener listener quic.Listener
closed bool
supportedVersionsAsString string supportedVersionsAsString string
} }
@ -88,6 +89,10 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
return errors.New("use of h2quic.Server without http.Server") return errors.New("use of h2quic.Server without http.Server")
} }
s.listenerMutex.Lock() s.listenerMutex.Lock()
if s.closed {
s.listenerMutex.Unlock()
return errors.New("Server is already closed")
}
if s.listener != nil { if s.listener != nil {
s.listenerMutex.Unlock() s.listenerMutex.Unlock()
return errors.New("ListenAndServe may only be called once") return errors.New("ListenAndServe may only be called once")
@ -223,7 +228,8 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
} }
if responseWriter.dataStream != nil { if responseWriter.dataStream != nil {
if !streamEnded && !reqBody.requestRead { if !streamEnded && !reqBody.requestRead {
responseWriter.dataStream.Reset(nil) // in gQUIC, the error code doesn't matter, so just use 0 here
responseWriter.dataStream.CancelRead(0)
} }
responseWriter.dataStream.Close() responseWriter.dataStream.Close()
} }
@ -241,6 +247,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
func (s *Server) Close() error { func (s *Server) Close() error {
s.listenerMutex.Lock() s.listenerMutex.Lock()
defer s.listenerMutex.Unlock() defer s.listenerMutex.Unlock()
s.closed = true
if s.listener != nil { if s.listener != nil {
err := s.listener.Close() err := s.listener.Close()
s.listener = nil s.listener = nil

View File

@ -70,6 +70,7 @@ func (s *mockSession) RemoteAddr() net.Addr {
func (s *mockSession) Context() context.Context { func (s *mockSession) Context() context.Context {
return s.ctx return s.ctx
} }
func (s *mockSession) ConnectionState() quic.ConnectionState { panic("not implemented") }
var _ = Describe("H2 server", func() { var _ = Describe("H2 server", func() {
var ( var (
@ -410,6 +411,13 @@ var _ = Describe("H2 server", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
It("errors when ListenAndServer is called after Close", func() {
serv := &Server{Server: &http.Server{}}
Expect(serv.Close()).To(Succeed())
err := serv.ListenAndServe()
Expect(err).To(MatchError("Server is already closed"))
})
Context("ListenAndServe", func() { Context("ListenAndServe", func() {
BeforeEach(func() { BeforeEach(func() {
s.Server.Addr = "localhost:0" s.Server.Addr = "localhost:0"

View File

@ -19,20 +19,42 @@ type VersionNumber = protocol.VersionNumber
// A Cookie can be used to verify the ownership of the client address. // A Cookie can be used to verify the ownership of the client address.
type Cookie = handshake.Cookie type Cookie = handshake.Cookie
// ConnectionState records basic details about the QUIC connection.
type ConnectionState = handshake.ConnectionState
// An ErrorCode is an application-defined error code.
type ErrorCode = protocol.ApplicationErrorCode
// Stream is the interface implemented by QUIC streams // Stream is the interface implemented by QUIC streams
type Stream interface { type Stream interface {
// StreamID returns the stream ID.
StreamID() StreamID
// Read reads data from the stream. // Read reads data from the stream.
// Read can be made to time out and return a net.Error with Timeout() == true // Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline. // after a fixed time limit; see SetDeadline and SetReadDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
io.Reader io.Reader
// Write writes data to the stream. // Write writes data to the stream.
// Write can be made to time out and return a net.Error with Timeout() == true // Write can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetWriteDeadline. // after a fixed time limit; see SetDeadline and SetWriteDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
io.Writer io.Writer
// Close closes the write-direction of the stream.
// Future calls to Write are not permitted after calling Close.
// It must not be called concurrently with Write.
// It must not be called after calling CancelWrite.
io.Closer io.Closer
StreamID() StreamID // CancelWrite aborts sending on this stream.
// Reset closes the stream with an error. // It must not be called after Close.
Reset(error) // Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
// Write will unblock immediately, and future calls to Write will fail.
CancelWrite(ErrorCode) error
// CancelRead aborts receiving on this stream.
// It will ask the peer to stop transmitting stream data.
// Read will unblock immediately, and future Read calls will fail.
CancelRead(ErrorCode) error
// The context is canceled as soon as the write-side of the stream is closed. // The context is canceled as soon as the write-side of the stream is closed.
// This happens when Close() is called, or when the stream is reset (either locally or remotely). // This happens when Close() is called, or when the stream is reset (either locally or remotely).
// Warning: This API should not be considered stable and might change soon. // Warning: This API should not be considered stable and might change soon.
@ -53,6 +75,41 @@ type Stream interface {
SetDeadline(t time.Time) error SetDeadline(t time.Time) error
} }
// A ReceiveStream is a unidirectional Receive Stream.
type ReceiveStream interface {
// see Stream.StreamID
StreamID() StreamID
// see Stream.Read
io.Reader
// see Stream.CancelRead
CancelRead(ErrorCode) error
// see Stream.SetReadDealine
SetReadDeadline(t time.Time) error
}
// A SendStream is a unidirectional Send Stream.
type SendStream interface {
// see Stream.StreamID
StreamID() StreamID
// see Stream.Write
io.Writer
// see Stream.Close
io.Closer
// see Stream.CancelWrite
CancelWrite(ErrorCode) error
// see Stream.Context
Context() context.Context
// see Stream.SetWriteDeadline
SetWriteDeadline(t time.Time) error
}
// StreamError is returned by Read and Write when the peer cancels the stream.
type StreamError interface {
error
Canceled() bool
ErrorCode() ErrorCode
}
// A Session is a QUIC connection between two peers. // A Session is a QUIC connection between two peers.
type Session interface { type Session interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available. // AcceptStream returns the next stream opened by the peer, blocking until one is available.
@ -74,13 +131,9 @@ type Session interface {
// The context is cancelled when the session is closed. // The context is cancelled when the session is closed.
// Warning: This API should not be considered stable and might change soon. // Warning: This API should not be considered stable and might change soon.
Context() context.Context Context() context.Context
} // ConnectionState returns basic details about the QUIC connection.
// Warning: This API should not be considered stable and might change soon.
// A NonFWSession is a QUIC connection between two peers half-way through the handshake. ConnectionState() ConnectionState
// The communication is encrypted, but not yet forward secure.
type NonFWSession interface {
Session
WaitUntilHandshakeComplete() error
} }
// Config contains all configuration data needed for a QUIC server or client. // Config contains all configuration data needed for a QUIC server or client.

View File

@ -18,6 +18,7 @@ type CertManager interface {
GetLeafCertHash() (uint64, error) GetLeafCertHash() (uint64, error)
VerifyServerProof(proof, chlo, serverConfigData []byte) bool VerifyServerProof(proof, chlo, serverConfigData []byte) bool
Verify(hostname string) error Verify(hostname string) error
GetChain() []*x509.Certificate
} }
type certManager struct { type certManager struct {
@ -54,6 +55,10 @@ func (c *certManager) SetData(data []byte) error {
return nil return nil
} }
func (c *certManager) GetChain() []*x509.Certificate {
return c.chain
}
func (c *certManager) GetCommonCertificateHashes() []byte { func (c *certManager) GetCommonCertificateHashes() []byte {
return getCommonCertificateHashes() return getCommonCertificateHashes()
} }

View File

@ -10,35 +10,30 @@ import (
) )
type baseFlowController struct { type baseFlowController struct {
mutex sync.RWMutex // for sending data
rttStats *congestion.RTTStats
bytesSent protocol.ByteCount bytesSent protocol.ByteCount
sendWindow protocol.ByteCount sendWindow protocol.ByteCount
lastWindowUpdateTime time.Time // for receiving data
mutex sync.RWMutex
bytesRead protocol.ByteCount bytesRead protocol.ByteCount
highestReceived protocol.ByteCount highestReceived protocol.ByteCount
receiveWindow protocol.ByteCount receiveWindow protocol.ByteCount
receiveWindowIncrement protocol.ByteCount receiveWindowSize protocol.ByteCount
maxReceiveWindowIncrement protocol.ByteCount maxReceiveWindowSize protocol.ByteCount
epochStartTime time.Time
epochStartOffset protocol.ByteCount
rttStats *congestion.RTTStats
} }
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) { func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.bytesSent += n c.bytesSent += n
} }
// UpdateSendWindow should be called after receiving a WindowUpdateFrame // UpdateSendWindow should be called after receiving a WindowUpdateFrame
// it returns true if the window was actually updated // it returns true if the window was actually updated
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) { func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) {
c.mutex.Lock()
defer c.mutex.Unlock()
if offset > c.sendWindow { if offset > c.sendWindow {
c.sendWindow = offset c.sendWindow = offset
} }
@ -57,52 +52,55 @@ func (c *baseFlowController) AddBytesRead(n protocol.ByteCount) {
defer c.mutex.Unlock() defer c.mutex.Unlock()
// pretend we sent a WindowUpdate when reading the first byte // pretend we sent a WindowUpdate when reading the first byte
// this way auto-tuning of the window increment already works for the first WindowUpdate // this way auto-tuning of the window size already works for the first WindowUpdate
if c.bytesRead == 0 { if c.bytesRead == 0 {
c.lastWindowUpdateTime = time.Now() c.startNewAutoTuningEpoch()
} }
c.bytesRead += n c.bytesRead += n
} }
func (c *baseFlowController) hasWindowUpdate() bool {
bytesRemaining := c.receiveWindow - c.bytesRead
// update the window when more than the threshold was consumed
return bytesRemaining <= protocol.ByteCount((float64(c.receiveWindowSize) * float64((1 - protocol.WindowUpdateThreshold))))
}
// getWindowUpdate updates the receive window, if necessary // getWindowUpdate updates the receive window, if necessary
// it returns the new offset // it returns the new offset
func (c *baseFlowController) getWindowUpdate() protocol.ByteCount { func (c *baseFlowController) getWindowUpdate() protocol.ByteCount {
diff := c.receiveWindow - c.bytesRead if !c.hasWindowUpdate() {
// update the window when more than half of it was already consumed
if diff >= (c.receiveWindowIncrement / 2) {
return 0 return 0
} }
c.maybeAdjustWindowIncrement() c.maybeAdjustWindowSize()
c.receiveWindow = c.bytesRead + c.receiveWindowIncrement c.receiveWindow = c.bytesRead + c.receiveWindowSize
c.lastWindowUpdateTime = time.Now()
return c.receiveWindow return c.receiveWindow
} }
func (c *baseFlowController) IsBlocked() bool { // maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often.
c.mutex.RLock() // For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing.
defer c.mutex.RUnlock() func (c *baseFlowController) maybeAdjustWindowSize() {
bytesReadInEpoch := c.bytesRead - c.epochStartOffset
return c.sendWindowSize() == 0 // don't do anything if less than half the window has been consumed
} if bytesReadInEpoch <= c.receiveWindowSize/2 {
// maybeAdjustWindowIncrement increases the receiveWindowIncrement if we're sending WindowUpdates too often
func (c *baseFlowController) maybeAdjustWindowIncrement() {
if c.lastWindowUpdateTime.IsZero() {
return return
} }
rtt := c.rttStats.SmoothedRTT() rtt := c.rttStats.SmoothedRTT()
if rtt == 0 { if rtt == 0 {
return return
} }
timeSinceLastWindowUpdate := time.Since(c.lastWindowUpdateTime) fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize)
// interval between the window updates is sufficiently large, no need to increase the increment if time.Since(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) {
if timeSinceLastWindowUpdate >= 2*rtt { // window is consumed too fast, try to increase the window size
return c.receiveWindowSize = utils.MinByteCount(2*c.receiveWindowSize, c.maxReceiveWindowSize)
} }
c.receiveWindowIncrement = utils.MinByteCount(2*c.receiveWindowIncrement, c.maxReceiveWindowIncrement) c.startNewAutoTuningEpoch()
}
func (c *baseFlowController) startNewAutoTuningEpoch() {
c.epochStartTime = time.Now()
c.epochStartOffset = c.bytesRead
} }
func (c *baseFlowController) checkFlowControlViolation() bool { func (c *baseFlowController) checkFlowControlViolation() bool {

View File

@ -1,6 +1,8 @@
package flowcontrol package flowcontrol
import ( import (
"os"
"strconv"
"time" "time"
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/congestion"
@ -9,6 +11,16 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
// on the CIs, the timing is a lot less precise, so scale every duration by this factor
func scaleDuration(t time.Duration) time.Duration {
scaleFactor := 1
if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set
scaleFactor = f
}
Expect(scaleFactor).ToNot(BeZero())
return time.Duration(scaleFactor) * t
}
var _ = Describe("Base Flow controller", func() { var _ = Describe("Base Flow controller", func() {
var controller *baseFlowController var controller *baseFlowController
@ -49,22 +61,18 @@ var _ = Describe("Base Flow controller", func() {
controller.UpdateSendWindow(10) controller.UpdateSendWindow(10)
Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(20))) Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(20)))
}) })
It("says when it's blocked", func() {
controller.UpdateSendWindow(100)
Expect(controller.IsBlocked()).To(BeFalse())
controller.AddBytesSent(100)
Expect(controller.IsBlocked()).To(BeTrue())
})
}) })
Context("receive flow control", func() { Context("receive flow control", func() {
var receiveWindow protocol.ByteCount = 10000 var (
var receiveWindowIncrement protocol.ByteCount = 600 receiveWindow protocol.ByteCount = 10000
receiveWindowSize protocol.ByteCount = 1000
)
BeforeEach(func() { BeforeEach(func() {
controller.bytesRead = receiveWindow - receiveWindowSize
controller.receiveWindow = receiveWindow controller.receiveWindow = receiveWindow
controller.receiveWindowIncrement = receiveWindowIncrement controller.receiveWindowSize = receiveWindowSize
}) })
It("adds bytes read", func() { It("adds bytes read", func() {
@ -74,31 +82,30 @@ var _ = Describe("Base Flow controller", func() {
}) })
It("triggers a window update when necessary", func() { It("triggers a window update when necessary", func() {
controller.lastWindowUpdateTime = time.Now().Add(-time.Hour) bytesConsumed := float64(receiveWindowSize)*protocol.WindowUpdateThreshold + 1 // consumed 1 byte more than the threshold
readPosition := receiveWindow - receiveWindowIncrement/2 + 1 bytesRemaining := receiveWindowSize - protocol.ByteCount(bytesConsumed)
readPosition := receiveWindow - bytesRemaining
controller.bytesRead = readPosition controller.bytesRead = readPosition
offset := controller.getWindowUpdate() offset := controller.getWindowUpdate()
Expect(offset).To(Equal(readPosition + receiveWindowIncrement)) Expect(offset).To(Equal(readPosition + receiveWindowSize))
Expect(controller.receiveWindow).To(Equal(readPosition + receiveWindowIncrement)) Expect(controller.receiveWindow).To(Equal(readPosition + receiveWindowSize))
Expect(controller.lastWindowUpdateTime).To(BeTemporally("~", time.Now(), 20*time.Millisecond))
}) })
It("doesn't trigger a window update when not necessary", func() { It("doesn't trigger a window update when not necessary", func() {
lastWindowUpdateTime := time.Now().Add(-time.Hour) bytesConsumed := float64(receiveWindowSize)*protocol.WindowUpdateThreshold - 1 // consumed 1 byte less than the threshold
controller.lastWindowUpdateTime = lastWindowUpdateTime bytesRemaining := receiveWindowSize - protocol.ByteCount(bytesConsumed)
readPosition := receiveWindow - receiveWindow/2 - 1 readPosition := receiveWindow - bytesRemaining
controller.bytesRead = readPosition controller.bytesRead = readPosition
offset := controller.getWindowUpdate() offset := controller.getWindowUpdate()
Expect(offset).To(BeZero()) Expect(offset).To(BeZero())
Expect(controller.lastWindowUpdateTime).To(Equal(lastWindowUpdateTime))
}) })
Context("receive window increment auto-tuning", func() { Context("receive window size auto-tuning", func() {
var oldIncrement protocol.ByteCount var oldWindowSize protocol.ByteCount
BeforeEach(func() { BeforeEach(func() {
oldIncrement = controller.receiveWindowIncrement oldWindowSize = controller.receiveWindowSize
controller.maxReceiveWindowIncrement = 3000 controller.maxReceiveWindowSize = 5000
}) })
// update the congestion such that it returns a given value for the smoothed RTT // update the congestion such that it returns a given value for the smoothed RTT
@ -107,72 +114,98 @@ var _ = Describe("Base Flow controller", func() {
Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked
} }
It("doesn't increase the increment for a new stream", func() { It("doesn't increase the window size for a new stream", func() {
controller.maybeAdjustWindowIncrement() controller.maybeAdjustWindowSize()
Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement)) Expect(controller.receiveWindowSize).To(Equal(oldWindowSize))
}) })
It("doesn't increase the increment when no RTT estimate is available", func() { It("doesn't increase the window size when no RTT estimate is available", func() {
setRtt(0) setRtt(0)
controller.lastWindowUpdateTime = time.Now() controller.startNewAutoTuningEpoch()
controller.maybeAdjustWindowIncrement() controller.AddBytesRead(400)
Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement)) offset := controller.getWindowUpdate()
Expect(offset).ToNot(BeZero()) // make sure a window update is sent
Expect(controller.receiveWindowSize).To(Equal(oldWindowSize))
}) })
It("increases the increment when the last WindowUpdate was sent less than two RTTs ago", func() { It("increases the window size if read so fast that the window would be consumed in less than 4 RTTs", func() {
setRtt(20 * time.Millisecond) bytesRead := controller.bytesRead
controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond) rtt := scaleDuration(20 * time.Millisecond)
controller.maybeAdjustWindowIncrement() setRtt(rtt)
Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement)) // consume more than 2/3 of the window...
}) dataRead := receiveWindowSize*2/3 + 1
// ... in 4*2/3 of the RTT
It("doesn't increase the increase increment when the last WindowUpdate was sent more than two RTTs ago", func() { controller.epochStartOffset = controller.bytesRead
setRtt(20 * time.Millisecond) controller.epochStartTime = time.Now().Add(-rtt * 4 * 2 / 3)
controller.lastWindowUpdateTime = time.Now().Add(-45 * time.Millisecond) controller.AddBytesRead(dataRead)
controller.maybeAdjustWindowIncrement()
Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement))
})
It("doesn't increase the increment to a value higher than the maxReceiveWindowIncrement", func() {
setRtt(20 * time.Millisecond)
controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond)
controller.maybeAdjustWindowIncrement()
Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement)) // 1200
// because the lastWindowUpdateTime is updated by MaybeTriggerWindowUpdate(), we can just call maybeAdjustWindowIncrement() multiple times and get an increase of the increment every time
controller.maybeAdjustWindowIncrement()
Expect(controller.receiveWindowIncrement).To(Equal(2 * 2 * oldIncrement)) // 2400
controller.maybeAdjustWindowIncrement()
Expect(controller.receiveWindowIncrement).To(Equal(controller.maxReceiveWindowIncrement)) // 3000
controller.maybeAdjustWindowIncrement()
Expect(controller.receiveWindowIncrement).To(Equal(controller.maxReceiveWindowIncrement)) // 3000
})
It("returns the new increment when updating the window", func() {
setRtt(20 * time.Millisecond)
controller.AddBytesRead(9900) // receive window is 10000
controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond)
offset := controller.getWindowUpdate() offset := controller.getWindowUpdate()
Expect(offset).ToNot(BeZero()) Expect(offset).ToNot(BeZero())
newIncrement := controller.receiveWindowIncrement // check that the window size was increased
Expect(newIncrement).To(Equal(2 * oldIncrement)) newWindowSize := controller.receiveWindowSize
Expect(offset).To(Equal(protocol.ByteCount(9900 + newIncrement))) Expect(newWindowSize).To(Equal(2 * oldWindowSize))
// check that the new window size was used to increase the offset
Expect(offset).To(Equal(protocol.ByteCount(bytesRead + dataRead + newWindowSize)))
}) })
It("increases the increment sent in the first WindowUpdate, if data is read fast enough", func() { It("doesn't increase the window size if data is read so fast that the window would be consumed in less than 4 RTTs, but less than half the window has been read", func() {
setRtt(20 * time.Millisecond) // this test only makes sense if a window update is triggered before half of the window has been consumed
controller.AddBytesRead(9900) Expect(protocol.WindowUpdateThreshold).To(BeNumerically(">", 1/3))
bytesRead := controller.bytesRead
rtt := scaleDuration(20 * time.Millisecond)
setRtt(rtt)
// consume more than 2/3 of the window...
dataRead := receiveWindowSize*1/3 + 1
// ... in 4*2/3 of the RTT
controller.epochStartOffset = controller.bytesRead
controller.epochStartTime = time.Now().Add(-rtt * 4 * 1 / 3)
controller.AddBytesRead(dataRead)
offset := controller.getWindowUpdate() offset := controller.getWindowUpdate()
Expect(offset).ToNot(BeZero()) Expect(offset).ToNot(BeZero())
Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement)) // check that the window size was not increased
newWindowSize := controller.receiveWindowSize
Expect(newWindowSize).To(Equal(oldWindowSize))
// check that the new window size was used to increase the offset
Expect(offset).To(Equal(protocol.ByteCount(bytesRead + dataRead + newWindowSize)))
}) })
It("doesn't increamse the increment sent in the first WindowUpdate, if data is read slowly", func() { It("doesn't increase the window size if read too slowly", func() {
setRtt(5 * time.Millisecond) bytesRead := controller.bytesRead
controller.AddBytesRead(9900) rtt := scaleDuration(20 * time.Millisecond)
time.Sleep(15 * time.Millisecond) // more than 2x RTT setRtt(rtt)
// consume less than 2/3 of the window...
dataRead := receiveWindowSize*2/3 - 1
// ... in 4*2/3 of the RTT
controller.epochStartOffset = controller.bytesRead
controller.epochStartTime = time.Now().Add(-rtt * 4 * 2 / 3)
controller.AddBytesRead(dataRead)
offset := controller.getWindowUpdate() offset := controller.getWindowUpdate()
Expect(offset).ToNot(BeZero()) Expect(offset).ToNot(BeZero())
Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement)) // check that the window size was not increased
Expect(controller.receiveWindowSize).To(Equal(oldWindowSize))
// check that the new window size was used to increase the offset
Expect(offset).To(Equal(protocol.ByteCount(bytesRead + dataRead + oldWindowSize)))
})
It("doesn't increase the window size to a value higher than the maxReceiveWindowSize", func() {
resetEpoch := func() {
// make sure the next call to maybeAdjustWindowSize will increase the window
controller.epochStartTime = time.Now().Add(-time.Millisecond)
controller.epochStartOffset = controller.bytesRead
controller.AddBytesRead(controller.receiveWindowSize/2 + 1)
}
setRtt(scaleDuration(20 * time.Millisecond))
resetEpoch()
controller.maybeAdjustWindowSize()
Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize)) // 2000
// because the lastWindowUpdateTime is updated by MaybeTriggerWindowUpdate(), we can just call maybeAdjustWindowSize() multiple times and get an increase of the window size every time
resetEpoch()
controller.maybeAdjustWindowSize()
Expect(controller.receiveWindowSize).To(Equal(2 * 2 * oldWindowSize)) // 4000
resetEpoch()
controller.maybeAdjustWindowSize()
Expect(controller.receiveWindowSize).To(Equal(controller.maxReceiveWindowSize)) // 5000
controller.maybeAdjustWindowSize()
Expect(controller.receiveWindowSize).To(Equal(controller.maxReceiveWindowSize)) // 5000
}) })
}) })
}) })

View File

@ -2,7 +2,6 @@ package flowcontrol
import ( import (
"fmt" "fmt"
"time"
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
@ -11,6 +10,7 @@ import (
) )
type connectionFlowController struct { type connectionFlowController struct {
lastBlockedAt protocol.ByteCount
baseFlowController baseFlowController
} }
@ -27,19 +27,27 @@ func NewConnectionFlowController(
baseFlowController: baseFlowController{ baseFlowController: baseFlowController{
rttStats: rttStats, rttStats: rttStats,
receiveWindow: receiveWindow, receiveWindow: receiveWindow,
receiveWindowIncrement: receiveWindow, receiveWindowSize: receiveWindow,
maxReceiveWindowIncrement: maxReceiveWindow, maxReceiveWindowSize: maxReceiveWindow,
}, },
} }
} }
func (c *connectionFlowController) SendWindowSize() protocol.ByteCount { func (c *connectionFlowController) SendWindowSize() protocol.ByteCount {
c.mutex.RLock()
defer c.mutex.RUnlock()
return c.baseFlowController.sendWindowSize() return c.baseFlowController.sendWindowSize()
} }
// IsNewlyBlocked says if it is newly blocked by flow control.
// For every offset, it only returns true once.
// If it is blocked, the offset is returned.
func (c *connectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
return false, 0
}
c.lastBlockedAt = c.sendWindow
return true, c.sendWindow
}
// IncrementHighestReceived adds an increment to the highestReceived value // IncrementHighestReceived adds an increment to the highestReceived value
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error { func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error {
c.mutex.Lock() c.mutex.Lock()
@ -54,24 +62,22 @@ func (c *connectionFlowController) IncrementHighestReceived(increment protocol.B
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount { func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() oldWindowSize := c.receiveWindowSize
oldWindowIncrement := c.receiveWindowIncrement
offset := c.baseFlowController.getWindowUpdate() offset := c.baseFlowController.getWindowUpdate()
if oldWindowIncrement < c.receiveWindowIncrement { if oldWindowSize < c.receiveWindowSize {
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10)) utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
} }
c.mutex.Unlock()
return offset return offset
} }
// EnsureMinimumWindowIncrement sets a minimum window increment // EnsureMinimumWindowSize sets a minimum window size
// it should make sure that the connection-level window is increased when a stream-level window grows // it should make sure that the connection-level window is increased when a stream-level window grows
func (c *connectionFlowController) EnsureMinimumWindowIncrement(inc protocol.ByteCount) { func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) {
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() if inc > c.receiveWindowSize {
c.receiveWindowSize = utils.MinByteCount(inc, c.maxReceiveWindowSize)
if inc > c.receiveWindowIncrement { c.startNewAutoTuningEpoch()
c.receiveWindowIncrement = utils.MinByteCount(inc, c.maxReceiveWindowIncrement)
c.lastWindowUpdateTime = time.Time{} // disables autotuning for the next window update
} }
c.mutex.Unlock()
} }

View File

@ -32,12 +32,12 @@ var _ = Describe("Connection Flow controller", func() {
fc := NewConnectionFlowController(receiveWindow, maxReceiveWindow, rttStats).(*connectionFlowController) fc := NewConnectionFlowController(receiveWindow, maxReceiveWindow, rttStats).(*connectionFlowController)
Expect(fc.receiveWindow).To(Equal(receiveWindow)) Expect(fc.receiveWindow).To(Equal(receiveWindow))
Expect(fc.maxReceiveWindowIncrement).To(Equal(maxReceiveWindow)) Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow))
}) })
}) })
Context("receive flow control", func() { Context("receive flow control", func() {
It("increases the highestReceived by a given increment", func() { It("increases the highestReceived by a given window size", func() {
controller.highestReceived = 1337 controller.highestReceived = 1337
controller.IncrementHighestReceived(123) controller.IncrementHighestReceived(123)
Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1337 + 123))) Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1337 + 123)))
@ -46,64 +46,96 @@ var _ = Describe("Connection Flow controller", func() {
Context("getting window updates", func() { Context("getting window updates", func() {
BeforeEach(func() { BeforeEach(func() {
controller.receiveWindow = 100 controller.receiveWindow = 100
controller.receiveWindowIncrement = 60 controller.receiveWindowSize = 60
controller.maxReceiveWindowIncrement = 1000 controller.maxReceiveWindowSize = 1000
controller.bytesRead = 100 - 60
}) })
It("gets a window update", func() { It("gets a window update", func() {
controller.AddBytesRead(80) windowSize := controller.receiveWindowSize
oldOffset := controller.bytesRead
dataRead := windowSize/2 - 1 // make sure not to trigger auto-tuning
controller.AddBytesRead(dataRead)
offset := controller.GetWindowUpdate() offset := controller.GetWindowUpdate()
Expect(offset).To(Equal(protocol.ByteCount(80 + 60))) Expect(offset).To(Equal(protocol.ByteCount(oldOffset + dataRead + 60)))
}) })
It("autotunes the window", func() { It("autotunes the window", func() {
controller.AddBytesRead(80) oldOffset := controller.bytesRead
setRtt(20 * time.Millisecond) oldWindowSize := controller.receiveWindowSize
controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond) rtt := scaleDuration(20 * time.Millisecond)
setRtt(rtt)
controller.epochStartTime = time.Now().Add(-time.Millisecond)
controller.epochStartOffset = oldOffset
dataRead := oldWindowSize/2 + 1
controller.AddBytesRead(dataRead)
offset := controller.GetWindowUpdate() offset := controller.GetWindowUpdate()
Expect(offset).To(Equal(protocol.ByteCount(80 + 2*60))) newWindowSize := controller.receiveWindowSize
Expect(newWindowSize).To(Equal(2 * oldWindowSize))
Expect(offset).To(Equal(protocol.ByteCount(oldOffset + dataRead + newWindowSize)))
}) })
}) })
}) })
Context("setting the minimum increment", func() { Context("send flow control", func() {
It("says when it's blocked", func() {
controller.UpdateSendWindow(100)
Expect(controller.IsNewlyBlocked()).To(BeFalse())
controller.AddBytesSent(100)
blocked, offset := controller.IsNewlyBlocked()
Expect(blocked).To(BeTrue())
Expect(offset).To(Equal(protocol.ByteCount(100)))
})
It("doesn't say that it's newly blocked multiple times for the same offset", func() {
controller.UpdateSendWindow(100)
controller.AddBytesSent(100)
newlyBlocked, offset := controller.IsNewlyBlocked()
Expect(newlyBlocked).To(BeTrue())
Expect(offset).To(Equal(protocol.ByteCount(100)))
newlyBlocked, _ = controller.IsNewlyBlocked()
Expect(newlyBlocked).To(BeFalse())
controller.UpdateSendWindow(150)
controller.AddBytesSent(150)
newlyBlocked, offset = controller.IsNewlyBlocked()
Expect(newlyBlocked).To(BeTrue())
})
})
Context("setting the minimum window size", func() {
var ( var (
oldIncrement protocol.ByteCount oldWindowSize protocol.ByteCount
receiveWindow protocol.ByteCount = 10000 receiveWindow protocol.ByteCount = 10000
receiveWindowIncrement protocol.ByteCount = 600 receiveWindowSize protocol.ByteCount = 1000
) )
BeforeEach(func() { BeforeEach(func() {
controller.bytesRead = receiveWindowSize - receiveWindowSize
controller.receiveWindow = receiveWindow controller.receiveWindow = receiveWindow
controller.receiveWindowIncrement = receiveWindowIncrement controller.receiveWindowSize = receiveWindowSize
oldIncrement = controller.receiveWindowIncrement oldWindowSize = controller.receiveWindowSize
controller.maxReceiveWindowIncrement = 3000 controller.maxReceiveWindowSize = 3000
}) })
It("sets the minimum window increment", func() { It("sets the minimum window window size", func() {
controller.EnsureMinimumWindowIncrement(1000) controller.EnsureMinimumWindowSize(1800)
Expect(controller.receiveWindowIncrement).To(Equal(protocol.ByteCount(1000))) Expect(controller.receiveWindowSize).To(Equal(protocol.ByteCount(1800)))
}) })
It("doesn't reduce the window increment", func() { It("doesn't reduce the window window size", func() {
controller.EnsureMinimumWindowIncrement(1) controller.EnsureMinimumWindowSize(1)
Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement)) Expect(controller.receiveWindowSize).To(Equal(oldWindowSize))
}) })
It("doens't increase the increment beyond the maxReceiveWindowIncrement", func() { It("doens't increase the window size beyond the maxReceiveWindowSize", func() {
max := controller.maxReceiveWindowIncrement max := controller.maxReceiveWindowSize
controller.EnsureMinimumWindowIncrement(2 * max) controller.EnsureMinimumWindowSize(2 * max)
Expect(controller.receiveWindowIncrement).To(Equal(max)) Expect(controller.receiveWindowSize).To(Equal(max))
}) })
It("doesn't auto-tune the window after the increment was increased", func() { It("starts a new epoch after the window size was increased", func() {
setRtt(20 * time.Millisecond) controller.EnsureMinimumWindowSize(1912)
controller.bytesRead = 9900 // receive window is 10000 Expect(controller.epochStartTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond))
controller.lastWindowUpdateTime = time.Now().Add(-20 * time.Millisecond)
controller.EnsureMinimumWindowIncrement(912)
offset := controller.getWindowUpdate()
Expect(controller.receiveWindowIncrement).To(Equal(protocol.ByteCount(912))) // no auto-tuning
Expect(offset).To(Equal(protocol.ByteCount(9900 + 912)))
}) })
}) })
}) })

View File

@ -5,7 +5,6 @@ import "github.com/lucas-clemente/quic-go/internal/protocol"
type flowController interface { type flowController interface {
// for sending // for sending
SendWindowSize() protocol.ByteCount SendWindowSize() protocol.ByteCount
IsBlocked() bool
UpdateSendWindow(protocol.ByteCount) UpdateSendWindow(protocol.ByteCount)
AddBytesSent(protocol.ByteCount) AddBytesSent(protocol.ByteCount)
// for receiving // for receiving
@ -16,22 +15,28 @@ type flowController interface {
// A StreamFlowController is a flow controller for a QUIC stream. // A StreamFlowController is a flow controller for a QUIC stream.
type StreamFlowController interface { type StreamFlowController interface {
flowController flowController
// for sending
IsBlocked() (bool, protocol.ByteCount)
// for receiving // for receiving
// UpdateHighestReceived should be called when a new highest offset is received // UpdateHighestReceived should be called when a new highest offset is received
// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame // final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame
UpdateHighestReceived(offset protocol.ByteCount, final bool) error UpdateHighestReceived(offset protocol.ByteCount, final bool) error
// HasWindowUpdate says if it is necessary to update the window
HasWindowUpdate() bool
} }
// The ConnectionFlowController is the flow controller for the connection. // The ConnectionFlowController is the flow controller for the connection.
type ConnectionFlowController interface { type ConnectionFlowController interface {
flowController flowController
// for sending
IsNewlyBlocked() (bool, protocol.ByteCount)
} }
type connectionFlowControllerI interface { type connectionFlowControllerI interface {
ConnectionFlowController ConnectionFlowController
// The following two methods are not supposed to be called from outside this packet, but are needed internally // The following two methods are not supposed to be called from outside this packet, but are needed internally
// for sending // for sending
EnsureMinimumWindowIncrement(protocol.ByteCount) EnsureMinimumWindowSize(protocol.ByteCount)
// for receiving // for receiving
IncrementHighestReceived(protocol.ByteCount) error IncrementHighestReceived(protocol.ByteCount) error
} }

View File

@ -39,8 +39,8 @@ func NewStreamFlowController(
baseFlowController: baseFlowController{ baseFlowController: baseFlowController{
rttStats: rttStats, rttStats: rttStats,
receiveWindow: receiveWindow, receiveWindow: receiveWindow,
receiveWindowIncrement: receiveWindow, receiveWindowSize: receiveWindow,
maxReceiveWindowIncrement: maxReceiveWindow, maxReceiveWindowSize: maxReceiveWindow,
sendWindow: initialSendWindow, sendWindow: initialSendWindow,
}, },
} }
@ -102,9 +102,6 @@ func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
} }
func (c *streamFlowController) SendWindowSize() protocol.ByteCount { func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
c.mutex.Lock()
defer c.mutex.Unlock()
window := c.baseFlowController.sendWindowSize() window := c.baseFlowController.sendWindowSize()
if c.contributesToConnection { if c.contributesToConnection {
window = utils.MinByteCount(window, c.connection.SendWindowSize()) window = utils.MinByteCount(window, c.connection.SendWindowSize())
@ -112,22 +109,39 @@ func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
return window return window
} }
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { // IsBlocked says if it is blocked by stream-level flow control.
c.mutex.Lock() // If it is blocked, the offset is returned.
defer c.mutex.Unlock() func (c *streamFlowController) IsBlocked() (bool, protocol.ByteCount) {
if c.sendWindowSize() != 0 {
return false, 0
}
return true, c.sendWindow
}
func (c *streamFlowController) HasWindowUpdate() bool {
c.mutex.Lock()
hasWindowUpdate := !c.receivedFinalOffset && c.hasWindowUpdate()
c.mutex.Unlock()
return hasWindowUpdate
}
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
// don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler
c.mutex.Lock()
// if we already received the final offset for this stream, the peer won't need any additional flow control credit // if we already received the final offset for this stream, the peer won't need any additional flow control credit
if c.receivedFinalOffset { if c.receivedFinalOffset {
c.mutex.Unlock()
return 0 return 0
} }
oldWindowIncrement := c.receiveWindowIncrement oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate() offset := c.baseFlowController.getWindowUpdate()
if c.receiveWindowIncrement > oldWindowIncrement { // auto-tuning enlarged the window increment if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10)) utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
if c.contributesToConnection { if c.contributesToConnection {
c.connection.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(c.receiveWindowIncrement) * protocol.ConnectionFlowControlMultiplier)) c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
} }
} }
c.mutex.Unlock()
return offset return offset
} }

View File

@ -19,7 +19,7 @@ var _ = Describe("Stream Flow controller", func() {
streamID: 10, streamID: 10,
connection: NewConnectionFlowController(1000, 1000, rttStats).(*connectionFlowController), connection: NewConnectionFlowController(1000, 1000, rttStats).(*connectionFlowController),
} }
controller.maxReceiveWindowIncrement = 10000 controller.maxReceiveWindowSize = 10000
controller.rttStats = rttStats controller.rttStats = rttStats
}) })
@ -35,7 +35,7 @@ var _ = Describe("Stream Flow controller", func() {
fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, rttStats).(*streamFlowController) fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, rttStats).(*streamFlowController)
Expect(fc.streamID).To(Equal(protocol.StreamID(5))) Expect(fc.streamID).To(Equal(protocol.StreamID(5)))
Expect(fc.receiveWindow).To(Equal(receiveWindow)) Expect(fc.receiveWindow).To(Equal(receiveWindow))
Expect(fc.maxReceiveWindowIncrement).To(Equal(maxReceiveWindow)) Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow))
Expect(fc.sendWindow).To(Equal(sendWindow)) Expect(fc.sendWindow).To(Equal(sendWindow))
Expect(fc.contributesToConnection).To(BeTrue()) Expect(fc.contributesToConnection).To(BeTrue())
}) })
@ -44,11 +44,11 @@ var _ = Describe("Stream Flow controller", func() {
Context("receiving data", func() { Context("receiving data", func() {
Context("registering received offsets", func() { Context("registering received offsets", func() {
var receiveWindow protocol.ByteCount = 10000 var receiveWindow protocol.ByteCount = 10000
var receiveWindowIncrement protocol.ByteCount = 600 var receiveWindowSize protocol.ByteCount = 600
BeforeEach(func() { BeforeEach(func() {
controller.receiveWindow = receiveWindow controller.receiveWindow = receiveWindow
controller.receiveWindowIncrement = receiveWindowIncrement controller.receiveWindowSize = receiveWindowSize
}) })
It("updates the highestReceived", func() { It("updates the highestReceived", func() {
@ -157,7 +157,7 @@ var _ = Describe("Stream Flow controller", func() {
}) })
Context("generating window updates", func() { Context("generating window updates", func() {
var oldIncrement protocol.ByteCount var oldWindowSize protocol.ByteCount
// update the congestion such that it returns a given value for the smoothed RTT // update the congestion such that it returns a given value for the smoothed RTT
setRtt := func(t time.Duration) { setRtt := func(t time.Duration) {
@ -167,37 +167,51 @@ var _ = Describe("Stream Flow controller", func() {
BeforeEach(func() { BeforeEach(func() {
controller.receiveWindow = 100 controller.receiveWindow = 100
controller.receiveWindowIncrement = 60 controller.receiveWindowSize = 60
controller.connection.(*connectionFlowController).receiveWindowIncrement = 120 controller.bytesRead = 100 - 60
oldIncrement = controller.receiveWindowIncrement controller.connection.(*connectionFlowController).receiveWindowSize = 120
oldWindowSize = controller.receiveWindowSize
})
It("tells if it has window updates", func() {
Expect(controller.HasWindowUpdate()).To(BeFalse())
controller.AddBytesRead(30)
Expect(controller.HasWindowUpdate()).To(BeTrue())
Expect(controller.GetWindowUpdate()).ToNot(BeZero())
Expect(controller.HasWindowUpdate()).To(BeFalse())
}) })
It("tells the connection flow controller when the window was autotuned", func() { It("tells the connection flow controller when the window was autotuned", func() {
oldOffset := controller.bytesRead
controller.contributesToConnection = true controller.contributesToConnection = true
controller.AddBytesRead(75) setRtt(scaleDuration(20 * time.Millisecond))
setRtt(20 * time.Millisecond) controller.epochStartOffset = oldOffset
controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond) controller.epochStartTime = time.Now().Add(-time.Millisecond)
controller.AddBytesRead(55)
offset := controller.GetWindowUpdate() offset := controller.GetWindowUpdate()
Expect(offset).To(Equal(protocol.ByteCount(75 + 2*60))) Expect(offset).To(Equal(protocol.ByteCount(oldOffset + 55 + 2*oldWindowSize)))
Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement)) Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize))
Expect(controller.connection.(*connectionFlowController).receiveWindowIncrement).To(Equal(protocol.ByteCount(float64(controller.receiveWindowIncrement) * protocol.ConnectionFlowControlMultiplier))) Expect(controller.connection.(*connectionFlowController).receiveWindowSize).To(Equal(protocol.ByteCount(float64(controller.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier)))
}) })
It("doesn't tell the connection flow controller if it doesn't contribute", func() { It("doesn't tell the connection flow controller if it doesn't contribute", func() {
oldOffset := controller.bytesRead
controller.contributesToConnection = false controller.contributesToConnection = false
controller.AddBytesRead(75) setRtt(scaleDuration(20 * time.Millisecond))
setRtt(20 * time.Millisecond) controller.epochStartOffset = oldOffset
controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond) controller.epochStartTime = time.Now().Add(-time.Millisecond)
controller.AddBytesRead(55)
offset := controller.GetWindowUpdate() offset := controller.GetWindowUpdate()
Expect(offset).ToNot(BeZero()) Expect(offset).ToNot(BeZero())
Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement)) Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize))
Expect(controller.connection.(*connectionFlowController).receiveWindowIncrement).To(Equal(protocol.ByteCount(120))) // unchanged Expect(controller.connection.(*connectionFlowController).receiveWindowSize).To(Equal(protocol.ByteCount(2 * oldWindowSize))) // unchanged
}) })
It("doesn't increase the window after a final offset was already received", func() { It("doesn't increase the window after a final offset was already received", func() {
controller.AddBytesRead(80) controller.AddBytesRead(30)
err := controller.UpdateHighestReceived(90, true) err := controller.UpdateHighestReceived(90, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(controller.HasWindowUpdate()).To(BeFalse())
offset := controller.GetWindowUpdate() offset := controller.GetWindowUpdate()
Expect(offset).To(BeZero()) Expect(offset).To(BeZero())
}) })
@ -231,7 +245,8 @@ var _ = Describe("Stream Flow controller", func() {
controller.connection.UpdateSendWindow(50) controller.connection.UpdateSendWindow(50)
controller.UpdateSendWindow(100) controller.UpdateSendWindow(100)
controller.AddBytesSent(50) controller.AddBytesSent(50)
Expect(controller.connection.IsBlocked()).To(BeTrue()) blocked, _ := controller.connection.IsNewlyBlocked()
Expect(blocked).To(BeTrue())
Expect(controller.IsBlocked()).To(BeFalse()) Expect(controller.IsBlocked()).To(BeFalse())
}) })
}) })

View File

@ -52,7 +52,7 @@ type cryptoSetupClient struct {
forwardSecureAEAD crypto.AEAD forwardSecureAEAD crypto.AEAD
paramsChan chan<- TransportParameters paramsChan chan<- TransportParameters
aeadChanged chan<- protocol.EncryptionLevel handshakeEvent chan<- struct{}
params *TransportParameters params *TransportParameters
} }
@ -74,7 +74,7 @@ func NewCryptoSetupClient(
tlsConfig *tls.Config, tlsConfig *tls.Config,
params *TransportParameters, params *TransportParameters,
paramsChan chan<- TransportParameters, paramsChan chan<- TransportParameters,
aeadChanged chan<- protocol.EncryptionLevel, handshakeEvent chan<- struct{},
initialVersion protocol.VersionNumber, initialVersion protocol.VersionNumber,
negotiatedVersions []protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber,
) (CryptoSetup, error) { ) (CryptoSetup, error) {
@ -93,7 +93,7 @@ func NewCryptoSetupClient(
keyExchange: getEphermalKEX, keyExchange: getEphermalKEX,
nullAEAD: nullAEAD, nullAEAD: nullAEAD,
paramsChan: paramsChan, paramsChan: paramsChan,
aeadChanged: aeadChanged, handshakeEvent: handshakeEvent,
initialVersion: initialVersion, initialVersion: initialVersion,
negotiatedVersions: negotiatedVersions, negotiatedVersions: negotiatedVersions,
divNonceChan: make(chan []byte), divNonceChan: make(chan []byte),
@ -159,8 +159,8 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
} }
// blocks until the session has received the parameters // blocks until the session has received the parameters
h.paramsChan <- *params h.paramsChan <- *params
h.aeadChanged <- protocol.EncryptionForwardSecure h.handshakeEvent <- struct{}{}
close(h.aeadChanged) close(h.handshakeEvent)
default: default:
return qerr.InvalidCryptoMessageType return qerr.InvalidCryptoMessageType
} }
@ -381,6 +381,15 @@ func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) {
h.divNonceChan <- data h.divNonceChan <- data
} }
func (h *cryptoSetupClient) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()
return ConnectionState{
HandshakeComplete: h.forwardSecureAEAD != nil,
PeerCertificates: h.certManager.GetChain(),
}
}
func (h *cryptoSetupClient) sendCHLO() error { func (h *cryptoSetupClient) sendCHLO() error {
h.clientHelloCounter++ h.clientHelloCounter++
if h.clientHelloCounter > protocol.MaxClientHellos { if h.clientHelloCounter > protocol.MaxClientHellos {
@ -496,10 +505,8 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
if err != nil { if err != nil {
return err return err
} }
h.handshakeEvent <- struct{}{}
h.aeadChanged <- protocol.EncryptionSecure
} }
return nil return nil
} }

View File

@ -2,6 +2,7 @@ package handshake
import ( import (
"bytes" "bytes"
"crypto/x509"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@ -10,6 +11,7 @@ import (
"github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/mocks/crypto" "github.com/lucas-clemente/quic-go/internal/mocks/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
@ -34,6 +36,8 @@ type mockCertManager struct {
commonCertificateHashes []byte commonCertificateHashes []byte
chain []*x509.Certificate
leafCert []byte leafCert []byte
leafCertHash uint64 leafCertHash uint64
leafCertHashError error leafCertHashError error
@ -45,6 +49,8 @@ type mockCertManager struct {
verifyCalled bool verifyCalled bool
} }
var _ crypto.CertManager = &mockCertManager{}
func (m *mockCertManager) SetData(data []byte) error { func (m *mockCertManager) SetData(data []byte) error {
m.setDataCalledWith = data m.setDataCalledWith = data
return m.setDataError return m.setDataError
@ -72,6 +78,10 @@ func (m *mockCertManager) Verify(hostname string) error {
return m.verifyError return m.verifyError
} }
func (m *mockCertManager) GetChain() []*x509.Certificate {
return m.chain
}
var _ = Describe("Client Crypto Setup", func() { var _ = Describe("Client Crypto Setup", func() {
var ( var (
cs *cryptoSetupClient cs *cryptoSetupClient
@ -79,7 +89,7 @@ var _ = Describe("Client Crypto Setup", func() {
stream *mockStream stream *mockStream
keyDerivationCalledWith *keyDerivationValues keyDerivationCalledWith *keyDerivationValues
shloMap map[Tag][]byte shloMap map[Tag][]byte
aeadChanged chan protocol.EncryptionLevel handshakeEvent chan struct{}
paramsChan chan TransportParameters paramsChan chan TransportParameters
) )
@ -108,7 +118,7 @@ var _ = Describe("Client Crypto Setup", func() {
version := protocol.Version39 version := protocol.Version39
// use a buffered channel here, so that we can parse a SHLO without having to receive the TransportParameters to avoid blocking // use a buffered channel here, so that we can parse a SHLO without having to receive the TransportParameters to avoid blocking
paramsChan = make(chan TransportParameters, 1) paramsChan = make(chan TransportParameters, 1)
aeadChanged = make(chan protocol.EncryptionLevel, 2) handshakeEvent = make(chan struct{}, 2)
csInt, err := NewCryptoSetupClient( csInt, err := NewCryptoSetupClient(
stream, stream,
"hostname", "hostname",
@ -117,7 +127,7 @@ var _ = Describe("Client Crypto Setup", func() {
nil, nil,
&TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout}, &TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout},
paramsChan, paramsChan,
aeadChanged, handshakeEvent,
protocol.Version39, protocol.Version39,
nil, nil,
) )
@ -130,10 +140,6 @@ var _ = Describe("Client Crypto Setup", func() {
cs.cryptoStream = stream cs.cryptoStream = stream
}) })
AfterEach(func() {
close(stream.unblockRead)
})
Context("Reading REJ", func() { Context("Reading REJ", func() {
var tagMap map[Tag][]byte var tagMap map[Tag][]byte
@ -158,8 +164,17 @@ var _ = Describe("Client Crypto Setup", func() {
stk := []byte("foobar") stk := []byte("foobar")
tagMap[TagSTK] = stk tagMap[TagSTK] = stk
HandshakeMessage{Tag: TagREJ, Data: tagMap}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagREJ, Data: tagMap}.Write(&stream.dataToRead)
go cs.HandleCryptoStream() done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
close(done)
}()
Eventually(func() []byte { return cs.stk }).Should(Equal(stk)) Eventually(func() []byte { return cs.stk }).Should(Equal(stk))
// make the go routine return
stream.close()
Eventually(done).Should(BeClosed())
}) })
It("saves the proof", func() { It("saves the proof", func() {
@ -380,22 +395,22 @@ var _ = Describe("Client Crypto Setup", func() {
cs.receivedSecurePacket = false cs.receivedSecurePacket = false
_, err := cs.handleSHLOMessage(shloMap) _, err := cs.handleSHLOMessage(shloMap)
Expect(err).To(MatchError(qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message"))) Expect(err).To(MatchError(qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")))
Expect(aeadChanged).ToNot(Receive()) Expect(handshakeEvent).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed()) Expect(handshakeEvent).ToNot(BeClosed())
}) })
It("rejects SHLOs without a PUBS", func() { It("rejects SHLOs without a PUBS", func() {
delete(shloMap, TagPUBS) delete(shloMap, TagPUBS)
_, err := cs.handleSHLOMessage(shloMap) _, err := cs.handleSHLOMessage(shloMap)
Expect(err).To(MatchError(qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS"))) Expect(err).To(MatchError(qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")))
Expect(aeadChanged).ToNot(BeClosed()) Expect(handshakeEvent).ToNot(BeClosed())
}) })
It("rejects SHLOs without a version list", func() { It("rejects SHLOs without a version list", func() {
delete(shloMap, TagVER) delete(shloMap, TagVER)
_, err := cs.handleSHLOMessage(shloMap) _, err := cs.handleSHLOMessage(shloMap)
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list"))) Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")))
Expect(aeadChanged).ToNot(BeClosed()) Expect(handshakeEvent).ToNot(BeClosed())
}) })
It("accepts a SHLO after a version negotiation", func() { It("accepts a SHLO after a version negotiation", func() {
@ -430,28 +445,38 @@ var _ = Describe("Client Crypto Setup", func() {
Expect(params.IdleTimeout).To(Equal(13 * time.Second)) Expect(params.IdleTimeout).To(Equal(13 * time.Second))
}) })
It("closes the aeadChanged when receiving an SHLO", func() { It("closes the handshakeEvent chan when receiving an SHLO", func() {
HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead)
done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
close(done)
}() }()
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionForwardSecure))) Eventually(handshakeEvent).Should(Receive())
Eventually(aeadChanged).Should(BeClosed()) Eventually(handshakeEvent).Should(BeClosed())
// make the go routine return
stream.close()
Eventually(done).Should(BeClosed())
}) })
It("passes the transport parameters on the channel", func() { It("passes the transport parameters on the channel", func() {
shloMap[TagSFCW] = []byte{0x0d, 0x00, 0xdf, 0xba} shloMap[TagSFCW] = []byte{0x0d, 0x00, 0xdf, 0xba}
HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead)
done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
close(done)
}() }()
var params TransportParameters var params TransportParameters
Eventually(paramsChan).Should(Receive(&params)) Eventually(paramsChan).Should(Receive(&params))
Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0xbadf000d))) Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0xbadf000d)))
// make the go routine return
stream.close()
Eventually(done).Should(BeClosed())
}) })
It("errors if it can't read a connection parameter", func() { It("errors if it can't read a connection parameter", func() {
@ -637,9 +662,9 @@ var _ = Describe("Client Crypto Setup", func() {
Expect(keyDerivationCalledWith.cert).To(Equal(certManager.leafCert)) Expect(keyDerivationCalledWith.cert).To(Equal(certManager.leafCert))
Expect(keyDerivationCalledWith.divNonce).To(Equal(cs.diversificationNonce)) Expect(keyDerivationCalledWith.divNonce).To(Equal(cs.diversificationNonce))
Expect(keyDerivationCalledWith.pers).To(Equal(protocol.PerspectiveClient)) Expect(keyDerivationCalledWith.pers).To(Equal(protocol.PerspectiveClient))
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) Expect(handshakeEvent).To(Receive())
Expect(aeadChanged).ToNot(Receive()) Expect(handshakeEvent).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed()) Expect(handshakeEvent).ToNot(BeClosed())
}) })
It("uses the server nonce, if the server sent one", func() { It("uses the server nonce, if the server sent one", func() {
@ -649,51 +674,64 @@ var _ = Describe("Client Crypto Setup", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(cs.secureAEAD).ToNot(BeNil()) Expect(cs.secureAEAD).ToNot(BeNil())
Expect(keyDerivationCalledWith.nonces).To(Equal(append(cs.nonc, cs.sno...))) Expect(keyDerivationCalledWith.nonces).To(Equal(append(cs.nonc, cs.sno...)))
Expect(aeadChanged).To(Receive()) Expect(handshakeEvent).To(Receive())
Expect(aeadChanged).ToNot(Receive()) Expect(handshakeEvent).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed()) Expect(handshakeEvent).ToNot(BeClosed())
}) })
It("doesn't create a secureAEAD if the certificate is not yet verified, even if it has all necessary values", func() { It("doesn't create a secureAEAD if the certificate is not yet verified, even if it has all necessary values", func() {
err := cs.maybeUpgradeCrypto() err := cs.maybeUpgradeCrypto()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(cs.secureAEAD).To(BeNil()) Expect(cs.secureAEAD).To(BeNil())
Expect(aeadChanged).ToNot(Receive()) Expect(handshakeEvent).ToNot(Receive())
cs.serverVerified = true cs.serverVerified = true
// make sure we really had all necessary values before, and only serverVerified was missing // make sure we really had all necessary values before, and only serverVerified was missing
err = cs.maybeUpgradeCrypto() err = cs.maybeUpgradeCrypto()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(cs.secureAEAD).ToNot(BeNil()) Expect(cs.secureAEAD).ToNot(BeNil())
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) Expect(handshakeEvent).To(Receive())
Expect(aeadChanged).ToNot(Receive()) Expect(handshakeEvent).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed()) Expect(handshakeEvent).ToNot(BeClosed())
}) })
It("tries to escalate before reading a handshake message", func() { It("tries to escalate before reading a handshake message", func() {
Expect(cs.secureAEAD).To(BeNil()) Expect(cs.secureAEAD).To(BeNil())
cs.serverVerified = true cs.serverVerified = true
go cs.HandleCryptoStream() done := make(chan struct{})
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure)))
Expect(cs.secureAEAD).ToNot(BeNil())
Expect(aeadChanged).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed())
})
It("tries to escalate the crypto after receiving a diversification nonce", func(done Done) {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
cs.HandleCryptoStream() err := cs.HandleCryptoStream()
Fail("HandleCryptoStream should not have returned") Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
close(done)
}()
Eventually(handshakeEvent).Should(Receive())
Expect(cs.secureAEAD).ToNot(BeNil())
Expect(handshakeEvent).ToNot(Receive())
Expect(handshakeEvent).ToNot(BeClosed())
// make the go routine return
stream.close()
Eventually(done).Should(BeClosed())
})
It("tries to escalate the crypto after receiving a diversification nonce", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
close(done)
}() }()
cs.diversificationNonce = nil cs.diversificationNonce = nil
cs.serverVerified = true cs.serverVerified = true
Expect(cs.secureAEAD).To(BeNil()) Expect(cs.secureAEAD).To(BeNil())
cs.SetDiversificationNonce([]byte("div")) cs.SetDiversificationNonce([]byte("div"))
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure))) Eventually(handshakeEvent).Should(Receive())
Expect(cs.secureAEAD).ToNot(BeNil()) Expect(cs.secureAEAD).ToNot(BeNil())
Expect(aeadChanged).ToNot(Receive()) Expect(handshakeEvent).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed()) Expect(handshakeEvent).ToNot(BeClosed())
close(done) // make the go routine return
stream.close()
Eventually(done).Should(BeClosed())
}) })
Context("null encryption", func() { Context("null encryption", func() {
@ -813,6 +851,22 @@ var _ = Describe("Client Crypto Setup", func() {
}) })
}) })
Context("reporting the connection state", func() {
It("reports the connection state before the handshake completes", func() {
chain := []*x509.Certificate{testdata.GetCertificate().Leaf}
certManager.chain = chain
state := cs.ConnectionState()
Expect(state.HandshakeComplete).To(BeFalse())
Expect(state.PeerCertificates).To(Equal(chain))
})
It("reports the connection state after the handshake completes", func() {
doSHLO()
state := cs.ConnectionState()
Expect(state.HandshakeComplete).To(BeTrue())
})
})
Context("forcing encryption levels", func() { Context("forcing encryption levels", func() {
It("forces null encryption", func() { It("forces null encryption", func() {
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(4), []byte{}).Return([]byte("foobar unencrypted")) cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(4), []byte{}).Return([]byte("foobar unencrypted"))
@ -862,32 +916,51 @@ var _ = Describe("Client Crypto Setup", func() {
Context("Diversification Nonces", func() { Context("Diversification Nonces", func() {
It("sets a diversification nonce", func() { It("sets a diversification nonce", func() {
go cs.HandleCryptoStream() done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
close(done)
}()
nonce := []byte("foobar") nonce := []byte("foobar")
cs.SetDiversificationNonce(nonce) cs.SetDiversificationNonce(nonce)
Eventually(func() []byte { return cs.diversificationNonce }).Should(Equal(nonce)) Eventually(func() []byte { return cs.diversificationNonce }).Should(Equal(nonce))
// make the go routine return
stream.close()
Eventually(done).Should(BeClosed())
}) })
It("doesn't do anything when called multiple times with the same nonce", func(done Done) { It("doesn't do anything when called multiple times with the same nonce", func() {
go cs.HandleCryptoStream() done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
close(done)
}()
nonce := []byte("foobar") nonce := []byte("foobar")
cs.SetDiversificationNonce(nonce) cs.SetDiversificationNonce(nonce)
cs.SetDiversificationNonce(nonce) cs.SetDiversificationNonce(nonce)
Eventually(func() []byte { return cs.diversificationNonce }).Should(Equal(nonce)) Eventually(func() []byte { return cs.diversificationNonce }).Should(Equal(nonce))
close(done) // make the go routine return
stream.close()
Eventually(done).Should(BeClosed())
}) })
It("rejects a different diversification nonce", func() { It("rejects a different diversification nonce", func() {
var err error done := make(chan struct{})
go func() { go func() {
err = cs.HandleCryptoStream() defer GinkgoRecover()
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(errConflictingDiversificationNonces))
close(done)
}() }()
nonce1 := []byte("foobar") nonce1 := []byte("foobar")
nonce2 := []byte("raboof") nonce2 := []byte("raboof")
cs.SetDiversificationNonce(nonce1) cs.SetDiversificationNonce(nonce1)
cs.SetDiversificationNonce(nonce2) cs.SetDiversificationNonce(nonce2)
Eventually(func() error { return err }).Should(MatchError(errConflictingDiversificationNonces)) Eventually(done).Should(BeClosed())
}) })
}) })

View File

@ -23,6 +23,8 @@ type KeyExchangeFunction func() crypto.KeyExchange
// The CryptoSetupServer handles all things crypto for the Session // The CryptoSetupServer handles all things crypto for the Session
type cryptoSetupServer struct { type cryptoSetupServer struct {
mutex sync.RWMutex
connID protocol.ConnectionID connID protocol.ConnectionID
remoteAddr net.Addr remoteAddr net.Addr
scfg *ServerConfig scfg *ServerConfig
@ -42,7 +44,7 @@ type cryptoSetupServer struct {
receivedParams bool receivedParams bool
paramsChan chan<- TransportParameters paramsChan chan<- TransportParameters
aeadChanged chan<- protocol.EncryptionLevel handshakeEvent chan<- struct{}
keyDerivation QuicCryptoKeyDerivationFunction keyDerivation QuicCryptoKeyDerivationFunction
keyExchange KeyExchangeFunction keyExchange KeyExchangeFunction
@ -51,7 +53,7 @@ type cryptoSetupServer struct {
params *TransportParameters params *TransportParameters
mutex sync.RWMutex sni string // need to fill out the ConnectionState
} }
var _ CryptoSetup = &cryptoSetupServer{} var _ CryptoSetup = &cryptoSetupServer{}
@ -76,7 +78,7 @@ func NewCryptoSetup(
supportedVersions []protocol.VersionNumber, supportedVersions []protocol.VersionNumber,
acceptSTK func(net.Addr, *Cookie) bool, acceptSTK func(net.Addr, *Cookie) bool,
paramsChan chan<- TransportParameters, paramsChan chan<- TransportParameters,
aeadChanged chan<- protocol.EncryptionLevel, handshakeEvent chan<- struct{},
) (CryptoSetup, error) { ) (CryptoSetup, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version) nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
if err != nil { if err != nil {
@ -96,7 +98,7 @@ func NewCryptoSetup(
acceptSTKCallback: acceptSTK, acceptSTKCallback: acceptSTK,
sentSHLO: make(chan struct{}), sentSHLO: make(chan struct{}),
paramsChan: paramsChan, paramsChan: paramsChan,
aeadChanged: aeadChanged, handshakeEvent: handshakeEvent,
}, nil }, nil
} }
@ -139,6 +141,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
if sni == "" { if sni == "" {
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required") return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
} }
h.sni = sni
// prevent version downgrade attacks // prevent version downgrade attacks
// see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/N-de9j63tCk for a discussion and examples // see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/N-de9j63tCk for a discussion and examples
@ -182,7 +185,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
if _, err := h.cryptoStream.Write(reply); err != nil { if _, err := h.cryptoStream.Write(reply); err != nil {
return false, err return false, err
} }
h.aeadChanged <- protocol.EncryptionForwardSecure h.handshakeEvent <- struct{}{}
close(h.sentSHLO) close(h.sentSHLO)
return true, nil return true, nil
} }
@ -206,9 +209,9 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
if err == nil { if err == nil {
if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client
h.receivedForwardSecurePacket = true h.receivedForwardSecurePacket = true
// wait until protocol.EncryptionForwardSecure was sent on the aeadChan // wait for the send on the handshakeEvent chan
<-h.sentSHLO <-h.sentSHLO
close(h.aeadChanged) close(h.handshakeEvent)
} }
return res, protocol.EncryptionForwardSecure, nil return res, protocol.EncryptionForwardSecure, nil
} }
@ -396,8 +399,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
if err != nil { if err != nil {
return nil, err return nil, err
} }
h.handshakeEvent <- struct{}{}
h.aeadChanged <- protocol.EncryptionSecure
// Generate a new curve instance to derive the forward secure key // Generate a new curve instance to derive the forward secure key
var fsNonce bytes.Buffer var fsNonce bytes.Buffer
@ -454,6 +456,15 @@ func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) {
panic("not needed for cryptoSetupServer") panic("not needed for cryptoSetupServer")
} }
func (h *cryptoSetupServer) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()
return ConnectionState{
ServerName: h.sni,
HandshakeComplete: h.receivedForwardSecurePacket,
}
}
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error { func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
if len(nonce) != 32 { if len(nonce) != 32 {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length") return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"io"
"net" "net"
"time" "time"
@ -63,35 +64,36 @@ func mockQuicCryptoKeyDerivation(forwardSecure bool, sharedSecret, nonces []byte
} }
type mockStream struct { type mockStream struct {
unblockRead chan struct{} // close this chan to unblock Read unblockRead chan struct{}
dataToRead bytes.Buffer dataToRead bytes.Buffer
dataWritten bytes.Buffer dataWritten bytes.Buffer
} }
var _ io.ReadWriter = &mockStream{}
var errMockStreamClosing = errors.New("mock stream closing")
func newMockStream() *mockStream { func newMockStream() *mockStream {
return &mockStream{unblockRead: make(chan struct{})} return &mockStream{unblockRead: make(chan struct{})}
} }
// call Close to make Read return
func (s *mockStream) Read(p []byte) (int, error) { func (s *mockStream) Read(p []byte) (int, error) {
n, _ := s.dataToRead.Read(p) n, _ := s.dataToRead.Read(p)
if n == 0 { // block if there's no data if n == 0 { // block if there's no data
<-s.unblockRead <-s.unblockRead
return 0, errMockStreamClosing
} }
return n, nil // never return an EOF return n, nil // never return an EOF
} }
func (s *mockStream) ReadByte() (byte, error) {
return s.dataToRead.ReadByte()
}
func (s *mockStream) Write(p []byte) (int, error) { func (s *mockStream) Write(p []byte) (int, error) {
return s.dataWritten.Write(p) return s.dataWritten.Write(p)
} }
func (s *mockStream) Close() error { panic("not implemented") } func (s *mockStream) close() {
func (s *mockStream) Reset(error) { panic("not implemented") } close(s.unblockRead)
func (mockStream) CloseRemote(offset protocol.ByteCount) { panic("not implemented") } }
func (s mockStream) StreamID() protocol.StreamID { panic("not implemented") }
type mockCookieProtector struct { type mockCookieProtector struct {
data []byte data []byte
@ -122,7 +124,7 @@ var _ = Describe("Server Crypto Setup", func() {
cs *cryptoSetupServer cs *cryptoSetupServer
stream *mockStream stream *mockStream
paramsChan chan TransportParameters paramsChan chan TransportParameters
aeadChanged chan protocol.EncryptionLevel handshakeEvent chan struct{}
nonce32 []byte nonce32 []byte
versionTag []byte versionTag []byte
validSTK []byte validSTK []byte
@ -144,7 +146,7 @@ var _ = Describe("Server Crypto Setup", func() {
// use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking // use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking
paramsChan = make(chan TransportParameters, 1) paramsChan = make(chan TransportParameters, 1)
aeadChanged = make(chan protocol.EncryptionLevel, 2) handshakeEvent = make(chan struct{}, 2)
stream = newMockStream() stream = newMockStream()
kex = &mockKEX{} kex = &mockKEX{}
signer = &mockSigner{} signer = &mockSigner{}
@ -168,7 +170,7 @@ var _ = Describe("Server Crypto Setup", func() {
supportedVersions, supportedVersions,
nil, nil,
paramsChan, paramsChan,
aeadChanged, handshakeEvent,
) )
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cs = csInt.(*cryptoSetupServer) cs = csInt.(*cryptoSetupServer)
@ -183,10 +185,6 @@ var _ = Describe("Server Crypto Setup", func() {
cs.cryptoStream = stream cs.cryptoStream = stream
}) })
AfterEach(func() {
close(stream.unblockRead)
})
Context("diversification nonce", func() { Context("diversification nonce", func() {
BeforeEach(func() { BeforeEach(func() {
cs.secureAEAD = mockcrypto.NewMockAEAD(mockCtrl) cs.secureAEAD = mockcrypto.NewMockAEAD(mockCtrl)
@ -345,10 +343,10 @@ var _ = Describe("Server Crypto Setup", func() {
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ")) Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ"))
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) Expect(handshakeEvent).To(Receive()) // for the switch to secure
Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO")) Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO"))
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
Expect(aeadChanged).ToNot(BeClosed()) Expect(handshakeEvent).ToNot(BeClosed())
}) })
It("rejects client nonces that have the wrong length", func() { It("rejects client nonces that have the wrong length", func() {
@ -379,9 +377,9 @@ var _ = Describe("Server Crypto Setup", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO")) Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO"))
Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ")) Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ"))
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) Expect(handshakeEvent).To(Receive()) // for the switch to secure
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
Expect(aeadChanged).ToNot(BeClosed()) Expect(handshakeEvent).ToNot(BeClosed())
}) })
It("recognizes inchoate CHLOs missing SCID", func() { It("recognizes inchoate CHLOs missing SCID", func() {
@ -537,7 +535,7 @@ var _ = Describe("Server Crypto Setup", func() {
TagKEXS: kexs, TagKEXS: kexs,
}) })
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) Expect(handshakeEvent).To(Receive()) // for the switch to secure
close(cs.sentSHLO) close(cs.sentSHLO)
} }
@ -659,7 +657,26 @@ var _ = Describe("Server Crypto Setup", func() {
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(200), []byte{}) cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(200), []byte{})
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 200, []byte{}) _, _, err := cs.Open(nil, []byte("forward secure encrypted"), 200, []byte{})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(aeadChanged).To(BeClosed()) Expect(handshakeEvent).To(BeClosed())
})
})
Context("reporting the connection state", func() {
It("reports before the handshake completes", func() {
cs.sni = "server name"
state := cs.ConnectionState()
Expect(state.HandshakeComplete).To(BeFalse())
Expect(state.ServerName).To(Equal("server name"))
})
It("reports after the handshake completes", func() {
doCHLO()
// receive a forward secure packet
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(11), []byte{})
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 11, []byte{})
Expect(err).ToNot(HaveOccurred())
state := cs.ConnectionState()
Expect(state.HandshakeComplete).To(BeTrue())
}) })
}) })
@ -723,6 +740,7 @@ var _ = Describe("Server Crypto Setup", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(done).To(BeFalse()) Expect(done).To(BeFalse())
Expect(stream.dataWritten.Bytes()).To(ContainSubstring(string(validSTK))) Expect(stream.dataWritten.Bytes()).To(ContainSubstring(string(validSTK)))
Expect(cs.sni).To(Equal("foo"))
}) })
It("works with proper STK", func() { It("works with proper STK", func() {

View File

@ -28,7 +28,7 @@ type cryptoSetupTLS struct {
tls MintTLS tls MintTLS
cryptoStream *CryptoStreamConn cryptoStream *CryptoStreamConn
aeadChanged chan<- protocol.EncryptionLevel handshakeEvent chan<- struct{}
} }
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server // NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
@ -36,7 +36,7 @@ func NewCryptoSetupTLSServer(
tls MintTLS, tls MintTLS,
cryptoStream *CryptoStreamConn, cryptoStream *CryptoStreamConn,
nullAEAD crypto.AEAD, nullAEAD crypto.AEAD,
aeadChanged chan<- protocol.EncryptionLevel, handshakeEvent chan<- struct{},
version protocol.VersionNumber, version protocol.VersionNumber,
) CryptoSetup { ) CryptoSetup {
return &cryptoSetupTLS{ return &cryptoSetupTLS{
@ -45,7 +45,7 @@ func NewCryptoSetupTLSServer(
nullAEAD: nullAEAD, nullAEAD: nullAEAD,
perspective: protocol.PerspectiveServer, perspective: protocol.PerspectiveServer,
keyDerivation: crypto.DeriveAESKeys, keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged, handshakeEvent: handshakeEvent,
} }
} }
@ -54,7 +54,7 @@ func NewCryptoSetupTLSClient(
cryptoStream io.ReadWriter, cryptoStream io.ReadWriter,
connID protocol.ConnectionID, connID protocol.ConnectionID,
hostname string, hostname string,
aeadChanged chan<- protocol.EncryptionLevel, handshakeEvent chan<- struct{},
tls MintTLS, tls MintTLS,
version protocol.VersionNumber, version protocol.VersionNumber,
) (CryptoSetup, error) { ) (CryptoSetup, error) {
@ -68,7 +68,7 @@ func NewCryptoSetupTLSClient(
tls: tls, tls: tls,
nullAEAD: nullAEAD, nullAEAD: nullAEAD,
keyDerivation: crypto.DeriveAESKeys, keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged, handshakeEvent: handshakeEvent,
}, nil }, nil
} }
@ -102,9 +102,8 @@ handshakeLoop:
h.aead = aead h.aead = aead
h.mutex.Unlock() h.mutex.Unlock()
// signal to the outside world that the handshake completed h.handshakeEvent <- struct{}{}
h.aeadChanged <- protocol.EncryptionForwardSecure close(h.handshakeEvent)
close(h.aeadChanged)
return nil return nil
} }
@ -165,3 +164,14 @@ func (h *cryptoSetupTLS) DiversificationNonce() []byte {
func (h *cryptoSetupTLS) SetDiversificationNonce([]byte) { func (h *cryptoSetupTLS) SetDiversificationNonce([]byte) {
panic("diversification nonce not needed for TLS") panic("diversification nonce not needed for TLS")
} }
func (h *cryptoSetupTLS) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()
mintConnState := h.tls.ConnectionState()
return ConnectionState{
// TODO: set the ServerName, once mint exports it
HandshakeComplete: h.aead != nil,
PeerCertificates: mintConnState.PeerCertificates,
}
}

View File

@ -21,16 +21,16 @@ func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, e
var _ = Describe("TLS Crypto Setup", func() { var _ = Describe("TLS Crypto Setup", func() {
var ( var (
cs *cryptoSetupTLS cs *cryptoSetupTLS
aeadChanged chan protocol.EncryptionLevel handshakeEvent chan struct{}
) )
BeforeEach(func() { BeforeEach(func() {
aeadChanged = make(chan protocol.EncryptionLevel, 2) handshakeEvent = make(chan struct{}, 2)
cs = NewCryptoSetupTLSServer( cs = NewCryptoSetupTLSServer(
nil, nil,
NewCryptoStreamConn(nil), NewCryptoStreamConn(nil),
nil, // AEAD nil, // AEAD
aeadChanged, handshakeEvent,
protocol.VersionTLS, protocol.VersionTLS,
).(*cryptoSetupTLS) ).(*cryptoSetupTLS)
cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl) cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl)
@ -51,8 +51,8 @@ var _ = Describe("TLS Crypto Setup", func() {
cs.keyDerivation = mockKeyDerivation cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) Expect(handshakeEvent).To(Receive())
Expect(aeadChanged).To(BeClosed()) Expect(handshakeEvent).To(BeClosed())
}) })
It("handshakes until it is connected", func() { It("handshakes until it is connected", func() {
@ -63,7 +63,30 @@ var _ = Describe("TLS Crypto Setup", func() {
cs.keyDerivation = mockKeyDerivation cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(aeadChanged).To(Receive()) Expect(handshakeEvent).To(Receive())
})
Context("reporting the handshake state", func() {
It("reports before the handshake compeletes", func() {
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{})
state := cs.ConnectionState()
Expect(state.HandshakeComplete).To(BeFalse())
Expect(state.PeerCertificates).To(BeNil())
})
It("reports after the handshake completes", func() {
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{})
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected)
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
state := cs.ConnectionState()
Expect(state.HandshakeComplete).To(BeTrue())
Expect(state.PeerCertificates).To(BeNil())
})
}) })
Context("escalating crypto", func() { Context("escalating crypto", func() {
@ -181,16 +204,16 @@ var _ = Describe("TLS Crypto Setup", func() {
var _ = Describe("TLS Crypto Setup, for the client", func() { var _ = Describe("TLS Crypto Setup, for the client", func() {
var ( var (
cs *cryptoSetupTLS cs *cryptoSetupTLS
aeadChanged chan protocol.EncryptionLevel handshakeEvent chan struct{}
) )
BeforeEach(func() { BeforeEach(func() {
aeadChanged = make(chan protocol.EncryptionLevel, 2) handshakeEvent = make(chan struct{})
csInt, err := NewCryptoSetupTLSClient( csInt, err := NewCryptoSetupTLSClient(
nil, nil,
0, 0,
"quic.clemente.io", "quic.clemente.io",
aeadChanged, handshakeEvent,
nil, // mintTLS nil, // mintTLS
protocol.VersionTLS, protocol.VersionTLS,
) )

View File

@ -1,6 +1,7 @@
package handshake package handshake
import ( import (
"crypto/x509"
"io" "io"
"github.com/bifurcation/mint" "github.com/bifurcation/mint"
@ -29,6 +30,7 @@ type MintTLS interface {
// additional methods // additional methods
Handshake() mint.Alert Handshake() mint.Alert
State() mint.State State() mint.State
ConnectionState() mint.ConnectionState
SetCryptoStream(io.ReadWriter) SetCryptoStream(io.ReadWriter)
SetExtensionHandler(mint.AppExtensionHandler) error SetExtensionHandler(mint.AppExtensionHandler) error
@ -41,8 +43,17 @@ type CryptoSetup interface {
// TODO: clean up this interface // TODO: clean up this interface
DiversificationNonce() []byte // only needed for cryptoSetupServer DiversificationNonce() []byte // only needed for cryptoSetupServer
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
ConnectionState() ConnectionState
GetSealer() (protocol.EncryptionLevel, Sealer) GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
} }
// ConnectionState records basic details about the QUIC connection.
// Warning: This API should not be considered stable and might change soon.
type ConnectionState struct {
HandshakeComplete bool // handshake is complete
ServerName string // server name requested by client, if any (server side only)
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer
}

View File

@ -24,6 +24,7 @@ type extensionHandlerClient struct {
var _ mint.AppExtensionHandler = &extensionHandlerClient{} var _ mint.AppExtensionHandler = &extensionHandlerClient{}
var _ TLSExtensionHandler = &extensionHandlerClient{} var _ TLSExtensionHandler = &extensionHandlerClient{}
// NewExtensionHandlerClient creates a new extension handler for the client.
func NewExtensionHandlerClient( func NewExtensionHandlerClient(
params *TransportParameters, params *TransportParameters,
initialVersion protocol.VersionNumber, initialVersion protocol.VersionNumber,
@ -57,7 +58,10 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi
func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error { func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
ext := &tlsExtensionBody{} ext := &tlsExtensionBody{}
found, _ := el.Find(ext) found, err := el.Find(ext)
if err != nil {
return err
}
if hType != mint.HandshakeTypeEncryptedExtensions && hType != mint.HandshakeTypeNewSessionTicket { if hType != mint.HandshakeTypeEncryptedExtensions && hType != mint.HandshakeTypeNewSessionTicket {
if found { if found {

View File

@ -39,7 +39,8 @@ var _ = Describe("TLS Extension Handler, for the client", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(el).To(HaveLen(1)) Expect(el).To(HaveLen(1))
ext := &tlsExtensionBody{} ext := &tlsExtensionBody{}
found := el.Find(ext) found, err := el.Find(ext)
Expect(err).ToNot(HaveOccurred())
Expect(found).To(BeTrue()) Expect(found).To(BeTrue())
chtp := &clientHelloTransportParameters{} chtp := &clientHelloTransportParameters{}
_, err = syntax.Unmarshal(ext.data, chtp) _, err = syntax.Unmarshal(ext.data, chtp)

View File

@ -24,6 +24,7 @@ type extensionHandlerServer struct {
var _ mint.AppExtensionHandler = &extensionHandlerServer{} var _ mint.AppExtensionHandler = &extensionHandlerServer{}
var _ TLSExtensionHandler = &extensionHandlerServer{} var _ TLSExtensionHandler = &extensionHandlerServer{}
// NewExtensionHandlerServer creates a new extension handler for the server
func NewExtensionHandlerServer( func NewExtensionHandlerServer(
params *TransportParameters, params *TransportParameters,
supportedVersions []protocol.VersionNumber, supportedVersions []protocol.VersionNumber,
@ -66,7 +67,10 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi
func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error { func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
ext := &tlsExtensionBody{} ext := &tlsExtensionBody{}
found, _ := el.Find(ext) found, err := el.Find(ext)
if err != nil {
return err
}
if hType != mint.HandshakeTypeClientHello { if hType != mint.HandshakeTypeClientHello {
if found { if found {

View File

@ -48,7 +48,8 @@ var _ = Describe("TLS Extension Handler, for the server", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(el).To(HaveLen(1)) Expect(el).To(HaveLen(1))
ext := &tlsExtensionBody{} ext := &tlsExtensionBody{}
found := el.Find(ext) found, err := el.Find(ext)
Expect(err).ToNot(HaveOccurred())
Expect(found).To(BeTrue()) Expect(found).To(BeTrue())
eetp := &encryptedExtensionsTransportParameters{} eetp := &encryptedExtensionsTransportParameters{}
_, err = syntax.Unmarshal(ext.data, eetp) _, err = syntax.Unmarshal(ext.data, eetp)

View File

@ -64,6 +64,9 @@ type ByteCount uint64
// MaxByteCount is the maximum value of a ByteCount // MaxByteCount is the maximum value of a ByteCount
const MaxByteCount = ByteCount(1<<62 - 1) const MaxByteCount = ByteCount(1<<62 - 1)
// An ApplicationErrorCode is an application-defined error code.
type ApplicationErrorCode uint16
// MaxReceivePacketSize maximum packet size of any QUIC packet, based on // MaxReceivePacketSize maximum packet size of any QUIC packet, based on
// ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header, // ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header,
// UDP adds an additional 8 bytes. This is a total overhead of 48 bytes. // UDP adds an additional 8 bytes. This is a total overhead of 48 bytes.

View File

@ -56,6 +56,9 @@ const DefaultMaxReceiveConnectionFlowControlWindowClient = 15 * (1 << 20) // 15
// This is the value that Chromium is using // This is the value that Chromium is using
const ConnectionFlowControlMultiplier = 1.5 const ConnectionFlowControlMultiplier = 1.5
// WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client
const WindowUpdateThreshold = 0.25
// MaxIncomingStreams is the maximum number of streams that a peer may open // MaxIncomingStreams is the maximum number of streams that a peer may open
const MaxIncomingStreams = 100 const MaxIncomingStreams = 100
@ -122,3 +125,9 @@ const ClosedSessionDeleteTimeout = time.Minute
// NumCachedCertificates is the number of cached compressed certificate chains, each taking ~1K space // NumCachedCertificates is the number of cached compressed certificate chains, each taking ~1K space
const NumCachedCertificates = 128 const NumCachedCertificates = 128
// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame.
// This avoids splitting up STREAM frames into small pieces, which has 2 advantages:
// 1. it reduces the framing overhead
// 2. it reduces the head-of-line blocking, when a packet is lost
const MinStreamFrameSize ByteCount = 128

View File

@ -139,7 +139,7 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error
} }
// MinLength of a written frame // MinLength of a written frame
func (f *AckFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { func (f *AckFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
if !version.UsesIETFFrameFormat() { if !version.UsesIETFFrameFormat() {
return f.minLengthLegacy(version) return f.minLengthLegacy(version)
} }
@ -157,7 +157,7 @@ func (f *AckFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount
length += utils.VarIntLen(uint64(f.LargestAcked - lowestInFirstRange)) length += utils.VarIntLen(uint64(f.LargestAcked - lowestInFirstRange))
if !f.HasMissingRanges() { if !f.HasMissingRanges() {
return length, nil return length
} }
var lowest protocol.PacketNumber var lowest protocol.PacketNumber
for i, ackRange := range f.AckRanges { for i, ackRange := range f.AckRanges {
@ -169,7 +169,7 @@ func (f *AckFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount
length += utils.VarIntLen(uint64(ackRange.Last - ackRange.First)) length += utils.VarIntLen(uint64(ackRange.Last - ackRange.First))
lowest = ackRange.First lowest = ackRange.First
} }
return length, nil return length
} }
// HasMissingRanges returns if this frame reports any missing packets // HasMissingRanges returns if this frame reports any missing packets

View File

@ -308,7 +308,7 @@ func (f *AckFrame) writeLegacy(b *bytes.Buffer, _ protocol.VersionNumber) error
return nil return nil
} }
func (f *AckFrame) minLengthLegacy(_ protocol.VersionNumber) (protocol.ByteCount, error) { func (f *AckFrame) minLengthLegacy(_ protocol.VersionNumber) protocol.ByteCount {
length := protocol.ByteCount(1 + 2 + 1) // 1 TypeByte, 2 ACK delay time, 1 Num Timestamp length := protocol.ByteCount(1 + 2 + 1) // 1 TypeByte, 2 ACK delay time, 1 Num Timestamp
length += protocol.ByteCount(protocol.GetPacketNumberLength(f.LargestAcked)) length += protocol.ByteCount(protocol.GetPacketNumberLength(f.LargestAcked))
@ -320,7 +320,7 @@ func (f *AckFrame) minLengthLegacy(_ protocol.VersionNumber) (protocol.ByteCount
length += missingSequenceNumberDeltaLen length += missingSequenceNumberDeltaLen
} }
// we don't write // we don't write
return length, nil return length
} }
// numWritableNackRanges calculates the number of ACK blocks that are about to be written // numWritableNackRanges calculates the number of ACK blocks that are about to be written

View File

@ -4,17 +4,26 @@ import (
"bytes" "bytes"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
) )
// A BlockedFrame is a BLOCKED frame // A BlockedFrame is a BLOCKED frame
type BlockedFrame struct{} type BlockedFrame struct {
Offset protocol.ByteCount
}
// ParseBlockedFrame parses a BLOCKED frame // ParseBlockedFrame parses a BLOCKED frame
func ParseBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*BlockedFrame, error) { func ParseBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*BlockedFrame, error) {
if _, err := r.ReadByte(); err != nil { if _, err := r.ReadByte(); err != nil {
return nil, err return nil, err
} }
return &BlockedFrame{}, nil offset, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
return &BlockedFrame{
Offset: protocol.ByteCount(offset),
}, nil
} }
func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
@ -23,13 +32,14 @@ func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) er
} }
typeByte := uint8(0x08) typeByte := uint8(0x08)
b.WriteByte(typeByte) b.WriteByte(typeByte)
utils.WriteVarInt(b, uint64(f.Offset))
return nil return nil
} }
// MinLength of a written frame // MinLength of a written frame
func (f *BlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { func (f *BlockedFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
if !version.UsesIETFFrameFormat() { // writing this frame would result in a legacy BLOCKED being written, which is longer if !version.UsesIETFFrameFormat() {
return 1 + 4, nil return 1 + 4
} }
return 1, nil return 1 + utils.VarIntLen(uint64(f.Offset))
} }

View File

@ -2,8 +2,10 @@ package wire
import ( import (
"bytes" "bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -12,30 +14,41 @@ import (
var _ = Describe("BLOCKED frame", func() { var _ = Describe("BLOCKED frame", func() {
Context("when parsing", func() { Context("when parsing", func() {
It("accepts sample frame", func() { It("accepts sample frame", func() {
b := bytes.NewReader([]byte{0x08}) data := []byte{0x08}
_, err := ParseBlockedFrame(b, protocol.VersionWhatever) data = append(data, encodeVarInt(0x12345678)...)
b := bytes.NewReader(data)
frame, err := ParseBlockedFrame(b, versionIETFFrames)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(frame.Offset).To(Equal(protocol.ByteCount(0x12345678)))
Expect(b.Len()).To(BeZero()) Expect(b.Len()).To(BeZero())
}) })
It("errors on EOFs", func() { It("errors on EOFs", func() {
_, err := ParseBlockedFrame(bytes.NewReader(nil), protocol.VersionWhatever) data := []byte{0x08}
Expect(err).To(HaveOccurred()) data = append(data, encodeVarInt(0x12345678)...)
_, err := ParseBlockedFrame(bytes.NewReader(data), versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
for i := range data {
_, err := ParseBlockedFrame(bytes.NewReader(data[:i]), versionIETFFrames)
Expect(err).To(MatchError(io.EOF))
}
}) })
}) })
Context("when writing", func() { Context("when writing", func() {
It("writes a sample frame", func() { It("writes a sample frame", func() {
b := &bytes.Buffer{} b := &bytes.Buffer{}
frame := BlockedFrame{} frame := BlockedFrame{Offset: 0xdeadbeef}
err := frame.Write(b, protocol.VersionWhatever) err := frame.Write(b, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(b.Bytes()).To(Equal([]byte{0x08})) expected := []byte{0x08}
expected = append(expected, encodeVarInt(0xdeadbeef)...)
Expect(b.Bytes()).To(Equal(expected))
}) })
It("has the correct min length", func() { It("has the correct min length", func() {
frame := BlockedFrame{} frame := BlockedFrame{Offset: 0x12345}
Expect(frame.MinLength(versionIETFFrames)).To(Equal(protocol.ByteCount(1))) Expect(frame.MinLength(versionIETFFrames)).To(Equal(1 + utils.VarIntLen(0x12345)))
}) })
}) })
}) })

View File

@ -68,11 +68,11 @@ func ParseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber)
} }
// MinLength of a written frame // MinLength of a written frame
func (f *ConnectionCloseFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { func (f *ConnectionCloseFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
if version.UsesIETFFrameFormat() { if version.UsesIETFFrameFormat() {
return 1 + 2 + utils.VarIntLen(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase)), nil return 1 + 2 + utils.VarIntLen(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase))
} }
return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase)), nil return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase))
} }
// Write writes an CONNECTION_CLOSE frame. // Write writes an CONNECTION_CLOSE frame.

View File

@ -9,5 +9,5 @@ import (
// A Frame in QUIC // A Frame in QUIC
type Frame interface { type Frame interface {
Write(b *bytes.Buffer, version protocol.VersionNumber) error Write(b *bytes.Buffer, version protocol.VersionNumber) error
MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) MinLength(version protocol.VersionNumber) protocol.ByteCount
} }

View File

@ -63,6 +63,6 @@ func (f *GoawayFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
} }
// MinLength of a written frame // MinLength of a written frame
func (f *GoawayFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { func (f *GoawayFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
return protocol.ByteCount(1 + 4 + 4 + 2 + len(f.ReasonPhrase)), nil return protocol.ByteCount(1 + 4 + 4 + 2 + len(f.ReasonPhrase))
} }

View File

@ -43,9 +43,9 @@ func (f *MaxDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) er
} }
// MinLength of a written frame // MinLength of a written frame
func (f *MaxDataFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { func (f *MaxDataFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
if !version.UsesIETFFrameFormat() { // writing this frame would result in a gQUIC WINDOW_UPDATE being written, which is longer if !version.UsesIETFFrameFormat() { // writing this frame would result in a gQUIC WINDOW_UPDATE being written, which is longer
return 1 + 4 + 8, nil return 1 + 4 + 8
} }
return 1 + utils.VarIntLen(uint64(f.ByteOffset)), nil return 1 + utils.VarIntLen(uint64(f.ByteOffset))
} }

View File

@ -51,10 +51,10 @@ func (f *MaxStreamDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumb
} }
// MinLength of a written frame // MinLength of a written frame
func (f *MaxStreamDataFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { func (f *MaxStreamDataFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
// writing this frame would result in a gQUIC WINDOW_UPDATE being written, which has a different length // writing this frame would result in a gQUIC WINDOW_UPDATE being written, which has a different length
if !version.UsesIETFFrameFormat() { if !version.UsesIETFFrameFormat() {
return 1 + 4 + 8, nil return 1 + 4 + 8
} }
return 1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.ByteOffset)), nil return 1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.ByteOffset))
} }

View File

@ -0,0 +1,37 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A MaxStreamIDFrame is a MAX_STREAM_ID frame
type MaxStreamIDFrame struct {
StreamID protocol.StreamID
}
// ParseMaxStreamIDFrame parses a MAX_STREAM_ID frame
func ParseMaxStreamIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamIDFrame, error) {
// read the Type byte
if _, err := r.ReadByte(); err != nil {
return nil, err
}
streamID, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
return &MaxStreamIDFrame{StreamID: protocol.StreamID(streamID)}, nil
}
func (f *MaxStreamIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
b.WriteByte(0x6)
utils.WriteVarInt(b, uint64(f.StreamID))
return nil
}
// MinLength of a written frame
func (f *MaxStreamIDFrame) MinLength(protocol.VersionNumber) protocol.ByteCount {
return 1 + utils.VarIntLen(uint64(f.StreamID))
}

View File

@ -0,0 +1,51 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("MAX_STREAM_ID frame", func() {
Context("parsing", func() {
It("accepts sample frame", func() {
data := []byte{0x6}
data = append(data, encodeVarInt(0xdecafbad)...)
b := bytes.NewReader(data)
f, err := ParseMaxStreamIDFrame(b, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
Expect(f.StreamID).To(Equal(protocol.StreamID(0xdecafbad)))
Expect(b.Len()).To(BeZero())
})
It("errors on EOFs", func() {
data := []byte{0x06}
data = append(data, encodeVarInt(0xdeadbeefcafe13)...)
_, err := ParseMaxStreamIDFrame(bytes.NewReader(data), protocol.VersionWhatever)
Expect(err).NotTo(HaveOccurred())
for i := range data {
_, err := ParseMaxStreamIDFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever)
Expect(err).To(HaveOccurred())
}
})
})
Context("writing", func() {
It("writes a sample frame", func() {
b := &bytes.Buffer{}
frame := MaxStreamIDFrame{StreamID: 0x12345678}
frame.Write(b, protocol.VersionWhatever)
expected := []byte{0x6}
expected = append(expected, encodeVarInt(0x12345678)...)
Expect(b.Bytes()).To(Equal(expected))
})
It("has the correct min length", func() {
frame := MaxStreamIDFrame{StreamID: 0x1337}
Expect(frame.MinLength(protocol.VersionWhatever)).To(Equal(1 + utils.VarIntLen(0x1337)))
})
})
})

View File

@ -28,6 +28,6 @@ func (f *PingFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error
} }
// MinLength of a written frame // MinLength of a written frame
func (f *PingFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { func (f *PingFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
return 1, nil return 1
} }

View File

@ -7,10 +7,12 @@ import (
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
// A RstStreamFrame in QUIC // A RstStreamFrame is a RST_STREAM frame in QUIC
type RstStreamFrame struct { type RstStreamFrame struct {
StreamID protocol.StreamID StreamID protocol.StreamID
ErrorCode uint32 // The error code is a uint32 in gQUIC, but a uint16 in IETF QUIC.
// protocol.ApplicaitonErrorCode is a uint16, so larger values in gQUIC frames will be truncated.
ErrorCode protocol.ApplicationErrorCode
ByteOffset protocol.ByteCount ByteOffset protocol.ByteCount
} }
@ -21,7 +23,7 @@ func ParseRstStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*RstS
} }
var streamID protocol.StreamID var streamID protocol.StreamID
var errorCode uint32 var errorCode uint16
var byteOffset protocol.ByteCount var byteOffset protocol.ByteCount
if version.UsesIETFFrameFormat() { if version.UsesIETFFrameFormat() {
sid, err := utils.ReadVarInt(r) sid, err := utils.ReadVarInt(r)
@ -29,11 +31,10 @@ func ParseRstStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*RstS
return nil, err return nil, err
} }
streamID = protocol.StreamID(sid) streamID = protocol.StreamID(sid)
ec, err := utils.BigEndian.ReadUint16(r) errorCode, err = utils.BigEndian.ReadUint16(r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
errorCode = uint32(ec)
bo, err := utils.ReadVarInt(r) bo, err := utils.ReadVarInt(r)
if err != nil { if err != nil {
return nil, err return nil, err
@ -54,12 +55,12 @@ func ParseRstStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*RstS
if err != nil { if err != nil {
return nil, err return nil, err
} }
errorCode = uint32(ec) errorCode = uint16(ec)
} }
return &RstStreamFrame{ return &RstStreamFrame{
StreamID: streamID, StreamID: streamID,
ErrorCode: errorCode, ErrorCode: protocol.ApplicationErrorCode(errorCode),
ByteOffset: byteOffset, ByteOffset: byteOffset,
}, nil }, nil
} }
@ -74,15 +75,15 @@ func (f *RstStreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber)
} else { } else {
utils.BigEndian.WriteUint32(b, uint32(f.StreamID)) utils.BigEndian.WriteUint32(b, uint32(f.StreamID))
utils.BigEndian.WriteUint64(b, uint64(f.ByteOffset)) utils.BigEndian.WriteUint64(b, uint64(f.ByteOffset))
utils.BigEndian.WriteUint32(b, f.ErrorCode) utils.BigEndian.WriteUint32(b, uint32(f.ErrorCode))
} }
return nil return nil
} }
// MinLength of a written frame // MinLength of a written frame
func (f *RstStreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { func (f *RstStreamFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
if version.UsesIETFFrameFormat() { if version.UsesIETFFrameFormat() {
return 1 + utils.VarIntLen(uint64(f.StreamID)) + 2 + utils.VarIntLen(uint64(f.ByteOffset)), nil return 1 + utils.VarIntLen(uint64(f.StreamID)) + 2 + utils.VarIntLen(uint64(f.ByteOffset))
} }
return 1 + 4 + 8 + 4, nil return 1 + 4 + 8 + 4
} }

View File

@ -22,7 +22,7 @@ var _ = Describe("RST_STREAM frame", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef)))
Expect(frame.ByteOffset).To(Equal(protocol.ByteCount(0x987654321))) Expect(frame.ByteOffset).To(Equal(protocol.ByteCount(0x987654321)))
Expect(frame.ErrorCode).To(Equal(uint32(0x1337))) Expect(frame.ErrorCode).To(Equal(protocol.ApplicationErrorCode(0x1337)))
}) })
It("errors on EOFs", func() { It("errors on EOFs", func() {
@ -44,13 +44,13 @@ var _ = Describe("RST_STREAM frame", func() {
b := bytes.NewReader([]byte{0x1, b := bytes.NewReader([]byte{0x1,
0xde, 0xad, 0xbe, 0xef, // stream id 0xde, 0xad, 0xbe, 0xef, // stream id
0x88, 0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11, // byte offset 0x88, 0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11, // byte offset
0x34, 0x12, 0x37, 0x13, // error code 0x0, 0x0, 0xca, 0xfe, // error code
}) })
frame, err := ParseRstStreamFrame(b, versionBigEndian) frame, err := ParseRstStreamFrame(b, versionBigEndian)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef)))
Expect(frame.ByteOffset).To(Equal(protocol.ByteCount(0x8877665544332211))) Expect(frame.ByteOffset).To(Equal(protocol.ByteCount(0x8877665544332211)))
Expect(frame.ErrorCode).To(Equal(uint32(0x34123713))) Expect(frame.ErrorCode).To(Equal(protocol.ApplicationErrorCode(0xcafe)))
}) })
It("errors on EOFs", func() { It("errors on EOFs", func() {
@ -103,7 +103,7 @@ var _ = Describe("RST_STREAM frame", func() {
frame := RstStreamFrame{ frame := RstStreamFrame{
StreamID: 0x1337, StreamID: 0x1337,
ByteOffset: 0x11223344decafbad, ByteOffset: 0x11223344decafbad,
ErrorCode: 0xdeadbeef, ErrorCode: 0xcafe,
} }
b := &bytes.Buffer{} b := &bytes.Buffer{}
err := frame.Write(b, versionBigEndian) err := frame.Write(b, versionBigEndian)
@ -111,7 +111,7 @@ var _ = Describe("RST_STREAM frame", func() {
Expect(b.Bytes()).To(Equal([]byte{0x01, Expect(b.Bytes()).To(Equal([]byte{0x01,
0x0, 0x0, 0x13, 0x37, // stream id 0x0, 0x0, 0x13, 0x37, // stream id
0x11, 0x22, 0x33, 0x44, 0xde, 0xca, 0xfb, 0xad, // byte offset 0x11, 0x22, 0x33, 0x44, 0xde, 0xca, 0xfb, 0xad, // byte offset
0xde, 0xad, 0xbe, 0xef, // error code 0x0, 0x0, 0xca, 0xfe, // error code
})) }))
}) })

View File

@ -0,0 +1,47 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A StopSendingFrame is a STOP_SENDING frame
type StopSendingFrame struct {
StreamID protocol.StreamID
ErrorCode protocol.ApplicationErrorCode
}
// ParseStopSendingFrame parses a STOP_SENDING frame
func ParseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSendingFrame, error) {
if _, err := r.ReadByte(); err != nil { // read the TypeByte
return nil, err
}
streamID, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
errorCode, err := utils.BigEndian.ReadUint16(r)
if err != nil {
return nil, err
}
return &StopSendingFrame{
StreamID: protocol.StreamID(streamID),
ErrorCode: protocol.ApplicationErrorCode(errorCode),
}, nil
}
// MinLength of a written frame
func (f *StopSendingFrame) MinLength(_ protocol.VersionNumber) protocol.ByteCount {
return 1 + utils.VarIntLen(uint64(f.StreamID)) + 2
}
func (f *StopSendingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
b.WriteByte(0x0c)
utils.WriteVarInt(b, uint64(f.StreamID))
utils.BigEndian.WriteUint16(b, uint16(f.ErrorCode))
return nil
}

View File

@ -0,0 +1,63 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("STOP_SENDING frame", func() {
Context("when parsing", func() {
It("parses a sample frame", func() {
data := []byte{0x0c}
data = append(data, encodeVarInt(0xdecafbad)...) // stream ID
data = append(data, []byte{0x13, 0x37}...) // error code
b := bytes.NewReader(data)
frame, err := ParseStopSendingFrame(b, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdecafbad)))
Expect(frame.ErrorCode).To(Equal(protocol.ApplicationErrorCode(0x1337)))
Expect(b.Len()).To(BeZero())
})
It("errors on EOFs", func() {
data := []byte{0x0c}
data = append(data, encodeVarInt(0xdecafbad)...) // stream ID
data = append(data, []byte{0x13, 0x37}...) // error code
_, err := ParseStopSendingFrame(bytes.NewReader(data), versionIETFFrames)
Expect(err).NotTo(HaveOccurred())
for i := range data {
_, err := ParseStopSendingFrame(bytes.NewReader(data[:i]), versionIETFFrames)
Expect(err).To(HaveOccurred())
}
})
})
Context("when writing", func() {
It("writes", func() {
frame := &StopSendingFrame{
StreamID: 0xdeadbeefcafe,
ErrorCode: 0x10,
}
buf := &bytes.Buffer{}
err := frame.Write(buf, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
expected := []byte{0x0c}
expected = append(expected, encodeVarInt(0xdeadbeefcafe)...)
expected = append(expected, []byte{0x0, 0x10}...)
Expect(buf.Bytes()).To(Equal(expected))
})
It("has the correct min length", func() {
frame := &StopSendingFrame{
StreamID: 0xdeadbeef,
ErrorCode: 0x10,
}
Expect(frame.MinLength(versionIETFFrames)).To(Equal(1 + 2 + utils.VarIntLen(0xdeadbeef)))
})
})
})

View File

@ -22,7 +22,10 @@ var (
errPacketNumberLenNotSet = errors.New("StopWaitingFrame: PacketNumberLen not set") errPacketNumberLenNotSet = errors.New("StopWaitingFrame: PacketNumberLen not set")
) )
func (f *StopWaitingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { func (f *StopWaitingFrame) Write(b *bytes.Buffer, v protocol.VersionNumber) error {
if v.UsesIETFFrameFormat() {
return errors.New("STOP_WAITING not defined in IETF QUIC")
}
// make sure the PacketNumber was set // make sure the PacketNumber was set
if f.PacketNumber == protocol.PacketNumber(0) { if f.PacketNumber == protocol.PacketNumber(0) {
return errPacketNumberNotSet return errPacketNumberNotSet
@ -49,14 +52,8 @@ func (f *StopWaitingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) erro
} }
// MinLength of a written frame // MinLength of a written frame
func (f *StopWaitingFrame) MinLength(_ protocol.VersionNumber) (protocol.ByteCount, error) { func (f *StopWaitingFrame) MinLength(_ protocol.VersionNumber) protocol.ByteCount {
minLength := protocol.ByteCount(1) // typeByte return 1 + protocol.ByteCount(f.PacketNumberLen)
if f.PacketNumberLen == protocol.PacketNumberLenInvalid {
return 0, errPacketNumberLenNotSet
}
minLength += protocol.ByteCount(f.PacketNumberLen)
return minLength, nil
} }
// ParseStopWaitingFrame parses a StopWaiting frame // ParseStopWaitingFrame parses a StopWaiting frame

View File

@ -84,7 +84,7 @@ var _ = Describe("StopWaitingFrame", func() {
LeastUnacked: 10, LeastUnacked: 10,
PacketNumberLen: protocol.PacketNumberLen1, PacketNumberLen: protocol.PacketNumberLen1,
} }
err := frame.Write(b, protocol.VersionWhatever) err := frame.Write(b, versionBigEndian)
Expect(err).To(MatchError(errPacketNumberNotSet)) Expect(err).To(MatchError(errPacketNumberNotSet))
}) })
@ -94,7 +94,7 @@ var _ = Describe("StopWaitingFrame", func() {
LeastUnacked: 10, LeastUnacked: 10,
PacketNumber: 13, PacketNumber: 13,
} }
err := frame.Write(b, protocol.VersionWhatever) err := frame.Write(b, versionBigEndian)
Expect(err).To(MatchError(errPacketNumberLenNotSet)) Expect(err).To(MatchError(errPacketNumberLenNotSet))
}) })
@ -105,10 +105,21 @@ var _ = Describe("StopWaitingFrame", func() {
PacketNumber: 5, PacketNumber: 5,
PacketNumberLen: protocol.PacketNumberLen1, PacketNumberLen: protocol.PacketNumberLen1,
} }
err := frame.Write(b, protocol.VersionWhatever) err := frame.Write(b, versionBigEndian)
Expect(err).To(MatchError(errLeastUnackedHigherThanPacketNumber)) Expect(err).To(MatchError(errLeastUnackedHigherThanPacketNumber))
}) })
It("refuses to write for IETF QUIC", func() {
b := &bytes.Buffer{}
frame := &StopWaitingFrame{
LeastUnacked: 10,
PacketNumber: 13,
PacketNumberLen: protocol.PacketNumberLen6,
}
err := frame.Write(b, versionIETFFrames)
Expect(err).To(MatchError("STOP_WAITING not defined in IETF QUIC"))
})
Context("LeastUnackedDelta length", func() { Context("LeastUnackedDelta length", func() {
Context("in big endian", func() { Context("in big endian", func() {
It("writes a 1-byte LeastUnackedDelta", func() { It("writes a 1-byte LeastUnackedDelta", func() {
@ -176,18 +187,10 @@ var _ = Describe("StopWaitingFrame", func() {
Expect(frame.MinLength(protocol.VersionWhatever)).To(Equal(protocol.ByteCount(length + 1))) Expect(frame.MinLength(protocol.VersionWhatever)).To(Equal(protocol.ByteCount(length + 1)))
} }
}) })
It("errors when packetNumberLen is not set", func() {
frame := &StopWaitingFrame{
LeastUnacked: 10,
}
_, err := frame.MinLength(0)
Expect(err).To(MatchError(errPacketNumberLenNotSet))
})
}) })
Context("self consistency", func() { Context("self consistency", func() {
It("reads a stop waiting frame that it wrote", func() { It("reads a STOP_WAITING frame that it wrote", func() {
packetNumber := protocol.PacketNumber(13) packetNumber := protocol.PacketNumber(13)
frame := &StopWaitingFrame{ frame := &StopWaitingFrame{
LeastUnacked: 10, LeastUnacked: 10,
@ -195,9 +198,9 @@ var _ = Describe("StopWaitingFrame", func() {
PacketNumberLen: protocol.PacketNumberLen4, PacketNumberLen: protocol.PacketNumberLen4,
} }
b := &bytes.Buffer{} b := &bytes.Buffer{}
err := frame.Write(b, protocol.VersionWhatever) err := frame.Write(b, versionBigEndian)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
readframe, err := ParseStopWaitingFrame(bytes.NewReader(b.Bytes()), packetNumber, protocol.PacketNumberLen4, protocol.VersionWhatever) readframe, err := ParseStopWaitingFrame(bytes.NewReader(b.Bytes()), packetNumber, protocol.PacketNumberLen4, versionBigEndian)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(readframe.LeastUnacked).To(Equal(frame.LeastUnacked)) Expect(readframe.LeastUnacked).To(Equal(frame.LeastUnacked))
}) })

View File

@ -10,10 +10,11 @@ import (
// A StreamBlockedFrame in QUIC // A StreamBlockedFrame in QUIC
type StreamBlockedFrame struct { type StreamBlockedFrame struct {
StreamID protocol.StreamID StreamID protocol.StreamID
Offset protocol.ByteCount
} }
// ParseStreamBlockedFrame parses a STREAM_BLOCKED frame // ParseStreamBlockedFrame parses a STREAM_BLOCKED frame
func ParseStreamBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamBlockedFrame, error) { func ParseStreamBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamBlockedFrame, error) {
if _, err := r.ReadByte(); err != nil { // read the TypeByte if _, err := r.ReadByte(); err != nil { // read the TypeByte
return nil, err return nil, err
} }
@ -21,7 +22,14 @@ func ParseStreamBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &StreamBlockedFrame{StreamID: protocol.StreamID(sid)}, nil offset, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
return &StreamBlockedFrame{
StreamID: protocol.StreamID(sid),
Offset: protocol.ByteCount(offset),
}, nil
} }
// Write writes a STREAM_BLOCKED frame // Write writes a STREAM_BLOCKED frame
@ -31,13 +39,14 @@ func (f *StreamBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumb
} }
b.WriteByte(0x09) b.WriteByte(0x09)
utils.WriteVarInt(b, uint64(f.StreamID)) utils.WriteVarInt(b, uint64(f.StreamID))
utils.WriteVarInt(b, uint64(f.Offset))
return nil return nil
} }
// MinLength of a written frame // MinLength of a written frame
func (f *StreamBlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { func (f *StreamBlockedFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
if !version.UsesIETFFrameFormat() { if !version.UsesIETFFrameFormat() {
return 1 + 4, nil return 1 + 4
} }
return 1 + utils.VarIntLen(uint64(f.StreamID)), nil return 1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.Offset))
} }

View File

@ -14,17 +14,20 @@ var _ = Describe("STREAM_BLOCKED frame", func() {
Context("parsing", func() { Context("parsing", func() {
It("accepts sample frame", func() { It("accepts sample frame", func() {
data := []byte{0x9} data := []byte{0x9}
data = append(data, encodeVarInt(0xdeadbeef)...) data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID
data = append(data, encodeVarInt(0xdecafbad)...) // offset
b := bytes.NewReader(data) b := bytes.NewReader(data)
frame, err := ParseStreamBlockedFrame(b, versionIETFFrames) frame, err := ParseStreamBlockedFrame(b, versionIETFFrames)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef)))
Expect(frame.Offset).To(Equal(protocol.ByteCount(0xdecafbad)))
Expect(b.Len()).To(BeZero()) Expect(b.Len()).To(BeZero())
}) })
It("errors on EOFs", func() { It("errors on EOFs", func() {
data := []byte{0x9} data := []byte{0x9}
data = append(data, encodeVarInt(0xdeadbeef)...) data = append(data, encodeVarInt(0xdeadbeef)...)
data = append(data, encodeVarInt(0xc0010ff)...)
_, err := ParseStreamBlockedFrame(bytes.NewReader(data), versionIETFFrames) _, err := ParseStreamBlockedFrame(bytes.NewReader(data), versionIETFFrames)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
for i := range data { for i := range data {
@ -38,19 +41,22 @@ var _ = Describe("STREAM_BLOCKED frame", func() {
It("has proper min length", func() { It("has proper min length", func() {
f := &StreamBlockedFrame{ f := &StreamBlockedFrame{
StreamID: 0x1337, StreamID: 0x1337,
Offset: 0xdeadbeef,
} }
Expect(f.MinLength(0)).To(Equal(1 + utils.VarIntLen(0x1337))) Expect(f.MinLength(0)).To(Equal(1 + utils.VarIntLen(0x1337) + utils.VarIntLen(0xdeadbeef)))
}) })
It("writes a sample frame", func() { It("writes a sample frame", func() {
b := &bytes.Buffer{} b := &bytes.Buffer{}
f := &StreamBlockedFrame{ f := &StreamBlockedFrame{
StreamID: 0xdecafbad, StreamID: 0xdecafbad,
Offset: 0x1337,
} }
err := f.Write(b, versionIETFFrames) err := f.Write(b, versionIETFFrames)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
expected := []byte{0x9} expected := []byte{0x9}
expected = append(expected, encodeVarInt(uint64(f.StreamID))...) expected = append(expected, encodeVarInt(uint64(f.StreamID))...)
expected = append(expected, encodeVarInt(uint64(f.Offset))...)
Expect(b.Bytes()).To(Equal(expected)) Expect(b.Bytes()).To(Equal(expected))
}) })
}) })

View File

@ -117,7 +117,7 @@ func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) err
// MinLength returns the length of the header of a StreamFrame // MinLength returns the length of the header of a StreamFrame
// the total length of the frame is frame.MinLength() + frame.DataLen() // the total length of the frame is frame.MinLength() + frame.DataLen()
func (f *StreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { func (f *StreamFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
if !version.UsesIETFFrameFormat() { if !version.UsesIETFFrameFormat() {
return f.minLengthLegacy(version) return f.minLengthLegacy(version)
} }
@ -128,5 +128,5 @@ func (f *StreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCo
if f.DataLenPresent { if f.DataLenPresent {
length += utils.VarIntLen(uint64(f.DataLen())) length += utils.VarIntLen(uint64(f.DataLen()))
} }
return length, nil return length
} }

View File

@ -183,12 +183,12 @@ func (f *StreamFrame) getOffsetLength() protocol.ByteCount {
return 8 return 8
} }
func (f *StreamFrame) minLengthLegacy(_ protocol.VersionNumber) (protocol.ByteCount, error) { func (f *StreamFrame) minLengthLegacy(_ protocol.VersionNumber) protocol.ByteCount {
length := protocol.ByteCount(1) + protocol.ByteCount(f.calculateStreamIDLength()) + f.getOffsetLength() length := protocol.ByteCount(1) + protocol.ByteCount(f.calculateStreamIDLength()) + f.getOffsetLength()
if f.DataLenPresent { if f.DataLenPresent {
length += 2 length += 2
} }
return length, nil return length
} }
// DataLen gives the length of data in bytes // DataLen gives the length of data in bytes

View File

@ -210,7 +210,7 @@ var _ = Describe("STREAM frame (for gQUIC)", func() {
} }
err := f.Write(b, versionBigEndian) err := f.Write(b, versionBigEndian)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
minLength, _ := f.MinLength(0) minLength := f.MinLength(0)
Expect(b.Bytes()[0] & 0x20).To(Equal(uint8(0x20))) Expect(b.Bytes()[0] & 0x20).To(Equal(uint8(0x20)))
Expect(b.Bytes()[minLength-2 : minLength]).To(Equal([]byte{0x13, 0x37})) Expect(b.Bytes()[minLength-2 : minLength]).To(Equal([]byte{0x13, 0x37}))
}) })
@ -229,9 +229,9 @@ var _ = Describe("STREAM frame (for gQUIC)", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(b.Bytes()[0] & 0x20).To(Equal(uint8(0))) Expect(b.Bytes()[0] & 0x20).To(Equal(uint8(0)))
Expect(b.Bytes()[1 : b.Len()-dataLen]).ToNot(ContainSubstring(string([]byte{0x37, 0x13}))) Expect(b.Bytes()[1 : b.Len()-dataLen]).ToNot(ContainSubstring(string([]byte{0x37, 0x13})))
minLength, _ := f.MinLength(versionBigEndian) minLength := f.MinLength(versionBigEndian)
f.DataLenPresent = true f.DataLenPresent = true
minLengthWithoutDataLen, _ := f.MinLength(versionBigEndian) minLengthWithoutDataLen := f.MinLength(versionBigEndian)
Expect(minLength).To(Equal(minLengthWithoutDataLen - 2)) Expect(minLength).To(Equal(minLengthWithoutDataLen - 2))
}) })
@ -242,7 +242,7 @@ var _ = Describe("STREAM frame (for gQUIC)", func() {
DataLenPresent: false, DataLenPresent: false,
Offset: 0xdeadbeef, Offset: 0xdeadbeef,
} }
minLengthWithoutDataLen, _ := f.MinLength(versionBigEndian) minLengthWithoutDataLen := f.MinLength(versionBigEndian)
f.DataLenPresent = true f.DataLenPresent = true
Expect(f.MinLength(versionBigEndian)).To(Equal(minLengthWithoutDataLen + 2)) Expect(f.MinLength(versionBigEndian)).To(Equal(minLengthWithoutDataLen + 2))
}) })

View File

@ -0,0 +1,37 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A StreamIDBlockedFrame is a STREAM_ID_BLOCKED frame
type StreamIDBlockedFrame struct {
StreamID protocol.StreamID
}
// ParseStreamIDBlockedFrame parses a STREAM_ID_BLOCKED frame
func ParseStreamIDBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamIDBlockedFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
streamID, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
return &StreamIDBlockedFrame{StreamID: protocol.StreamID(streamID)}, nil
}
func (f *StreamIDBlockedFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
typeByte := uint8(0x0a)
b.WriteByte(typeByte)
utils.WriteVarInt(b, uint64(f.StreamID))
return nil
}
// MinLength of a written frame
func (f *StreamIDBlockedFrame) MinLength(_ protocol.VersionNumber) protocol.ByteCount {
return 1 + utils.VarIntLen(uint64(f.StreamID))
}

View File

@ -0,0 +1,53 @@
package wire
import (
"bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("STREAM_ID_BLOCKED frame", func() {
Context("parsing", func() {
It("accepts sample frame", func() {
expected := []byte{0xa}
expected = append(expected, encodeVarInt(0xdecafbad)...)
b := bytes.NewReader(expected)
frame, err := ParseStreamIDBlockedFrame(b, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdecafbad)))
Expect(b.Len()).To(BeZero())
})
It("errors on EOFs", func() {
data := []byte{0xa}
data = append(data, encodeVarInt(0x12345678)...)
_, err := ParseStreamIDBlockedFrame(bytes.NewReader(data), versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
for i := range data {
_, err := ParseStreamIDBlockedFrame(bytes.NewReader(data[:i]), versionIETFFrames)
Expect(err).To(MatchError(io.EOF))
}
})
})
Context("writing", func() {
It("writes a sample frame", func() {
b := &bytes.Buffer{}
frame := StreamIDBlockedFrame{StreamID: 0xdeadbeefcafe}
err := frame.Write(b, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
expected := []byte{0xa}
expected = append(expected, encodeVarInt(0xdeadbeefcafe)...)
Expect(b.Bytes()).To(Equal(expected))
})
It("has the correct min length", func() {
frame := StreamIDBlockedFrame{StreamID: 0x123456}
Expect(frame.MinLength(0)).To(Equal(protocol.ByteCount(1) + utils.VarIntLen(0x123456)))
})
})
})

View File

@ -56,6 +56,10 @@ func (mc *mintController) State() mint.State {
return mc.conn.State().HandshakeState return mc.conn.State().HandshakeState
} }
func (mc *mintController) ConnectionState() mint.ConnectionState {
return mc.conn.State()
}
func (mc *mintController) SetCryptoStream(stream io.ReadWriter) { func (mc *mintController) SetCryptoStream(stream io.ReadWriter) {
mc.csc.SetStream(stream) mc.csc.SetStream(stream)
} }
@ -73,6 +77,7 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf
}, },
} }
if tlsConf != nil { if tlsConf != nil {
mconf.ServerName = tlsConf.ServerName
mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates)) mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates))
for i, certChain := range tlsConf.Certificates { for i, certChain := range tlsConf.Certificates {
mconf.Certificates[i] = &mint.Certificate{ mconf.Certificates[i] = &mint.Certificate{
@ -87,6 +92,13 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf
mconf.Certificates[i].Chain[j] = c mconf.Certificates[i].Chain[j] = c
} }
} }
switch tlsConf.ClientAuth {
case tls.NoClientCert:
case tls.RequireAnyClientCert:
mconf.RequireClientAuth = true
default:
return nil, errors.New("mint currently only support ClientAuthType RequireAnyClientCert")
}
} }
if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil { if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil {
return nil, err return nil, err
@ -128,14 +140,14 @@ func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, versio
// packUnencryptedPacket provides a low-overhead way to pack a packet. // packUnencryptedPacket provides a low-overhead way to pack a packet.
// It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available. // It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available.
func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, sf *wire.StreamFrame, pers protocol.Perspective) ([]byte, error) { func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, f wire.Frame, pers protocol.Perspective) ([]byte, error) {
raw := getPacketBuffer() raw := getPacketBuffer()
buffer := bytes.NewBuffer(raw) buffer := bytes.NewBuffer(raw)
if err := hdr.Write(buffer, pers, hdr.Version); err != nil { if err := hdr.Write(buffer, pers, hdr.Version); err != nil {
return nil, err return nil, err
} }
payloadStartIndex := buffer.Len() payloadStartIndex := buffer.Len()
if err := sf.Write(buffer, hdr.Version); err != nil { if err := f.Write(buffer, hdr.Version); err != nil {
return nil, err return nil, err
} }
raw = raw[0:buffer.Len()] raw = raw[0:buffer.Len()]
@ -144,7 +156,7 @@ func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, sf *wire.StreamFr
if utils.Debug() { if utils.Debug() {
utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(raw), hdr.ConnectionID, protocol.EncryptionUnencrypted) utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(raw), hdr.ConnectionID, protocol.EncryptionUnencrypted)
hdr.Log() hdr.Log()
wire.LogFrame(sf, true) wire.LogFrame(f, true)
} }
return raw, nil return raw, nil
} }

View File

@ -2,9 +2,11 @@ package quic
import ( import (
"bytes" "bytes"
"crypto/tls"
"github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -33,6 +35,45 @@ var _ = Describe("Packing and unpacking Initial packets", func() {
hdr.Raw = buf.Bytes() hdr.Raw = buf.Bytes()
}) })
Context("generating a mint.Config", func() {
It("sets non-blocking mode", func() {
mintConf, err := tlsToMintConfig(nil, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
Expect(mintConf.NonBlocking).To(BeTrue())
})
It("sets the server name", func() {
conf := &tls.Config{ServerName: "www.example.com"}
mintConf, err := tlsToMintConfig(conf, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
Expect(mintConf.ServerName).To(Equal("www.example.com"))
})
It("sets the certificate chain", func() {
tlsConf := testdata.GetTLSConfig()
mintConf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
Expect(mintConf.Certificates).ToNot(BeEmpty())
Expect(mintConf.Certificates).To(HaveLen(len(tlsConf.Certificates)))
})
It("requires client authentication", func() {
mintConf, err := tlsToMintConfig(nil, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
Expect(mintConf.RequireClientAuth).To(BeFalse())
conf := &tls.Config{ClientAuth: tls.RequireAnyClientCert}
mintConf, err = tlsToMintConfig(conf, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
Expect(mintConf.RequireClientAuth).To(BeTrue())
})
It("rejects unsupported client auth types", func() {
conf := &tls.Config{ClientAuth: tls.RequireAndVerifyClientCert}
_, err := tlsToMintConfig(conf, protocol.PerspectiveClient)
Expect(err).To(MatchError("mint currently only support ClientAuthType RequireAnyClientCert"))
})
})
Context("unpacking", func() { Context("unpacking", func() {
packPacket := func(frames []wire.Frame) []byte { packPacket := func(frames []wire.Frame) []byte {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}

View File

@ -0,0 +1,141 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: CryptoStream)
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
wire "github.com/lucas-clemente/quic-go/internal/wire"
)
// MockCryptoStream is a mock of CryptoStream interface
type MockCryptoStream struct {
ctrl *gomock.Controller
recorder *MockCryptoStreamMockRecorder
}
// MockCryptoStreamMockRecorder is the mock recorder for MockCryptoStream
type MockCryptoStreamMockRecorder struct {
mock *MockCryptoStream
}
// NewMockCryptoStream creates a new mock instance
func NewMockCryptoStream(ctrl *gomock.Controller) *MockCryptoStream {
mock := &MockCryptoStream{ctrl: ctrl}
mock.recorder = &MockCryptoStreamMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockCryptoStream) EXPECT() *MockCryptoStreamMockRecorder {
return m.recorder
}
// Read mocks base method
func (m *MockCryptoStream) Read(arg0 []byte) (int, error) {
ret := m.ctrl.Call(m, "Read", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Read indicates an expected call of Read
func (mr *MockCryptoStreamMockRecorder) Read(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockCryptoStream)(nil).Read), arg0)
}
// StreamID mocks base method
func (m *MockCryptoStream) StreamID() protocol.StreamID {
ret := m.ctrl.Call(m, "StreamID")
ret0, _ := ret[0].(protocol.StreamID)
return ret0
}
// StreamID indicates an expected call of StreamID
func (mr *MockCryptoStreamMockRecorder) StreamID() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockCryptoStream)(nil).StreamID))
}
// Write mocks base method
func (m *MockCryptoStream) Write(arg0 []byte) (int, error) {
ret := m.ctrl.Call(m, "Write", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Write indicates an expected call of Write
func (mr *MockCryptoStreamMockRecorder) Write(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockCryptoStream)(nil).Write), arg0)
}
// closeForShutdown mocks base method
func (m *MockCryptoStream) closeForShutdown(arg0 error) {
m.ctrl.Call(m, "closeForShutdown", arg0)
}
// closeForShutdown indicates an expected call of closeForShutdown
func (mr *MockCryptoStreamMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockCryptoStream)(nil).closeForShutdown), arg0)
}
// getWindowUpdate mocks base method
func (m *MockCryptoStream) getWindowUpdate() protocol.ByteCount {
ret := m.ctrl.Call(m, "getWindowUpdate")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
// getWindowUpdate indicates an expected call of getWindowUpdate
func (mr *MockCryptoStreamMockRecorder) getWindowUpdate() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockCryptoStream)(nil).getWindowUpdate))
}
// handleMaxStreamDataFrame mocks base method
func (m *MockCryptoStream) handleMaxStreamDataFrame(arg0 *wire.MaxStreamDataFrame) {
m.ctrl.Call(m, "handleMaxStreamDataFrame", arg0)
}
// handleMaxStreamDataFrame indicates an expected call of handleMaxStreamDataFrame
func (mr *MockCryptoStreamMockRecorder) handleMaxStreamDataFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleMaxStreamDataFrame", reflect.TypeOf((*MockCryptoStream)(nil).handleMaxStreamDataFrame), arg0)
}
// handleStreamFrame mocks base method
func (m *MockCryptoStream) handleStreamFrame(arg0 *wire.StreamFrame) error {
ret := m.ctrl.Call(m, "handleStreamFrame", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// handleStreamFrame indicates an expected call of handleStreamFrame
func (mr *MockCryptoStreamMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockCryptoStream)(nil).handleStreamFrame), arg0)
}
// popStreamFrame mocks base method
func (m *MockCryptoStream) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) {
ret := m.ctrl.Call(m, "popStreamFrame", arg0)
ret0, _ := ret[0].(*wire.StreamFrame)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// popStreamFrame indicates an expected call of popStreamFrame
func (mr *MockCryptoStreamMockRecorder) popStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockCryptoStream)(nil).popStreamFrame), arg0)
}
// setReadOffset mocks base method
func (m *MockCryptoStream) setReadOffset(arg0 protocol.ByteCount) {
m.ctrl.Call(m, "setReadOffset", arg0)
}
// setReadOffset indicates an expected call of setReadOffset
func (mr *MockCryptoStreamMockRecorder) setReadOffset(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setReadOffset", reflect.TypeOf((*MockCryptoStream)(nil).setReadOffset), arg0)
}

View File

@ -0,0 +1,132 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: ReceiveStreamI)
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
wire "github.com/lucas-clemente/quic-go/internal/wire"
)
// MockReceiveStreamI is a mock of ReceiveStreamI interface
type MockReceiveStreamI struct {
ctrl *gomock.Controller
recorder *MockReceiveStreamIMockRecorder
}
// MockReceiveStreamIMockRecorder is the mock recorder for MockReceiveStreamI
type MockReceiveStreamIMockRecorder struct {
mock *MockReceiveStreamI
}
// NewMockReceiveStreamI creates a new mock instance
func NewMockReceiveStreamI(ctrl *gomock.Controller) *MockReceiveStreamI {
mock := &MockReceiveStreamI{ctrl: ctrl}
mock.recorder = &MockReceiveStreamIMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockReceiveStreamI) EXPECT() *MockReceiveStreamIMockRecorder {
return m.recorder
}
// CancelRead mocks base method
func (m *MockReceiveStreamI) CancelRead(arg0 protocol.ApplicationErrorCode) error {
ret := m.ctrl.Call(m, "CancelRead", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// CancelRead indicates an expected call of CancelRead
func (mr *MockReceiveStreamIMockRecorder) CancelRead(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockReceiveStreamI)(nil).CancelRead), arg0)
}
// Read mocks base method
func (m *MockReceiveStreamI) Read(arg0 []byte) (int, error) {
ret := m.ctrl.Call(m, "Read", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Read indicates an expected call of Read
func (mr *MockReceiveStreamIMockRecorder) Read(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockReceiveStreamI)(nil).Read), arg0)
}
// SetReadDeadline mocks base method
func (m *MockReceiveStreamI) SetReadDeadline(arg0 time.Time) error {
ret := m.ctrl.Call(m, "SetReadDeadline", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetReadDeadline indicates an expected call of SetReadDeadline
func (mr *MockReceiveStreamIMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockReceiveStreamI)(nil).SetReadDeadline), arg0)
}
// StreamID mocks base method
func (m *MockReceiveStreamI) StreamID() protocol.StreamID {
ret := m.ctrl.Call(m, "StreamID")
ret0, _ := ret[0].(protocol.StreamID)
return ret0
}
// StreamID indicates an expected call of StreamID
func (mr *MockReceiveStreamIMockRecorder) StreamID() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockReceiveStreamI)(nil).StreamID))
}
// closeForShutdown mocks base method
func (m *MockReceiveStreamI) closeForShutdown(arg0 error) {
m.ctrl.Call(m, "closeForShutdown", arg0)
}
// closeForShutdown indicates an expected call of closeForShutdown
func (mr *MockReceiveStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockReceiveStreamI)(nil).closeForShutdown), arg0)
}
// getWindowUpdate mocks base method
func (m *MockReceiveStreamI) getWindowUpdate() protocol.ByteCount {
ret := m.ctrl.Call(m, "getWindowUpdate")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
// getWindowUpdate indicates an expected call of getWindowUpdate
func (mr *MockReceiveStreamIMockRecorder) getWindowUpdate() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockReceiveStreamI)(nil).getWindowUpdate))
}
// handleRstStreamFrame mocks base method
func (m *MockReceiveStreamI) handleRstStreamFrame(arg0 *wire.RstStreamFrame) error {
ret := m.ctrl.Call(m, "handleRstStreamFrame", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// handleRstStreamFrame indicates an expected call of handleRstStreamFrame
func (mr *MockReceiveStreamIMockRecorder) handleRstStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleRstStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleRstStreamFrame), arg0)
}
// handleStreamFrame mocks base method
func (m *MockReceiveStreamI) handleStreamFrame(arg0 *wire.StreamFrame) error {
ret := m.ctrl.Call(m, "handleStreamFrame", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// handleStreamFrame indicates an expected call of handleStreamFrame
func (mr *MockReceiveStreamIMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleStreamFrame), arg0)
}

View File

@ -0,0 +1,154 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: SendStreamI)
// Package quic is a generated GoMock package.
package quic
import (
context "context"
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
wire "github.com/lucas-clemente/quic-go/internal/wire"
)
// MockSendStreamI is a mock of SendStreamI interface
type MockSendStreamI struct {
ctrl *gomock.Controller
recorder *MockSendStreamIMockRecorder
}
// MockSendStreamIMockRecorder is the mock recorder for MockSendStreamI
type MockSendStreamIMockRecorder struct {
mock *MockSendStreamI
}
// NewMockSendStreamI creates a new mock instance
func NewMockSendStreamI(ctrl *gomock.Controller) *MockSendStreamI {
mock := &MockSendStreamI{ctrl: ctrl}
mock.recorder = &MockSendStreamIMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockSendStreamI) EXPECT() *MockSendStreamIMockRecorder {
return m.recorder
}
// CancelWrite mocks base method
func (m *MockSendStreamI) CancelWrite(arg0 protocol.ApplicationErrorCode) error {
ret := m.ctrl.Call(m, "CancelWrite", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// CancelWrite indicates an expected call of CancelWrite
func (mr *MockSendStreamIMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockSendStreamI)(nil).CancelWrite), arg0)
}
// Close mocks base method
func (m *MockSendStreamI) Close() error {
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close
func (mr *MockSendStreamIMockRecorder) Close() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSendStreamI)(nil).Close))
}
// Context mocks base method
func (m *MockSendStreamI) Context() context.Context {
ret := m.ctrl.Call(m, "Context")
ret0, _ := ret[0].(context.Context)
return ret0
}
// Context indicates an expected call of Context
func (mr *MockSendStreamIMockRecorder) Context() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockSendStreamI)(nil).Context))
}
// SetWriteDeadline mocks base method
func (m *MockSendStreamI) SetWriteDeadline(arg0 time.Time) error {
ret := m.ctrl.Call(m, "SetWriteDeadline", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetWriteDeadline indicates an expected call of SetWriteDeadline
func (mr *MockSendStreamIMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockSendStreamI)(nil).SetWriteDeadline), arg0)
}
// StreamID mocks base method
func (m *MockSendStreamI) StreamID() protocol.StreamID {
ret := m.ctrl.Call(m, "StreamID")
ret0, _ := ret[0].(protocol.StreamID)
return ret0
}
// StreamID indicates an expected call of StreamID
func (mr *MockSendStreamIMockRecorder) StreamID() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockSendStreamI)(nil).StreamID))
}
// Write mocks base method
func (m *MockSendStreamI) Write(arg0 []byte) (int, error) {
ret := m.ctrl.Call(m, "Write", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Write indicates an expected call of Write
func (mr *MockSendStreamIMockRecorder) Write(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendStreamI)(nil).Write), arg0)
}
// closeForShutdown mocks base method
func (m *MockSendStreamI) closeForShutdown(arg0 error) {
m.ctrl.Call(m, "closeForShutdown", arg0)
}
// closeForShutdown indicates an expected call of closeForShutdown
func (mr *MockSendStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockSendStreamI)(nil).closeForShutdown), arg0)
}
// handleMaxStreamDataFrame mocks base method
func (m *MockSendStreamI) handleMaxStreamDataFrame(arg0 *wire.MaxStreamDataFrame) {
m.ctrl.Call(m, "handleMaxStreamDataFrame", arg0)
}
// handleMaxStreamDataFrame indicates an expected call of handleMaxStreamDataFrame
func (mr *MockSendStreamIMockRecorder) handleMaxStreamDataFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleMaxStreamDataFrame", reflect.TypeOf((*MockSendStreamI)(nil).handleMaxStreamDataFrame), arg0)
}
// handleStopSendingFrame mocks base method
func (m *MockSendStreamI) handleStopSendingFrame(arg0 *wire.StopSendingFrame) {
m.ctrl.Call(m, "handleStopSendingFrame", arg0)
}
// handleStopSendingFrame indicates an expected call of handleStopSendingFrame
func (mr *MockSendStreamIMockRecorder) handleStopSendingFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockSendStreamI)(nil).handleStopSendingFrame), arg0)
}
// popStreamFrame mocks base method
func (m *MockSendStreamI) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) {
ret := m.ctrl.Call(m, "popStreamFrame", arg0)
ret0, _ := ret[0].(*wire.StreamFrame)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// popStreamFrame indicates an expected call of popStreamFrame
func (mr *MockSendStreamIMockRecorder) popStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockSendStreamI)(nil).popStreamFrame), arg0)
}

View File

@ -0,0 +1,72 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: StreamFrameSource)
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
wire "github.com/lucas-clemente/quic-go/internal/wire"
)
// MockStreamFrameSource is a mock of StreamFrameSource interface
type MockStreamFrameSource struct {
ctrl *gomock.Controller
recorder *MockStreamFrameSourceMockRecorder
}
// MockStreamFrameSourceMockRecorder is the mock recorder for MockStreamFrameSource
type MockStreamFrameSourceMockRecorder struct {
mock *MockStreamFrameSource
}
// NewMockStreamFrameSource creates a new mock instance
func NewMockStreamFrameSource(ctrl *gomock.Controller) *MockStreamFrameSource {
mock := &MockStreamFrameSource{ctrl: ctrl}
mock.recorder = &MockStreamFrameSourceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockStreamFrameSource) EXPECT() *MockStreamFrameSourceMockRecorder {
return m.recorder
}
// HasCryptoStreamData mocks base method
func (m *MockStreamFrameSource) HasCryptoStreamData() bool {
ret := m.ctrl.Call(m, "HasCryptoStreamData")
ret0, _ := ret[0].(bool)
return ret0
}
// HasCryptoStreamData indicates an expected call of HasCryptoStreamData
func (mr *MockStreamFrameSourceMockRecorder) HasCryptoStreamData() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasCryptoStreamData", reflect.TypeOf((*MockStreamFrameSource)(nil).HasCryptoStreamData))
}
// PopCryptoStreamFrame mocks base method
func (m *MockStreamFrameSource) PopCryptoStreamFrame(arg0 protocol.ByteCount) *wire.StreamFrame {
ret := m.ctrl.Call(m, "PopCryptoStreamFrame", arg0)
ret0, _ := ret[0].(*wire.StreamFrame)
return ret0
}
// PopCryptoStreamFrame indicates an expected call of PopCryptoStreamFrame
func (mr *MockStreamFrameSourceMockRecorder) PopCryptoStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopCryptoStreamFrame", reflect.TypeOf((*MockStreamFrameSource)(nil).PopCryptoStreamFrame), arg0)
}
// PopStreamFrames mocks base method
func (m *MockStreamFrameSource) PopStreamFrames(arg0 protocol.ByteCount) []*wire.StreamFrame {
ret := m.ctrl.Call(m, "PopStreamFrames", arg0)
ret0, _ := ret[0].([]*wire.StreamFrame)
return ret0
}
// PopStreamFrames indicates an expected call of PopStreamFrames
func (mr *MockStreamFrameSourceMockRecorder) PopStreamFrames(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopStreamFrames", reflect.TypeOf((*MockStreamFrameSource)(nil).PopStreamFrames), arg0)
}

View File

@ -0,0 +1,61 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: StreamGetter)
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockStreamGetter is a mock of StreamGetter interface
type MockStreamGetter struct {
ctrl *gomock.Controller
recorder *MockStreamGetterMockRecorder
}
// MockStreamGetterMockRecorder is the mock recorder for MockStreamGetter
type MockStreamGetterMockRecorder struct {
mock *MockStreamGetter
}
// NewMockStreamGetter creates a new mock instance
func NewMockStreamGetter(ctrl *gomock.Controller) *MockStreamGetter {
mock := &MockStreamGetter{ctrl: ctrl}
mock.recorder = &MockStreamGetterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockStreamGetter) EXPECT() *MockStreamGetterMockRecorder {
return m.recorder
}
// GetOrOpenReceiveStream mocks base method
func (m *MockStreamGetter) GetOrOpenReceiveStream(arg0 protocol.StreamID) (receiveStreamI, error) {
ret := m.ctrl.Call(m, "GetOrOpenReceiveStream", arg0)
ret0, _ := ret[0].(receiveStreamI)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream
func (mr *MockStreamGetterMockRecorder) GetOrOpenReceiveStream(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenReceiveStream), arg0)
}
// GetOrOpenSendStream mocks base method
func (m *MockStreamGetter) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStreamI, error) {
ret := m.ctrl.Call(m, "GetOrOpenSendStream", arg0)
ret0, _ := ret[0].(sendStreamI)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream
func (mr *MockStreamGetterMockRecorder) GetOrOpenSendStream(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenSendStream), arg0)
}

View File

@ -0,0 +1,239 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: StreamI)
// Package quic is a generated GoMock package.
package quic
import (
context "context"
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
wire "github.com/lucas-clemente/quic-go/internal/wire"
)
// MockStreamI is a mock of StreamI interface
type MockStreamI struct {
ctrl *gomock.Controller
recorder *MockStreamIMockRecorder
}
// MockStreamIMockRecorder is the mock recorder for MockStreamI
type MockStreamIMockRecorder struct {
mock *MockStreamI
}
// NewMockStreamI creates a new mock instance
func NewMockStreamI(ctrl *gomock.Controller) *MockStreamI {
mock := &MockStreamI{ctrl: ctrl}
mock.recorder = &MockStreamIMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockStreamI) EXPECT() *MockStreamIMockRecorder {
return m.recorder
}
// CancelRead mocks base method
func (m *MockStreamI) CancelRead(arg0 protocol.ApplicationErrorCode) error {
ret := m.ctrl.Call(m, "CancelRead", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// CancelRead indicates an expected call of CancelRead
func (mr *MockStreamIMockRecorder) CancelRead(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockStreamI)(nil).CancelRead), arg0)
}
// CancelWrite mocks base method
func (m *MockStreamI) CancelWrite(arg0 protocol.ApplicationErrorCode) error {
ret := m.ctrl.Call(m, "CancelWrite", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// CancelWrite indicates an expected call of CancelWrite
func (mr *MockStreamIMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockStreamI)(nil).CancelWrite), arg0)
}
// Close mocks base method
func (m *MockStreamI) Close() error {
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close
func (mr *MockStreamIMockRecorder) Close() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStreamI)(nil).Close))
}
// Context mocks base method
func (m *MockStreamI) Context() context.Context {
ret := m.ctrl.Call(m, "Context")
ret0, _ := ret[0].(context.Context)
return ret0
}
// Context indicates an expected call of Context
func (mr *MockStreamIMockRecorder) Context() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockStreamI)(nil).Context))
}
// Read mocks base method
func (m *MockStreamI) Read(arg0 []byte) (int, error) {
ret := m.ctrl.Call(m, "Read", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Read indicates an expected call of Read
func (mr *MockStreamIMockRecorder) Read(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStreamI)(nil).Read), arg0)
}
// SetDeadline mocks base method
func (m *MockStreamI) SetDeadline(arg0 time.Time) error {
ret := m.ctrl.Call(m, "SetDeadline", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetDeadline indicates an expected call of SetDeadline
func (mr *MockStreamIMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockStreamI)(nil).SetDeadline), arg0)
}
// SetReadDeadline mocks base method
func (m *MockStreamI) SetReadDeadline(arg0 time.Time) error {
ret := m.ctrl.Call(m, "SetReadDeadline", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetReadDeadline indicates an expected call of SetReadDeadline
func (mr *MockStreamIMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockStreamI)(nil).SetReadDeadline), arg0)
}
// SetWriteDeadline mocks base method
func (m *MockStreamI) SetWriteDeadline(arg0 time.Time) error {
ret := m.ctrl.Call(m, "SetWriteDeadline", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetWriteDeadline indicates an expected call of SetWriteDeadline
func (mr *MockStreamIMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStreamI)(nil).SetWriteDeadline), arg0)
}
// StreamID mocks base method
func (m *MockStreamI) StreamID() protocol.StreamID {
ret := m.ctrl.Call(m, "StreamID")
ret0, _ := ret[0].(protocol.StreamID)
return ret0
}
// StreamID indicates an expected call of StreamID
func (mr *MockStreamIMockRecorder) StreamID() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockStreamI)(nil).StreamID))
}
// Write mocks base method
func (m *MockStreamI) Write(arg0 []byte) (int, error) {
ret := m.ctrl.Call(m, "Write", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Write indicates an expected call of Write
func (mr *MockStreamIMockRecorder) Write(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStreamI)(nil).Write), arg0)
}
// closeForShutdown mocks base method
func (m *MockStreamI) closeForShutdown(arg0 error) {
m.ctrl.Call(m, "closeForShutdown", arg0)
}
// closeForShutdown indicates an expected call of closeForShutdown
func (mr *MockStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockStreamI)(nil).closeForShutdown), arg0)
}
// getWindowUpdate mocks base method
func (m *MockStreamI) getWindowUpdate() protocol.ByteCount {
ret := m.ctrl.Call(m, "getWindowUpdate")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
// getWindowUpdate indicates an expected call of getWindowUpdate
func (mr *MockStreamIMockRecorder) getWindowUpdate() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockStreamI)(nil).getWindowUpdate))
}
// handleMaxStreamDataFrame mocks base method
func (m *MockStreamI) handleMaxStreamDataFrame(arg0 *wire.MaxStreamDataFrame) {
m.ctrl.Call(m, "handleMaxStreamDataFrame", arg0)
}
// handleMaxStreamDataFrame indicates an expected call of handleMaxStreamDataFrame
func (mr *MockStreamIMockRecorder) handleMaxStreamDataFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleMaxStreamDataFrame", reflect.TypeOf((*MockStreamI)(nil).handleMaxStreamDataFrame), arg0)
}
// handleRstStreamFrame mocks base method
func (m *MockStreamI) handleRstStreamFrame(arg0 *wire.RstStreamFrame) error {
ret := m.ctrl.Call(m, "handleRstStreamFrame", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// handleRstStreamFrame indicates an expected call of handleRstStreamFrame
func (mr *MockStreamIMockRecorder) handleRstStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleRstStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleRstStreamFrame), arg0)
}
// handleStopSendingFrame mocks base method
func (m *MockStreamI) handleStopSendingFrame(arg0 *wire.StopSendingFrame) {
m.ctrl.Call(m, "handleStopSendingFrame", arg0)
}
// handleStopSendingFrame indicates an expected call of handleStopSendingFrame
func (mr *MockStreamIMockRecorder) handleStopSendingFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockStreamI)(nil).handleStopSendingFrame), arg0)
}
// handleStreamFrame mocks base method
func (m *MockStreamI) handleStreamFrame(arg0 *wire.StreamFrame) error {
ret := m.ctrl.Call(m, "handleStreamFrame", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// handleStreamFrame indicates an expected call of handleStreamFrame
func (mr *MockStreamIMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleStreamFrame), arg0)
}
// popStreamFrame mocks base method
func (m *MockStreamI) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) {
ret := m.ctrl.Call(m, "popStreamFrame", arg0)
ret0, _ := ret[0].(*wire.StreamFrame)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// popStreamFrame indicates an expected call of popStreamFrame
func (mr *MockStreamIMockRecorder) popStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockStreamI)(nil).popStreamFrame), arg0)
}

View File

@ -0,0 +1,146 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: StreamManager)
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockStreamManager is a mock of StreamManager interface
type MockStreamManager struct {
ctrl *gomock.Controller
recorder *MockStreamManagerMockRecorder
}
// MockStreamManagerMockRecorder is the mock recorder for MockStreamManager
type MockStreamManagerMockRecorder struct {
mock *MockStreamManager
}
// NewMockStreamManager creates a new mock instance
func NewMockStreamManager(ctrl *gomock.Controller) *MockStreamManager {
mock := &MockStreamManager{ctrl: ctrl}
mock.recorder = &MockStreamManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockStreamManager) EXPECT() *MockStreamManagerMockRecorder {
return m.recorder
}
// AcceptStream mocks base method
func (m *MockStreamManager) AcceptStream() (Stream, error) {
ret := m.ctrl.Call(m, "AcceptStream")
ret0, _ := ret[0].(Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AcceptStream indicates an expected call of AcceptStream
func (mr *MockStreamManagerMockRecorder) AcceptStream() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream))
}
// CloseWithError mocks base method
func (m *MockStreamManager) CloseWithError(arg0 error) {
m.ctrl.Call(m, "CloseWithError", arg0)
}
// CloseWithError indicates an expected call of CloseWithError
func (mr *MockStreamManagerMockRecorder) CloseWithError(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockStreamManager)(nil).CloseWithError), arg0)
}
// DeleteStream mocks base method
func (m *MockStreamManager) DeleteStream(arg0 protocol.StreamID) error {
ret := m.ctrl.Call(m, "DeleteStream", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteStream indicates an expected call of DeleteStream
func (mr *MockStreamManagerMockRecorder) DeleteStream(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStream", reflect.TypeOf((*MockStreamManager)(nil).DeleteStream), arg0)
}
// GetOrOpenReceiveStream mocks base method
func (m *MockStreamManager) GetOrOpenReceiveStream(arg0 protocol.StreamID) (receiveStreamI, error) {
ret := m.ctrl.Call(m, "GetOrOpenReceiveStream", arg0)
ret0, _ := ret[0].(receiveStreamI)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream
func (mr *MockStreamManagerMockRecorder) GetOrOpenReceiveStream(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenReceiveStream), arg0)
}
// GetOrOpenSendStream mocks base method
func (m *MockStreamManager) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStreamI, error) {
ret := m.ctrl.Call(m, "GetOrOpenSendStream", arg0)
ret0, _ := ret[0].(sendStreamI)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream
func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0)
}
// GetOrOpenStream mocks base method
func (m *MockStreamManager) GetOrOpenStream(arg0 protocol.StreamID) (streamI, error) {
ret := m.ctrl.Call(m, "GetOrOpenStream", arg0)
ret0, _ := ret[0].(streamI)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrOpenStream indicates an expected call of GetOrOpenStream
func (mr *MockStreamManagerMockRecorder) GetOrOpenStream(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenStream), arg0)
}
// OpenStream mocks base method
func (m *MockStreamManager) OpenStream() (Stream, error) {
ret := m.ctrl.Call(m, "OpenStream")
ret0, _ := ret[0].(Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenStream indicates an expected call of OpenStream
func (mr *MockStreamManagerMockRecorder) OpenStream() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockStreamManager)(nil).OpenStream))
}
// OpenStreamSync mocks base method
func (m *MockStreamManager) OpenStreamSync() (Stream, error) {
ret := m.ctrl.Call(m, "OpenStreamSync")
ret0, _ := ret[0].(Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenStreamSync indicates an expected call of OpenStreamSync
func (mr *MockStreamManagerMockRecorder) OpenStreamSync() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync))
}
// UpdateLimits mocks base method
func (m *MockStreamManager) UpdateLimits(arg0 *handshake.TransportParameters) {
m.ctrl.Call(m, "UpdateLimits", arg0)
}
// UpdateLimits indicates an expected call of UpdateLimits
func (mr *MockStreamManagerMockRecorder) UpdateLimits(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLimits", reflect.TypeOf((*MockStreamManager)(nil).UpdateLimits), arg0)
}

View File

@ -0,0 +1,76 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: StreamSender)
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
wire "github.com/lucas-clemente/quic-go/internal/wire"
)
// MockStreamSender is a mock of StreamSender interface
type MockStreamSender struct {
ctrl *gomock.Controller
recorder *MockStreamSenderMockRecorder
}
// MockStreamSenderMockRecorder is the mock recorder for MockStreamSender
type MockStreamSenderMockRecorder struct {
mock *MockStreamSender
}
// NewMockStreamSender creates a new mock instance
func NewMockStreamSender(ctrl *gomock.Controller) *MockStreamSender {
mock := &MockStreamSender{ctrl: ctrl}
mock.recorder = &MockStreamSenderMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockStreamSender) EXPECT() *MockStreamSenderMockRecorder {
return m.recorder
}
// onHasStreamData mocks base method
func (m *MockStreamSender) onHasStreamData(arg0 protocol.StreamID) {
m.ctrl.Call(m, "onHasStreamData", arg0)
}
// onHasStreamData indicates an expected call of onHasStreamData
func (mr *MockStreamSenderMockRecorder) onHasStreamData(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamData", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamData), arg0)
}
// onHasWindowUpdate mocks base method
func (m *MockStreamSender) onHasWindowUpdate(arg0 protocol.StreamID) {
m.ctrl.Call(m, "onHasWindowUpdate", arg0)
}
// onHasWindowUpdate indicates an expected call of onHasWindowUpdate
func (mr *MockStreamSenderMockRecorder) onHasWindowUpdate(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasWindowUpdate", reflect.TypeOf((*MockStreamSender)(nil).onHasWindowUpdate), arg0)
}
// onStreamCompleted mocks base method
func (m *MockStreamSender) onStreamCompleted(arg0 protocol.StreamID) {
m.ctrl.Call(m, "onStreamCompleted", arg0)
}
// onStreamCompleted indicates an expected call of onStreamCompleted
func (mr *MockStreamSenderMockRecorder) onStreamCompleted(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onStreamCompleted", reflect.TypeOf((*MockStreamSender)(nil).onStreamCompleted), arg0)
}
// queueControlFrame mocks base method
func (m *MockStreamSender) queueControlFrame(arg0 wire.Frame) {
m.ctrl.Call(m, "queueControlFrame", arg0)
}
// queueControlFrame indicates an expected call of queueControlFrame
func (mr *MockStreamSenderMockRecorder) queueControlFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "queueControlFrame", reflect.TypeOf((*MockStreamSender)(nil).queueControlFrame), arg0)
}

12
vendor/github.com/lucas-clemente/quic-go/mockgen.go generated vendored Normal file
View File

@ -0,0 +1,12 @@
package quic
//go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/lucas-clemente/quic-go streamI StreamI"
//go:generate sh -c "./mockgen_private.sh quic mock_receive_stream_internal_test.go github.com/lucas-clemente/quic-go receiveStreamI ReceiveStreamI"
//go:generate sh -c "./mockgen_private.sh quic mock_send_stream_internal_test.go github.com/lucas-clemente/quic-go sendStreamI SendStreamI"
//go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender StreamSender"
//go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter StreamGetter"
//go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource StreamFrameSource"
//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStreamI CryptoStream"
//go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager StreamManager"
//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_stream_getter_test.go mock_stream_manager_test.go"
//go:generate sh -c "goimports -w mock*_test.go"

19
vendor/github.com/lucas-clemente/quic-go/mockgen_private.sh generated vendored Executable file
View File

@ -0,0 +1,19 @@
#!/bin/bash
# Mockgen refuses to generate mocks private types.
# This script copies the quic package to a temporary directory, and adds an public alias for the private type.
# It then creates a mock for this public (alias) type.
TEMP_DIR=$(mktemp -d)
mkdir -p $TEMP_DIR/src/github.com/lucas-clemente/quic-go/
# copy all .go files to a temporary directory
# golang.org/x/crypto/curve25519/ uses Go compiler directives, which is confusing to mockgen
rsync -r --exclude 'vendor/golang.org/x/crypto/curve25519/' --include='*.go' --include '*/' --exclude '*' $GOPATH/src/github.com/lucas-clemente/quic-go/ $TEMP_DIR/src/github.com/lucas-clemente/quic-go/
echo "type $5 = $4" >> $TEMP_DIR/src/github.com/lucas-clemente/quic-go/interface.go
export GOPATH="$TEMP_DIR:$GOPATH"
mockgen -package $1 -self_package $1 -destination $2 $3 $5
rm -r "$TEMP_DIR"

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"sync"
"github.com/lucas-clemente/quic-go/ackhandler" "github.com/lucas-clemente/quic-go/ackhandler"
"github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/handshake"
@ -18,6 +19,12 @@ type packedPacket struct {
encryptionLevel protocol.EncryptionLevel encryptionLevel protocol.EncryptionLevel
} }
type streamFrameSource interface {
HasCryptoStreamData() bool
PopCryptoStreamFrame(protocol.ByteCount) *wire.StreamFrame
PopStreamFrames(protocol.ByteCount) []*wire.StreamFrame
}
type packetPacker struct { type packetPacker struct {
connectionID protocol.ConnectionID connectionID protocol.ConnectionID
perspective protocol.Perspective perspective protocol.Perspective
@ -25,20 +32,23 @@ type packetPacker struct {
cryptoSetup handshake.CryptoSetup cryptoSetup handshake.CryptoSetup
packetNumberGenerator *packetNumberGenerator packetNumberGenerator *packetNumberGenerator
streamFramer *streamFramer streams streamFrameSource
controlFrameMutex sync.Mutex
controlFrames []wire.Frame controlFrames []wire.Frame
stopWaiting *wire.StopWaitingFrame stopWaiting *wire.StopWaitingFrame
ackFrame *wire.AckFrame ackFrame *wire.AckFrame
leastUnacked protocol.PacketNumber leastUnacked protocol.PacketNumber
omitConnectionID bool omitConnectionID bool
hasSentPacket bool // has the packetPacker already sent a packet hasSentPacket bool // has the packetPacker already sent a packet
makeNextPacketRetransmittable bool
} }
func newPacketPacker(connectionID protocol.ConnectionID, func newPacketPacker(connectionID protocol.ConnectionID,
initialPacketNumber protocol.PacketNumber, initialPacketNumber protocol.PacketNumber,
cryptoSetup handshake.CryptoSetup, cryptoSetup handshake.CryptoSetup,
streamFramer *streamFramer, streamFramer streamFrameSource,
perspective protocol.Perspective, perspective protocol.Perspective,
version protocol.VersionNumber, version protocol.VersionNumber,
) *packetPacker { ) *packetPacker {
@ -47,7 +57,7 @@ func newPacketPacker(connectionID protocol.ConnectionID,
connectionID: connectionID, connectionID: connectionID,
perspective: perspective, perspective: perspective,
version: version, version: version,
streamFramer: streamFramer, streams: streamFramer,
packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength), packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength),
} }
} }
@ -73,7 +83,7 @@ func (p *packetPacker) PackAckPacket() (*packedPacket, error) {
encLevel, sealer := p.cryptoSetup.GetSealer() encLevel, sealer := p.cryptoSetup.GetSealer()
header := p.getHeader(encLevel) header := p.getHeader(encLevel)
frames := []wire.Frame{p.ackFrame} frames := []wire.Frame{p.ackFrame}
if p.stopWaiting != nil { if p.stopWaiting != nil { // a STOP_WAITING will only be queued when using gQUIC
p.stopWaiting.PacketNumber = header.PacketNumber p.stopWaiting.PacketNumber = header.PacketNumber
p.stopWaiting.PacketNumberLen = header.PacketNumberLen p.stopWaiting.PacketNumberLen = header.PacketNumberLen
frames = append(frames, p.stopWaiting) frames = append(frames, p.stopWaiting)
@ -98,14 +108,20 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (*
if err != nil { if err != nil {
return nil, err return nil, err
} }
if p.stopWaiting == nil {
return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame")
}
header := p.getHeader(packet.EncryptionLevel) header := p.getHeader(packet.EncryptionLevel)
p.stopWaiting.PacketNumber = header.PacketNumber var frames []wire.Frame
p.stopWaiting.PacketNumberLen = header.PacketNumberLen if !p.version.UsesIETFFrameFormat() { // for gQUIC: pack a STOP_WAITING first
frames := append([]wire.Frame{p.stopWaiting}, packet.Frames...) if p.stopWaiting == nil {
return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame")
}
swf := p.stopWaiting
swf.PacketNumber = header.PacketNumber
swf.PacketNumberLen = header.PacketNumberLen
p.stopWaiting = nil p.stopWaiting = nil
frames = append([]wire.Frame{swf}, packet.Frames...)
} else {
frames = packet.Frames
}
raw, err := p.writeAndSealPacket(header, frames, sealer) raw, err := p.writeAndSealPacket(header, frames, sealer)
return &packedPacket{ return &packedPacket{
header: header, header: header,
@ -118,7 +134,7 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (*
// PackPacket packs a new packet // PackPacket packs a new packet
// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
func (p *packetPacker) PackPacket() (*packedPacket, error) { func (p *packetPacker) PackPacket() (*packedPacket, error) {
hasCryptoStreamFrame := p.streamFramer.HasCryptoStreamFrame() hasCryptoStreamFrame := p.streams.HasCryptoStreamData()
// if this is the first packet to be send, make sure it contains stream data // if this is the first packet to be send, make sure it contains stream data
if !p.hasSentPacket && !hasCryptoStreamFrame { if !p.hasSentPacket && !hasCryptoStreamFrame {
return nil, nil return nil, nil
@ -153,6 +169,15 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) {
if len(payloadFrames) == 1 && p.stopWaiting != nil { if len(payloadFrames) == 1 && p.stopWaiting != nil {
return nil, nil return nil, nil
} }
// check if this packet only contains an ACK and / or STOP_WAITING
if !ackhandler.HasRetransmittableFrames(payloadFrames) {
if p.makeNextPacketRetransmittable {
payloadFrames = append(payloadFrames, &wire.PingFrame{})
p.makeNextPacketRetransmittable = false
}
} else { // this packet already contains a retransmittable frame. No need to send a PING
p.makeNextPacketRetransmittable = false
}
p.stopWaiting = nil p.stopWaiting = nil
p.ackFrame = nil p.ackFrame = nil
@ -176,7 +201,9 @@ func (p *packetPacker) packCryptoPacket() (*packedPacket, error) {
return nil, err return nil, err
} }
maxLen := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength maxLen := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength
frames := []wire.Frame{p.streamFramer.PopCryptoStreamFrame(maxLen)} sf := p.streams.PopCryptoStreamFrame(maxLen)
sf.DataLenPresent = false
frames := []wire.Frame{sf}
raw, err := p.writeAndSealPacket(header, frames, sealer) raw, err := p.writeAndSealPacket(header, frames, sealer)
if err != nil { if err != nil {
return nil, err return nil, err
@ -197,29 +224,20 @@ func (p *packetPacker) composeNextPacket(
var payloadFrames []wire.Frame var payloadFrames []wire.Frame
// STOP_WAITING and ACK will always fit // STOP_WAITING and ACK will always fit
if p.stopWaiting != nil { if p.ackFrame != nil { // ACKs need to go first, so that the sentPacketHandler will recognize them
payloadFrames = append(payloadFrames, p.stopWaiting)
l, err := p.stopWaiting.MinLength(p.version)
if err != nil {
return nil, err
}
payloadLength += l
}
if p.ackFrame != nil {
payloadFrames = append(payloadFrames, p.ackFrame) payloadFrames = append(payloadFrames, p.ackFrame)
l, err := p.ackFrame.MinLength(p.version) l := p.ackFrame.MinLength(p.version)
if err != nil {
return nil, err
}
payloadLength += l payloadLength += l
} }
if p.stopWaiting != nil { // a STOP_WAITING will only be queued when using gQUIC
payloadFrames = append(payloadFrames, p.stopWaiting)
payloadLength += p.stopWaiting.MinLength(p.version)
}
p.controlFrameMutex.Lock()
for len(p.controlFrames) > 0 { for len(p.controlFrames) > 0 {
frame := p.controlFrames[len(p.controlFrames)-1] frame := p.controlFrames[len(p.controlFrames)-1]
minLength, err := frame.MinLength(p.version) minLength := frame.MinLength(p.version)
if err != nil {
return nil, err
}
if payloadLength+minLength > maxFrameSize { if payloadLength+minLength > maxFrameSize {
break break
} }
@ -227,6 +245,7 @@ func (p *packetPacker) composeNextPacket(
payloadLength += minLength payloadLength += minLength
p.controlFrames = p.controlFrames[:len(p.controlFrames)-1] p.controlFrames = p.controlFrames[:len(p.controlFrames)-1]
} }
p.controlFrameMutex.Unlock()
if payloadLength > maxFrameSize { if payloadLength > maxFrameSize {
return nil, fmt.Errorf("Packet Packer BUG: packet payload (%d) too large (%d)", payloadLength, maxFrameSize) return nil, fmt.Errorf("Packet Packer BUG: packet payload (%d) too large (%d)", payloadLength, maxFrameSize)
@ -247,20 +266,14 @@ func (p *packetPacker) composeNextPacket(
maxFrameSize += 2 maxFrameSize += 2
} }
fs := p.streamFramer.PopStreamFrames(maxFrameSize - payloadLength) fs := p.streams.PopStreamFrames(maxFrameSize - payloadLength)
if len(fs) != 0 { if len(fs) != 0 {
fs[len(fs)-1].DataLenPresent = false fs[len(fs)-1].DataLenPresent = false
} }
// TODO: Simplify
for _, f := range fs { for _, f := range fs {
payloadFrames = append(payloadFrames, f) payloadFrames = append(payloadFrames, f)
} }
for b := p.streamFramer.PopBlockedFrame(); b != nil; b = p.streamFramer.PopBlockedFrame() {
p.controlFrames = append(p.controlFrames, b)
}
return payloadFrames, nil return payloadFrames, nil
} }
@ -271,7 +284,9 @@ func (p *packetPacker) QueueControlFrame(frame wire.Frame) {
case *wire.AckFrame: case *wire.AckFrame:
p.ackFrame = f p.ackFrame = f
default: default:
p.controlFrameMutex.Lock()
p.controlFrames = append(p.controlFrames, f) p.controlFrames = append(p.controlFrames, f)
p.controlFrameMutex.Unlock()
} }
} }
@ -377,3 +392,7 @@ func (p *packetPacker) SetLeastUnacked(leastUnacked protocol.PacketNumber) {
func (p *packetPacker) SetOmitConnectionID() { func (p *packetPacker) SetOmitConnectionID() {
p.omitConnectionID = true p.omitConnectionID = true
} }
func (p *packetPacker) MakeNextPacketRetransmittable() {
p.makeNextPacketRetransmittable = true
}

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"math" "math"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/ackhandler" "github.com/lucas-clemente/quic-go/ackhandler"
"github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/handshake"
@ -49,29 +50,32 @@ func (m *mockCryptoSetup) GetSealerWithEncryptionLevel(protocol.EncryptionLevel)
} }
func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.divNonce } func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.divNonce }
func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce = divNonce } func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce = divNonce }
func (m *mockCryptoSetup) ConnectionState() ConnectionState { panic("not implemented") }
var _ = Describe("Packet packer", func() { var _ = Describe("Packet packer", func() {
var ( var (
packer *packetPacker packer *packetPacker
publicHeaderLen protocol.ByteCount publicHeaderLen protocol.ByteCount
maxFrameSize protocol.ByteCount maxFrameSize protocol.ByteCount
streamFramer *streamFramer cryptoStream cryptoStreamI
cryptoStream *stream mockStreamFramer *MockStreamFrameSource
) )
BeforeEach(func() { BeforeEach(func() {
version := versionGQUICFrames version := versionGQUICFrames
cryptoStream = &stream{streamID: version.CryptoStreamID(), flowController: flowcontrol.NewStreamFlowController(version.CryptoStreamID(), false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil)} mockSender := NewMockStreamSender(mockCtrl)
streamsMap := newStreamsMap(nil, protocol.PerspectiveServer, versionGQUICFrames) mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes()
streamFramer = newStreamFramer(cryptoStream, streamsMap, nil, versionGQUICFrames) cryptoStream = newCryptoStream(mockSender, flowcontrol.NewStreamFlowController(version.CryptoStreamID(), false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil), version)
mockStreamFramer = NewMockStreamFrameSource(mockCtrl)
packer = &packetPacker{ packer = newPacketPacker(
cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, 0x1337,
connectionID: 0x1337, 1,
packetNumberGenerator: newPacketNumberGenerator(1, protocol.SkipPacketAveragePeriodLength), &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure},
streamFramer: streamFramer, mockStreamFramer,
perspective: protocol.PerspectiveServer, protocol.PerspectiveServer,
} version,
)
publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number
maxFrameSize = protocol.MaxPacketSize - protocol.ByteCount((&mockSealer{}).Overhead()) - publicHeaderLen maxFrameSize = protocol.MaxPacketSize - protocol.ByteCount((&mockSealer{}).Overhead()) - publicHeaderLen
packer.hasSentPacket = true packer.hasSentPacket = true
@ -79,33 +83,36 @@ var _ = Describe("Packet packer", func() {
}) })
It("returns nil when no packet is queued", func() { It("returns nil when no packet is queued", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(p).To(BeNil()) Expect(p).To(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
It("packs single packets", func() { It("packs single packets", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
f := &wire.StreamFrame{ f := &wire.StreamFrame{
StreamID: 5, StreamID: 5,
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
} }
streamFramer.AddFrameForRetransmission(f) mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f})
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
b := &bytes.Buffer{} b := &bytes.Buffer{}
f.Write(b, packer.version) f.Write(b, packer.version)
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(Equal([]wire.Frame{f}))
Expect(p.raw).To(ContainSubstring(string(b.Bytes()))) Expect(p.raw).To(ContainSubstring(string(b.Bytes())))
}) })
It("stores the encryption level a packet was sealed with", func() { It("stores the encryption level a packet was sealed with", func() {
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionForwardSecure mockStreamFramer.EXPECT().HasCryptoStreamData()
f := &wire.StreamFrame{ mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{{
StreamID: 5, StreamID: 5,
Data: []byte("foobar"), Data: []byte("foobar"),
} }})
streamFramer.AddFrameForRetransmission(f) packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionForwardSecure
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure))
@ -213,7 +220,7 @@ var _ = Describe("Packet packer", func() {
}) })
}) })
It("packs a ConnectionClose", func() { It("packs a CONNECTION_CLOSE", func() {
ccf := wire.ConnectionCloseFrame{ ccf := wire.ConnectionCloseFrame{
ErrorCode: 0x1337, ErrorCode: 0x1337,
ReasonPhrase: "foobar", ReasonPhrase: "foobar",
@ -224,23 +231,21 @@ var _ = Describe("Packet packer", func() {
Expect(p.frames[0]).To(Equal(&ccf)) Expect(p.frames[0]).To(Equal(&ccf))
}) })
It("doesn't send any other frames when sending a ConnectionClose", func() { It("doesn't send any other frames when sending a CONNECTION_CLOSE", func() {
ccf := wire.ConnectionCloseFrame{ // expect no mockStreamFramer.PopStreamFrames
ccf := &wire.ConnectionCloseFrame{
ErrorCode: 0x1337, ErrorCode: 0x1337,
ReasonPhrase: "foobar", ReasonPhrase: "foobar",
} }
packer.controlFrames = []wire.Frame{&wire.MaxStreamDataFrame{StreamID: 37}} packer.controlFrames = []wire.Frame{&wire.MaxStreamDataFrame{StreamID: 37}}
streamFramer.AddFrameForRetransmission(&wire.StreamFrame{ p, err := packer.PackConnectionClose(ccf)
StreamID: 5,
Data: []byte("foobar"),
})
p, err := packer.PackConnectionClose(&ccf)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(Equal([]wire.Frame{ccf}))
Expect(p.frames[0]).To(Equal(&ccf))
}) })
It("packs only control frames", func() { It("packs only control frames", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
packer.QueueControlFrame(&wire.RstStreamFrame{}) packer.QueueControlFrame(&wire.RstStreamFrame{})
packer.QueueControlFrame(&wire.MaxDataFrame{}) packer.QueueControlFrame(&wire.MaxDataFrame{})
p, err := packer.PackPacket() p, err := packer.PackPacket()
@ -251,6 +256,8 @@ var _ = Describe("Packet packer", func() {
}) })
It("increases the packet number", func() { It("increases the packet number", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2)
packer.QueueControlFrame(&wire.RstStreamFrame{}) packer.QueueControlFrame(&wire.RstStreamFrame{})
p1, err := packer.PackPacket() p1, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -263,6 +270,8 @@ var _ = Describe("Packet packer", func() {
}) })
It("packs a STOP_WAITING frame first", func() { It("packs a STOP_WAITING frame first", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
packer.packetNumberGenerator.next = 15 packer.packetNumberGenerator.next = 15
swf := &wire.StopWaitingFrame{LeastUnacked: 10} swf := &wire.StopWaitingFrame{LeastUnacked: 10}
packer.QueueControlFrame(&wire.RstStreamFrame{}) packer.QueueControlFrame(&wire.RstStreamFrame{})
@ -275,6 +284,8 @@ var _ = Describe("Packet packer", func() {
}) })
It("sets the LeastUnackedDelta length of a STOP_WAITING frame", func() { It("sets the LeastUnackedDelta length of a STOP_WAITING frame", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number
packer.packetNumberGenerator.next = packetNumber packer.packetNumberGenerator.next = packetNumber
swf := &wire.StopWaitingFrame{LeastUnacked: packetNumber - 0x100} swf := &wire.StopWaitingFrame{LeastUnacked: packetNumber - 0x100}
@ -286,6 +297,8 @@ var _ = Describe("Packet packer", func() {
}) })
It("does not pack a packet containing only a STOP_WAITING frame", func() { It("does not pack a packet containing only a STOP_WAITING frame", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
swf := &wire.StopWaitingFrame{LeastUnacked: 10} swf := &wire.StopWaitingFrame{LeastUnacked: 10}
packer.QueueControlFrame(swf) packer.QueueControlFrame(swf)
p, err := packer.PackPacket() p, err := packer.PackPacket()
@ -294,6 +307,8 @@ var _ = Describe("Packet packer", func() {
}) })
It("packs a packet if it has queued control frames, but no new control frames", func() { It("packs a packet if it has queued control frames, but no new control frames", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}} packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}}
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -301,6 +316,7 @@ var _ = Describe("Packet packer", func() {
}) })
It("refuses to send a packet that doesn't contain crypto stream data, if it has never sent a packet before", func() { It("refuses to send a packet that doesn't contain crypto stream data, if it has never sent a packet before", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
packer.hasSentPacket = false packer.hasSentPacket = false
packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}} packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}}
p, err := packer.PackPacket() p, err := packer.PackPacket()
@ -329,8 +345,7 @@ var _ = Describe("Packet packer", func() {
It("packs a lot of control frames into 2 packets if they don't fit into one", func() { It("packs a lot of control frames into 2 packets if they don't fit into one", func() {
blockedFrame := &wire.BlockedFrame{} blockedFrame := &wire.BlockedFrame{}
minLength, _ := blockedFrame.MinLength(packer.version) maxFramesPerPacket := int(maxFrameSize) / int(blockedFrame.MinLength(packer.version))
maxFramesPerPacket := int(maxFrameSize) / int(minLength)
var controlFrames []wire.Frame var controlFrames []wire.Frame
for i := 0; i < maxFramesPerPacket+10; i++ { for i := 0; i < maxFramesPerPacket+10; i++ {
controlFrames = append(controlFrames, blockedFrame) controlFrames = append(controlFrames, blockedFrame)
@ -345,16 +360,17 @@ var _ = Describe("Packet packer", func() {
}) })
It("only increases the packet number when there is an actual packet to send", func() { It("only increases the packet number when there is an actual packet to send", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
packer.packetNumberGenerator.nextToSkip = 1000 packer.packetNumberGenerator.nextToSkip = 1000
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(p).To(BeNil()) Expect(p).To(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(1))) Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(1)))
f := &wire.StreamFrame{ mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{{
StreamID: 5, StreamID: 5,
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
} }})
streamFramer.AddFrameForRetransmission(f)
p, err = packer.PackPacket() p, err = packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
@ -362,320 +378,207 @@ var _ = Describe("Packet packer", func() {
Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(2))) Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(2)))
}) })
Context("STREAM Frame handling", func() { It("adds a PING frame when it's supposed to send a retransmittable packet", func() {
It("does not splits a STREAM frame with maximum size, for gQUIC frames", func() { mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
f := &wire.StreamFrame{ mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2)
Offset: 1, packer.QueueControlFrame(&wire.AckFrame{})
StreamID: 5, packer.QueueControlFrame(&wire.StopWaitingFrame{})
DataLenPresent: false, packer.MakeNextPacketRetransmittable()
} p, err := packer.PackPacket()
minLength, _ := f.MinLength(packer.version) Expect(p).ToNot(BeNil())
maxStreamFrameDataLen := maxFrameSize - minLength
f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen))
streamFramer.AddFrameForRetransmission(f)
payloadFrames, err := packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(1)) Expect(p.frames).To(HaveLen(3))
Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) Expect(p.frames).To(ContainElement(&wire.PingFrame{}))
payloadFrames, err = packer.composeNextPacket(maxFrameSize, true) // make sure the next packet doesn't contain another PING
packer.QueueControlFrame(&wire.AckFrame{})
p, err = packer.PackPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(BeEmpty()) Expect(p.frames).To(HaveLen(1))
}) })
It("does not splits a STREAM frame with maximum size, for IETF draft style frame", func() { It("waits until there's something to send before adding a PING frame", func() {
packer.version = versionIETFFrames mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
streamFramer.version = versionIETFFrames mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2)
packer.MakeNextPacketRetransmittable()
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil())
packer.QueueControlFrame(&wire.AckFrame{})
p, err = packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(2))
Expect(p.frames).To(ContainElement(&wire.PingFrame{}))
})
It("doesn't send a PING if it already sent another retransmittable frame", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2)
packer.MakeNextPacketRetransmittable()
packer.QueueControlFrame(&wire.MaxDataFrame{})
p, err := packer.PackPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
packer.QueueControlFrame(&wire.AckFrame{})
p, err = packer.PackPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
})
Context("STREAM frame handling", func() {
It("does not splits a STREAM frame with maximum size, for gQUIC frames", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).DoAndReturn(func(maxSize protocol.ByteCount) []*wire.StreamFrame {
f := &wire.StreamFrame{ f := &wire.StreamFrame{
Offset: 1, Offset: 1,
StreamID: 5, StreamID: 5,
DataLenPresent: true, DataLenPresent: true,
} }
minLength, _ := f.MinLength(packer.version) f.Data = bytes.Repeat([]byte{'f'}, int(maxSize-f.MinLength(packer.version)))
// for IETF draft style STREAM frames, we don't know the size of the DataLen, because it is a variable length integer return []*wire.StreamFrame{f}
// in the general case, we therefore use a STREAM frame that is 1 byte smaller than the maximum size
maxStreamFrameDataLen := maxFrameSize - minLength - 1
f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen))
streamFramer.AddFrameForRetransmission(f)
payloadFrames, err := packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(1))
Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(BeEmpty())
}) })
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
It("correctly handles a STREAM frame with one byte less than maximum size", func() {
maxStreamFrameDataLen := maxFrameSize - (1 + 1 + 2) - 1
f1 := &wire.StreamFrame{
StreamID: 5,
Offset: 1,
Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)),
}
f2 := &wire.StreamFrame{
StreamID: 5,
Offset: 1,
Data: []byte("foobar"),
}
streamFramer.AddFrameForRetransmission(f1)
streamFramer.AddFrameForRetransmission(f2)
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize - 1)))
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
p, err = packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
})
It("packs multiple small STREAM frames into single packet", func() {
f1 := &wire.StreamFrame{
StreamID: 5,
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
}
f2 := &wire.StreamFrame{
StreamID: 5,
Data: []byte{0xBE, 0xEF, 0x13, 0x37},
}
f3 := &wire.StreamFrame{
StreamID: 3,
Data: []byte{0xCA, 0xFE},
}
streamFramer.AddFrameForRetransmission(f1)
streamFramer.AddFrameForRetransmission(f2)
streamFramer.AddFrameForRetransmission(f3)
p, err := packer.PackPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
b := &bytes.Buffer{}
f1.Write(b, 0)
f2.Write(b, 0)
f3.Write(b, 0)
Expect(p.frames).To(HaveLen(3))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue())
Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeTrue())
Expect(p.frames[2].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
Expect(p.raw).To(ContainSubstring(string(f1.Data)))
Expect(p.raw).To(ContainSubstring(string(f2.Data)))
Expect(p.raw).To(ContainSubstring(string(f3.Data)))
})
It("splits one STREAM frame larger than maximum size", func() {
f := &wire.StreamFrame{
StreamID: 7,
Offset: 1,
}
minLength, _ := f.MinLength(packer.version)
maxStreamFrameDataLen := maxFrameSize - minLength
f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+200)
streamFramer.AddFrameForRetransmission(f)
payloadFrames, err := packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(1))
Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
Expect(payloadFrames[0].(*wire.StreamFrame).Data).To(HaveLen(int(maxStreamFrameDataLen)))
payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(1))
Expect(payloadFrames[0].(*wire.StreamFrame).Data).To(HaveLen(200))
Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(BeEmpty())
})
It("packs 2 STREAM frames that are too big for one packet correctly", func() {
maxStreamFrameDataLen := maxFrameSize - (1 + 1 + 2)
f1 := &wire.StreamFrame{
StreamID: 5,
Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+100),
Offset: 1,
}
f2 := &wire.StreamFrame{
StreamID: 5,
Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+100),
Offset: 1,
}
streamFramer.AddFrameForRetransmission(f1)
streamFramer.AddFrameForRetransmission(f2)
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
p, err = packer.PackPacket()
Expect(p.frames).To(HaveLen(2))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue())
Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
Expect(err).ToNot(HaveOccurred())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
p, err = packer.PackPacket()
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
p, err = packer.PackPacket() p, err = packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil()) Expect(p).To(BeNil())
}) })
It("packs a packet that has the maximum packet size when given a large enough STREAM frame", func() { It("does not splits a STREAM frame with maximum size, for IETF draft style frame", func() {
packer.version = versionIETFFrames
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).DoAndReturn(func(maxSize protocol.ByteCount) []*wire.StreamFrame {
f := &wire.StreamFrame{ f := &wire.StreamFrame{
StreamID: 5,
Offset: 1, Offset: 1,
StreamID: 5,
DataLenPresent: true,
} }
minLength, _ := f.MinLength(packer.version) f.Data = bytes.Repeat([]byte{'f'}, int(maxSize-f.MinLength(packer.version)))
f.Data = bytes.Repeat([]byte{'f'}, int(maxFrameSize-minLength+1)) // + 1 since MinceLength is 1 bigger than the actual StreamFrame header return []*wire.StreamFrame{f}
streamFramer.AddFrameForRetransmission(f) })
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p.frames).To(HaveLen(1))
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
p, err = packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil())
}) })
It("splits a STREAM frame larger than the maximum size", func() { It("packs multiple small STREAM frames into single packet", func() {
f := &wire.StreamFrame{ f1 := &wire.StreamFrame{
StreamID: 5, StreamID: 5,
Offset: 1, Data: []byte("frame 1"),
DataLenPresent: true,
} }
minLength, _ := f.MinLength(packer.version) f2 := &wire.StreamFrame{
f.Data = bytes.Repeat([]byte{'f'}, int(maxFrameSize-minLength+2)) // + 2 since MinceLength is 1 bigger than the actual StreamFrame header StreamID: 5,
Data: []byte("frame 2"),
streamFramer.AddFrameForRetransmission(f) DataLenPresent: true,
payloadFrames, err := packer.composeNextPacket(maxFrameSize, true) }
f3 := &wire.StreamFrame{
StreamID: 3,
Data: []byte("frame 3"),
DataLenPresent: true,
}
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f1, f2, f3})
p, err := packer.PackPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(1)) Expect(p.frames).To(HaveLen(3))
payloadFrames, err = packer.composeNextPacket(maxFrameSize, true) Expect(p.frames[0].(*wire.StreamFrame).Data).To(Equal([]byte("frame 1")))
Expect(err).ToNot(HaveOccurred()) Expect(p.frames[1].(*wire.StreamFrame).Data).To(Equal([]byte("frame 2")))
Expect(payloadFrames).To(HaveLen(1)) Expect(p.frames[2].(*wire.StreamFrame).Data).To(Equal([]byte("frame 3")))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue())
Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeTrue())
Expect(p.frames[2].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
}) })
It("refuses to send unencrypted stream data on a data stream", func() { It("refuses to send unencrypted stream data on a data stream", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
// don't expect a call to mockStreamFramer.PopStreamFrames
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted
f := &wire.StreamFrame{
StreamID: 3,
Data: []byte("foobar"),
}
streamFramer.AddFrameForRetransmission(f)
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(p).To(BeNil()) Expect(p).To(BeNil())
}) })
It("sends non forward-secure data as the client", func() { It("sends non forward-secure data as the client", func() {
packer.perspective = protocol.PerspectiveClient
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure
f := &wire.StreamFrame{ f := &wire.StreamFrame{
StreamID: 5, StreamID: 5,
Data: []byte("foobar"), Data: []byte("foobar"),
} }
streamFramer.AddFrameForRetransmission(f) mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f})
packer.perspective = protocol.PerspectiveClient
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure))
Expect(p.frames[0]).To(Equal(f)) Expect(p.frames).To(Equal([]wire.Frame{f}))
}) })
It("does not send non forward-secure data as the server", func() { It("does not send non forward-secure data as the server", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
// don't expect a call to mockStreamFramer.PopStreamFrames
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure
f := &wire.StreamFrame{
StreamID: 5,
Data: []byte("foobar"),
}
streamFramer.AddFrameForRetransmission(f)
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil()) Expect(p).To(BeNil())
}) })
It("sends unencrypted stream data on the crypto stream", func() { It("sends unencrypted stream data on the crypto stream", func() {
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted f := &wire.StreamFrame{
cryptoStream.dataForWriting = []byte("foobar")
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted))
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0]).To(Equal(&wire.StreamFrame{
StreamID: packer.version.CryptoStreamID(), StreamID: packer.version.CryptoStreamID(),
Data: []byte("foobar"), Data: []byte("foobar"),
})) }
mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true)
mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f)
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(Equal([]wire.Frame{f}))
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted))
}) })
It("sends encrypted stream data on the crypto stream", func() { It("sends encrypted stream data on the crypto stream", func() {
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionSecure f := &wire.StreamFrame{
cryptoStream.dataForWriting = []byte("foobar")
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure))
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0]).To(Equal(&wire.StreamFrame{
StreamID: packer.version.CryptoStreamID(), StreamID: packer.version.CryptoStreamID(),
Data: []byte("foobar"), Data: []byte("foobar"),
})) }
mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true)
mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f)
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionSecure
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(Equal([]wire.Frame{f}))
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure))
}) })
It("does not pack stream frames if not allowed", func() { It("does not pack STREAM frames if not allowed", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
// don't expect a call to mockStreamFramer.PopStreamFrames
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted
packer.QueueControlFrame(&wire.AckFrame{}) ack := &wire.AckFrame{LargestAcked: 10}
streamFramer.AddFrameForRetransmission(&wire.StreamFrame{StreamID: 3, Data: []byte("foobar")}) packer.QueueControlFrame(ack)
p, err := packer.PackPacket() p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(Equal([]wire.Frame{ack}))
Expect(func() { _ = p.frames[0].(*wire.AckFrame) }).NotTo(Panic())
}) })
}) })
Context("BLOCKED frames", func() {
It("queues a BLOCKED frame", func() {
length := 100
streamFramer.blockedFrameQueue = []wire.Frame{&wire.StreamBlockedFrame{StreamID: 5}}
f := &wire.StreamFrame{
StreamID: 5,
Data: bytes.Repeat([]byte{'f'}, length),
}
streamFramer.AddFrameForRetransmission(f)
_, err := packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(packer.controlFrames[0]).To(Equal(&wire.StreamBlockedFrame{StreamID: 5}))
})
It("removes the dataLen attribute from the last StreamFrame, even if it queued a BLOCKED frame", func() {
length := 100
streamFramer.blockedFrameQueue = []wire.Frame{&wire.StreamBlockedFrame{StreamID: 5}}
f := &wire.StreamFrame{
StreamID: 5,
Data: bytes.Repeat([]byte{'f'}, length),
}
streamFramer.AddFrameForRetransmission(f)
p, err := packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(p).To(HaveLen(1))
Expect(p[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
})
It("packs a connection-level BlockedFrame", func() {
streamFramer.blockedFrameQueue = []wire.Frame{&wire.BlockedFrame{}}
f := &wire.StreamFrame{
StreamID: 5,
Data: []byte("foobar"),
}
streamFramer.AddFrameForRetransmission(f)
_, err := packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(packer.controlFrames[0]).To(Equal(&wire.BlockedFrame{}))
})
})
It("returns nil if we only have a single STOP_WAITING", func() {
packer.QueueControlFrame(&wire.StopWaitingFrame{})
p, err := packer.PackPacket()
Expect(err).NotTo(HaveOccurred())
Expect(p).To(BeNil())
})
It("packs a single ACK", func() { It("packs a single ACK", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
ack := &wire.AckFrame{LargestAcked: 42} ack := &wire.AckFrame{LargestAcked: 42}
packer.QueueControlFrame(ack) packer.QueueControlFrame(ack)
p, err := packer.PackPacket() p, err := packer.PackPacket()
@ -685,6 +588,8 @@ var _ = Describe("Packet packer", func() {
}) })
It("does not return nil if we only have a single ACK but request it to be sent", func() { It("does not return nil if we only have a single ACK but request it to be sent", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
ack := &wire.AckFrame{} ack := &wire.AckFrame{}
packer.QueueControlFrame(ack) packer.QueueControlFrame(ack)
p, err := packer.PackPacket() p, err := packer.PackPacket()
@ -692,15 +597,6 @@ var _ = Describe("Packet packer", func() {
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
}) })
It("queues a control frame to be sent in the next packet", func() {
msd := &wire.MaxStreamDataFrame{StreamID: 5}
packer.QueueControlFrame(msd)
p, err := packer.PackPacket()
Expect(err).NotTo(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0]).To(Equal(msd))
})
Context("retransmitting of handshake packets", func() { Context("retransmitting of handshake packets", func() {
swf := &wire.StopWaitingFrame{LeastUnacked: 1} swf := &wire.StopWaitingFrame{LeastUnacked: 1}
sf := &wire.StreamFrame{ sf := &wire.StreamFrame{
@ -719,8 +615,19 @@ var _ = Describe("Packet packer", func() {
} }
p, err := packer.PackHandshakeRetransmission(packet) p, err := packer.PackHandshakeRetransmission(packet)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(ContainElement(sf)) Expect(p.frames).To(Equal([]wire.Frame{swf, sf}))
Expect(p.frames).To(ContainElement(swf)) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted))
})
It("doesn't add a STOP_WAITING frame for IETF QUIC", func() {
packer.version = versionIETFFrames
packet := &ackhandler.Packet{
EncryptionLevel: protocol.EncryptionUnencrypted,
Frames: []wire.Frame{sf},
}
p, err := packer.PackHandshakeRetransmission(packet)
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(Equal([]wire.Frame{sf}))
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted))
}) })
@ -733,8 +640,7 @@ var _ = Describe("Packet packer", func() {
} }
p, err := packer.PackHandshakeRetransmission(packet) p, err := packer.PackHandshakeRetransmission(packet)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(ContainElement(sf)) Expect(p.frames).To(Equal([]wire.Frame{swf, sf}))
Expect(p.frames).To(ContainElement(swf))
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure))
// a packet sent by the server with initial encryption contains the SHLO // a packet sent by the server with initial encryption contains the SHLO
// it needs to have a diversification nonce // it needs to have a diversification nonce
@ -768,11 +674,16 @@ var _ = Describe("Packet packer", func() {
}) })
It("pads Initial packets to the required minimum packet size", func() { It("pads Initial packets to the required minimum packet size", func() {
f := &wire.StreamFrame{
StreamID: packer.version.CryptoStreamID(),
Data: []byte("foobar"),
}
mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true)
mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f)
packer.version = protocol.VersionTLS packer.version = protocol.VersionTLS
packer.hasSentPacket = false packer.hasSentPacket = false
packer.perspective = protocol.PerspectiveClient packer.perspective = protocol.PerspectiveClient
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted
cryptoStream.dataForWriting = []byte("foobar")
packet, err := packer.PackPacket() packet, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.raw).To(HaveLen(protocol.MinInitialPacketSize)) Expect(packet.raw).To(HaveLen(protocol.MinInitialPacketSize))
@ -795,7 +706,7 @@ var _ = Describe("Packet packer", func() {
_, err := packer.PackHandshakeRetransmission(&ackhandler.Packet{ _, err := packer.PackHandshakeRetransmission(&ackhandler.Packet{
EncryptionLevel: protocol.EncryptionSecure, EncryptionLevel: protocol.EncryptionSecure,
}) })
Expect(err).To(MatchError("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame")) Expect(err).To(MatchError("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame"))
}) })
}) })
@ -807,7 +718,7 @@ var _ = Describe("Packet packer", func() {
Expect(p.frames).To(Equal([]wire.Frame{&wire.AckFrame{DelayTime: math.MaxInt64}})) Expect(p.frames).To(Equal([]wire.Frame{&wire.AckFrame{DelayTime: math.MaxInt64}}))
}) })
It("packs ACK packets with SWFs", func() { It("packs ACK packets with STOP_WAITING frames", func() {
packer.QueueControlFrame(&wire.AckFrame{}) packer.QueueControlFrame(&wire.AckFrame{})
packer.QueueControlFrame(&wire.StopWaitingFrame{}) packer.QueueControlFrame(&wire.StopWaitingFrame{})
p, err := packer.PackAckPacket() p, err := packer.PackAckPacket()

View File

@ -107,21 +107,32 @@ func (u *packetUnpacker) parseIETFFrame(r *bytes.Reader, typeByte byte, hdr *wir
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
} }
case 0x6: case 0x6:
// TODO(#964): remove STOP_WAITING frames frame, err = wire.ParseMaxStreamIDFrame(r, u.version)
// TODO(#878): implement the MAX_STREAM_ID frame
frame, err = wire.ParseStopWaitingFrame(r, hdr.PacketNumber, hdr.PacketNumberLen, u.version)
if err != nil { if err != nil {
err = qerr.Error(qerr.InvalidStopWaitingData, err.Error()) err = qerr.Error(qerr.InvalidFrameData, err.Error())
} }
case 0x7: case 0x7:
frame, err = wire.ParsePingFrame(r, u.version) frame, err = wire.ParsePingFrame(r, u.version)
case 0x8: case 0x8:
frame, err = wire.ParseBlockedFrame(r, u.version) frame, err = wire.ParseBlockedFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
}
case 0x9: case 0x9:
frame, err = wire.ParseStreamBlockedFrame(r, u.version) frame, err = wire.ParseStreamBlockedFrame(r, u.version)
if err != nil { if err != nil {
err = qerr.Error(qerr.InvalidBlockedData, err.Error()) err = qerr.Error(qerr.InvalidBlockedData, err.Error())
} }
case 0xa:
frame, err = wire.ParseStreamIDBlockedFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidFrameData, err.Error())
}
case 0xc:
frame, err = wire.ParseStopSendingFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidFrameData, err.Error())
}
case 0xe: case 0xe:
frame, err = wire.ParseAckFrame(r, u.version) frame, err = wire.ParseAckFrame(r, u.version)
if err != nil { if err != nil {

View File

@ -102,7 +102,7 @@ var _ = Describe("Packet unpacker", func() {
f := &wire.RstStreamFrame{ f := &wire.RstStreamFrame{
StreamID: 0xdeadbeef, StreamID: 0xdeadbeef,
ByteOffset: 0xdecafbad11223344, ByteOffset: 0xdecafbad11223344,
ErrorCode: 0x13371234, ErrorCode: 0x1337,
} }
err := f.Write(buf, versionGQUICFrames) err := f.Write(buf, versionGQUICFrames)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -342,8 +342,19 @@ var _ = Describe("Packet unpacker", func() {
Expect(packet.frames).To(Equal([]wire.Frame{f})) Expect(packet.frames).To(Equal([]wire.Frame{f}))
}) })
It("unpacks MAX_STREAM_ID frames", func() {
f := &wire.MaxStreamIDFrame{StreamID: 0x1337}
buf := &bytes.Buffer{}
err := f.Write(buf, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
setData(buf.Bytes())
packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]wire.Frame{f}))
})
It("unpacks connection-level BLOCKED frames", func() { It("unpacks connection-level BLOCKED frames", func() {
f := &wire.BlockedFrame{} f := &wire.BlockedFrame{Offset: 0x1234}
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := f.Write(buf, versionIETFFrames) err := f.Write(buf, versionIETFFrames)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -354,7 +365,32 @@ var _ = Describe("Packet unpacker", func() {
}) })
It("unpacks stream-level BLOCKED frames", func() { It("unpacks stream-level BLOCKED frames", func() {
f := &wire.StreamBlockedFrame{StreamID: 0xdeadbeef} f := &wire.StreamBlockedFrame{
StreamID: 0xdeadbeef,
Offset: 0xdead,
}
buf := &bytes.Buffer{}
err := f.Write(buf, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
setData(buf.Bytes())
packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]wire.Frame{f}))
})
It("unpacks STREAM_ID_BLOCKED frames", func() {
f := &wire.StreamIDBlockedFrame{StreamID: 0x1234567}
buf := &bytes.Buffer{}
err := f.Write(buf, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
setData(buf.Bytes())
packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]wire.Frame{f}))
})
It("unpacks STOP_SENDING frames", func() {
f := &wire.StopSendingFrame{StreamID: 0x42}
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := f.Write(buf, versionIETFFrames) err := f.Write(buf, versionIETFFrames)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -392,9 +428,13 @@ var _ = Describe("Packet unpacker", func() {
0x02: qerr.InvalidConnectionCloseData, 0x02: qerr.InvalidConnectionCloseData,
0x04: qerr.InvalidWindowUpdateData, 0x04: qerr.InvalidWindowUpdateData,
0x05: qerr.InvalidWindowUpdateData, 0x05: qerr.InvalidWindowUpdateData,
0x06: qerr.InvalidFrameData,
0x08: qerr.InvalidBlockedData,
0x09: qerr.InvalidBlockedData, 0x09: qerr.InvalidBlockedData,
0x0a: qerr.InvalidFrameData,
0x0c: qerr.InvalidFrameData,
0x0e: qerr.InvalidAckData,
0x10: qerr.InvalidStreamData, 0x10: qerr.InvalidStreamData,
0xe: qerr.InvalidAckData,
} { } {
setData([]byte{b}) setData([]byte{b})
_, err := unpacker.Unpack(hdrBin, hdr, data) _, err := unpacker.Unpack(hdrBin, hdr, data)

View File

@ -1,8 +1,8 @@
// Code generated by "stringer -type=ErrorCode"; DO NOT EDIT // Code generated by "stringer -type=ErrorCode"; DO NOT EDIT.
package qerr package qerr
import "fmt" import "strconv"
const ( const (
_ErrorCode_name_0 = "InternalErrorStreamDataAfterTerminationInvalidPacketHeaderInvalidFrameDataInvalidFecDataInvalidRstStreamDataInvalidConnectionCloseDataInvalidGoawayDataInvalidAckDataInvalidVersionNegotiationPacketInvalidPublicRstPacketDecryptionFailureEncryptionFailurePacketTooLarge" _ErrorCode_name_0 = "InternalErrorStreamDataAfterTerminationInvalidPacketHeaderInvalidFrameDataInvalidFecDataInvalidRstStreamDataInvalidConnectionCloseDataInvalidGoawayDataInvalidAckDataInvalidVersionNegotiationPacketInvalidPublicRstPacketDecryptionFailureEncryptionFailurePacketTooLarge"
@ -19,7 +19,6 @@ var (
_ErrorCode_index_2 = [...]uint16{0, 15, 37, 57, 75, 96, 112, 127, 147, 167, 191, 226, 250, 279, 309, 340, 366, 385, 410, 425, 445, 457, 475, 505, 530, 547} _ErrorCode_index_2 = [...]uint16{0, 15, 37, 57, 75, 96, 112, 127, 147, 167, 191, 226, 250, 279, 309, 340, 366, 385, 410, 425, 445, 457, 475, 505, 530, 547}
_ErrorCode_index_3 = [...]uint16{0, 14, 29, 50, 65, 90, 119, 158, 184, 208, 231, 249, 279, 301, 322, 340, 366, 390, 425} _ErrorCode_index_3 = [...]uint16{0, 14, 29, 50, 65, 90, 119, 158, 184, 208, 231, 249, 279, 301, 322, 340, 366, 390, 425}
_ErrorCode_index_4 = [...]uint16{0, 16, 45, 78, 97, 114, 144, 169, 192, 215, 238, 256, 276, 292, 308, 346, 379, 410, 448, 459, 477, 498, 532} _ErrorCode_index_4 = [...]uint16{0, 16, 45, 78, 97, 114, 144, 169, 192, 215, 238, 256, 276, 292, 308, 346, 379, 410, 448, 459, 477, 498, 532}
_ErrorCode_index_5 = [...]uint8{0, 34}
) )
func (i ErrorCode) String() string { func (i ErrorCode) String() string {
@ -42,6 +41,6 @@ func (i ErrorCode) String() string {
case i == 97: case i == 97:
return _ErrorCode_name_5 return _ErrorCode_name_5
default: default:
return fmt.Sprintf("ErrorCode(%d)", i) return "ErrorCode(" + strconv.FormatInt(int64(i), 10) + ")"
} }
} }

View File

@ -0,0 +1,284 @@
package quic
import (
"fmt"
"io"
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type receiveStreamI interface {
ReceiveStream
handleStreamFrame(*wire.StreamFrame) error
handleRstStreamFrame(*wire.RstStreamFrame) error
closeForShutdown(error)
getWindowUpdate() protocol.ByteCount
}
type receiveStream struct {
mutex sync.Mutex
streamID protocol.StreamID
sender streamSender
frameQueue *streamFrameSorter
readPosInFrame int
readOffset protocol.ByteCount
closeForShutdownErr error
cancelReadErr error
resetRemotelyErr StreamError
closedForShutdown bool // set when CloseForShutdown() is called
finRead bool // set once we read a frame with a FinBit
canceledRead bool // set when CancelRead() is called
resetRemotely bool // set when HandleRstStreamFrame() is called
readChan chan struct{}
readDeadline time.Time
flowController flowcontrol.StreamFlowController
version protocol.VersionNumber
}
var _ ReceiveStream = &receiveStream{}
var _ receiveStreamI = &receiveStream{}
func newReceiveStream(
streamID protocol.StreamID,
sender streamSender,
flowController flowcontrol.StreamFlowController,
) *receiveStream {
return &receiveStream{
streamID: streamID,
sender: sender,
flowController: flowController,
frameQueue: newStreamFrameSorter(),
readChan: make(chan struct{}, 1),
}
}
func (s *receiveStream) StreamID() protocol.StreamID {
return s.streamID
}
// Read implements io.Reader. It is not thread safe!
func (s *receiveStream) Read(p []byte) (int, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.finRead {
return 0, io.EOF
}
if s.canceledRead {
return 0, s.cancelReadErr
}
if s.resetRemotely {
return 0, s.resetRemotelyErr
}
if s.closedForShutdown {
return 0, s.closeForShutdownErr
}
bytesRead := 0
for bytesRead < len(p) {
frame := s.frameQueue.Head()
if frame == nil && bytesRead > 0 {
return bytesRead, s.closeForShutdownErr
}
for {
// Stop waiting on errors
if s.closedForShutdown {
return bytesRead, s.closeForShutdownErr
}
if s.canceledRead {
return bytesRead, s.cancelReadErr
}
if s.resetRemotely {
return bytesRead, s.resetRemotelyErr
}
deadline := s.readDeadline
if !deadline.IsZero() && !time.Now().Before(deadline) {
return bytesRead, errDeadline
}
if frame != nil {
s.readPosInFrame = int(s.readOffset - frame.Offset)
break
}
s.mutex.Unlock()
if deadline.IsZero() {
<-s.readChan
} else {
select {
case <-s.readChan:
case <-time.After(deadline.Sub(time.Now())):
}
}
s.mutex.Lock()
frame = s.frameQueue.Head()
}
if bytesRead > len(p) {
return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p))
}
if s.readPosInFrame > int(frame.DataLen()) {
return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, frame.DataLen())
}
s.mutex.Unlock()
copy(p[bytesRead:], frame.Data[s.readPosInFrame:])
m := utils.Min(len(p)-bytesRead, int(frame.DataLen())-s.readPosInFrame)
s.readPosInFrame += m
bytesRead += m
s.readOffset += protocol.ByteCount(m)
s.mutex.Lock()
// when a RST_STREAM was received, the was already informed about the final byteOffset for this stream
if !s.resetRemotely {
s.flowController.AddBytesRead(protocol.ByteCount(m))
}
// this call triggers the flow controller to increase the flow control window, if necessary
if s.flowController.HasWindowUpdate() {
s.sender.onHasWindowUpdate(s.streamID)
}
if s.readPosInFrame >= int(frame.DataLen()) {
s.frameQueue.Pop()
s.finRead = frame.FinBit
if frame.FinBit {
s.sender.onStreamCompleted(s.streamID)
return bytesRead, io.EOF
}
}
}
return bytesRead, nil
}
func (s *receiveStream) CancelRead(errorCode protocol.ApplicationErrorCode) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.finRead {
return nil
}
if s.canceledRead {
return nil
}
s.canceledRead = true
s.cancelReadErr = fmt.Errorf("Read on stream %d canceled with error code %d", s.streamID, errorCode)
s.signalRead()
if s.version.UsesIETFFrameFormat() {
s.sender.queueControlFrame(&wire.StopSendingFrame{
StreamID: s.streamID,
ErrorCode: errorCode,
})
}
return nil
}
func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error {
maxOffset := frame.Offset + frame.DataLen()
if err := s.flowController.UpdateHighestReceived(maxOffset, frame.FinBit); err != nil {
return err
}
s.mutex.Lock()
defer s.mutex.Unlock()
if err := s.frameQueue.Push(frame); err != nil && err != errDuplicateStreamData {
return err
}
s.signalRead()
return nil
}
func (s *receiveStream) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.closedForShutdown {
return nil
}
if err := s.flowController.UpdateHighestReceived(frame.ByteOffset, true); err != nil {
return err
}
// In gQUIC, error code 0 has a special meaning.
// The peer will reliably continue transmitting, but is not interested in reading from the stream.
// We should therefore just continue reading from the stream, until we encounter the FIN bit.
if !s.version.UsesIETFFrameFormat() && frame.ErrorCode == 0 {
return nil
}
// ignore duplicate RST_STREAM frames for this stream (after checking their final offset)
if s.resetRemotely {
return nil
}
s.resetRemotely = true
s.resetRemotelyErr = streamCanceledError{
errorCode: frame.ErrorCode,
error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode),
}
s.signalRead()
s.sender.onStreamCompleted(s.streamID)
return nil
}
func (s *receiveStream) CloseRemote(offset protocol.ByteCount) {
s.handleStreamFrame(&wire.StreamFrame{FinBit: true, Offset: offset})
}
func (s *receiveStream) onClose(offset protocol.ByteCount) {
if s.canceledRead && !s.version.UsesIETFFrameFormat() {
s.sender.queueControlFrame(&wire.RstStreamFrame{
StreamID: s.streamID,
ByteOffset: offset,
ErrorCode: 0,
})
}
}
func (s *receiveStream) SetReadDeadline(t time.Time) error {
s.mutex.Lock()
oldDeadline := s.readDeadline
s.readDeadline = t
s.mutex.Unlock()
// if the new deadline is before the currently set deadline, wake up Read()
if t.Before(oldDeadline) {
s.signalRead()
}
return nil
}
// CloseForShutdown closes a stream abruptly.
// It makes Read unblock (and return the error) immediately.
// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
func (s *receiveStream) closeForShutdown(err error) {
s.mutex.Lock()
s.closedForShutdown = true
s.closeForShutdownErr = err
s.mutex.Unlock()
s.signalRead()
}
func (s *receiveStream) getWindowUpdate() protocol.ByteCount {
return s.flowController.GetWindowUpdate()
}
// signalRead performs a non-blocking send on the readChan
func (s *receiveStream) signalRead() {
select {
case s.readChan <- struct{}{}:
default:
}
}

View File

@ -0,0 +1,648 @@
package quic
import (
"errors"
"io"
"runtime"
"time"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/gbytes"
)
var _ = Describe("Receive Stream", func() {
const streamID protocol.StreamID = 1337
var (
str *receiveStream
strWithTimeout io.Reader // str wrapped with gbytes.TimeoutReader
mockFC *mocks.MockStreamFlowController
mockSender *MockStreamSender
)
BeforeEach(func() {
mockSender = NewMockStreamSender(mockCtrl)
mockFC = mocks.NewMockStreamFlowController(mockCtrl)
str = newReceiveStream(streamID, mockSender, mockFC)
timeout := scaleDuration(250 * time.Millisecond)
strWithTimeout = gbytes.TimeoutReader(str, timeout)
})
It("gets stream id", func() {
Expect(str.StreamID()).To(Equal(protocol.StreamID(1337)))
})
Context("reading", func() {
It("reads a single STREAM frame", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4))
mockFC.EXPECT().HasWindowUpdate()
frame := wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
}
err := str.handleStreamFrame(&frame)
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
})
It("reads a single STREAM frame in multiple goes", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
mockFC.EXPECT().HasWindowUpdate().Times(2)
frame := wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
}
err := str.handleStreamFrame(&frame)
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 2)
n, err := strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(2))
Expect(b).To(Equal([]byte{0xDE, 0xAD}))
n, err = strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(2))
Expect(b).To(Equal([]byte{0xBE, 0xEF}))
})
It("reads all data available", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2)
mockFC.EXPECT().HasWindowUpdate().Times(2)
frame1 := wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD},
}
frame2 := wire.StreamFrame{
Offset: 2,
Data: []byte{0xBE, 0xEF},
}
err := str.handleStreamFrame(&frame1)
Expect(err).ToNot(HaveOccurred())
err = str.handleStreamFrame(&frame2)
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 6)
n, err := strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF, 0x00, 0x00}))
})
It("assembles multiple STREAM frames", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2)
mockFC.EXPECT().HasWindowUpdate().Times(2)
frame1 := wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD},
}
frame2 := wire.StreamFrame{
Offset: 2,
Data: []byte{0xBE, 0xEF},
}
err := str.handleStreamFrame(&frame1)
Expect(err).ToNot(HaveOccurred())
err = str.handleStreamFrame(&frame2)
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
})
It("waits until data is available", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
mockFC.EXPECT().HasWindowUpdate()
go func() {
defer GinkgoRecover()
frame := wire.StreamFrame{Data: []byte{0xDE, 0xAD}}
time.Sleep(10 * time.Millisecond)
err := str.handleStreamFrame(&frame)
Expect(err).ToNot(HaveOccurred())
}()
b := make([]byte, 2)
n, err := strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(2))
})
It("handles STREAM frames in wrong order", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2)
mockFC.EXPECT().HasWindowUpdate().Times(2)
frame1 := wire.StreamFrame{
Offset: 2,
Data: []byte{0xBE, 0xEF},
}
frame2 := wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD},
}
err := str.handleStreamFrame(&frame1)
Expect(err).ToNot(HaveOccurred())
err = str.handleStreamFrame(&frame2)
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
})
It("ignores duplicate STREAM frames", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2)
mockFC.EXPECT().HasWindowUpdate().Times(2)
frame1 := wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD},
}
frame2 := wire.StreamFrame{
Offset: 0,
Data: []byte{0x13, 0x37},
}
frame3 := wire.StreamFrame{
Offset: 2,
Data: []byte{0xBE, 0xEF},
}
err := str.handleStreamFrame(&frame1)
Expect(err).ToNot(HaveOccurred())
err = str.handleStreamFrame(&frame2)
Expect(err).ToNot(HaveOccurred())
err = str.handleStreamFrame(&frame3)
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
})
It("doesn't rejects a STREAM frames with an overlapping data range", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4))
mockFC.EXPECT().HasWindowUpdate().Times(2)
frame1 := wire.StreamFrame{
Offset: 0,
Data: []byte("foob"),
}
frame2 := wire.StreamFrame{
Offset: 2,
Data: []byte("obar"),
}
err := str.handleStreamFrame(&frame1)
Expect(err).ToNot(HaveOccurred())
err = str.handleStreamFrame(&frame2)
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 6)
n, err := strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
Expect(b).To(Equal([]byte("foobar")))
})
It("passes on errors from the streamFrameSorter", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), false)
err := str.handleStreamFrame(&wire.StreamFrame{StreamID: streamID}) // STREAM frame without data
Expect(err).To(MatchError(errEmptyStreamData))
})
It("calls the onHasWindowUpdate callback, when the a MAX_STREAM_DATA should be sent", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6))
mockFC.EXPECT().HasWindowUpdate().Return(true)
mockSender.EXPECT().onHasWindowUpdate(streamID)
frame1 := wire.StreamFrame{
Offset: 0,
Data: []byte("foobar"),
}
err := str.handleStreamFrame(&frame1)
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 6)
_, err = strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
})
Context("deadlines", func() {
It("the deadline error has the right net.Error properties", func() {
Expect(errDeadline.Temporary()).To(BeTrue())
Expect(errDeadline.Timeout()).To(BeTrue())
Expect(errDeadline).To(MatchError("deadline exceeded"))
})
It("returns an error when Read is called after the deadline", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false).AnyTimes()
f := &wire.StreamFrame{Data: []byte("foobar")}
err := str.handleStreamFrame(f)
Expect(err).ToNot(HaveOccurred())
str.SetReadDeadline(time.Now().Add(-time.Second))
b := make([]byte, 6)
n, err := strWithTimeout.Read(b)
Expect(err).To(MatchError(errDeadline))
Expect(n).To(BeZero())
})
It("unblocks after the deadline", func() {
deadline := time.Now().Add(scaleDuration(50 * time.Millisecond))
str.SetReadDeadline(deadline)
b := make([]byte, 6)
n, err := strWithTimeout.Read(b)
Expect(err).To(MatchError(errDeadline))
Expect(n).To(BeZero())
Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(10*time.Millisecond)))
})
It("doesn't unblock if the deadline is changed before the first one expires", func() {
deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond))
deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond))
str.SetReadDeadline(deadline1)
go func() {
defer GinkgoRecover()
time.Sleep(scaleDuration(20 * time.Millisecond))
str.SetReadDeadline(deadline2)
// make sure that this was actually execute before the deadline expires
Expect(time.Now()).To(BeTemporally("<", deadline1))
}()
runtime.Gosched()
b := make([]byte, 10)
n, err := strWithTimeout.Read(b)
Expect(err).To(MatchError(errDeadline))
Expect(n).To(BeZero())
Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond)))
})
It("unblocks earlier, when a new deadline is set", func() {
deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond))
deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond))
go func() {
defer GinkgoRecover()
time.Sleep(scaleDuration(10 * time.Millisecond))
str.SetReadDeadline(deadline2)
// make sure that this was actually execute before the deadline expires
Expect(time.Now()).To(BeTemporally("<", deadline2))
}()
str.SetReadDeadline(deadline1)
runtime.Gosched()
b := make([]byte, 10)
_, err := strWithTimeout.Read(b)
Expect(err).To(MatchError(errDeadline))
Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(25*time.Millisecond)))
})
})
Context("closing", func() {
Context("with FIN bit", func() {
It("returns EOFs", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4))
mockFC.EXPECT().HasWindowUpdate()
str.handleStreamFrame(&wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
FinBit: true,
})
mockSender.EXPECT().onStreamCompleted(streamID)
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).To(MatchError(io.EOF))
Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
n, err = strWithTimeout.Read(b)
Expect(n).To(BeZero())
Expect(err).To(MatchError(io.EOF))
})
It("handles out-of-order frames", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2)
mockFC.EXPECT().HasWindowUpdate().Times(2)
frame1 := wire.StreamFrame{
Offset: 2,
Data: []byte{0xBE, 0xEF},
FinBit: true,
}
frame2 := wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD},
}
err := str.handleStreamFrame(&frame1)
Expect(err).ToNot(HaveOccurred())
err = str.handleStreamFrame(&frame2)
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().onStreamCompleted(streamID)
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).To(MatchError(io.EOF))
Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
n, err = strWithTimeout.Read(b)
Expect(n).To(BeZero())
Expect(err).To(MatchError(io.EOF))
})
It("returns EOFs with partial read", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), true)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
mockFC.EXPECT().HasWindowUpdate()
err := str.handleStreamFrame(&wire.StreamFrame{
Offset: 0,
Data: []byte{0xde, 0xad},
FinBit: true,
})
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().onStreamCompleted(streamID)
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).To(MatchError(io.EOF))
Expect(n).To(Equal(2))
Expect(b[:n]).To(Equal([]byte{0xde, 0xad}))
})
It("handles immediate FINs", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0))
mockFC.EXPECT().HasWindowUpdate()
err := str.handleStreamFrame(&wire.StreamFrame{
Offset: 0,
FinBit: true,
})
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().onStreamCompleted(streamID)
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(n).To(BeZero())
Expect(err).To(MatchError(io.EOF))
})
})
It("closes when CloseRemote is called", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0))
mockFC.EXPECT().HasWindowUpdate()
str.CloseRemote(0)
mockSender.EXPECT().onStreamCompleted(streamID)
b := make([]byte, 8)
n, err := strWithTimeout.Read(b)
Expect(n).To(BeZero())
Expect(err).To(MatchError(io.EOF))
})
})
Context("closing for shutdown", func() {
testErr := errors.New("test error")
It("immediately returns all reads", func() {
done := make(chan struct{})
b := make([]byte, 4)
go func() {
defer GinkgoRecover()
n, err := strWithTimeout.Read(b)
Expect(n).To(BeZero())
Expect(err).To(MatchError(testErr))
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
str.closeForShutdown(testErr)
Eventually(done).Should(BeClosed())
})
It("errors for all following reads", func() {
str.closeForShutdown(testErr)
b := make([]byte, 1)
n, err := strWithTimeout.Read(b)
Expect(n).To(BeZero())
Expect(err).To(MatchError(testErr))
})
})
})
Context("stream cancelations", func() {
Context("canceling read", func() {
It("unblocks Read", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := strWithTimeout.Read([]byte{0})
Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234"))
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
err := str.CancelRead(1234)
Expect(err).ToNot(HaveOccurred())
Eventually(done).Should(BeClosed())
})
It("doesn't allow further calls to Read", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
err := str.CancelRead(1234)
Expect(err).ToNot(HaveOccurred())
_, err = strWithTimeout.Read([]byte{0})
Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234"))
})
It("does nothing when CancelRead is called twice", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
err := str.CancelRead(1234)
Expect(err).ToNot(HaveOccurred())
err = str.CancelRead(2345)
Expect(err).ToNot(HaveOccurred())
_, err = strWithTimeout.Read([]byte{0})
Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234"))
})
It("doesn't send a RST_STREAM frame, if the FIN was already read", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6))
mockFC.EXPECT().HasWindowUpdate()
// no calls to mockSender.queueControlFrame
err := str.handleStreamFrame(&wire.StreamFrame{
StreamID: streamID,
Data: []byte("foobar"),
FinBit: true,
})
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().onStreamCompleted(streamID)
_, err = strWithTimeout.Read(make([]byte, 100))
Expect(err).To(MatchError(io.EOF))
err = str.CancelRead(1234)
Expect(err).ToNot(HaveOccurred())
})
It("queues a STOP_SENDING frame, for IETF QUIC", func() {
mockSender.EXPECT().queueControlFrame(&wire.StopSendingFrame{
StreamID: streamID,
ErrorCode: 1234,
})
err := str.CancelRead(1234)
Expect(err).ToNot(HaveOccurred())
})
It("doesn't queue a STOP_SENDING frame, for gQUIC", func() {
})
})
Context("receiving RST_STREAM frames", func() {
rst := &wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 42,
ErrorCode: 1234,
}
It("unblocks Read", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := strWithTimeout.Read([]byte{0})
Expect(err).To(MatchError("Stream 1337 was reset with error code 1234"))
Expect(err).To(BeAssignableToTypeOf(streamCanceledError{}))
Expect(err.(streamCanceledError).Canceled()).To(BeTrue())
Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(1234)))
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
mockSender.EXPECT().onStreamCompleted(streamID)
str.handleRstStreamFrame(rst)
Eventually(done).Should(BeClosed())
})
It("doesn't allow further calls to Read", func() {
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
err := str.handleRstStreamFrame(rst)
Expect(err).ToNot(HaveOccurred())
_, err = strWithTimeout.Read([]byte{0})
Expect(err).To(MatchError("Stream 1337 was reset with error code 1234"))
Expect(err).To(BeAssignableToTypeOf(streamCanceledError{}))
Expect(err.(streamCanceledError).Canceled()).To(BeTrue())
Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(1234)))
})
It("errors when receiving a RST_STREAM with an inconsistent offset", func() {
testErr := errors.New("already received a different final offset before")
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Return(testErr)
err := str.handleRstStreamFrame(rst)
Expect(err).To(MatchError(testErr))
})
It("ignores duplicate RST_STREAM frames", func() {
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2)
err := str.handleRstStreamFrame(rst)
Expect(err).ToNot(HaveOccurred())
err = str.handleRstStreamFrame(rst)
Expect(err).ToNot(HaveOccurred())
})
It("doesn't do anyting when it was closed for shutdown", func() {
str.closeForShutdown(nil)
err := str.handleRstStreamFrame(rst)
Expect(err).ToNot(HaveOccurred())
})
Context("for gQUIC", func() {
BeforeEach(func() {
str.version = versionGQUICFrames
})
It("unblocks Read when receiving a RST_STREAM frame with non-zero error code", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
readReturned := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := strWithTimeout.Read([]byte{0})
Expect(err).To(MatchError("Stream 1337 was reset with error code 1234"))
Expect(err).To(BeAssignableToTypeOf(streamCanceledError{}))
Expect(err.(streamCanceledError).Canceled()).To(BeTrue())
Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(1234)))
close(readReturned)
}()
Consistently(readReturned).ShouldNot(BeClosed())
mockSender.EXPECT().onStreamCompleted(streamID)
err := str.handleRstStreamFrame(rst)
Expect(err).ToNot(HaveOccurred())
Eventually(readReturned).Should(BeClosed())
})
It("continues reading until the end when receiving a RST_STREAM frame with error code 0", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true).Times(2)
gomock.InOrder(
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)),
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)),
mockSender.EXPECT().onStreamCompleted(streamID),
)
mockFC.EXPECT().HasWindowUpdate().Times(2)
readReturned := make(chan struct{})
go func() {
defer GinkgoRecover()
n, err := strWithTimeout.Read(make([]byte, 4))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4))
n, err = strWithTimeout.Read(make([]byte, 4))
Expect(err).To(MatchError(io.EOF))
Expect(n).To(Equal(2))
close(readReturned)
}()
Consistently(readReturned).ShouldNot(BeClosed())
err := str.handleStreamFrame(&wire.StreamFrame{
StreamID: streamID,
Data: []byte("foobar"),
FinBit: true,
})
Expect(err).ToNot(HaveOccurred())
err = str.handleRstStreamFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 6,
ErrorCode: 0,
})
Expect(err).ToNot(HaveOccurred())
Eventually(readReturned).Should(BeClosed())
})
})
})
})
Context("flow control", func() {
It("errors when a STREAM frame causes a flow control violation", func() {
testErr := errors.New("flow control violation")
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), false).Return(testErr)
frame := wire.StreamFrame{
Offset: 2,
Data: []byte("foobar"),
}
err := str.handleStreamFrame(&frame)
Expect(err).To(MatchError(testErr))
})
It("gets a window update", func() {
mockFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x100))
Expect(str.getWindowUpdate()).To(Equal(protocol.ByteCount(0x100)))
})
})
})

313
vendor/github.com/lucas-clemente/quic-go/send_stream.go generated vendored Normal file
View File

@ -0,0 +1,313 @@
package quic
import (
"context"
"fmt"
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type sendStreamI interface {
SendStream
handleStopSendingFrame(*wire.StopSendingFrame)
popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool)
closeForShutdown(error)
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
}
type sendStream struct {
mutex sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
streamID protocol.StreamID
sender streamSender
writeOffset protocol.ByteCount
cancelWriteErr error
closeForShutdownErr error
closedForShutdown bool // set when CloseForShutdown() is called
finishedWriting bool // set once Close() is called
canceledWrite bool // set when CancelWrite() is called, or a STOP_SENDING frame is received
finSent bool // set when a STREAM_FRAME with FIN bit has b
dataForWriting []byte
writeChan chan struct{}
writeDeadline time.Time
flowController flowcontrol.StreamFlowController
version protocol.VersionNumber
}
var _ SendStream = &sendStream{}
var _ sendStreamI = &sendStream{}
func newSendStream(
streamID protocol.StreamID,
sender streamSender,
flowController flowcontrol.StreamFlowController,
version protocol.VersionNumber,
) *sendStream {
s := &sendStream{
streamID: streamID,
sender: sender,
flowController: flowController,
writeChan: make(chan struct{}, 1),
version: version,
}
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
return s
}
func (s *sendStream) StreamID() protocol.StreamID {
return s.streamID // same for receiveStream and sendStream
}
func (s *sendStream) Write(p []byte) (int, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.finishedWriting {
return 0, fmt.Errorf("write on closed stream %d", s.streamID)
}
if s.canceledWrite {
return 0, s.cancelWriteErr
}
if s.closeForShutdownErr != nil {
return 0, s.closeForShutdownErr
}
if !s.writeDeadline.IsZero() && !time.Now().Before(s.writeDeadline) {
return 0, errDeadline
}
if len(p) == 0 {
return 0, nil
}
s.dataForWriting = make([]byte, len(p))
copy(s.dataForWriting, p)
s.sender.onHasStreamData(s.streamID)
var bytesWritten int
var err error
for {
bytesWritten = len(p) - len(s.dataForWriting)
deadline := s.writeDeadline
if !deadline.IsZero() && !time.Now().Before(deadline) {
s.dataForWriting = nil
err = errDeadline
break
}
if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown {
break
}
s.mutex.Unlock()
if deadline.IsZero() {
<-s.writeChan
} else {
select {
case <-s.writeChan:
case <-time.After(deadline.Sub(time.Now())):
}
}
s.mutex.Lock()
}
if s.closeForShutdownErr != nil {
err = s.closeForShutdownErr
} else if s.cancelWriteErr != nil {
err = s.cancelWriteErr
}
return bytesWritten, err
}
// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
// maxBytes is the maximum length this frame (including frame header) will have.
func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.closeForShutdownErr != nil {
return nil, false
}
frame := &wire.StreamFrame{
StreamID: s.streamID,
Offset: s.writeOffset,
DataLenPresent: true,
}
frameLen := frame.MinLength(s.version)
if frameLen >= maxBytes { // a STREAM frame must have at least one byte of data
return nil, s.dataForWriting != nil
}
frame.Data, frame.FinBit = s.getDataForWriting(maxBytes - frameLen)
if len(frame.Data) == 0 && !frame.FinBit {
// this can happen if:
// - popStreamFrame is called but there's no data for writing
// - there's data for writing, but the stream is stream-level flow control blocked
// - there's data for writing, but the stream is connection-level flow control blocked
if s.dataForWriting == nil {
return nil, false
}
isBlocked, _ := s.flowController.IsBlocked()
return nil, !isBlocked
}
if frame.FinBit {
s.finSent = true
s.sender.onStreamCompleted(s.streamID)
} else if s.streamID != s.version.CryptoStreamID() { // TODO(#657): Flow control for the crypto stream
if isBlocked, offset := s.flowController.IsBlocked(); isBlocked {
s.sender.queueControlFrame(&wire.StreamBlockedFrame{
StreamID: s.streamID,
Offset: offset,
})
return frame, false
}
}
return frame, s.dataForWriting != nil
}
func (s *sendStream) getDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) {
if s.dataForWriting == nil {
return nil, s.finishedWriting && !s.finSent
}
// TODO(#657): Flow control for the crypto stream
if s.streamID != s.version.CryptoStreamID() {
maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize())
}
if maxBytes == 0 {
return nil, false
}
var ret []byte
if protocol.ByteCount(len(s.dataForWriting)) > maxBytes {
ret = s.dataForWriting[:maxBytes]
s.dataForWriting = s.dataForWriting[maxBytes:]
} else {
ret = s.dataForWriting
s.dataForWriting = nil
s.signalWrite()
}
s.writeOffset += protocol.ByteCount(len(ret))
s.flowController.AddBytesSent(protocol.ByteCount(len(ret)))
return ret, s.finishedWriting && s.dataForWriting == nil && !s.finSent
}
func (s *sendStream) Close() error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.canceledWrite {
return fmt.Errorf("Close called for canceled stream %d", s.streamID)
}
s.finishedWriting = true
s.sender.onHasStreamData(s.streamID) // need to send the FIN
s.ctxCancel()
return nil
}
func (s *sendStream) CancelWrite(errorCode protocol.ApplicationErrorCode) error {
s.mutex.Lock()
defer s.mutex.Unlock()
return s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode))
}
// must be called after locking the mutex
func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, writeErr error) error {
if s.canceledWrite {
return nil
}
if s.finishedWriting {
return fmt.Errorf("CancelWrite for closed stream %d", s.streamID)
}
s.canceledWrite = true
s.cancelWriteErr = writeErr
s.signalWrite()
s.sender.queueControlFrame(&wire.RstStreamFrame{
StreamID: s.streamID,
ByteOffset: s.writeOffset,
ErrorCode: errorCode,
})
// TODO(#991): cancel retransmissions for this stream
s.ctxCancel()
s.sender.onStreamCompleted(s.streamID)
return nil
}
func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.handleStopSendingFrameImpl(frame)
}
func (s *sendStream) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) {
s.flowController.UpdateSendWindow(frame.ByteOffset)
s.mutex.Lock()
if s.dataForWriting != nil {
s.sender.onHasStreamData(s.streamID)
}
s.mutex.Unlock()
}
// must be called after locking the mutex
func (s *sendStream) handleStopSendingFrameImpl(frame *wire.StopSendingFrame) {
writeErr := streamCanceledError{
errorCode: frame.ErrorCode,
error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode),
}
errorCode := errorCodeStopping
if !s.version.UsesIETFFrameFormat() {
errorCode = errorCodeStoppingGQUIC
}
s.cancelWriteImpl(errorCode, writeErr)
}
func (s *sendStream) Context() context.Context {
return s.ctx
}
func (s *sendStream) SetWriteDeadline(t time.Time) error {
s.mutex.Lock()
oldDeadline := s.writeDeadline
s.writeDeadline = t
s.mutex.Unlock()
if t.Before(oldDeadline) {
s.signalWrite()
}
return nil
}
// CloseForShutdown closes a stream abruptly.
// It makes Write unblock (and return the error) immediately.
// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
func (s *sendStream) closeForShutdown(err error) {
s.mutex.Lock()
s.closedForShutdown = true
s.closeForShutdownErr = err
s.mutex.Unlock()
s.signalWrite()
s.ctxCancel()
}
func (s *sendStream) getWriteOffset() protocol.ByteCount {
return s.writeOffset
}
// signalWrite performs a non-blocking send on the writeChan
func (s *sendStream) signalWrite() {
select {
case s.writeChan <- struct{}{}:
default:
}
}

View File

@ -0,0 +1,636 @@
package quic
import (
"bytes"
"errors"
"io"
"runtime"
"time"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/gbytes"
)
var _ = Describe("Send Stream", func() {
const streamID protocol.StreamID = 1337
var (
str *sendStream
strWithTimeout io.Writer // str wrapped with gbytes.TimeoutWriter
mockFC *mocks.MockStreamFlowController
mockSender *MockStreamSender
)
BeforeEach(func() {
mockSender = NewMockStreamSender(mockCtrl)
mockFC = mocks.NewMockStreamFlowController(mockCtrl)
str = newSendStream(streamID, mockSender, mockFC, protocol.VersionWhatever)
timeout := scaleDuration(250 * time.Millisecond)
strWithTimeout = gbytes.TimeoutWriter(str, timeout)
})
waitForWrite := func() {
EventuallyWithOffset(0, func() []byte {
str.mutex.Lock()
data := str.dataForWriting
str.mutex.Unlock()
return data
}).ShouldNot(BeEmpty())
}
It("gets stream id", func() {
Expect(str.StreamID()).To(Equal(protocol.StreamID(1337)))
})
Context("writing", func() {
It("writes and gets all data at once", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6))
mockFC.EXPECT().IsBlocked()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
n, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
close(done)
}()
waitForWrite()
f, _ := str.popStreamFrame(1000)
Expect(f.Data).To(Equal([]byte("foobar")))
Expect(f.FinBit).To(BeFalse())
Expect(f.Offset).To(BeZero())
Expect(f.DataLenPresent).To(BeTrue())
Expect(str.writeOffset).To(Equal(protocol.ByteCount(6)))
Expect(str.dataForWriting).To(BeNil())
Eventually(done).Should(BeClosed())
})
It("writes and gets data in two turns", func() {
mockSender.EXPECT().onHasStreamData(streamID)
frameHeaderLen := protocol.ByteCount(4)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2)
mockFC.EXPECT().AddBytesSent(gomock.Any() /* protocol.ByteCount(3)*/).Times(2)
mockFC.EXPECT().IsBlocked().Times(2)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
n, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
close(done)
}()
waitForWrite()
f, _ := str.popStreamFrame(3 + frameHeaderLen)
Expect(f.Data).To(Equal([]byte("foo")))
Expect(f.FinBit).To(BeFalse())
Expect(f.Offset).To(BeZero())
Expect(f.DataLenPresent).To(BeTrue())
f, _ = str.popStreamFrame(100)
Expect(f.Data).To(Equal([]byte("bar")))
Expect(f.FinBit).To(BeFalse())
Expect(f.Offset).To(Equal(protocol.ByteCount(3)))
Expect(f.DataLenPresent).To(BeTrue())
Expect(str.popStreamFrame(1000)).To(BeNil())
Eventually(done).Should(BeClosed())
})
It("popStreamFrame returns nil if no data is available", func() {
frame, hasMoreData := str.popStreamFrame(1000)
Expect(frame).To(BeNil())
Expect(hasMoreData).To(BeFalse())
})
It("says if it has more data for writing", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2)
mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2)
mockFC.EXPECT().IsBlocked().Times(2)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
n, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 100))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(100))
close(done)
}()
waitForWrite()
frame, hasMoreData := str.popStreamFrame(50)
Expect(hasMoreData).To(BeTrue())
frame, hasMoreData = str.popStreamFrame(1000)
Expect(frame).ToNot(BeNil())
Expect(hasMoreData).To(BeFalse())
frame, _ = str.popStreamFrame(1000)
Expect(frame).To(BeNil())
Eventually(done).Should(BeClosed())
})
It("copies the slice while writing", func() {
mockSender.EXPECT().onHasStreamData(streamID)
frameHeaderSize := protocol.ByteCount(4)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2)
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(1))
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(2))
mockFC.EXPECT().IsBlocked().Times(2)
s := []byte("foo")
done := make(chan struct{})
go func() {
defer GinkgoRecover()
n, err := strWithTimeout.Write(s)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
close(done)
}()
waitForWrite()
frame, _ := str.popStreamFrame(frameHeaderSize + 1)
Expect(frame.Data).To(Equal([]byte("f")))
s[1] = 'e'
f, _ := str.popStreamFrame(100)
Expect(f).ToNot(BeNil())
Expect(f.Data).To(Equal([]byte("oo")))
Eventually(done).Should(BeClosed())
})
It("returns when given a nil input", func() {
n, err := strWithTimeout.Write(nil)
Expect(n).To(BeZero())
Expect(err).ToNot(HaveOccurred())
})
It("returns when given an empty slice", func() {
n, err := strWithTimeout.Write([]byte(""))
Expect(n).To(BeZero())
Expect(err).ToNot(HaveOccurred())
})
It("cancels the context when Close is called", func() {
mockSender.EXPECT().onHasStreamData(streamID)
Expect(str.Context().Done()).ToNot(BeClosed())
str.Close()
Expect(str.Context().Done()).To(BeClosed())
})
Context("flow control blocking", func() {
It("returns nil when it is blocked", func() {
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(0))
mockFC.EXPECT().IsBlocked().Return(true, protocol.ByteCount(10))
mockSender.EXPECT().onHasStreamData(streamID)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
close(done)
}()
waitForWrite()
f, hasMoreData := str.popStreamFrame(1000)
Expect(f).To(BeNil())
Expect(hasMoreData).To(BeFalse())
// make the Write go routine return
str.closeForShutdown(nil)
Eventually(done).Should(BeClosed())
})
It("queues a BLOCKED frame if the stream is flow control blocked", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().queueControlFrame(&wire.StreamBlockedFrame{
StreamID: streamID,
Offset: 10,
})
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6))
// don't use offset 6 here, to make sure the BLOCKED frame contains the number returned by the flow controller
mockFC.EXPECT().IsBlocked().Return(true, protocol.ByteCount(10))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
close(done)
}()
waitForWrite()
f, hasMoreData := str.popStreamFrame(1000)
Expect(f).ToNot(BeNil())
Expect(hasMoreData).To(BeFalse())
Eventually(done).Should(BeClosed())
})
It("says that it doesn't have any more data, when it is flow control blocked", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
mockFC.EXPECT().AddBytesSent(gomock.Any())
mockFC.EXPECT().IsBlocked().Return(true, protocol.ByteCount(10))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := str.Write(bytes.Repeat([]byte{0}, 100))
Expect(err).ToNot(HaveOccurred())
close(done)
}()
waitForWrite()
f, hasMoreData := str.popStreamFrame(50)
Expect(f).ToNot(BeNil())
Expect(hasMoreData).To(BeFalse())
// make the Write go routine return
str.closeForShutdown(nil)
Eventually(done).Should(BeClosed())
})
It("doesn't queue a BLOCKED frame if the stream is flow control blocked, but the frame popped has the FIN bit set", func() {
mockSender.EXPECT().onHasStreamData(streamID).Times(2) // once for the Write, once for the Close
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6))
// don't EXPECT a call to mockFC.IsBlocked
// don't EXPECT a call to mockSender.queueControlFrame
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
close(done)
}()
waitForWrite()
Expect(str.Close()).To(Succeed())
f, hasMoreData := str.popStreamFrame(1000)
Expect(hasMoreData).To(BeFalse())
Expect(f).ToNot(BeNil())
Expect(f.FinBit).To(BeTrue())
Eventually(done).Should(BeClosed())
})
})
Context("deadlines", func() {
It("returns an error when Write is called after the deadline", func() {
str.SetWriteDeadline(time.Now().Add(-time.Second))
n, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).To(MatchError(errDeadline))
Expect(n).To(BeZero())
})
It("unblocks after the deadline", func() {
mockSender.EXPECT().onHasStreamData(streamID)
deadline := time.Now().Add(scaleDuration(50 * time.Millisecond))
str.SetWriteDeadline(deadline)
n, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).To(MatchError(errDeadline))
Expect(n).To(BeZero())
Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond)))
})
It("returns the number of bytes written, when the deadline expires", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(10000)).AnyTimes()
mockFC.EXPECT().AddBytesSent(gomock.Any())
mockFC.EXPECT().IsBlocked()
deadline := time.Now().Add(scaleDuration(50 * time.Millisecond))
str.SetWriteDeadline(deadline)
var n int
writeReturned := make(chan struct{})
go func() {
defer GinkgoRecover()
var err error
n, err = strWithTimeout.Write(bytes.Repeat([]byte{0}, 100))
Expect(err).To(MatchError(errDeadline))
Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond)))
close(writeReturned)
}()
waitForWrite()
frame, hasMoreData := str.popStreamFrame(50)
Expect(frame).ToNot(BeNil())
Expect(hasMoreData).To(BeTrue())
Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed())
Expect(n).To(BeEquivalentTo(frame.DataLen()))
})
It("doesn't pop any data after the deadline expired", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(10000)).AnyTimes()
mockFC.EXPECT().AddBytesSent(gomock.Any())
mockFC.EXPECT().IsBlocked()
deadline := time.Now().Add(scaleDuration(50 * time.Millisecond))
str.SetWriteDeadline(deadline)
writeReturned := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 100))
Expect(err).To(MatchError(errDeadline))
close(writeReturned)
}()
waitForWrite()
frame, hasMoreData := str.popStreamFrame(50)
Expect(frame).ToNot(BeNil())
Expect(hasMoreData).To(BeTrue())
Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed())
frame, hasMoreData = str.popStreamFrame(50)
Expect(frame).To(BeNil())
Expect(hasMoreData).To(BeFalse())
})
It("doesn't unblock if the deadline is changed before the first one expires", func() {
mockSender.EXPECT().onHasStreamData(streamID)
deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond))
deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond))
str.SetWriteDeadline(deadline1)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
time.Sleep(scaleDuration(20 * time.Millisecond))
str.SetWriteDeadline(deadline2)
// make sure that this was actually execute before the deadline expires
Expect(time.Now()).To(BeTemporally("<", deadline1))
close(done)
}()
runtime.Gosched()
n, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).To(MatchError(errDeadline))
Expect(n).To(BeZero())
Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond)))
Eventually(done).Should(BeClosed())
})
It("unblocks earlier, when a new deadline is set", func() {
mockSender.EXPECT().onHasStreamData(streamID)
deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond))
deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
time.Sleep(scaleDuration(10 * time.Millisecond))
str.SetWriteDeadline(deadline2)
// make sure that this was actually execute before the deadline expires
Expect(time.Now()).To(BeTemporally("<", deadline2))
close(done)
}()
str.SetWriteDeadline(deadline1)
runtime.Gosched()
_, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).To(MatchError(errDeadline))
Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond)))
Eventually(done).Should(BeClosed())
})
})
Context("closing", func() {
It("doesn't allow writes after it has been closed", func() {
mockSender.EXPECT().onHasStreamData(streamID)
str.Close()
_, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).To(MatchError("write on closed stream 1337"))
})
It("allows FIN", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onStreamCompleted(streamID)
str.Close()
f, hasMoreData := str.popStreamFrame(1000)
Expect(f).ToNot(BeNil())
Expect(f.Data).To(BeEmpty())
Expect(f.FinBit).To(BeTrue())
Expect(hasMoreData).To(BeFalse())
})
It("doesn't send a FIN when there's still data", func() {
mockSender.EXPECT().onHasStreamData(streamID)
frameHeaderLen := protocol.ByteCount(4)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2)
mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2)
mockFC.EXPECT().IsBlocked()
str.dataForWriting = []byte("foobar")
Expect(str.Close()).To(Succeed())
f, _ := str.popStreamFrame(3 + frameHeaderLen)
Expect(f).ToNot(BeNil())
Expect(f.Data).To(Equal([]byte("foo")))
Expect(f.FinBit).To(BeFalse())
mockSender.EXPECT().onStreamCompleted(streamID)
f, _ = str.popStreamFrame(100)
Expect(f.Data).To(Equal([]byte("bar")))
Expect(f.FinBit).To(BeTrue())
})
It("doesn't allow FIN after it is closed for shutdown", func() {
str.closeForShutdown(errors.New("test"))
f, hasMoreData := str.popStreamFrame(1000)
Expect(f).To(BeNil())
Expect(hasMoreData).To(BeFalse())
})
It("doesn't allow FIN twice", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onStreamCompleted(streamID)
str.Close()
f, _ := str.popStreamFrame(1000)
Expect(f).ToNot(BeNil())
Expect(f.Data).To(BeEmpty())
Expect(f.FinBit).To(BeTrue())
f, hasMoreData := str.popStreamFrame(1000)
Expect(f).To(BeNil())
Expect(hasMoreData).To(BeFalse())
})
})
Context("closing for shutdown", func() {
testErr := errors.New("test")
It("returns errors when the stream is cancelled", func() {
str.closeForShutdown(testErr)
n, err := strWithTimeout.Write([]byte("foo"))
Expect(n).To(BeZero())
Expect(err).To(MatchError(testErr))
})
It("doesn't get data for writing if an error occurred", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
mockFC.EXPECT().AddBytesSent(gomock.Any())
mockFC.EXPECT().IsBlocked()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 500))
Expect(err).To(MatchError(testErr))
close(done)
}()
waitForWrite()
frame, hasMoreData := str.popStreamFrame(50) // get a STREAM frame containing some data, but not all
Expect(frame).ToNot(BeNil())
Expect(hasMoreData).To(BeTrue())
str.closeForShutdown(testErr)
frame, hasMoreData = str.popStreamFrame(1000)
Expect(frame).To(BeNil())
Expect(hasMoreData).To(BeFalse())
Eventually(done).Should(BeClosed())
})
It("cancels the context", func() {
Expect(str.Context().Done()).ToNot(BeClosed())
str.closeForShutdown(testErr)
Expect(str.Context().Done()).To(BeClosed())
})
})
})
Context("handling MAX_STREAM_DATA frames", func() {
It("informs the flow controller", func() {
mockFC.EXPECT().UpdateSendWindow(protocol.ByteCount(0x1337))
str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{
StreamID: streamID,
ByteOffset: 0x1337,
})
})
It("says when it has data for sending", func() {
mockFC.EXPECT().UpdateSendWindow(gomock.Any())
mockSender.EXPECT().onHasStreamData(streamID).Times(2) // once for Write, once for the MAX_STREAM_DATA frame
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
close(done)
}()
waitForWrite()
str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{
StreamID: streamID,
ByteOffset: 42,
})
// make sure the Write go routine returns
str.closeForShutdown(nil)
Eventually(done).Should(BeClosed())
})
})
Context("stream cancelations", func() {
Context("canceling writing", func() {
It("queues a RST_STREAM frame", func() {
mockSender.EXPECT().queueControlFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 1234,
ErrorCode: 9876,
})
mockSender.EXPECT().onStreamCompleted(streamID)
str.writeOffset = 1234
err := str.CancelWrite(9876)
Expect(err).ToNot(HaveOccurred())
})
It("unblocks Write", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onStreamCompleted(streamID)
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount)
mockFC.EXPECT().AddBytesSent(gomock.Any())
mockFC.EXPECT().IsBlocked()
writeReturned := make(chan struct{})
var n int
go func() {
defer GinkgoRecover()
var err error
n, err = strWithTimeout.Write(bytes.Repeat([]byte{0}, 100))
Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234"))
close(writeReturned)
}()
waitForWrite()
frame, _ := str.popStreamFrame(50)
Expect(frame).ToNot(BeNil())
err := str.CancelWrite(1234)
Expect(err).ToNot(HaveOccurred())
Eventually(writeReturned).Should(BeClosed())
Expect(n).To(BeEquivalentTo(frame.DataLen()))
})
It("cancels the context", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
Expect(str.Context().Done()).ToNot(BeClosed())
str.CancelWrite(1234)
Expect(str.Context().Done()).To(BeClosed())
})
It("doesn't allow further calls to Write", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
err := str.CancelWrite(1234)
Expect(err).ToNot(HaveOccurred())
_, err = strWithTimeout.Write([]byte("foobar"))
Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234"))
})
It("only cancels once", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
err := str.CancelWrite(1234)
Expect(err).ToNot(HaveOccurred())
err = str.CancelWrite(4321)
Expect(err).ToNot(HaveOccurred())
})
It("doesn't cancel when the stream was already closed", func() {
mockSender.EXPECT().onHasStreamData(streamID)
err := str.Close()
Expect(err).ToNot(HaveOccurred())
err = str.CancelWrite(123)
Expect(err).To(MatchError("CancelWrite for closed stream 1337"))
})
})
Context("receiving STOP_SENDING frames", func() {
It("queues a RST_STREAM frames with error code Stopping", func() {
mockSender.EXPECT().queueControlFrame(&wire.RstStreamFrame{
StreamID: streamID,
ErrorCode: errorCodeStopping,
})
mockSender.EXPECT().onStreamCompleted(streamID)
str.handleStopSendingFrame(&wire.StopSendingFrame{
StreamID: streamID,
ErrorCode: 101,
})
})
It("unblocks Write", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().queueControlFrame(gomock.Any())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := str.Write([]byte("foobar"))
Expect(err).To(MatchError("Stream 1337 was reset with error code 123"))
Expect(err).To(BeAssignableToTypeOf(streamCanceledError{}))
Expect(err.(streamCanceledError).Canceled()).To(BeTrue())
Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(123)))
close(done)
}()
waitForWrite()
mockSender.EXPECT().onStreamCompleted(streamID)
str.handleStopSendingFrame(&wire.StopSendingFrame{
StreamID: streamID,
ErrorCode: 123,
})
Eventually(done).Should(BeClosed())
})
It("doesn't allow further calls to Write", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
str.handleStopSendingFrame(&wire.StopSendingFrame{
StreamID: streamID,
ErrorCode: 123,
})
_, err := str.Write([]byte("foobar"))
Expect(err).To(MatchError("Stream 1337 was reset with error code 123"))
Expect(err).To(BeAssignableToTypeOf(streamCanceledError{}))
Expect(err.(streamCanceledError).Canceled()).To(BeTrue())
Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(123)))
})
})
})
})

View File

@ -19,8 +19,8 @@ import (
// packetHandler handles packets // packetHandler handles packets
type packetHandler interface { type packetHandler interface {
Session Session
getCryptoStream() cryptoStream getCryptoStream() cryptoStreamI
handshakeStatus() <-chan handshakeEvent handshakeStatus() <-chan error
handlePacket(*receivedPacket) handlePacket(*receivedPacket)
GetVersion() protocol.VersionNumber GetVersion() protocol.VersionNumber
run() error run() error
@ -40,15 +40,17 @@ type server struct {
certChain crypto.CertChain certChain crypto.CertChain
scfg *handshake.ServerConfig scfg *handshake.ServerConfig
sessions map[protocol.ConnectionID]packetHandler
sessionsMutex sync.RWMutex sessionsMutex sync.RWMutex
deleteClosedSessionsAfter time.Duration sessions map[protocol.ConnectionID]packetHandler
closed bool
serverError error serverError error
sessionQueue chan Session sessionQueue chan Session
errorChan chan struct{} errorChan chan struct{}
// set as members, so they can be set in the tests
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, error) newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, error)
deleteClosedSessionsAfter time.Duration
} }
var _ Listener = &server{} var _ Listener = &server{}
@ -240,6 +242,12 @@ func (s *server) Accept() (Session, error) {
// Close the server // Close the server
func (s *server) Close() error { func (s *server) Close() error {
s.sessionsMutex.Lock() s.sessionsMutex.Lock()
if s.closed {
s.sessionsMutex.Unlock()
return nil
}
s.closed = true
var wg sync.WaitGroup var wg sync.WaitGroup
for _, session := range s.sessions { for _, session := range s.sessions {
if session != nil { if session != nil {
@ -254,10 +262,9 @@ func (s *server) Close() error {
s.sessionsMutex.Unlock() s.sessionsMutex.Unlock()
wg.Wait() wg.Wait()
if s.conn == nil { err := s.conn.Close()
return nil <-s.errorChan // wait for serve() to return
} return err
return s.conn.Close()
} }
// Addr returns the server's network address // Addr returns the server's network address
@ -384,15 +391,9 @@ func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.C
}() }()
go func() { go func() {
for { if err := <-session.handshakeStatus(); err != nil {
ev := <-session.handshakeStatus()
if ev.err != nil {
return return
} }
if ev.encLevel == protocol.EncryptionForwardSecure {
break
}
}
s.sessionQueue <- session s.sessionQueue <- session
}() }()
} }

View File

@ -28,8 +28,7 @@ type mockSession struct {
closeReason error closeReason error
closedRemote bool closedRemote bool
stopRunLoop chan struct{} // run returns as soon as this channel receives a value stopRunLoop chan struct{} // run returns as soon as this channel receives a value
handshakeChan chan handshakeEvent handshakeChan chan error
handshakeComplete chan error // for WaitUntilHandshakeComplete
} }
func (s *mockSession) handlePacket(*receivedPacket) { func (s *mockSession) handlePacket(*receivedPacket) {
@ -40,9 +39,6 @@ func (s *mockSession) run() error {
<-s.stopRunLoop <-s.stopRunLoop
return s.closeReason return s.closeReason
} }
func (s *mockSession) WaitUntilHandshakeComplete() error {
return <-s.handshakeComplete
}
func (s *mockSession) Close(e error) error { func (s *mockSession) Close(e error) error {
if s.closed { if s.closed {
return nil return nil
@ -59,19 +55,19 @@ func (s *mockSession) closeRemote(e error) {
close(s.stopRunLoop) close(s.stopRunLoop)
} }
func (s *mockSession) OpenStream() (Stream, error) { func (s *mockSession) OpenStream() (Stream, error) {
return &stream{streamID: 1337}, nil return &stream{}, nil
} }
func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") } func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") }
func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") } func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") }
func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") } func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") }
func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") } func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") }
func (*mockSession) Context() context.Context { panic("not implemented") } func (*mockSession) Context() context.Context { panic("not implemented") }
func (*mockSession) ConnectionState() ConnectionState { panic("not implemented") }
func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever } func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever }
func (s *mockSession) handshakeStatus() <-chan handshakeEvent { return s.handshakeChan } func (s *mockSession) handshakeStatus() <-chan error { return s.handshakeChan }
func (*mockSession) getCryptoStream() cryptoStream { panic("not implemented") } func (*mockSession) getCryptoStream() cryptoStreamI { panic("not implemented") }
var _ Session = &mockSession{} var _ Session = &mockSession{}
var _ NonFWSession = &mockSession{}
func newMockSession( func newMockSession(
_ connection, _ connection,
@ -83,8 +79,7 @@ func newMockSession(
) (packetHandler, error) { ) (packetHandler, error) {
s := mockSession{ s := mockSession{
connectionID: connectionID, connectionID: connectionID,
handshakeChan: make(chan handshakeEvent), handshakeChan: make(chan error),
handshakeComplete: make(chan error),
stopRunLoop: make(chan struct{}), stopRunLoop: make(chan struct{}),
} }
return &s, nil return &s, nil
@ -155,9 +150,8 @@ var _ = Describe("Server", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1)) Expect(serv.sessions).To(HaveLen(1))
sess := serv.sessions[connID].(*mockSession) sess := serv.sessions[connID].(*mockSession)
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
Consistently(func() Session { return acceptedSess }).Should(BeNil()) Consistently(func() Session { return acceptedSess }).Should(BeNil())
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionForwardSecure} close(sess.handshakeChan)
Eventually(func() Session { return acceptedSess }).Should(Equal(sess)) Eventually(func() Session { return acceptedSess }).Should(Equal(sess))
close(done) close(done)
}, 0.5) }, 0.5)
@ -173,7 +167,7 @@ var _ = Describe("Server", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1)) Expect(serv.sessions).To(HaveLen(1))
sess := serv.sessions[connID].(*mockSession) sess := serv.sessions[connID].(*mockSession)
sess.handshakeChan <- handshakeEvent{err: errors.New("handshake failed")} sess.handshakeChan <- errors.New("handshake failed")
Consistently(func() bool { return accepted }).Should(BeFalse()) Consistently(func() bool { return accepted }).Should(BeFalse())
close(done) close(done)
}) })
@ -222,6 +216,7 @@ var _ = Describe("Server", func() {
}) })
It("closes sessions and the connection when Close is called", func() { It("closes sessions and the connection when Close is called", func() {
go serv.serve()
session, _ := newMockSession(nil, 0, 0, nil, nil, nil) session, _ := newMockSession(nil, 0, 0, nil, nil, nil)
serv.sessions[1] = session serv.sessions[1] = session
err := serv.Close() err := serv.Close()

View File

@ -12,6 +12,7 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
) )
type nullAEAD struct { type nullAEAD struct {
@ -98,6 +99,26 @@ func (s *serverTLS) newMintConnImpl(bc *handshake.CryptoStreamConn, v protocol.V
return tls, extHandler.GetPeerParams(), nil return tls, extHandler.GetPeerParams(), nil
} }
func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Header, aead crypto.AEAD, closeErr error) error {
ccf := &wire.ConnectionCloseFrame{
ErrorCode: qerr.HandshakeFailed,
ReasonPhrase: closeErr.Error(),
}
replyHdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
ConnectionID: clientHdr.ConnectionID, // echo the client's connection ID
PacketNumber: 1, // random packet number
Version: clientHdr.Version,
}
data, err := packUnencryptedPacket(aead, replyHdr, ccf, protocol.PerspectiveServer)
if err != nil {
return err
}
_, err = s.conn.WriteTo(data, remoteAddr)
return err
}
func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, error) { func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, error) {
if len(hdr.Raw)+len(data) < protocol.MinInitialPacketSize { if len(hdr.Raw)+len(data) < protocol.MinInitialPacketSize {
return nil, errors.New("dropping too small Initial packet") return nil, errors.New("dropping too small Initial packet")
@ -110,19 +131,30 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat
} }
// unpack packet and check stream frame contents // unpack packet and check stream frame contents
version := hdr.Version aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.ConnectionID, hdr.Version)
aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.ConnectionID, version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
frame, err := unpackInitialPacket(aead, hdr, data, version) frame, err := unpackInitialPacket(aead, hdr, data, hdr.Version)
if err != nil { if err != nil {
utils.Debugf("Error unpacking initial packet: %s", err) utils.Debugf("Error unpacking initial packet: %s", err)
return nil, nil return nil, nil
} }
sess, err := s.handleUnpackedInitial(remoteAddr, hdr, frame, aead)
if err != nil {
if ccerr := s.sendConnectionClose(remoteAddr, hdr, aead, err); ccerr != nil {
utils.Debugf("Error sending CONNECTION_CLOSE: ", ccerr)
}
return nil, err
}
return sess, nil
}
func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (packetHandler, error) {
version := hdr.Version
bc := handshake.NewCryptoStreamConn(remoteAddr) bc := handshake.NewCryptoStreamConn(remoteAddr)
bc.AddDataForReading(frame.Data) bc.AddDataForReading(frame.Data)
tls, paramsChan, err := s.newMintConn(bc, hdr.Version) tls, paramsChan, err := s.newMintConn(bc, version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -176,7 +208,7 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat
return nil, err return nil, err
} }
cs := sess.getCryptoStream() cs := sess.getCryptoStream()
cs.SetReadOffset(frame.DataLen()) cs.setReadOffset(frame.DataLen())
bc.SetStream(cs) bc.SetStream(cs)
return sess, nil return sess, nil
} }

View File

@ -4,15 +4,16 @@ import (
"bytes" "bytes"
"io" "io"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/mocks/handshake"
"github.com/bifurcation/mint" "github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/mocks/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
@ -65,6 +66,18 @@ var _ = Describe("Stateless TLS handling", func() {
return hdr, data return hdr, data
} }
unpackPacket := func(data []byte) (*wire.Header, []byte) {
r := bytes.NewReader(conn.dataWritten.Bytes())
hdr, err := wire.ParseHeaderSentByServer(r, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
hdr.Raw = data[:len(data)-r.Len()]
aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, hdr.ConnectionID, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
payload, err := aead.Open(nil, data[len(data)-r.Len():], hdr.PacketNumber, hdr.Raw)
Expect(err).ToNot(HaveOccurred())
return hdr, payload
}
It("sends a version negotiation packet if it doesn't support the version", func() { It("sends a version negotiation packet if it doesn't support the version", func() {
server.HandleInitial(nil, &wire.Header{Version: 0x1337}, bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize)) server.HandleInitial(nil, &wire.Header{Version: 0x1337}, bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize))
Expect(conn.dataWritten.Len()).ToNot(BeZero()) Expect(conn.dataWritten.Len()).ToNot(BeZero())
@ -124,4 +137,20 @@ var _ = Describe("Stateless TLS handling", func() {
Eventually(sessionChan).Should(Receive()) Eventually(sessionChan).Should(Receive())
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
It("sends a CONNECTION_CLOSE, if mint returns an error", func() {
mintTLS.EXPECT().Handshake().Return(mint.AlertAccessDenied)
extHandler.EXPECT().GetPeerParams()
hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")})
server.HandleInitial(nil, hdr, data)
// the Handshake packet is written by the session
Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty())
// unpack the packet to check that it actually contains a CONNECTION_CLOSE
hdr, data = unpackPacket(conn.dataWritten.Bytes())
Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake))
ccf, err := wire.ParseConnectionCloseFrame(bytes.NewReader(data), protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
Expect(ccf.ErrorCode).To(Equal(qerr.HandshakeFailed))
Expect(ccf.ReasonPhrase).To(Equal(mint.AlertAccessDenied.String()))
})
}) })

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt"
"net" "net"
"sync" "sync"
"time" "time"
@ -24,6 +23,23 @@ type unpacker interface {
Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error)
} }
type streamGetter interface {
GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error)
GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error)
}
type streamManager interface {
GetOrOpenStream(protocol.StreamID) (streamI, error)
GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error)
GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error)
OpenStream() (Stream, error)
OpenStreamSync() (Stream, error)
AcceptStream() (Stream, error)
DeleteStream(protocol.StreamID) error
UpdateLimits(*handshake.TransportParameters)
CloseWithError(error)
}
type receivedPacket struct { type receivedPacket struct {
remoteAddr net.Addr remoteAddr net.Addr
header *wire.Header header *wire.Header
@ -36,11 +52,6 @@ var (
newCryptoSetupClient = handshake.NewCryptoSetupClient newCryptoSetupClient = handshake.NewCryptoSetupClient
) )
type handshakeEvent struct {
encLevel protocol.EncryptionLevel
err error
}
type closeError struct { type closeError struct {
err error err error
remote bool remote bool
@ -55,15 +66,15 @@ type session struct {
conn connection conn connection
streamsMap *streamsMap streamsMap streamManager
cryptoStream cryptoStream cryptoStream cryptoStreamI
rttStats *congestion.RTTStats rttStats *congestion.RTTStats
sentPacketHandler ackhandler.SentPacketHandler sentPacketHandler ackhandler.SentPacketHandler
receivedPacketHandler ackhandler.ReceivedPacketHandler receivedPacketHandler ackhandler.ReceivedPacketHandler
streamFramer *streamFramer streamFramer *streamFramer
windowUpdateQueue *windowUpdateQueue
connFlowController flowcontrol.ConnectionFlowController connFlowController flowcontrol.ConnectionFlowController
unpacker unpacker unpacker unpacker
@ -87,17 +98,14 @@ type session struct {
// this channel is passed to the CryptoSetup and receives the transport parameters, as soon as the peer sends them // this channel is passed to the CryptoSetup and receives the transport parameters, as soon as the peer sends them
paramsChan <-chan handshake.TransportParameters paramsChan <-chan handshake.TransportParameters
// this channel is passed to the CryptoSetup and receives the current encryption level // the handshakeEvent channel is passed to the CryptoSetup.
// it is closed as soon as the handshake is complete // It receives when it makes sense to try decrypting undecryptable packets.
aeadChanged <-chan protocol.EncryptionLevel handshakeEvent <-chan struct{}
// handshakeChan is returned by handshakeStatus.
// It receives any error that might occur during the handshake.
// It is closed when the handshake is complete.
handshakeChan chan error
handshakeComplete bool handshakeComplete bool
// will be closed as soon as the handshake completes, and receive any error that might occur until then
// it is used to block WaitUntilHandshakeComplete()
handshakeCompleteChan chan error
// handshakeChan receives handshake events and is closed as soon the handshake completes
// the receiving end of this channel is passed to the creator of the session
// it receives at most 3 handshake events: 2 when the encryption level changes, and one error
handshakeChan chan handshakeEvent
lastRcvdPacketNumber protocol.PacketNumber lastRcvdPacketNumber protocol.PacketNumber
// Used to calculate the next packet number from the truncated wire // Used to calculate the next packet number from the truncated wire
@ -116,6 +124,7 @@ type session struct {
} }
var _ Session = &session{} var _ Session = &session{}
var _ streamSender = &session{}
// newSession makes a new session // newSession makes a new session
func newSession( func newSession(
@ -127,14 +136,14 @@ func newSession(
config *Config, config *Config,
) (packetHandler, error) { ) (packetHandler, error) {
paramsChan := make(chan handshake.TransportParameters) paramsChan := make(chan handshake.TransportParameters)
aeadChanged := make(chan protocol.EncryptionLevel, 2) handshakeEvent := make(chan struct{}, 1)
s := &session{ s := &session{
conn: conn, conn: conn,
connectionID: connectionID, connectionID: connectionID,
perspective: protocol.PerspectiveServer, perspective: protocol.PerspectiveServer,
version: v, version: v,
config: config, config: config,
aeadChanged: aeadChanged, handshakeEvent: handshakeEvent,
paramsChan: paramsChan, paramsChan: paramsChan,
} }
s.preSetup() s.preSetup()
@ -154,7 +163,7 @@ func newSession(
s.config.Versions, s.config.Versions,
s.config.AcceptCookie, s.config.AcceptCookie,
paramsChan, paramsChan,
aeadChanged, handshakeEvent,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -175,14 +184,14 @@ var newClientSession = func(
negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton
) (packetHandler, error) { ) (packetHandler, error) {
paramsChan := make(chan handshake.TransportParameters) paramsChan := make(chan handshake.TransportParameters)
aeadChanged := make(chan protocol.EncryptionLevel, 2) handshakeEvent := make(chan struct{}, 1)
s := &session{ s := &session{
conn: conn, conn: conn,
connectionID: connectionID, connectionID: connectionID,
perspective: protocol.PerspectiveClient, perspective: protocol.PerspectiveClient,
version: v, version: v,
config: config, config: config,
aeadChanged: aeadChanged, handshakeEvent: handshakeEvent,
paramsChan: paramsChan, paramsChan: paramsChan,
} }
s.preSetup() s.preSetup()
@ -201,7 +210,7 @@ var newClientSession = func(
tlsConf, tlsConf,
transportParams, transportParams,
paramsChan, paramsChan,
aeadChanged, handshakeEvent,
initialVersion, initialVersion,
negotiatedVersions, negotiatedVersions,
) )
@ -223,21 +232,21 @@ func newTLSServerSession(
peerParams *handshake.TransportParameters, peerParams *handshake.TransportParameters,
v protocol.VersionNumber, v protocol.VersionNumber,
) (packetHandler, error) { ) (packetHandler, error) {
aeadChanged := make(chan protocol.EncryptionLevel, 2) handshakeEvent := make(chan struct{}, 1)
s := &session{ s := &session{
conn: conn, conn: conn,
config: config, config: config,
connectionID: connectionID, connectionID: connectionID,
perspective: protocol.PerspectiveServer, perspective: protocol.PerspectiveServer,
version: v, version: v,
aeadChanged: aeadChanged, handshakeEvent: handshakeEvent,
} }
s.preSetup() s.preSetup()
s.cryptoSetup = handshake.NewCryptoSetupTLSServer( s.cryptoSetup = handshake.NewCryptoSetupTLSServer(
tls, tls,
cryptoStreamConn, cryptoStreamConn,
nullAEAD, nullAEAD,
aeadChanged, handshakeEvent,
v, v,
) )
if err := s.postSetup(initialPacketNumber); err != nil { if err := s.postSetup(initialPacketNumber); err != nil {
@ -260,14 +269,14 @@ var newTLSClientSession = func(
paramsChan <-chan handshake.TransportParameters, paramsChan <-chan handshake.TransportParameters,
initialPacketNumber protocol.PacketNumber, initialPacketNumber protocol.PacketNumber,
) (packetHandler, error) { ) (packetHandler, error) {
aeadChanged := make(chan protocol.EncryptionLevel, 2) handshakeEvent := make(chan struct{}, 1)
s := &session{ s := &session{
conn: conn, conn: conn,
config: config, config: config,
connectionID: connectionID, connectionID: connectionID,
perspective: protocol.PerspectiveClient, perspective: protocol.PerspectiveClient,
version: v, version: v,
aeadChanged: aeadChanged, handshakeEvent: handshakeEvent,
paramsChan: paramsChan, paramsChan: paramsChan,
} }
s.preSetup() s.preSetup()
@ -276,7 +285,7 @@ var newTLSClientSession = func(
s.cryptoStream, s.cryptoStream,
s.connectionID, s.connectionID,
hostname, hostname,
aeadChanged, handshakeEvent,
tls, tls,
v, v,
) )
@ -294,12 +303,11 @@ func (s *session) preSetup() {
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
s.rttStats, s.rttStats,
) )
s.cryptoStream = s.newStream(s.version.CryptoStreamID()).(cryptoStream) s.cryptoStream = s.newCryptoStream()
} }
func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error { func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error {
s.handshakeChan = make(chan handshakeEvent, 3) s.handshakeChan = make(chan error, 1)
s.handshakeCompleteChan = make(chan error, 1)
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
s.closeChan = make(chan closeError, 1) s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1) s.sendingScheduled = make(chan struct{}, 1)
@ -314,9 +322,12 @@ func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error {
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version) s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version)
s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.version) if s.version.UsesTLS() {
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.connFlowController, s.version) s.streamsMap = newStreamsMap(s.newStream, s.perspective)
} else {
s.streamsMap = newStreamsMapLegacy(s.newStream, s.perspective)
}
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version)
s.packer = newPacketPacker(s.connectionID, s.packer = newPacketPacker(s.connectionID,
initialPacketNumber, initialPacketNumber,
s.cryptoSetup, s.cryptoSetup,
@ -324,6 +335,7 @@ func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error {
s.perspective, s.perspective,
s.version, s.version,
) )
s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.cryptoStream, s.packer.QueueControlFrame)
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version} s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
return nil return nil
} }
@ -339,7 +351,7 @@ func (s *session) run() error {
}() }()
var closeErr closeError var closeErr closeError
aeadChanged := s.aeadChanged handshakeEvent := s.handshakeEvent
runLoop: runLoop:
for { for {
@ -377,16 +389,20 @@ runLoop:
putPacketBuffer(p.header.Raw) putPacketBuffer(p.header.Raw)
case p := <-s.paramsChan: case p := <-s.paramsChan:
s.processTransportParameters(&p) s.processTransportParameters(&p)
case l, ok := <-aeadChanged: case _, ok := <-handshakeEvent:
if !ok { // the aeadChanged chan was closed. This means that the handshake is completed. if !ok { // the aeadChanged chan was closed. This means that the handshake is completed.
s.handshakeComplete = true s.handshakeComplete = true
aeadChanged = nil // prevent this case from ever being selected again handshakeEvent = nil // prevent this case from ever being selected again
s.sentPacketHandler.SetHandshakeComplete() s.sentPacketHandler.SetHandshakeComplete()
if !s.version.UsesTLS() && s.perspective == protocol.PerspectiveClient {
// In gQUIC, there's no equivalent to the Finished message in TLS
// The server knows that the handshake is complete when it receives the first forward-secure packet sent by the client.
// We need to make sure that the client actually sends such a packet.
s.packer.QueueControlFrame(&wire.PingFrame{})
}
close(s.handshakeChan) close(s.handshakeChan)
close(s.handshakeCompleteChan)
} else { } else {
s.tryDecryptingQueuedPackets() s.tryDecryptingQueuedPackets()
s.handshakeChan <- handshakeEvent{encLevel: l}
} }
} }
@ -403,9 +419,26 @@ runLoop:
s.keepAlivePingSent = true s.keepAlivePingSent = true
} }
if err := s.sendPacket(); err != nil { sendingAllowed := s.sentPacketHandler.SendingAllowed()
if !sendingAllowed { // if congestion limited, at least try sending an ACK frame
if err := s.maybeSendAckOnlyPacket(); err != nil {
s.closeLocal(err) s.closeLocal(err)
} }
} else {
// repeatedly try sending until we don't have any more data, or run out of the congestion window
for sendingAllowed {
sentPacket, err := s.sendPacket()
if err != nil {
s.closeLocal(err)
break
}
if !sentPacket {
break
}
sendingAllowed = s.sentPacketHandler.SendingAllowed()
}
}
if !s.receivedTooManyUndecrytablePacketsTime.IsZero() && s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout).Before(now) && len(s.undecryptablePackets) != 0 { if !s.receivedTooManyUndecrytablePacketsTime.IsZero() && s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout).Before(now) && len(s.undecryptablePackets) != 0 {
s.closeLocal(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received")) s.closeLocal(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received"))
} }
@ -415,17 +448,12 @@ runLoop:
if s.handshakeComplete && now.Sub(s.lastNetworkActivityTime) >= s.config.IdleTimeout { if s.handshakeComplete && now.Sub(s.lastNetworkActivityTime) >= s.config.IdleTimeout {
s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
} }
if err := s.streamsMap.DeleteClosedStreams(); err != nil {
s.closeLocal(err)
}
} }
// only send the error the handshakeChan when the handshake is not completed yet // only send the error the handshakeChan when the handshake is not completed yet
// otherwise this chan will already be closed // otherwise this chan will already be closed
if !s.handshakeComplete { if !s.handshakeComplete {
s.handshakeCompleteChan <- closeErr.err s.handshakeChan <- closeErr.err
s.handshakeChan <- handshakeEvent{err: closeErr.err}
} }
s.handleCloseError(closeErr) s.handleCloseError(closeErr)
return closeErr.err return closeErr.err
@ -435,6 +463,10 @@ func (s *session) Context() context.Context {
return s.ctx return s.ctx
} }
func (s *session) ConnectionState() ConnectionState {
return s.cryptoSetup.ConnectionState()
}
func (s *session) maybeResetTimer() { func (s *session) maybeResetTimer() {
var deadline time.Time var deadline time.Time
if s.config.KeepAlive && s.handshakeComplete && !s.keepAlivePingSent { if s.config.KeepAlive && s.handshakeComplete && !s.keepAlivePingSent {
@ -504,7 +536,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber) s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber)
isRetransmittable := ackhandler.HasRetransmittableFrames(packet.frames) isRetransmittable := ackhandler.HasRetransmittableFrames(packet.frames)
if err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, isRetransmittable); err != nil { if err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, p.rcvTime, isRetransmittable); err != nil {
return err return err
} }
@ -524,8 +556,7 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve
s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase)) s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase))
case *wire.GoawayFrame: case *wire.GoawayFrame:
err = errors.New("unimplemented: handling GOAWAY frames") err = errors.New("unimplemented: handling GOAWAY frames")
case *wire.StopWaitingFrame: case *wire.StopWaitingFrame: // ignore STOP_WAITINGs
s.receivedPacketHandler.IgnoreBelow(frame.LeastUnacked)
case *wire.RstStreamFrame: case *wire.RstStreamFrame:
err = s.handleRstStreamFrame(frame) err = s.handleRstStreamFrame(frame)
case *wire.MaxDataFrame: case *wire.MaxDataFrame:
@ -534,6 +565,8 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve
err = s.handleMaxStreamDataFrame(frame) err = s.handleMaxStreamDataFrame(frame)
case *wire.BlockedFrame: case *wire.BlockedFrame:
case *wire.StreamBlockedFrame: case *wire.StreamBlockedFrame:
case *wire.StopSendingFrame:
err = s.handleStopSendingFrame(frame)
case *wire.PingFrame: case *wire.PingFrame:
default: default:
return errors.New("Session BUG: unexpected frame type") return errors.New("Session BUG: unexpected frame type")
@ -563,9 +596,12 @@ func (s *session) handlePacket(p *receivedPacket) {
func (s *session) handleStreamFrame(frame *wire.StreamFrame) error { func (s *session) handleStreamFrame(frame *wire.StreamFrame) error {
if frame.StreamID == s.version.CryptoStreamID() { if frame.StreamID == s.version.CryptoStreamID() {
return s.cryptoStream.AddStreamFrame(frame) if frame.FinBit {
return errors.New("Received STREAM frame with FIN bit for the crypto stream")
} }
str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) return s.cryptoStream.handleStreamFrame(frame)
}
str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID)
if err != nil { if err != nil {
return err return err
} }
@ -574,7 +610,7 @@ func (s *session) handleStreamFrame(frame *wire.StreamFrame) error {
// ignore this StreamFrame // ignore this StreamFrame
return nil return nil
} }
return str.AddStreamFrame(frame) return str.handleStreamFrame(frame)
} }
func (s *session) handleMaxDataFrame(frame *wire.MaxDataFrame) { func (s *session) handleMaxDataFrame(frame *wire.MaxDataFrame) {
@ -582,7 +618,11 @@ func (s *session) handleMaxDataFrame(frame *wire.MaxDataFrame) {
} }
func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error { func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error {
str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) if frame.StreamID == s.version.CryptoStreamID() {
s.cryptoStream.handleMaxStreamDataFrame(frame)
return nil
}
str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID)
if err != nil { if err != nil {
return err return err
} }
@ -590,12 +630,15 @@ func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error
// stream is closed and already garbage collected // stream is closed and already garbage collected
return nil return nil
} }
str.UpdateSendWindow(frame.ByteOffset) str.handleMaxStreamDataFrame(frame)
return nil return nil
} }
func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error { func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) if frame.StreamID == s.version.CryptoStreamID() {
return errors.New("Received RST_STREAM frame for the crypto stream")
}
str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID)
if err != nil { if err != nil {
return err return err
} }
@ -603,11 +646,31 @@ func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
// stream is closed and already garbage collected // stream is closed and already garbage collected
return nil return nil
} }
return str.RegisterRemoteError(fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode), frame.ByteOffset) return str.handleRstStreamFrame(frame)
}
func (s *session) handleStopSendingFrame(frame *wire.StopSendingFrame) error {
if frame.StreamID == s.version.CryptoStreamID() {
return errors.New("Received a STOP_SENDING frame for the crypto stream")
}
str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID)
if err != nil {
return err
}
if str == nil {
// stream is closed and already garbage collected
return nil
}
str.handleStopSendingFrame(frame)
return nil
} }
func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error { func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error {
return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime) if err := s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime); err != nil {
return err
}
s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked())
return nil
} }
func (s *session) closeLocal(e error) { func (s *session) closeLocal(e error) {
@ -647,7 +710,7 @@ func (s *session) handleCloseError(closeErr closeError) error {
utils.Errorf("Closing session with error: %s", closeErr.err.Error()) utils.Errorf("Closing session with error: %s", closeErr.err.Error())
} }
s.cryptoStream.Cancel(quicErr) s.cryptoStream.closeForShutdown(quicErr)
s.streamsMap.CloseWithError(quicErr) s.streamsMap.CloseWithError(quicErr)
if closeErr.err == errCloseSessionForNewVersion || closeErr.err == handshake.ErrCloseSessionForRetry { if closeErr.err == errCloseSessionForNewVersion || closeErr.err == handshake.ErrCloseSessionForRetry {
@ -669,42 +732,27 @@ func (s *session) handleCloseError(closeErr closeError) error {
func (s *session) processTransportParameters(params *handshake.TransportParameters) { func (s *session) processTransportParameters(params *handshake.TransportParameters) {
s.peerParams = params s.peerParams = params
s.streamsMap.UpdateMaxStreamLimit(params.MaxStreams) s.streamsMap.UpdateLimits(params)
if params.OmitConnectionID { if params.OmitConnectionID {
s.packer.SetOmitConnectionID() s.packer.SetOmitConnectionID()
} }
s.connFlowController.UpdateSendWindow(params.ConnectionFlowControlWindow) s.connFlowController.UpdateSendWindow(params.ConnectionFlowControlWindow)
s.streamsMap.Range(func(str streamI) { // the crypto stream is the only open stream at this moment
str.UpdateSendWindow(params.StreamFlowControlWindow) // so we don't need to update stream flow control windows
})
}
func (s *session) sendPacket() error {
s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked())
// Get MAX_DATA and MAX_STREAM_DATA frames
// this call triggers the flow controller to increase the flow control windows, if necessary
windowUpdates := s.getWindowUpdates()
for _, f := range windowUpdates {
s.packer.QueueControlFrame(f)
} }
func (s *session) maybeSendAckOnlyPacket() error {
ack := s.receivedPacketHandler.GetAckFrame() ack := s.receivedPacketHandler.GetAckFrame()
if ack != nil {
s.packer.QueueControlFrame(ack)
}
// Repeatedly try sending until we don't have any more data, or run out of the congestion window
for {
if !s.sentPacketHandler.SendingAllowed() {
if ack == nil { if ack == nil {
return nil return nil
} }
// If we aren't allowed to send, at least try sending an ACK frame s.packer.QueueControlFrame(ack)
swf := s.sentPacketHandler.GetStopWaitingFrame(false)
if swf != nil { if !s.version.UsesIETFFrameFormat() { // for gQUIC, maybe add a STOP_WAITING
if swf := s.sentPacketHandler.GetStopWaitingFrame(false); swf != nil {
s.packer.QueueControlFrame(swf) s.packer.QueueControlFrame(swf)
} }
}
packet, err := s.packer.PackAckPacket() packet, err := s.packer.PackAckPacket()
if err != nil { if err != nil {
return err return err
@ -712,6 +760,22 @@ func (s *session) sendPacket() error {
return s.sendPackedPacket(packet) return s.sendPackedPacket(packet)
} }
func (s *session) sendPacket() (bool, error) {
s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked())
if offset := s.connFlowController.GetWindowUpdate(); offset != 0 {
s.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: offset})
}
if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked {
s.packer.QueueControlFrame(&wire.BlockedFrame{Offset: offset})
}
s.windowUpdateQueue.QueueAll()
ack := s.receivedPacketHandler.GetAckFrame()
if ack != nil {
s.packer.QueueControlFrame(ack)
}
// check for retransmissions first // check for retransmissions first
for { for {
retransmitPacket := s.sentPacketHandler.DequeuePacketForRetransmission() retransmitPacket := s.sentPacketHandler.DequeuePacketForRetransmission()
@ -719,21 +783,23 @@ func (s *session) sendPacket() error {
break break
} }
// retransmit handshake packets
if retransmitPacket.EncryptionLevel != protocol.EncryptionForwardSecure { if retransmitPacket.EncryptionLevel != protocol.EncryptionForwardSecure {
if s.handshakeComplete {
// Don't retransmit handshake packets when the handshake is complete
continue
}
utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber) utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber)
if !s.version.UsesIETFFrameFormat() {
s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true)) s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true))
}
packet, err := s.packer.PackHandshakeRetransmission(retransmitPacket) packet, err := s.packer.PackHandshakeRetransmission(retransmitPacket)
if err != nil { if err != nil {
return err return false, err
} }
if err = s.sendPackedPacket(packet); err != nil { if err := s.sendPackedPacket(packet); err != nil {
return err return false, err
} }
} else { return true, nil
}
// queue all retransmittable frames sent in forward-secure packets
utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber) utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber)
// resend the frames that were in the packet // resend the frames that were in the packet
for _, frame := range retransmitPacket.GetFramesForRetransmission() { for _, frame := range retransmitPacket.GetFramesForRetransmission() {
@ -746,34 +812,25 @@ func (s *session) sendPacket() error {
} }
} }
} }
}
hasRetransmission := s.streamFramer.HasFramesForRetransmission() hasRetransmission := s.streamFramer.HasFramesForRetransmission()
if ack != nil || hasRetransmission { if !s.version.UsesIETFFrameFormat() && (ack != nil || hasRetransmission) {
swf := s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission) if swf := s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission); swf != nil {
if swf != nil {
s.packer.QueueControlFrame(swf) s.packer.QueueControlFrame(swf)
} }
} }
// add a retransmittable frame // add a retransmittable frame
if s.sentPacketHandler.ShouldSendRetransmittablePacket() { if s.sentPacketHandler.ShouldSendRetransmittablePacket() {
s.packer.QueueControlFrame(&wire.PingFrame{}) s.packer.MakeNextPacketRetransmittable()
} }
packet, err := s.packer.PackPacket() packet, err := s.packer.PackPacket()
if err != nil || packet == nil { if err != nil || packet == nil {
return err return false, err
} }
if err = s.sendPackedPacket(packet); err != nil { if err := s.sendPackedPacket(packet); err != nil {
return err return false, err
}
// send every window update twice
for _, f := range windowUpdates {
s.packer.QueueControlFrame(f)
}
windowUpdates = nil
ack = nil
} }
return true, nil
} }
func (s *session) sendPackedPacket(packet *packedPacket) error { func (s *session) sendPackedPacket(packet *packedPacket) error {
@ -824,7 +881,7 @@ func (s *session) GetOrOpenStream(id protocol.StreamID) (Stream, error) {
return str, err return str, err
} }
// make sure to return an actual nil value here, not an Stream with value nil // make sure to return an actual nil value here, not an Stream with value nil
return nil, err return str, err
} }
// AcceptStream returns the next stream openend by the peer // AcceptStream returns the next stream openend by the peer
@ -841,18 +898,6 @@ func (s *session) OpenStreamSync() (Stream, error) {
return s.streamsMap.OpenStreamSync() return s.streamsMap.OpenStreamSync()
} }
func (s *session) WaitUntilHandshakeComplete() error {
return <-s.handshakeCompleteChan
}
func (s *session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) {
s.packer.QueueControlFrame(&wire.RstStreamFrame{
StreamID: id,
ByteOffset: offset,
})
s.scheduleSending()
}
func (s *session) newStream(id protocol.StreamID) streamI { func (s *session) newStream(id protocol.StreamID) streamI {
var initialSendWindow protocol.ByteCount var initialSendWindow protocol.ByteCount
if s.peerParams != nil { if s.peerParams != nil {
@ -867,7 +912,21 @@ func (s *session) newStream(id protocol.StreamID) streamI {
initialSendWindow, initialSendWindow,
s.rttStats, s.rttStats,
) )
return newStream(id, s.scheduleSending, s.queueResetStreamFrame, flowController, s.version) return newStream(id, s, flowController, s.version)
}
func (s *session) newCryptoStream() cryptoStreamI {
id := s.version.CryptoStreamID()
flowController := flowcontrol.NewStreamFlowController(
id,
s.version.StreamContributesToConnectionFlowControl(id),
s.connFlowController,
protocol.ReceiveStreamFlowControlWindow,
protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow),
0,
s.rttStats,
)
return newCryptoStream(s, flowController, s.version)
} }
func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error { func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error {
@ -908,22 +967,25 @@ func (s *session) tryDecryptingQueuedPackets() {
s.undecryptablePackets = s.undecryptablePackets[:0] s.undecryptablePackets = s.undecryptablePackets[:0]
} }
func (s *session) getWindowUpdates() []wire.Frame { func (s *session) queueControlFrame(f wire.Frame) {
var res []wire.Frame s.packer.QueueControlFrame(f)
s.streamsMap.Range(func(str streamI) { s.scheduleSending()
if offset := str.GetWindowUpdate(); offset != 0 {
res = append(res, &wire.MaxStreamDataFrame{
StreamID: str.StreamID(),
ByteOffset: offset,
})
} }
})
if offset := s.connFlowController.GetWindowUpdate(); offset != 0 { func (s *session) onHasWindowUpdate(id protocol.StreamID) {
res = append(res, &wire.MaxDataFrame{ s.windowUpdateQueue.Add(id)
ByteOffset: offset, s.scheduleSending()
}) }
func (s *session) onHasStreamData(id protocol.StreamID) {
s.streamFramer.AddActiveStream(id)
s.scheduleSending()
}
func (s *session) onStreamCompleted(id protocol.StreamID) {
if err := s.streamsMap.DeleteStream(id); err != nil {
s.Close(err)
} }
return res
} }
func (s *session) LocalAddr() net.Addr { func (s *session) LocalAddr() net.Addr {
@ -935,11 +997,11 @@ func (s *session) RemoteAddr() net.Addr {
return s.conn.RemoteAddr() return s.conn.RemoteAddr()
} }
func (s *session) handshakeStatus() <-chan handshakeEvent { func (s *session) handshakeStatus() <-chan error {
return s.handshakeChan return s.handshakeChan
} }
func (s *session) getCryptoStream() cryptoStream { func (s *session) getCryptoStream() cryptoStreamI {
return s.cryptoStream return s.cryptoStream
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,88 +1,85 @@
package quic package quic
import ( import (
"context"
"fmt"
"io"
"net" "net"
"sync" "sync"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
) )
const (
errorCodeStopping protocol.ApplicationErrorCode = 0
errorCodeStoppingGQUIC protocol.ApplicationErrorCode = 7
)
// The streamSender is notified by the stream about various events.
type streamSender interface {
queueControlFrame(wire.Frame)
onHasWindowUpdate(protocol.StreamID)
onHasStreamData(protocol.StreamID)
onStreamCompleted(protocol.StreamID)
}
// Each of the both stream halves gets its own uniStreamSender.
// This is necessary in order to keep track when both halves have been completed.
type uniStreamSender struct {
streamSender
onStreamCompletedImpl func()
}
func (s *uniStreamSender) queueControlFrame(f wire.Frame) {
s.streamSender.queueControlFrame(f)
}
func (s *uniStreamSender) onHasWindowUpdate(id protocol.StreamID) {
s.streamSender.onHasWindowUpdate(id)
}
func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) {
s.streamSender.onHasStreamData(id)
}
func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) {
s.onStreamCompletedImpl()
}
var _ streamSender = &uniStreamSender{}
type streamI interface { type streamI interface {
Stream Stream
closeForShutdown(error)
AddStreamFrame(*wire.StreamFrame) error // for receiving
RegisterRemoteError(error, protocol.ByteCount) error handleStreamFrame(*wire.StreamFrame) error
HasDataForWriting() bool handleRstStreamFrame(*wire.RstStreamFrame) error
GetDataForWriting(maxBytes protocol.ByteCount) (data []byte, shouldSendFin bool) getWindowUpdate() protocol.ByteCount
GetWriteOffset() protocol.ByteCount // for sending
Finished() bool handleStopSendingFrame(*wire.StopSendingFrame)
Cancel(error) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool)
// methods needed for flow control handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
GetWindowUpdate() protocol.ByteCount
UpdateSendWindow(protocol.ByteCount)
IsFlowControlBlocked() bool
} }
type cryptoStream interface { var _ receiveStreamI = (streamI)(nil)
streamI var _ sendStreamI = (streamI)(nil)
SetReadOffset(protocol.ByteCount)
}
// A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface // A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface
// //
// Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually. // Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually.
type stream struct { type stream struct {
mutex sync.Mutex receiveStream
sendStream
ctx context.Context completedMutex sync.Mutex
ctxCancel context.CancelFunc sender streamSender
receiveStreamCompleted bool
sendStreamCompleted bool
streamID protocol.StreamID
onData func()
// onReset is a callback that should send a RST_STREAM
onReset func(protocol.StreamID, protocol.ByteCount)
readPosInFrame int
writeOffset protocol.ByteCount
readOffset protocol.ByteCount
// Once set, the errors must not be changed!
err error
// cancelled is set when Cancel() is called
cancelled utils.AtomicBool
// finishedReading is set once we read a frame with a FinBit
finishedReading utils.AtomicBool
// finisedWriting is set once Close() is called
finishedWriting utils.AtomicBool
// resetLocally is set if Reset() is called
resetLocally utils.AtomicBool
// resetRemotely is set if RegisterRemoteError() is called
resetRemotely utils.AtomicBool
frameQueue *streamFrameSorter
readChan chan struct{}
readDeadline time.Time
dataForWriting []byte
finSent utils.AtomicBool
rstSent utils.AtomicBool
writeChan chan struct{}
writeDeadline time.Time
flowController flowcontrol.StreamFlowController
version protocol.VersionNumber version protocol.VersionNumber
} }
var _ Stream = &stream{} var _ Stream = &stream{}
var _ streamI = &stream{}
type deadlineError struct{} type deadlineError struct{}
@ -92,290 +89,58 @@ func (deadlineError) Timeout() bool { return true }
var errDeadline net.Error = &deadlineError{} var errDeadline net.Error = &deadlineError{}
type streamCanceledError struct {
error
errorCode protocol.ApplicationErrorCode
}
func (streamCanceledError) Canceled() bool { return true }
func (e streamCanceledError) ErrorCode() protocol.ApplicationErrorCode { return e.errorCode }
var _ StreamError = &streamCanceledError{}
// newStream creates a new Stream // newStream creates a new Stream
func newStream(StreamID protocol.StreamID, func newStream(streamID protocol.StreamID,
onData func(), sender streamSender,
onReset func(protocol.StreamID, protocol.ByteCount),
flowController flowcontrol.StreamFlowController, flowController flowcontrol.StreamFlowController,
version protocol.VersionNumber, version protocol.VersionNumber,
) *stream { ) *stream {
s := &stream{ s := &stream{sender: sender}
onData: onData, senderForSendStream := &uniStreamSender{
onReset: onReset, streamSender: sender,
streamID: StreamID, onStreamCompletedImpl: func() {
flowController: flowController, s.completedMutex.Lock()
frameQueue: newStreamFrameSorter(), s.sendStreamCompleted = true
readChan: make(chan struct{}, 1), s.checkIfCompleted()
writeChan: make(chan struct{}, 1), s.completedMutex.Unlock()
version: version, },
} }
s.ctx, s.ctxCancel = context.WithCancel(context.Background()) s.sendStream = *newSendStream(streamID, senderForSendStream, flowController, version)
senderForReceiveStream := &uniStreamSender{
streamSender: sender,
onStreamCompletedImpl: func() {
s.completedMutex.Lock()
s.receiveStreamCompleted = true
s.checkIfCompleted()
s.completedMutex.Unlock()
},
}
s.receiveStream = *newReceiveStream(streamID, senderForReceiveStream, flowController)
return s return s
} }
// Read implements io.Reader. It is not thread safe! // need to define StreamID() here, since both receiveStream and readStream have a StreamID()
func (s *stream) Read(p []byte) (int, error) { func (s *stream) StreamID() protocol.StreamID {
s.mutex.Lock() // the result is same for receiveStream and sendStream
err := s.err return s.sendStream.StreamID()
s.mutex.Unlock()
if s.cancelled.Get() || s.resetLocally.Get() {
return 0, err
}
if s.finishedReading.Get() {
return 0, io.EOF
} }
bytesRead := 0
for bytesRead < len(p) {
s.mutex.Lock()
frame := s.frameQueue.Head()
if frame == nil && bytesRead > 0 {
err = s.err
s.mutex.Unlock()
return bytesRead, err
}
var err error
for {
// Stop waiting on errors
if s.resetLocally.Get() || s.cancelled.Get() {
err = s.err
break
}
deadline := s.readDeadline
if !deadline.IsZero() && !time.Now().Before(deadline) {
err = errDeadline
break
}
if frame != nil {
s.readPosInFrame = int(s.readOffset - frame.Offset)
break
}
s.mutex.Unlock()
if deadline.IsZero() {
<-s.readChan
} else {
select {
case <-s.readChan:
case <-time.After(deadline.Sub(time.Now())):
}
}
s.mutex.Lock()
frame = s.frameQueue.Head()
}
s.mutex.Unlock()
if err != nil {
return bytesRead, err
}
m := utils.Min(len(p)-bytesRead, int(frame.DataLen())-s.readPosInFrame)
if bytesRead > len(p) {
return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p))
}
if s.readPosInFrame > int(frame.DataLen()) {
return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, frame.DataLen())
}
copy(p[bytesRead:], frame.Data[s.readPosInFrame:])
s.readPosInFrame += m
bytesRead += m
s.readOffset += protocol.ByteCount(m)
// when a RST_STREAM was received, the was already informed about the final byteOffset for this stream
if !s.resetRemotely.Get() {
s.flowController.AddBytesRead(protocol.ByteCount(m))
}
s.onData() // so that a possible WINDOW_UPDATE is sent
if s.readPosInFrame >= int(frame.DataLen()) {
fin := frame.FinBit
s.mutex.Lock()
s.frameQueue.Pop()
s.mutex.Unlock()
if fin {
s.finishedReading.Set(true)
return bytesRead, io.EOF
}
}
}
return bytesRead, nil
}
func (s *stream) Write(p []byte) (int, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.resetLocally.Get() || s.err != nil {
return 0, s.err
}
if s.finishedWriting.Get() {
return 0, fmt.Errorf("write on closed stream %d", s.streamID)
}
if len(p) == 0 {
return 0, nil
}
s.dataForWriting = make([]byte, len(p))
copy(s.dataForWriting, p)
s.onData()
var err error
for {
deadline := s.writeDeadline
if !deadline.IsZero() && !time.Now().Before(deadline) {
err = errDeadline
break
}
if s.dataForWriting == nil || s.err != nil {
break
}
s.mutex.Unlock()
if deadline.IsZero() {
<-s.writeChan
} else {
select {
case <-s.writeChan:
case <-time.After(deadline.Sub(time.Now())):
}
}
s.mutex.Lock()
}
if err != nil {
return 0, err
}
if s.err != nil {
return len(p) - len(s.dataForWriting), s.err
}
return len(p), nil
}
func (s *stream) GetWriteOffset() protocol.ByteCount {
return s.writeOffset
}
// HasDataForWriting says if there's stream available to be dequeued for writing
func (s *stream) HasDataForWriting() bool {
s.mutex.Lock()
hasData := s.err == nil && // nothing should be sent if an error occurred
(len(s.dataForWriting) > 0 || // there is data queued for sending
s.finishedWriting.Get() && !s.finSent.Get()) // if there is no data, but writing finished and the FIN hasn't been sent yet
s.mutex.Unlock()
return hasData
}
func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) {
data, shouldSendFin := s.getDataForWritingImpl(maxBytes)
if shouldSendFin {
s.finSent.Set(true)
}
return data, shouldSendFin
}
func (s *stream) getDataForWritingImpl(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.err != nil || s.dataForWriting == nil {
return nil, s.finishedWriting.Get() && !s.finSent.Get()
}
// TODO(#657): Flow control for the crypto stream
if s.streamID != s.version.CryptoStreamID() {
maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize())
}
if maxBytes == 0 {
return nil, false
}
var ret []byte
if protocol.ByteCount(len(s.dataForWriting)) > maxBytes {
ret = s.dataForWriting[:maxBytes]
s.dataForWriting = s.dataForWriting[maxBytes:]
} else {
ret = s.dataForWriting
s.dataForWriting = nil
s.signalWrite()
}
s.writeOffset += protocol.ByteCount(len(ret))
s.flowController.AddBytesSent(protocol.ByteCount(len(ret)))
return ret, s.finishedWriting.Get() && s.dataForWriting == nil && !s.finSent.Get()
}
// Close implements io.Closer
func (s *stream) Close() error { func (s *stream) Close() error {
s.finishedWriting.Set(true) if err := s.sendStream.Close(); err != nil {
s.ctxCancel()
s.onData()
return nil
}
func (s *stream) shouldSendReset() bool {
if s.rstSent.Get() {
return false
}
return (s.resetLocally.Get() || s.resetRemotely.Get()) && !s.finishedWriteAndSentFin()
}
// AddStreamFrame adds a new stream frame
func (s *stream) AddStreamFrame(frame *wire.StreamFrame) error {
maxOffset := frame.Offset + frame.DataLen()
if err := s.flowController.UpdateHighestReceived(maxOffset, frame.FinBit); err != nil {
return err return err
} }
// in gQUIC, we need to send a RST_STREAM with the final offset if CancelRead() was called
s.mutex.Lock() s.receiveStream.onClose(s.sendStream.getWriteOffset())
defer s.mutex.Unlock()
if err := s.frameQueue.Push(frame); err != nil && err != errDuplicateStreamData {
return err
}
s.signalRead()
return nil
}
// signalRead performs a non-blocking send on the readChan
func (s *stream) signalRead() {
select {
case s.readChan <- struct{}{}:
default:
}
}
// signalRead performs a non-blocking send on the writeChan
func (s *stream) signalWrite() {
select {
case s.writeChan <- struct{}{}:
default:
}
}
func (s *stream) SetReadDeadline(t time.Time) error {
s.mutex.Lock()
oldDeadline := s.readDeadline
s.readDeadline = t
s.mutex.Unlock()
// if the new deadline is before the currently set deadline, wake up Read()
if t.Before(oldDeadline) {
s.signalRead()
}
return nil
}
func (s *stream) SetWriteDeadline(t time.Time) error {
s.mutex.Lock()
oldDeadline := s.writeDeadline
s.writeDeadline = t
s.mutex.Unlock()
if t.Before(oldDeadline) {
s.signalWrite()
}
return nil return nil
} }
@ -385,107 +150,31 @@ func (s *stream) SetDeadline(t time.Time) error {
return nil return nil
} }
// CloseRemote makes the stream receive a "virtual" FIN stream frame at a given offset // CloseForShutdown closes a stream abruptly.
func (s *stream) CloseRemote(offset protocol.ByteCount) { // It makes Read and Write unblock (and return the error) immediately.
s.AddStreamFrame(&wire.StreamFrame{FinBit: true, Offset: offset}) // The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
func (s *stream) closeForShutdown(err error) {
s.sendStream.closeForShutdown(err)
s.receiveStream.closeForShutdown(err)
} }
// Cancel is called by session to indicate that an error occurred func (s *stream) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
// The stream should will be closed immediately if err := s.receiveStream.handleRstStreamFrame(frame); err != nil {
func (s *stream) Cancel(err error) {
s.mutex.Lock()
s.cancelled.Set(true)
s.ctxCancel()
// errors must not be changed!
if s.err == nil {
s.err = err
s.signalRead()
s.signalWrite()
}
s.mutex.Unlock()
}
// resets the stream locally
func (s *stream) Reset(err error) {
if s.resetLocally.Get() {
return
}
s.mutex.Lock()
s.resetLocally.Set(true)
s.ctxCancel()
// errors must not be changed!
if s.err == nil {
s.err = err
s.signalRead()
s.signalWrite()
}
if s.shouldSendReset() {
s.onReset(s.streamID, s.writeOffset)
s.rstSent.Set(true)
}
s.mutex.Unlock()
}
// resets the stream remotely
func (s *stream) RegisterRemoteError(err error, offset protocol.ByteCount) error {
if s.resetRemotely.Get() {
return nil
}
s.mutex.Lock()
s.resetRemotely.Set(true)
s.ctxCancel()
// errors must not be changed!
if s.err == nil {
s.err = err
s.signalWrite()
}
if err := s.flowController.UpdateHighestReceived(offset, true); err != nil {
return err return err
} }
if s.shouldSendReset() { if !s.version.UsesIETFFrameFormat() {
s.onReset(s.streamID, s.writeOffset) s.handleStopSendingFrame(&wire.StopSendingFrame{
s.rstSent.Set(true) StreamID: s.StreamID(),
ErrorCode: frame.ErrorCode,
})
} }
s.mutex.Unlock()
return nil return nil
} }
func (s *stream) finishedWriteAndSentFin() bool { // checkIfCompleted is called from the uniStreamSender, when one of the stream halves is completed.
return s.finishedWriting.Get() && s.finSent.Get() // It makes sure that the onStreamCompleted callback is only called if both receive and send side have completed.
func (s *stream) checkIfCompleted() {
if s.sendStreamCompleted && s.receiveStreamCompleted {
s.sender.onStreamCompleted(s.StreamID())
} }
func (s *stream) Finished() bool {
return s.cancelled.Get() ||
(s.finishedReading.Get() && s.finishedWriteAndSentFin()) ||
(s.resetRemotely.Get() && s.rstSent.Get()) ||
(s.finishedReading.Get() && s.rstSent.Get()) ||
(s.finishedWriteAndSentFin() && s.resetRemotely.Get())
}
func (s *stream) Context() context.Context {
return s.ctx
}
func (s *stream) StreamID() protocol.StreamID {
return s.streamID
}
func (s *stream) UpdateSendWindow(n protocol.ByteCount) {
s.flowController.UpdateSendWindow(n)
}
func (s *stream) IsFlowControlBlocked() bool {
return s.flowController.IsBlocked()
}
func (s *stream) GetWindowUpdate() protocol.ByteCount {
return s.flowController.GetWindowUpdate()
}
// SetReadOffset sets the read offset.
// It is only needed for the crypto stream.
// It must not be called concurrently with any other stream methods, especially Read and Write.
func (s *stream) SetReadOffset(offset protocol.ByteCount) {
s.readOffset = offset
s.frameQueue.readPosition = offset
} }

View File

@ -1,32 +1,34 @@
package quic package quic
import ( import (
"github.com/lucas-clemente/quic-go/internal/flowcontrol" "sync"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
) )
type streamFramer struct { type streamFramer struct {
streamsMap *streamsMap streamGetter streamGetter
cryptoStream streamI cryptoStream cryptoStreamI
version protocol.VersionNumber version protocol.VersionNumber
connFlowController flowcontrol.ConnectionFlowController
retransmissionQueue []*wire.StreamFrame retransmissionQueue []*wire.StreamFrame
blockedFrameQueue []wire.Frame
streamQueueMutex sync.Mutex
activeStreams map[protocol.StreamID]struct{}
streamQueue []protocol.StreamID
hasCryptoStreamData bool
} }
func newStreamFramer( func newStreamFramer(
cryptoStream streamI, cryptoStream cryptoStreamI,
streamsMap *streamsMap, streamGetter streamGetter,
cfc flowcontrol.ConnectionFlowController,
v protocol.VersionNumber, v protocol.VersionNumber,
) *streamFramer { ) *streamFramer {
return &streamFramer{ return &streamFramer{
streamsMap: streamsMap, streamGetter: streamGetter,
cryptoStream: cryptoStream, cryptoStream: cryptoStream,
connFlowController: cfc, activeStreams: make(map[protocol.StreamID]struct{}),
version: v, version: v,
} }
} }
@ -35,114 +37,101 @@ func (f *streamFramer) AddFrameForRetransmission(frame *wire.StreamFrame) {
f.retransmissionQueue = append(f.retransmissionQueue, frame) f.retransmissionQueue = append(f.retransmissionQueue, frame)
} }
func (f *streamFramer) AddActiveStream(id protocol.StreamID) {
if id == f.version.CryptoStreamID() { // the crypto stream is handled separately
f.streamQueueMutex.Lock()
f.hasCryptoStreamData = true
f.streamQueueMutex.Unlock()
return
}
f.streamQueueMutex.Lock()
if _, ok := f.activeStreams[id]; !ok {
f.streamQueue = append(f.streamQueue, id)
f.activeStreams[id] = struct{}{}
}
f.streamQueueMutex.Unlock()
}
func (f *streamFramer) PopStreamFrames(maxLen protocol.ByteCount) []*wire.StreamFrame { func (f *streamFramer) PopStreamFrames(maxLen protocol.ByteCount) []*wire.StreamFrame {
fs, currentLen := f.maybePopFramesForRetransmission(maxLen) fs, currentLen := f.maybePopFramesForRetransmission(maxLen)
return append(fs, f.maybePopNormalFrames(maxLen-currentLen)...) return append(fs, f.maybePopNormalFrames(maxLen-currentLen)...)
} }
func (f *streamFramer) PopBlockedFrame() wire.Frame {
if len(f.blockedFrameQueue) == 0 {
return nil
}
frame := f.blockedFrameQueue[0]
f.blockedFrameQueue = f.blockedFrameQueue[1:]
return frame
}
func (f *streamFramer) HasFramesForRetransmission() bool { func (f *streamFramer) HasFramesForRetransmission() bool {
return len(f.retransmissionQueue) > 0 return len(f.retransmissionQueue) > 0
} }
func (f *streamFramer) HasCryptoStreamFrame() bool { func (f *streamFramer) HasCryptoStreamData() bool {
return f.cryptoStream.HasDataForWriting() f.streamQueueMutex.Lock()
hasCryptoStreamData := f.hasCryptoStreamData
f.streamQueueMutex.Unlock()
return hasCryptoStreamData
} }
// TODO(lclemente): This is somewhat duplicate with the normal path for generating frames.
func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.StreamFrame { func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.StreamFrame {
if !f.HasCryptoStreamFrame() { f.streamQueueMutex.Lock()
return nil frame, hasMoreData := f.cryptoStream.popStreamFrame(maxLen)
} f.hasCryptoStreamData = hasMoreData
frame := &wire.StreamFrame{ f.streamQueueMutex.Unlock()
StreamID: f.cryptoStream.StreamID(),
Offset: f.cryptoStream.GetWriteOffset(),
}
frameHeaderBytes, _ := frame.MinLength(f.version) // can never error
frame.Data, frame.FinBit = f.cryptoStream.GetDataForWriting(maxLen - frameHeaderBytes)
return frame return frame
} }
func (f *streamFramer) maybePopFramesForRetransmission(maxLen protocol.ByteCount) (res []*wire.StreamFrame, currentLen protocol.ByteCount) { func (f *streamFramer) maybePopFramesForRetransmission(maxTotalLen protocol.ByteCount) (res []*wire.StreamFrame, currentLen protocol.ByteCount) {
for len(f.retransmissionQueue) > 0 { for len(f.retransmissionQueue) > 0 {
frame := f.retransmissionQueue[0] frame := f.retransmissionQueue[0]
frame.DataLenPresent = true frame.DataLenPresent = true
frameHeaderLen, _ := frame.MinLength(f.version) // can never error frameHeaderLen := frame.MinLength(f.version) // can never error
if currentLen+frameHeaderLen >= maxLen { maxLen := maxTotalLen - currentLen
if frameHeaderLen+frame.DataLen() > maxLen && maxLen < protocol.MinStreamFrameSize {
break break
} }
currentLen += frameHeaderLen splitFrame := maybeSplitOffFrame(frame, maxLen-frameHeaderLen)
splitFrame := maybeSplitOffFrame(frame, maxLen-currentLen)
if splitFrame != nil { // StreamFrame was split if splitFrame != nil { // StreamFrame was split
res = append(res, splitFrame) res = append(res, splitFrame)
currentLen += splitFrame.DataLen() currentLen += frameHeaderLen + splitFrame.DataLen()
break break
} }
f.retransmissionQueue = f.retransmissionQueue[1:] f.retransmissionQueue = f.retransmissionQueue[1:]
res = append(res, frame) res = append(res, frame)
currentLen += frame.DataLen() currentLen += frameHeaderLen + frame.DataLen()
} }
return return
} }
func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []*wire.StreamFrame) { func (f *streamFramer) maybePopNormalFrames(maxTotalLen protocol.ByteCount) []*wire.StreamFrame {
frame := &wire.StreamFrame{DataLenPresent: true}
var currentLen protocol.ByteCount var currentLen protocol.ByteCount
var frames []*wire.StreamFrame
fn := func(s streamI) (bool, error) { f.streamQueueMutex.Lock()
if s == nil { // pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet
return true, nil numActiveStreams := len(f.streamQueue)
for i := 0; i < numActiveStreams; i++ {
if maxTotalLen-currentLen < protocol.MinStreamFrameSize {
break
} }
id := f.streamQueue[0]
frame.StreamID = s.StreamID() f.streamQueue = f.streamQueue[1:]
frame.Offset = s.GetWriteOffset() str, err := f.streamGetter.GetOrOpenSendStream(id)
// not perfect, but thread-safe since writeOffset is only written when getting data if err != nil { // can happen if the stream completed after it said it had data
frameHeaderBytes, _ := frame.MinLength(f.version) // can never error delete(f.activeStreams, id)
if currentLen+frameHeaderBytes > maxBytes { continue
return false, nil // theoretically, we could find another stream that fits, but this is quite unlikely, so we stop here
} }
maxLen := maxBytes - currentLen - frameHeaderBytes frame, hasMoreData := str.popStreamFrame(maxTotalLen - currentLen)
if hasMoreData { // put the stream back in the queue (at the end)
if s.HasDataForWriting() { f.streamQueue = append(f.streamQueue, id)
frame.Data, frame.FinBit = s.GetDataForWriting(maxLen) } else { // no more data to send. Stream is not active any more
delete(f.activeStreams, id)
} }
if len(frame.Data) == 0 && !frame.FinBit { if frame == nil { // can happen if the receiveStream was canceled after it said it had data
return true, nil continue
} }
frames = append(frames, frame)
// Finally, check if we are now FC blocked and should queue a BLOCKED frame currentLen += frame.MinLength(f.version) + frame.DataLen()
if !frame.FinBit && s.IsFlowControlBlocked() {
f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.StreamBlockedFrame{StreamID: s.StreamID()})
} }
if f.connFlowController.IsBlocked() { f.streamQueueMutex.Unlock()
f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.BlockedFrame{}) return frames
}
res = append(res, frame)
currentLen += frameHeaderBytes + frame.DataLen()
if currentLen == maxBytes {
return false, nil
}
frame = &wire.StreamFrame{DataLenPresent: true}
return true, nil
}
f.streamsMap.RoundRobinIterate(fn)
return
} }
// maybeSplitOffFrame removes the first n bytes and returns them as a separate frame. If n >= len(frame), nil is returned and nothing is modified. // maybeSplitOffFrame removes the first n bytes and returns them as a separate frame. If n >= len(frame), nil is returned and nothing is modified.

View File

@ -2,9 +2,9 @@ package quic
import ( import (
"bytes" "bytes"
"errors"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
@ -21,12 +21,13 @@ var _ = Describe("Stream Framer", func() {
var ( var (
retransmittedFrame1, retransmittedFrame2 *wire.StreamFrame retransmittedFrame1, retransmittedFrame2 *wire.StreamFrame
framer *streamFramer framer *streamFramer
streamsMap *streamsMap cryptoStream *MockCryptoStream
stream1, stream2 *mocks.MockStreamI stream1, stream2 *MockSendStreamI
connFC *mocks.MockConnectionFlowController streamGetter *MockStreamGetter
) )
BeforeEach(func() { BeforeEach(func() {
streamGetter = NewMockStreamGetter(mockCtrl)
retransmittedFrame1 = &wire.StreamFrame{ retransmittedFrame1 = &wire.StreamFrame{
StreamID: 5, StreamID: 5,
Data: []byte{0x13, 0x37}, Data: []byte{0x13, 0x37},
@ -36,25 +37,14 @@ var _ = Describe("Stream Framer", func() {
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
} }
stream1 = mocks.NewMockStreamI(mockCtrl) stream1 = NewMockSendStreamI(mockCtrl)
stream1.EXPECT().StreamID().Return(protocol.StreamID(5)).AnyTimes() stream1.EXPECT().StreamID().Return(protocol.StreamID(5)).AnyTimes()
stream2 = mocks.NewMockStreamI(mockCtrl) stream2 = NewMockSendStreamI(mockCtrl)
stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes() stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes()
cryptoStream = NewMockCryptoStream(mockCtrl)
streamsMap = newStreamsMap(nil, protocol.PerspectiveServer, versionGQUICFrames) framer = newStreamFramer(cryptoStream, streamGetter, versionGQUICFrames)
streamsMap.putStream(stream1)
streamsMap.putStream(stream2)
connFC = mocks.NewMockConnectionFlowController(mockCtrl)
framer = newStreamFramer(nil, streamsMap, connFC, versionGQUICFrames)
}) })
setNoData := func(str *mocks.MockStreamI) {
str.EXPECT().HasDataForWriting().Return(false).AnyTimes()
str.EXPECT().GetDataForWriting(gomock.Any()).Return(nil, false).AnyTimes()
str.EXPECT().GetWriteOffset().AnyTimes()
}
It("says if it has retransmissions", func() { It("says if it has retransmissions", func() {
Expect(framer.HasFramesForRetransmission()).To(BeFalse()) Expect(framer.HasFramesForRetransmission()).To(BeFalse())
framer.AddFrameForRetransmission(retransmittedFrame1) framer.AddFrameForRetransmission(retransmittedFrame1)
@ -62,119 +52,220 @@ var _ = Describe("Stream Framer", func() {
}) })
It("sets the DataLenPresent for dequeued retransmitted frames", func() { It("sets the DataLenPresent for dequeued retransmitted frames", func() {
setNoData(stream1)
setNoData(stream2)
framer.AddFrameForRetransmission(retransmittedFrame1) framer.AddFrameForRetransmission(retransmittedFrame1)
fs := framer.PopStreamFrames(protocol.MaxByteCount) fs := framer.PopStreamFrames(protocol.MaxByteCount)
Expect(fs).To(HaveLen(1)) Expect(fs).To(HaveLen(1))
Expect(fs[0].DataLenPresent).To(BeTrue()) Expect(fs[0].DataLenPresent).To(BeTrue())
}) })
It("sets the DataLenPresent for dequeued normal frames", func() { Context("handling the crypto stream", func() {
connFC.EXPECT().IsBlocked() It("says if it has crypto stream data", func() {
setNoData(stream2) Expect(framer.HasCryptoStreamData()).To(BeFalse())
stream1.EXPECT().GetWriteOffset() framer.AddActiveStream(framer.version.CryptoStreamID())
stream1.EXPECT().HasDataForWriting().Return(true) Expect(framer.HasCryptoStreamData()).To(BeTrue())
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) })
stream1.EXPECT().IsFlowControlBlocked()
fs := framer.PopStreamFrames(protocol.MaxByteCount) It("says that it doesn't have crypto stream data after popping all data", func() {
Expect(fs).To(HaveLen(1)) streamID := framer.version.CryptoStreamID()
Expect(fs[0].DataLenPresent).To(BeTrue()) f := &wire.StreamFrame{
StreamID: streamID,
Data: []byte("foobar"),
}
cryptoStream.EXPECT().popStreamFrame(protocol.ByteCount(1000)).Return(f, false)
framer.AddActiveStream(streamID)
Expect(framer.PopCryptoStreamFrame(1000)).To(Equal(f))
Expect(framer.HasCryptoStreamData()).To(BeFalse())
})
It("says that it has more crypto stream data if not all data was popped", func() {
streamID := framer.version.CryptoStreamID()
f := &wire.StreamFrame{
StreamID: streamID,
Data: []byte("foobar"),
}
cryptoStream.EXPECT().popStreamFrame(protocol.ByteCount(1000)).Return(f, true)
framer.AddActiveStream(streamID)
Expect(framer.PopCryptoStreamFrame(1000)).To(Equal(f))
Expect(framer.HasCryptoStreamData()).To(BeTrue())
})
}) })
Context("Popping", func() { Context("Popping", func() {
BeforeEach(func() {
// nothing is blocked here
connFC.EXPECT().IsBlocked().AnyTimes()
stream1.EXPECT().IsFlowControlBlocked().Return(false).AnyTimes()
stream2.EXPECT().IsFlowControlBlocked().Return(false).AnyTimes()
})
It("returns nil when popping an empty framer", func() { It("returns nil when popping an empty framer", func() {
setNoData(stream1)
setNoData(stream2)
Expect(framer.PopStreamFrames(1000)).To(BeEmpty()) Expect(framer.PopStreamFrames(1000)).To(BeEmpty())
}) })
It("pops frames for retransmission", func() { It("pops frames for retransmission", func() {
setNoData(stream1)
setNoData(stream2)
framer.AddFrameForRetransmission(retransmittedFrame1) framer.AddFrameForRetransmission(retransmittedFrame1)
framer.AddFrameForRetransmission(retransmittedFrame2) framer.AddFrameForRetransmission(retransmittedFrame2)
fs := framer.PopStreamFrames(1000) fs := framer.PopStreamFrames(1000)
Expect(fs).To(HaveLen(2)) Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1, retransmittedFrame2}))
Expect(fs[0]).To(Equal(retransmittedFrame1)) // make sure the frames are actually removed, and not returned a second time
Expect(fs[1]).To(Equal(retransmittedFrame2))
Expect(framer.PopStreamFrames(1000)).To(BeEmpty()) Expect(framer.PopStreamFrames(1000)).To(BeEmpty())
}) })
It("returns normal frames", func() { It("doesn't pop frames for retransmission, if the size would be smaller than the minimum STREAM frame size", func() {
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) framer.AddFrameForRetransmission(&wire.StreamFrame{
stream1.EXPECT().HasDataForWriting().Return(true) StreamID: id1,
stream1.EXPECT().GetWriteOffset() Data: bytes.Repeat([]byte{'a'}, int(protocol.MinStreamFrameSize)),
setNoData(stream2)
fs := framer.PopStreamFrames(1000)
Expect(fs).To(HaveLen(1))
Expect(fs[0].StreamID).To(Equal(stream1.StreamID()))
Expect(fs[0].Data).To(Equal([]byte("foobar")))
Expect(fs[0].FinBit).To(BeFalse())
}) })
fs := framer.PopStreamFrames(protocol.MinStreamFrameSize - 1)
It("returns multiple normal frames", func() {
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false)
stream1.EXPECT().HasDataForWriting().Return(true)
stream1.EXPECT().GetWriteOffset()
stream2.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobaz"), false)
stream2.EXPECT().HasDataForWriting().Return(true)
stream2.EXPECT().GetWriteOffset()
fs := framer.PopStreamFrames(1000)
Expect(fs).To(HaveLen(2))
// Swap if we dequeued in other order
if fs[0].StreamID != stream1.StreamID() {
fs[0], fs[1] = fs[1], fs[0]
}
Expect(fs[0].StreamID).To(Equal(stream1.StreamID()))
Expect(fs[0].Data).To(Equal([]byte("foobar")))
Expect(fs[1].StreamID).To(Equal(stream2.StreamID()))
Expect(fs[1].Data).To(Equal([]byte("foobaz")))
})
It("returns retransmission frames before normal frames", func() {
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false)
stream1.EXPECT().HasDataForWriting().Return(true)
stream1.EXPECT().GetWriteOffset()
setNoData(stream2)
framer.AddFrameForRetransmission(retransmittedFrame1)
fs := framer.PopStreamFrames(1000)
Expect(fs).To(HaveLen(2))
Expect(fs[0]).To(Equal(retransmittedFrame1))
Expect(fs[1].StreamID).To(Equal(stream1.StreamID()))
})
It("does not pop empty frames", func() {
stream1.EXPECT().HasDataForWriting().Return(false)
stream1.EXPECT().GetWriteOffset()
setNoData(stream2)
fs := framer.PopStreamFrames(5)
Expect(fs).To(BeEmpty()) Expect(fs).To(BeEmpty())
}) })
It("uses the round-robin scheduling", func() { It("pops frames for retransmission, even if the remaining space in the packet is too small, if the frame doesn't need to be split", func() {
streamFrameHeaderLen := protocol.ByteCount(4) framer.AddFrameForRetransmission(retransmittedFrame1)
stream1.EXPECT().GetDataForWriting(10-streamFrameHeaderLen).Return(bytes.Repeat([]byte("f"), int(10-streamFrameHeaderLen)), false) fs := framer.PopStreamFrames(protocol.MinStreamFrameSize - 1)
stream1.EXPECT().HasDataForWriting().Return(true) Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1}))
stream1.EXPECT().GetWriteOffset() })
stream2.EXPECT().GetDataForWriting(protocol.ByteCount(10-streamFrameHeaderLen)).Return(bytes.Repeat([]byte("e"), int(10-streamFrameHeaderLen)), false)
stream2.EXPECT().HasDataForWriting().Return(true) It("pops frames for retransmission, if the remaining size is the miniumum STREAM frame size", func() {
stream2.EXPECT().GetWriteOffset() framer.AddFrameForRetransmission(retransmittedFrame1)
fs := framer.PopStreamFrames(10) fs := framer.PopStreamFrames(protocol.MinStreamFrameSize)
Expect(fs).To(HaveLen(1)) Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1}))
// it doesn't matter here if this data is from stream1 or from stream2... })
firstStreamID := fs[0].StreamID
fs = framer.PopStreamFrames(10) It("returns normal frames", func() {
Expect(fs).To(HaveLen(1)) streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
// ... but the data popped this time has to be from the other stream f := &wire.StreamFrame{
Expect(fs[0].StreamID).ToNot(Equal(firstStreamID)) StreamID: id1,
Data: []byte("foobar"),
Offset: 42,
}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false)
framer.AddActiveStream(id1)
fs := framer.PopStreamFrames(1000)
Expect(fs).To(Equal([]*wire.StreamFrame{f}))
})
It("skips a stream that was reported active, but was completed shortly after", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(nil, errors.New("stream was already deleted"))
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
f := &wire.StreamFrame{
StreamID: id2,
Data: []byte("foobar"),
}
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f, false)
framer.AddActiveStream(id1)
framer.AddActiveStream(id2)
Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f}))
})
It("skips a stream that was reported active, but doesn't have any data", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
f := &wire.StreamFrame{
StreamID: id2,
Data: []byte("foobar"),
}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(nil, false)
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f, false)
framer.AddActiveStream(id1)
framer.AddActiveStream(id2)
Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f}))
})
It("pops from a stream multiple times, if it has enough data", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2)
f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
f2 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, true)
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false)
framer.AddActiveStream(id1) // only add it once
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f1}))
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f2}))
// no further calls to popStreamFrame, after popStreamFrame said there's no more data
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(BeNil())
})
It("re-queues a stream at the end, if it has enough data", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2)
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
f11 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
f12 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")}
f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f11, true)
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f12, false)
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false)
framer.AddActiveStream(id1) // only add it once
framer.AddActiveStream(id2)
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f11})) // first a frame from stream 1
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f2})) // then a frame from stream 2
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f12})) // then another frame from stream 1
})
It("only dequeues data from each stream once per packet", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")}
// both streams have more data, and will be re-queued
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, true)
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, true)
framer.AddActiveStream(id1)
framer.AddActiveStream(id2)
Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f1, f2}))
})
It("returns multiple normal frames in the order they were reported active", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
f1 := &wire.StreamFrame{Data: []byte("foobar")}
f2 := &wire.StreamFrame{Data: []byte("foobaz")}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, false)
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false)
framer.AddActiveStream(id2)
framer.AddActiveStream(id1)
Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f2, f1}))
})
It("only asks a stream for data once, even if it was reported active multiple times", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
f := &wire.StreamFrame{Data: []byte("foobar")}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) // only one call to this function
framer.AddActiveStream(id1)
framer.AddActiveStream(id1)
Expect(framer.PopStreamFrames(1000)).To(HaveLen(1))
})
It("returns retransmission frames before normal frames", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
framer.AddActiveStream(id1)
f1 := &wire.StreamFrame{Data: []byte("foobar")}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, false)
framer.AddFrameForRetransmission(retransmittedFrame1)
fs := framer.PopStreamFrames(1000)
Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1, f1}))
})
It("does not pop empty frames", func() {
fs := framer.PopStreamFrames(500)
Expect(fs).To(BeEmpty())
})
It("pops frames that have the minimum size", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
stream1.EXPECT().popStreamFrame(protocol.MinStreamFrameSize).Return(&wire.StreamFrame{Data: []byte("foobar")}, false)
framer.AddActiveStream(id1)
framer.PopStreamFrames(protocol.MinStreamFrameSize)
})
It("does not pop frames smaller than the mimimum size", func() {
// don't expect a call to PopStreamFrame()
framer.PopStreamFrames(protocol.MinStreamFrameSize - 1)
})
It("stops iterating when the remaining size is smaller than the minimum STREAM frame size", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
// pop a frame such that the remaining size is one byte less than the minimum STREAM frame size
f := &wire.StreamFrame{
StreamID: id1,
Data: bytes.Repeat([]byte("f"), int(500-protocol.MinStreamFrameSize)),
}
stream1.EXPECT().popStreamFrame(protocol.ByteCount(500)).Return(f, false)
framer.AddActiveStream(id1)
fs := framer.PopStreamFrames(500)
Expect(fs).To(Equal([]*wire.StreamFrame{f}))
}) })
Context("splitting of frames", func() { Context("splitting of frames", func() {
@ -212,139 +303,28 @@ var _ = Describe("Stream Framer", func() {
}) })
It("splits a frame", func() { It("splits a frame", func() {
setNoData(stream1) frame := &wire.StreamFrame{Data: bytes.Repeat([]byte{0}, 600)}
setNoData(stream2) framer.AddFrameForRetransmission(frame)
framer.AddFrameForRetransmission(retransmittedFrame2) fs := framer.PopStreamFrames(500)
origlen := retransmittedFrame2.DataLen()
fs := framer.PopStreamFrames(6)
Expect(fs).To(HaveLen(1)) Expect(fs).To(HaveLen(1))
minLength, _ := fs[0].MinLength(framer.version) minLength := fs[0].MinLength(framer.version)
Expect(minLength + fs[0].DataLen()).To(Equal(protocol.ByteCount(6))) Expect(minLength + fs[0].DataLen()).To(Equal(protocol.ByteCount(500)))
Expect(framer.retransmissionQueue[0].Data).To(HaveLen(int(origlen - fs[0].DataLen()))) Expect(framer.retransmissionQueue[0].Data).To(HaveLen(int(600 - fs[0].DataLen())))
Expect(framer.retransmissionQueue[0].Offset).To(Equal(fs[0].DataLen())) Expect(framer.retransmissionQueue[0].Offset).To(Equal(fs[0].DataLen()))
}) })
It("never returns an empty stream frame", func() {
// this one frame will be split off from again and again in this test. Therefore, it has to be large enough (checked again at the end)
origFrame := &wire.StreamFrame{
StreamID: 5,
Offset: 1,
FinBit: false,
Data: bytes.Repeat([]byte{'f'}, 30*30),
}
framer.AddFrameForRetransmission(origFrame)
minFrameDataLen := protocol.MaxPacketSize
for i := 0; i < 30; i++ {
frames, currentLen := framer.maybePopFramesForRetransmission(protocol.ByteCount(i))
if len(frames) == 0 {
Expect(currentLen).To(BeZero())
} else {
Expect(frames).To(HaveLen(1))
Expect(currentLen).ToNot(BeZero())
dataLen := frames[0].DataLen()
Expect(dataLen).ToNot(BeZero())
if dataLen < minFrameDataLen {
minFrameDataLen = dataLen
}
}
}
Expect(framer.retransmissionQueue).To(HaveLen(1)) // check that origFrame was large enough for this test and didn't get used up completely
Expect(minFrameDataLen).To(Equal(protocol.ByteCount(1)))
})
It("only removes a frame from the framer after returning all split parts", func() { It("only removes a frame from the framer after returning all split parts", func() {
setNoData(stream1) frameHeaderLen := protocol.ByteCount(4)
setNoData(stream2) frame := &wire.StreamFrame{Data: bytes.Repeat([]byte{0}, int(501-frameHeaderLen))}
framer.AddFrameForRetransmission(retransmittedFrame2) framer.AddFrameForRetransmission(frame)
fs := framer.PopStreamFrames(6) fs := framer.PopStreamFrames(500)
Expect(fs).To(HaveLen(1)) Expect(fs).To(HaveLen(1))
Expect(framer.retransmissionQueue).ToNot(BeEmpty()) Expect(framer.retransmissionQueue).ToNot(BeEmpty())
fs = framer.PopStreamFrames(1000) fs = framer.PopStreamFrames(500)
Expect(fs).To(HaveLen(1)) Expect(fs).To(HaveLen(1))
Expect(fs[0].DataLen()).To(BeEquivalentTo(1))
Expect(framer.retransmissionQueue).To(BeEmpty()) Expect(framer.retransmissionQueue).To(BeEmpty())
}) })
}) })
Context("sending FINs", func() {
It("sends FINs when streams are closed", func() {
offset := protocol.ByteCount(42)
stream1.EXPECT().HasDataForWriting().Return(true)
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return(nil, true)
stream1.EXPECT().GetWriteOffset().Return(offset)
setNoData(stream2)
fs := framer.PopStreamFrames(1000)
Expect(fs).To(HaveLen(1))
Expect(fs[0].StreamID).To(Equal(stream1.StreamID()))
Expect(fs[0].Offset).To(Equal(offset))
Expect(fs[0].FinBit).To(BeTrue())
Expect(fs[0].Data).To(BeEmpty())
})
It("bundles FINs with data", func() {
offset := protocol.ByteCount(42)
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), true)
stream1.EXPECT().HasDataForWriting().Return(true)
stream1.EXPECT().GetWriteOffset().Return(offset)
setNoData(stream2)
fs := framer.PopStreamFrames(1000)
Expect(fs).To(HaveLen(1))
Expect(fs[0].StreamID).To(Equal(stream1.StreamID()))
Expect(fs[0].Data).To(Equal([]byte("foobar")))
Expect(fs[0].FinBit).To(BeTrue())
})
})
})
Context("BLOCKED frames", func() {
It("Pop returns nil if no frame is queued", func() {
Expect(framer.PopBlockedFrame()).To(BeNil())
})
It("queues and pops BLOCKED frames for individually blocked streams", func() {
connFC.EXPECT().IsBlocked()
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false)
stream1.EXPECT().HasDataForWriting().Return(true)
stream1.EXPECT().GetWriteOffset()
stream1.EXPECT().IsFlowControlBlocked().Return(true)
setNoData(stream2)
frames := framer.PopStreamFrames(1000)
Expect(frames).To(HaveLen(1))
f := framer.PopBlockedFrame()
Expect(f).To(BeAssignableToTypeOf(&wire.StreamBlockedFrame{}))
bf := f.(*wire.StreamBlockedFrame)
Expect(bf.StreamID).To(Equal(stream1.StreamID()))
Expect(framer.PopBlockedFrame()).To(BeNil())
})
It("does not queue a stream-level BLOCKED frame after sending the FinBit frame", func() {
connFC.EXPECT().IsBlocked()
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo"), true)
stream1.EXPECT().HasDataForWriting().Return(true)
stream1.EXPECT().GetWriteOffset()
setNoData(stream2)
frames := framer.PopStreamFrames(1000)
Expect(frames).To(HaveLen(1))
Expect(frames[0].FinBit).To(BeTrue())
Expect(frames[0].DataLen()).To(Equal(protocol.ByteCount(3)))
blockedFrame := framer.PopBlockedFrame()
Expect(blockedFrame).To(BeNil())
})
It("queues and pops BLOCKED frames for connection blocked streams", func() {
connFC.EXPECT().IsBlocked().Return(true)
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo"), false)
stream1.EXPECT().HasDataForWriting().Return(true)
stream1.EXPECT().GetWriteOffset()
stream1.EXPECT().IsFlowControlBlocked().Return(false)
setNoData(stream2)
framer.PopStreamFrames(1000)
f := framer.PopBlockedFrame()
Expect(f).To(BeAssignableToTypeOf(&wire.BlockedFrame{}))
Expect(framer.PopBlockedFrame()).To(BeNil())
})
}) })
}) })

File diff suppressed because it is too large Load Diff

View File

@ -5,8 +5,9 @@ import (
"fmt" "fmt"
"sync" "sync"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
) )
@ -16,11 +17,8 @@ type streamsMap struct {
perspective protocol.Perspective perspective protocol.Perspective
streams map[protocol.StreamID]streamI streams map[protocol.StreamID]streamI
// needed for round-robin scheduling
openStreams []protocol.StreamID
roundRobinIndex int
nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() nextStreamToOpen protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
highestStreamOpenedByPeer protocol.StreamID highestStreamOpenedByPeer protocol.StreamID
nextStreamOrErrCond sync.Cond nextStreamOrErrCond sync.Cond
openStreamOrErrCond sync.Cond openStreamOrErrCond sync.Cond
@ -29,47 +27,32 @@ type streamsMap struct {
nextStreamToAccept protocol.StreamID nextStreamToAccept protocol.StreamID
newStream newStreamLambda newStream newStreamLambda
numOutgoingStreams uint32
numIncomingStreams uint32
maxIncomingStreams uint32
maxOutgoingStreams uint32
} }
type streamLambda func(streamI) (bool, error) var _ streamManager = &streamsMap{}
type newStreamLambda func(protocol.StreamID) streamI type newStreamLambda func(protocol.StreamID) streamI
var errMapAccess = errors.New("streamsMap: Error accessing the streams map") var errMapAccess = errors.New("streamsMap: Error accessing the streams map")
func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver protocol.VersionNumber) *streamsMap { func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective) streamManager {
// add some tolerance to the maximum incoming streams value
maxStreams := uint32(protocol.MaxIncomingStreams)
maxIncomingStreams := utils.MaxUint32(
maxStreams+protocol.MaxStreamsMinimumIncrement,
uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)),
)
sm := streamsMap{ sm := streamsMap{
perspective: pers, perspective: pers,
streams: make(map[protocol.StreamID]streamI), streams: make(map[protocol.StreamID]streamI),
openStreams: make([]protocol.StreamID, 0),
newStream: newStream, newStream: newStream,
maxIncomingStreams: maxIncomingStreams,
} }
sm.nextStreamOrErrCond.L = &sm.mutex sm.nextStreamOrErrCond.L = &sm.mutex
sm.openStreamOrErrCond.L = &sm.mutex sm.openStreamOrErrCond.L = &sm.mutex
nextOddStream := protocol.StreamID(1) nextClientInitiatedStream := protocol.StreamID(1)
if ver.CryptoStreamID() == protocol.StreamID(1) { nextServerInitiatedStream := protocol.StreamID(2)
nextOddStream = 3 if pers == protocol.PerspectiveServer {
} sm.nextStreamToOpen = nextServerInitiatedStream
if pers == protocol.PerspectiveClient { sm.nextStreamToAccept = nextClientInitiatedStream
sm.nextStream = nextOddStream
sm.nextStreamToAccept = 2
} else { } else {
sm.nextStream = 2 sm.nextStreamToOpen = nextClientInitiatedStream
sm.nextStreamToAccept = nextOddStream sm.nextStreamToAccept = nextServerInitiatedStream
} }
return &sm return &sm
} }
@ -81,6 +64,23 @@ func (m *streamsMap) streamInitiatedBy(id protocol.StreamID) protocol.Perspectiv
return protocol.PerspectiveClient return protocol.PerspectiveClient
} }
func (m *streamsMap) nextStreamID(id protocol.StreamID) protocol.StreamID {
if m.perspective == protocol.PerspectiveServer && id == 0 {
return 1
}
return id + 2
}
func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
// every bidirectional stream is also a receive stream
return m.GetOrOpenStream(id)
}
func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
// every bidirectional stream is also a send stream
return m.GetOrOpenStream(id)
}
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. // GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used. // Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used.
func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
@ -88,7 +88,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
s, ok := m.streams[id] s, ok := m.streams[id]
m.mutex.RUnlock() m.mutex.RUnlock()
if ok { if ok {
return s, nil // s may be nil return s, nil
} }
// ... we don't have an existing stream // ... we don't have an existing stream
@ -101,7 +101,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
} }
if m.perspective == m.streamInitiatedBy(id) { if m.perspective == m.streamInitiatedBy(id) {
if id <= m.nextStream { // this is a stream opened by us. Must have been closed already if id <= m.nextStreamToOpen { // this is a stream opened by us. Must have been closed already
return nil, nil return nil, nil
} }
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id)) return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id))
@ -110,14 +110,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
return nil, nil return nil, nil
} }
// sid is the next stream that will be opened for sid := m.nextStreamID(m.highestStreamOpenedByPeer); sid <= id; sid = m.nextStreamID(sid) {
sid := m.highestStreamOpenedByPeer + 2
// if there is no stream opened yet, and this is the server, stream 1 should be openend
if sid == 2 && m.perspective == protocol.PerspectiveServer {
sid = 1
}
for ; sid <= id; sid += 2 {
if _, err := m.openRemoteStream(sid); err != nil { if _, err := m.openRemoteStream(sid); err != nil {
return nil, err return nil, err
} }
@ -128,38 +121,26 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
} }
func (m *streamsMap) openRemoteStream(id protocol.StreamID) (streamI, error) { func (m *streamsMap) openRemoteStream(id protocol.StreamID) (streamI, error) {
if m.numIncomingStreams >= m.maxIncomingStreams {
return nil, qerr.TooManyOpenStreams
}
if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer { if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer)) return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer))
} }
m.numIncomingStreams++
if id > m.highestStreamOpenedByPeer { if id > m.highestStreamOpenedByPeer {
m.highestStreamOpenedByPeer = id m.highestStreamOpenedByPeer = id
} }
s := m.newStream(id) s := m.newStream(id)
m.putStream(s) m.putStream(s)
return s, nil return s, nil
} }
func (m *streamsMap) openStreamImpl() (streamI, error) { func (m *streamsMap) openStreamImpl() (streamI, error) {
id := m.nextStream s := m.newStream(m.nextStreamToOpen)
if m.numOutgoingStreams >= m.maxOutgoingStreams {
return nil, qerr.TooManyOpenStreams
}
m.numOutgoingStreams++
m.nextStream += 2
s := m.newStream(id)
m.putStream(s) m.putStream(s)
m.nextStreamToOpen = m.nextStreamID(m.nextStreamToOpen)
return s, nil return s, nil
} }
// OpenStream opens the next available stream // OpenStream opens the next available stream
func (m *streamsMap) OpenStream() (streamI, error) { func (m *streamsMap) OpenStream() (Stream, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@ -169,7 +150,7 @@ func (m *streamsMap) OpenStream() (streamI, error) {
return m.openStreamImpl() return m.openStreamImpl()
} }
func (m *streamsMap) OpenStreamSync() (streamI, error) { func (m *streamsMap) OpenStreamSync() (Stream, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@ -190,7 +171,7 @@ func (m *streamsMap) OpenStreamSync() (streamI, error) {
// AcceptStream returns the next stream opened by the peer // AcceptStream returns the next stream opened by the peer
// it blocks until a new stream is opened // it blocks until a new stream is opened
func (m *streamsMap) AcceptStream() (streamI, error) { func (m *streamsMap) AcceptStream() (Stream, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
var str streamI var str streamI
@ -209,104 +190,24 @@ func (m *streamsMap) AcceptStream() (streamI, error) {
return str, nil return str, nil
} }
func (m *streamsMap) DeleteClosedStreams() error { func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
_, ok := m.streams[id]
var numDeletedStreams int
// for every closed stream, the streamID is replaced by 0 in the openStreams slice
for i, streamID := range m.openStreams {
str, ok := m.streams[streamID]
if !ok { if !ok {
return errMapAccess return errMapAccess
} }
if !str.Finished() { delete(m.streams, id)
continue
}
numDeletedStreams++
m.openStreams[i] = 0
if m.streamInitiatedBy(streamID) == m.perspective {
m.numOutgoingStreams--
} else {
m.numIncomingStreams--
}
delete(m.streams, streamID)
}
if numDeletedStreams == 0 {
return nil
}
// remove all 0s (representing closed streams) from the openStreams slice
// and adjust the roundRobinIndex
var j int
for i, id := range m.openStreams {
if i != j {
m.openStreams[j] = m.openStreams[i]
}
if id != 0 {
j++
} else if j < m.roundRobinIndex {
m.roundRobinIndex--
}
}
m.openStreams = m.openStreams[:len(m.openStreams)-numDeletedStreams]
m.openStreamOrErrCond.Signal() m.openStreamOrErrCond.Signal()
return nil return nil
} }
// RoundRobinIterate executes the streamLambda for every open stream, until the streamLambda returns false
// It uses a round-robin-like scheduling to ensure that every stream is considered fairly
// It prioritizes the the header-stream (StreamID 3)
func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
m.mutex.Lock()
defer m.mutex.Unlock()
numStreams := len(m.streams)
startIndex := m.roundRobinIndex
for i := 0; i < numStreams; i++ {
streamID := m.openStreams[(i+startIndex)%numStreams]
cont, err := m.iterateFunc(streamID, fn)
if err != nil {
return err
}
m.roundRobinIndex = (m.roundRobinIndex + 1) % numStreams
if !cont {
break
}
}
return nil
}
// Range executes a callback for all streams, in pseudo-random order
func (m *streamsMap) Range(cb func(s streamI)) {
m.mutex.RLock()
defer m.mutex.RUnlock()
for _, s := range m.streams {
if s != nil {
cb(s)
}
}
}
func (m *streamsMap) iterateFunc(streamID protocol.StreamID, fn streamLambda) (bool, error) {
str, ok := m.streams[streamID]
if !ok {
return true, errMapAccess
}
return fn(str)
}
func (m *streamsMap) putStream(s streamI) error { func (m *streamsMap) putStream(s streamI) error {
id := s.StreamID() id := s.StreamID()
if _, ok := m.streams[id]; ok { if _, ok := m.streams[id]; ok {
return fmt.Errorf("a stream with ID %d already exists", id) return fmt.Errorf("a stream with ID %d already exists", id)
} }
m.streams[id] = s m.streams[id] = s
m.openStreams = append(m.openStreams, id)
return nil return nil
} }
@ -316,14 +217,20 @@ func (m *streamsMap) CloseWithError(err error) {
m.closeErr = err m.closeErr = err
m.nextStreamOrErrCond.Broadcast() m.nextStreamOrErrCond.Broadcast()
m.openStreamOrErrCond.Broadcast() m.openStreamOrErrCond.Broadcast()
for _, s := range m.openStreams { for _, s := range m.streams {
m.streams[s].Cancel(err) s.closeForShutdown(err)
} }
} }
func (m *streamsMap) UpdateMaxStreamLimit(limit uint32) { // TODO(#952): this won't be needed when gQUIC supports stateless handshakes
func (m *streamsMap) UpdateLimits(params *handshake.TransportParameters) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() for id, str := range m.streams {
m.maxOutgoingStreams = limit str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{
StreamID: id,
ByteOffset: params.StreamFlowControlWindow,
})
}
m.mutex.Unlock()
m.openStreamOrErrCond.Broadcast() m.openStreamOrErrCond.Broadcast()
} }

View File

@ -0,0 +1,257 @@
package quic
import (
"fmt"
"sync"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
type streamsMapLegacy struct {
mutex sync.RWMutex
perspective protocol.Perspective
streams map[protocol.StreamID]streamI
nextStreamToOpen protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
highestStreamOpenedByPeer protocol.StreamID
nextStreamOrErrCond sync.Cond
openStreamOrErrCond sync.Cond
closeErr error
nextStreamToAccept protocol.StreamID
newStream newStreamLambda
numOutgoingStreams uint32
numIncomingStreams uint32
maxIncomingStreams uint32
maxOutgoingStreams uint32
}
var _ streamManager = &streamsMapLegacy{}
func newStreamsMapLegacy(newStream newStreamLambda, pers protocol.Perspective) streamManager {
// add some tolerance to the maximum incoming streams value
maxStreams := uint32(protocol.MaxIncomingStreams)
maxIncomingStreams := utils.MaxUint32(
maxStreams+protocol.MaxStreamsMinimumIncrement,
uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)),
)
sm := streamsMapLegacy{
perspective: pers,
streams: make(map[protocol.StreamID]streamI),
newStream: newStream,
maxIncomingStreams: maxIncomingStreams,
}
sm.nextStreamOrErrCond.L = &sm.mutex
sm.openStreamOrErrCond.L = &sm.mutex
nextServerInitiatedStream := protocol.StreamID(2)
nextClientInitiatedStream := protocol.StreamID(3)
if pers == protocol.PerspectiveServer {
sm.highestStreamOpenedByPeer = 1
}
if pers == protocol.PerspectiveServer {
sm.nextStreamToOpen = nextServerInitiatedStream
sm.nextStreamToAccept = nextClientInitiatedStream
} else {
sm.nextStreamToOpen = nextClientInitiatedStream
sm.nextStreamToAccept = nextServerInitiatedStream
}
return &sm
}
// getStreamPerspective says which side should initiate a stream
func (m *streamsMapLegacy) streamInitiatedBy(id protocol.StreamID) protocol.Perspective {
if id%2 == 0 {
return protocol.PerspectiveServer
}
return protocol.PerspectiveClient
}
func (m *streamsMapLegacy) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
// every bidirectional stream is also a receive stream
return m.GetOrOpenStream(id)
}
func (m *streamsMapLegacy) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
// every bidirectional stream is also a send stream
return m.GetOrOpenStream(id)
}
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used.
func (m *streamsMapLegacy) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
m.mutex.RLock()
s, ok := m.streams[id]
m.mutex.RUnlock()
if ok {
return s, nil
}
// ... we don't have an existing stream
m.mutex.Lock()
defer m.mutex.Unlock()
// We need to check whether another invocation has already created a stream (between RUnlock() and Lock()).
s, ok = m.streams[id]
if ok {
return s, nil
}
if m.perspective == m.streamInitiatedBy(id) {
if id <= m.nextStreamToOpen { // this is a stream opened by us. Must have been closed already
return nil, nil
}
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id))
}
if id <= m.highestStreamOpenedByPeer { // this is a peer-initiated stream that doesn't exist anymore. Must have been closed already
return nil, nil
}
for sid := m.highestStreamOpenedByPeer + 2; sid <= id; sid += 2 {
if _, err := m.openRemoteStream(sid); err != nil {
return nil, err
}
}
m.nextStreamOrErrCond.Broadcast()
return m.streams[id], nil
}
func (m *streamsMapLegacy) openRemoteStream(id protocol.StreamID) (streamI, error) {
if m.numIncomingStreams >= m.maxIncomingStreams {
return nil, qerr.TooManyOpenStreams
}
if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer))
}
m.numIncomingStreams++
if id > m.highestStreamOpenedByPeer {
m.highestStreamOpenedByPeer = id
}
s := m.newStream(id)
m.putStream(s)
return s, nil
}
func (m *streamsMapLegacy) openStreamImpl() (streamI, error) {
if m.numOutgoingStreams >= m.maxOutgoingStreams {
return nil, qerr.TooManyOpenStreams
}
m.numOutgoingStreams++
s := m.newStream(m.nextStreamToOpen)
m.putStream(s)
m.nextStreamToOpen += 2
return s, nil
}
// OpenStream opens the next available stream
func (m *streamsMapLegacy) OpenStream() (Stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.closeErr != nil {
return nil, m.closeErr
}
return m.openStreamImpl()
}
func (m *streamsMapLegacy) OpenStreamSync() (Stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
for {
if m.closeErr != nil {
return nil, m.closeErr
}
str, err := m.openStreamImpl()
if err == nil {
return str, err
}
if err != nil && err != qerr.TooManyOpenStreams {
return nil, err
}
m.openStreamOrErrCond.Wait()
}
}
// AcceptStream returns the next stream opened by the peer
// it blocks until a new stream is opened
func (m *streamsMapLegacy) AcceptStream() (Stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var str streamI
for {
var ok bool
if m.closeErr != nil {
return nil, m.closeErr
}
str, ok = m.streams[m.nextStreamToAccept]
if ok {
break
}
m.nextStreamOrErrCond.Wait()
}
m.nextStreamToAccept += 2
return str, nil
}
func (m *streamsMapLegacy) DeleteStream(id protocol.StreamID) error {
m.mutex.Lock()
defer m.mutex.Unlock()
_, ok := m.streams[id]
if !ok {
return errMapAccess
}
delete(m.streams, id)
if m.streamInitiatedBy(id) == m.perspective {
m.numOutgoingStreams--
} else {
m.numIncomingStreams--
}
m.openStreamOrErrCond.Signal()
return nil
}
func (m *streamsMapLegacy) putStream(s streamI) error {
id := s.StreamID()
if _, ok := m.streams[id]; ok {
return fmt.Errorf("a stream with ID %d already exists", id)
}
m.streams[id] = s
return nil
}
func (m *streamsMapLegacy) CloseWithError(err error) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.closeErr = err
m.nextStreamOrErrCond.Broadcast()
m.openStreamOrErrCond.Broadcast()
for _, s := range m.streams {
s.closeForShutdown(err)
}
}
// TODO(#952): this won't be needed when gQUIC supports stateless handshakes
func (m *streamsMapLegacy) UpdateLimits(params *handshake.TransportParameters) {
m.mutex.Lock()
m.maxOutgoingStreams = params.MaxStreams
for id, str := range m.streams {
str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{
StreamID: id,
ByteOffset: params.StreamFlowControlWindow,
})
}
m.mutex.Unlock()
m.openStreamOrErrCond.Broadcast()
}

Some files were not shown because too many files have changed in this diff Show More