diff --git a/Gopkg.lock b/Gopkg.lock index 317bcb1..24f8499 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -421,6 +421,7 @@ revision = "393af48d391698c6ae4219566bfbdfef67269997" [[projects]] + branch = "master" name = "github.com/lucas-clemente/quic-go" packages = [ ".", @@ -435,8 +436,7 @@ "internal/wire", "qerr" ] - revision = "ded0eb4f6f30a8049bd334a26ff7ff728648fe13" - version = "v0.6.0" + revision = "15bcc2579f7dab14c84f438741f2b535cf474df9" [[projects]] branch = "master" @@ -753,6 +753,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "a11e1692755a705514dbd401ba4795821d1ac221d6f9100124c38a29db98c568" + inputs-digest = "97c8282ef9b3abed71907d17ccf38379134714596610880b02d5ca03be634678" solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index e69de29..8c5cdc6 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -0,0 +1,137 @@ +# Gopkg.toml example +# +# Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md +# for detailed Gopkg.toml documentation. +# +# required = ["github.com/user/thing/cmd/thing"] +# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] +# +# [[constraint]] +# name = "github.com/user/project" +# version = "1.0.0" +# +# [[constraint]] +# name = "github.com/user/project2" +# branch = "dev" +# source = "github.com/myfork/project2" +# +# [[override]] +# name = "github.com/x/y" +# version = "2.4.0" + + +[[constraint]] + branch = "master" + name = "github.com/Xe/gopreload" + +[[constraint]] + name = "github.com/Xe/ln" + version = "0.1.0" + +[[constraint]] + branch = "master" + name = "github.com/Xe/uuid" + +[[constraint]] + branch = "master" + name = "github.com/Xe/x" + +[[constraint]] + name = "github.com/asdine/storm" + version = "2.0.2" + +[[constraint]] + branch = "master" + name = "github.com/brandur/simplebox" + +[[constraint]] + name = "github.com/caarlos0/env" + version = "3.2.0" + +[[constraint]] + branch = "master" + name = "github.com/dgryski/go-failure" + +[[constraint]] + branch = "master" + name = "github.com/dickeyxxx/netrc" + +[[constraint]] + branch = "master" + name = "github.com/facebookgo/flagenv" + +[[constraint]] + branch = "master" + name = "github.com/golang/protobuf" + +[[constraint]] + name = "github.com/google/gops" + version = "0.3.2" + +[[constraint]] + name = "github.com/hashicorp/terraform" + version = "0.11.2" + +[[constraint]] + name = "github.com/joho/godotenv" + version = "1.2.0" + +[[constraint]] + branch = "master" + name = "github.com/jtolds/qod" + +[[constraint]] + branch = "master" + name = "github.com/kr/pretty" + +[[constraint]] + name = "github.com/lucas-clemente/quic-go" + branch = "master" + +[[constraint]] + name = "github.com/magefile/mage" + version = "2.0.1" + +[[constraint]] + branch = "master" + name = "github.com/mtneug/pkg" + +[[constraint]] + branch = "master" + name = "github.com/olekukonko/tablewriter" + +[[constraint]] + name = "github.com/pkg/errors" + version = "0.8.0" + +[[constraint]] + branch = "master" + name = "github.com/streamrail/concurrent-map" + +[[constraint]] + name = "github.com/xtaci/kcp-go" + version = "3.23.0" + +[[constraint]] + name = "github.com/xtaci/smux" + version = "1.0.6" + +[[constraint]] + name = "go.uber.org/atomic" + version = "1.3.1" + +[[constraint]] + branch = "master" + name = "golang.org/x/crypto" + +[[constraint]] + branch = "master" + name = "golang.org/x/net" + +[[constraint]] + name = "google.golang.org/grpc" + version = "1.9.2" + +[[constraint]] + name = "gopkg.in/alecthomas/kingpin.v2" + version = "2.2.6" diff --git a/vendor/github.com/lucas-clemente/quic-go/.travis.yml b/vendor/github.com/lucas-clemente/quic-go/.travis.yml index 304b377..dfb0d28 100644 --- a/vendor/github.com/lucas-clemente/quic-go/.travis.yml +++ b/vendor/github.com/lucas-clemente/quic-go/.travis.yml @@ -1,4 +1,5 @@ dist: trusty +group: travis_latest addons: hosts: @@ -8,6 +9,7 @@ language: go go: - 1.9.2 + - 1.10beta1 # first part of the GOARCH workaround # setting the GOARCH directly doesn't work, since the value will be overwritten later @@ -30,6 +32,7 @@ before_install: - export GOARCH=$TRAVIS_GOARCH - go env # for debugging - google-chrome --version + - "printf \"quic.clemente.io certificate valid until: \" && openssl x509 -in example/fullchain.pem -enddate -noout | cut -d = -f 2" - "export DISPLAY=:99.0" - "Xvfb $DISPLAY &> /dev/null &" diff --git a/vendor/github.com/lucas-clemente/quic-go/Changelog.md b/vendor/github.com/lucas-clemente/quic-go/Changelog.md index 8f65a9d..d7ff667 100644 --- a/vendor/github.com/lucas-clemente/quic-go/Changelog.md +++ b/vendor/github.com/lucas-clemente/quic-go/Changelog.md @@ -1,6 +1,10 @@ # Changelog -## v0.6.1 (unreleased) +## v0.7 (unreleased) + +- The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored. +- Remove `DialNonFWSecure` and `DialAddrNonFWSecure`. +- Expose the `ConnectionState` in the `Session` (experimental API). ## v0.6.0 (2017-12-12) diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/ackhandler_suite_test.go b/vendor/github.com/lucas-clemente/quic-go/ackhandler/ackhandler_suite_test.go index 53108c1..9e7e077 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/ackhandler_suite_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/ackhandler/ackhandler_suite_test.go @@ -1,6 +1,7 @@ package ackhandler import ( + "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -11,3 +12,13 @@ func TestCrypto(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "AckHandler Suite") } + +var mockCtrl *gomock.Controller + +var _ = BeforeEach(func() { + mockCtrl = gomock.NewController(GinkgoT()) +}) + +var _ = AfterEach(func() { + mockCtrl.Finish() +}) diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go b/vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go index 7b68faa..8bda958 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go +++ b/vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go @@ -16,6 +16,7 @@ type SentPacketHandler interface { SendingAllowed() bool GetStopWaitingFrame(force bool) *wire.StopWaitingFrame + GetLowestPacketNotConfirmedAcked() protocol.PacketNumber ShouldSendRetransmittablePacket() bool DequeuePacketForRetransmission() (packet *Packet) GetLeastUnacked() protocol.PacketNumber @@ -26,7 +27,7 @@ type SentPacketHandler interface { // ReceivedPacketHandler handles ACKs needed to send for incoming packets type ReceivedPacketHandler interface { - ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error + ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error IgnoreBelow(protocol.PacketNumber) GetAlarmTimeout() time.Time diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go b/vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go index 9c4ee30..e4213a0 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go +++ b/vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go @@ -15,7 +15,8 @@ type Packet struct { Length protocol.ByteCount EncryptionLevel protocol.EncryptionLevel - SendTime time.Time + largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK + sendTime time.Time } // GetFramesForRetransmission gets all the frames for retransmission diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go b/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go index 97410f3..c316af4 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go +++ b/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go @@ -34,10 +34,10 @@ func NewReceivedPacketHandler(version protocol.VersionNumber) ReceivedPacketHand } } -func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error { +func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error { if packetNumber > h.largestObserved { h.largestObserved = packetNumber - h.largestObservedReceivedTime = time.Now() + h.largestObservedReceivedTime = rcvTime } if packetNumber < h.ignoreBelow { @@ -47,7 +47,7 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil { return err } - h.maybeQueueAck(packetNumber, shouldInstigateAck) + h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck) return nil } @@ -58,7 +58,7 @@ func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) { h.packetHistory.DeleteBelow(p) } -func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) { +func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) { h.packetsReceivedSinceLastAck++ if shouldInstigateAck { @@ -86,7 +86,7 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber h.ackQueued = true } else { if h.ackAlarm.IsZero() { - h.ackAlarm = time.Now().Add(h.ackSendDelay) + h.ackAlarm = rcvTime.Add(h.ackSendDelay) } } } diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler_test.go b/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler_test.go index 4a87c72..10246bd 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler_test.go @@ -21,34 +21,36 @@ var _ = Describe("receivedPacketHandler", func() { Context("accepting packets", func() { It("handles a packet that arrives late", func() { - err := handler.ReceivedPacket(protocol.PacketNumber(1), true) + err := handler.ReceivedPacket(protocol.PacketNumber(1), time.Time{}, true) Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedPacket(protocol.PacketNumber(3), true) + err = handler.ReceivedPacket(protocol.PacketNumber(3), time.Time{}, true) Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedPacket(protocol.PacketNumber(2), true) + err = handler.ReceivedPacket(protocol.PacketNumber(2), time.Time{}, true) Expect(err).ToNot(HaveOccurred()) }) It("saves the time when each packet arrived", func() { - err := handler.ReceivedPacket(protocol.PacketNumber(3), true) + err := handler.ReceivedPacket(protocol.PacketNumber(3), time.Now(), true) Expect(err).ToNot(HaveOccurred()) Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond)) }) It("updates the largestObserved and the largestObservedReceivedTime", func() { + now := time.Now() handler.largestObserved = 3 - handler.largestObservedReceivedTime = time.Now().Add(-1 * time.Second) - err := handler.ReceivedPacket(5, true) + handler.largestObservedReceivedTime = now.Add(-1 * time.Second) + err := handler.ReceivedPacket(5, now, true) Expect(err).ToNot(HaveOccurred()) Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5))) - Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond)) + Expect(handler.largestObservedReceivedTime).To(Equal(now)) }) It("doesn't update the largestObserved and the largestObservedReceivedTime for a belated packet", func() { - timestamp := time.Now().Add(-1 * time.Second) + now := time.Now() + timestamp := now.Add(-1 * time.Second) handler.largestObserved = 5 handler.largestObservedReceivedTime = timestamp - err := handler.ReceivedPacket(4, true) + err := handler.ReceivedPacket(4, now, true) Expect(err).ToNot(HaveOccurred()) Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5))) Expect(handler.largestObservedReceivedTime).To(Equal(timestamp)) @@ -57,7 +59,7 @@ var _ = Describe("receivedPacketHandler", func() { It("passes on errors from receivedPacketHistory", func() { var err error for i := protocol.PacketNumber(0); i < 5*protocol.MaxTrackedReceivedAckRanges; i++ { - err = handler.ReceivedPacket(2*i+1, true) + err = handler.ReceivedPacket(2*i+1, time.Time{}, true) // this will eventually return an error // details about when exactly the receivedPacketHistory errors are tested there if err != nil { @@ -72,7 +74,7 @@ var _ = Describe("receivedPacketHandler", func() { Context("queueing ACKs", func() { receiveAndAck10Packets := func() { for i := 1; i <= 10; i++ { - err := handler.ReceivedPacket(protocol.PacketNumber(i), true) + err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true) Expect(err).ToNot(HaveOccurred()) } Expect(handler.GetAckFrame()).ToNot(BeNil()) @@ -80,14 +82,14 @@ var _ = Describe("receivedPacketHandler", func() { } It("always queues an ACK for the first packet", func() { - err := handler.ReceivedPacket(1, false) + err := handler.ReceivedPacket(1, time.Time{}, false) Expect(err).ToNot(HaveOccurred()) Expect(handler.ackQueued).To(BeTrue()) Expect(handler.GetAlarmTimeout()).To(BeZero()) }) It("works with packet number 0", func() { - err := handler.ReceivedPacket(0, false) + err := handler.ReceivedPacket(0, time.Time{}, false) Expect(err).ToNot(HaveOccurred()) Expect(handler.ackQueued).To(BeTrue()) Expect(handler.GetAlarmTimeout()).To(BeZero()) @@ -95,11 +97,11 @@ var _ = Describe("receivedPacketHandler", func() { It("queues an ACK for every second retransmittable packet, if they are arriving fast", func() { receiveAndAck10Packets() - err := handler.ReceivedPacket(11, true) + err := handler.ReceivedPacket(11, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) Expect(handler.ackQueued).To(BeFalse()) Expect(handler.GetAlarmTimeout()).NotTo(BeZero()) - err = handler.ReceivedPacket(12, true) + err = handler.ReceivedPacket(12, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) Expect(handler.ackQueued).To(BeTrue()) Expect(handler.GetAlarmTimeout()).To(BeZero()) @@ -107,11 +109,11 @@ var _ = Describe("receivedPacketHandler", func() { It("only sets the timer when receiving a retransmittable packets", func() { receiveAndAck10Packets() - err := handler.ReceivedPacket(11, false) + err := handler.ReceivedPacket(11, time.Time{}, false) Expect(err).ToNot(HaveOccurred()) Expect(handler.ackQueued).To(BeFalse()) Expect(handler.ackAlarm).To(BeZero()) - err = handler.ReceivedPacket(12, true) + err = handler.ReceivedPacket(12, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) Expect(handler.ackQueued).To(BeFalse()) Expect(handler.ackAlarm).ToNot(BeZero()) @@ -120,15 +122,15 @@ var _ = Describe("receivedPacketHandler", func() { It("queues an ACK if it was reported missing before", func() { receiveAndAck10Packets() - err := handler.ReceivedPacket(11, true) + err := handler.ReceivedPacket(11, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedPacket(13, true) + err = handler.ReceivedPacket(13, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) ack := handler.GetAckFrame() // ACK: 1 and 3, missing: 2 Expect(ack).ToNot(BeNil()) Expect(ack.HasMissingRanges()).To(BeTrue()) Expect(handler.ackQueued).To(BeFalse()) - err = handler.ReceivedPacket(12, false) + err = handler.ReceivedPacket(12, time.Time{}, false) Expect(err).ToNot(HaveOccurred()) Expect(handler.ackQueued).To(BeTrue()) }) @@ -136,10 +138,10 @@ var _ = Describe("receivedPacketHandler", func() { It("queues an ACK if it creates a new missing range", func() { receiveAndAck10Packets() for i := 11; i < 16; i++ { - err := handler.ReceivedPacket(protocol.PacketNumber(i), true) + err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true) Expect(err).ToNot(HaveOccurred()) } - err := handler.ReceivedPacket(20, true) // we now know that packets 16 to 19 are missing + err := handler.ReceivedPacket(20, time.Time{}, true) // we now know that packets 16 to 19 are missing Expect(err).ToNot(HaveOccurred()) Expect(handler.ackQueued).To(BeTrue()) ack := handler.GetAckFrame() @@ -154,9 +156,9 @@ var _ = Describe("receivedPacketHandler", func() { }) It("generates a simple ACK frame", func() { - err := handler.ReceivedPacket(1, true) + err := handler.ReceivedPacket(1, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedPacket(2, true) + err = handler.ReceivedPacket(2, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) ack := handler.GetAckFrame() Expect(ack).ToNot(BeNil()) @@ -166,7 +168,7 @@ var _ = Describe("receivedPacketHandler", func() { }) It("generates an ACK for packet number 0", func() { - err := handler.ReceivedPacket(0, true) + err := handler.ReceivedPacket(0, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) ack := handler.GetAckFrame() Expect(ack).ToNot(BeNil()) @@ -176,12 +178,12 @@ var _ = Describe("receivedPacketHandler", func() { }) It("saves the last sent ACK", func() { - err := handler.ReceivedPacket(1, true) + err := handler.ReceivedPacket(1, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) ack := handler.GetAckFrame() Expect(ack).ToNot(BeNil()) Expect(handler.lastAck).To(Equal(ack)) - err = handler.ReceivedPacket(2, true) + err = handler.ReceivedPacket(2, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) handler.ackQueued = true ack = handler.GetAckFrame() @@ -190,9 +192,9 @@ var _ = Describe("receivedPacketHandler", func() { }) It("generates an ACK frame with missing packets", func() { - err := handler.ReceivedPacket(1, true) + err := handler.ReceivedPacket(1, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedPacket(4, true) + err = handler.ReceivedPacket(4, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) ack := handler.GetAckFrame() Expect(ack).ToNot(BeNil()) @@ -205,11 +207,11 @@ var _ = Describe("receivedPacketHandler", func() { }) It("generates an ACK for packet number 0 and other packets", func() { - err := handler.ReceivedPacket(0, true) + err := handler.ReceivedPacket(0, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedPacket(1, true) + err = handler.ReceivedPacket(1, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedPacket(3, true) + err = handler.ReceivedPacket(3, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) ack := handler.GetAckFrame() Expect(ack).ToNot(BeNil()) @@ -223,15 +225,15 @@ var _ = Describe("receivedPacketHandler", func() { It("accepts packets below the lower limit", func() { handler.IgnoreBelow(6) - err := handler.ReceivedPacket(2, true) + err := handler.ReceivedPacket(2, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) }) It("doesn't add delayed packets to the packetHistory", func() { handler.IgnoreBelow(7) - err := handler.ReceivedPacket(4, true) + err := handler.ReceivedPacket(4, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedPacket(10, true) + err = handler.ReceivedPacket(10, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) ack := handler.GetAckFrame() Expect(ack).ToNot(BeNil()) @@ -241,7 +243,7 @@ var _ = Describe("receivedPacketHandler", func() { It("deletes packets from the packetHistory when a lower limit is set", func() { for i := 1; i <= 12; i++ { - err := handler.ReceivedPacket(protocol.PacketNumber(i), true) + err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true) Expect(err).ToNot(HaveOccurred()) } handler.IgnoreBelow(7) @@ -256,7 +258,7 @@ var _ = Describe("receivedPacketHandler", func() { // TODO: remove this test when dropping support for STOP_WAITINGs It("handles a lower limit of 0", func() { handler.IgnoreBelow(0) - err := handler.ReceivedPacket(1337, true) + err := handler.ReceivedPacket(1337, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) ack := handler.GetAckFrame() Expect(ack).ToNot(BeNil()) @@ -264,7 +266,7 @@ var _ = Describe("receivedPacketHandler", func() { }) It("resets all counters needed for the ACK queueing decision when sending an ACK", func() { - err := handler.ReceivedPacket(1, true) + err := handler.ReceivedPacket(1, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) handler.ackAlarm = time.Now().Add(-time.Minute) Expect(handler.GetAckFrame()).ToNot(BeNil()) @@ -275,7 +277,7 @@ var _ = Describe("receivedPacketHandler", func() { }) It("doesn't generate an ACK when none is queued and the timer is not set", func() { - err := handler.ReceivedPacket(1, true) + err := handler.ReceivedPacket(1, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) handler.ackQueued = false handler.ackAlarm = time.Time{} @@ -283,7 +285,7 @@ var _ = Describe("receivedPacketHandler", func() { }) It("doesn't generate an ACK when none is queued and the timer has not yet expired", func() { - err := handler.ReceivedPacket(1, true) + err := handler.ReceivedPacket(1, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) handler.ackQueued = false handler.ackAlarm = time.Now().Add(time.Minute) @@ -291,7 +293,7 @@ var _ = Describe("receivedPacketHandler", func() { }) It("generates an ACK when the timer has expired", func() { - err := handler.ReceivedPacket(1, true) + err := handler.ReceivedPacket(1, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) handler.ackQueued = false handler.ackAlarm = time.Now().Add(-time.Minute) diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go b/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go index 4fe6681..e0d9b08 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go +++ b/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go @@ -40,6 +40,10 @@ type sentPacketHandler struct { largestAcked protocol.PacketNumber largestReceivedPacketWithAck protocol.PacketNumber + // lowestPacketNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived + // example: we send an ACK for packets 90-100 with packet number 20 + // once we receive an ACK from the peer for packet 20, the lowestPacketNotConfirmedAcked is 101 + lowestPacketNotConfirmedAcked protocol.PacketNumber packetHistory *PacketList stopWaitingManager stopWaitingManager @@ -95,6 +99,13 @@ func (h *sentPacketHandler) ShouldSendRetransmittablePacket() bool { } func (h *sentPacketHandler) SetHandshakeComplete() { + var queue []*Packet + for _, packet := range h.retransmissionQueue { + if packet.EncryptionLevel == protocol.EncryptionForwardSecure { + queue = append(queue, packet) + } + } + h.retransmissionQueue = queue h.handshakeComplete = true } @@ -114,11 +125,19 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error { h.lastSentPacketNumber = packet.PacketNumber now := time.Now() + var largestAcked protocol.PacketNumber + if len(packet.Frames) > 0 { + if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok { + largestAcked = ackFrame.LargestAcked + } + } + packet.Frames = stripNonRetransmittableFrames(packet.Frames) isRetransmittable := len(packet.Frames) != 0 if isRetransmittable { - packet.SendTime = now + packet.sendTime = now + packet.largestAcked = largestAcked h.bytesInFlight += packet.Length h.packetHistory.PushBack(*packet) h.numNonRetransmittablePackets = 0 @@ -134,7 +153,7 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error { isRetransmittable, ) - h.updateLossDetectionAlarm() + h.updateLossDetectionAlarm(now) return nil } @@ -146,14 +165,12 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe // duplicate or out-of-order ACK // if withPacketNumber <= h.largestReceivedPacketWithAck && withPacketNumber != 0 { if withPacketNumber <= h.largestReceivedPacketWithAck { - utils.Debugf("ignoring ack because duplicate") return ErrDuplicateOrOutOfOrderAck } h.largestReceivedPacketWithAck = withPacketNumber // ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK) if ackFrame.LargestAcked < h.lowestUnacked() { - utils.Debugf("ignoring ack because repeated") return nil } h.largestAcked = ackFrame.LargestAcked @@ -178,13 +195,19 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe if encLevel < p.Value.EncryptionLevel { return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.Value.PacketNumber, p.Value.EncryptionLevel) } + // largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0 + // It is safe to ignore the corner case of packets that just acked packet 0, because + // the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send. + if p.Value.largestAcked != 0 { + h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.Value.largestAcked+1) + } h.onPacketAcked(p) h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight) } } - h.detectLostPackets() - h.updateLossDetectionAlarm() + h.detectLostPackets(rcvTime) + h.updateLossDetectionAlarm(rcvTime) h.garbageCollectSkippedPackets() h.stopWaitingManager.ReceivedAck(ackFrame) @@ -192,6 +215,10 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe return nil } +func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber { + return h.lowestPacketNotConfirmedAcked +} + func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*PacketElement, error) { var ackedPackets []*PacketElement ackRangeIndex := 0 @@ -233,7 +260,7 @@ func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, a for el := h.packetHistory.Front(); el != nil; el = el.Next() { packet := el.Value if packet.PacketNumber == largestAcked { - h.rttStats.UpdateRTT(rcvTime.Sub(packet.SendTime), ackDelay, time.Now()) + h.rttStats.UpdateRTT(rcvTime.Sub(packet.sendTime), ackDelay, rcvTime) return true } // Packets are sorted by number, so we can stop searching @@ -244,7 +271,7 @@ func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, a return false } -func (h *sentPacketHandler) updateLossDetectionAlarm() { +func (h *sentPacketHandler) updateLossDetectionAlarm(now time.Time) { // Cancel the alarm if no packets are outstanding if h.packetHistory.Len() == 0 { h.alarm = time.Time{} @@ -253,19 +280,18 @@ func (h *sentPacketHandler) updateLossDetectionAlarm() { // TODO(#497): TLP if !h.handshakeComplete { - h.alarm = time.Now().Add(h.computeHandshakeTimeout()) + h.alarm = now.Add(h.computeHandshakeTimeout()) } else if !h.lossTime.IsZero() { // Early retransmit timer or time loss detection. h.alarm = h.lossTime } else { // RTO - h.alarm = time.Now().Add(h.computeRTOTimeout()) + h.alarm = now.Add(h.computeRTOTimeout()) } } -func (h *sentPacketHandler) detectLostPackets() { +func (h *sentPacketHandler) detectLostPackets(now time.Time) { h.lossTime = time.Time{} - now := time.Now() maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT())) delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT) @@ -278,7 +304,7 @@ func (h *sentPacketHandler) detectLostPackets() { break } - timeSinceSent := now.Sub(packet.SendTime) + timeSinceSent := now.Sub(packet.sendTime) if timeSinceSent > delayUntilLost { lostPackets = append(lostPackets, el) } else if h.lossTime.IsZero() { @@ -296,20 +322,22 @@ func (h *sentPacketHandler) detectLostPackets() { } func (h *sentPacketHandler) OnAlarm() { + now := time.Now() + // TODO(#497): TLP if !h.handshakeComplete { h.queueHandshakePacketsForRetransmission() h.handshakeCount++ } else if !h.lossTime.IsZero() { // Early retransmit or time loss detection - h.detectLostPackets() + h.detectLostPackets(now) } else { // RTO h.retransmitOldestTwoPackets() h.rtoCount++ } - h.updateLossDetectionAlarm() + h.updateLossDetectionAlarm(now) } func (h *sentPacketHandler) GetAlarmTimeout() time.Time { @@ -345,12 +373,11 @@ func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFra } func (h *sentPacketHandler) SendingAllowed() bool { - congestionLimited := h.bytesInFlight > h.congestion.GetCongestionWindow() + cwnd := h.congestion.GetCongestionWindow() + congestionLimited := h.bytesInFlight > cwnd maxTrackedLimited := protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()) >= protocol.MaxTrackedSentPackets if congestionLimited { - utils.Debugf("Congestion limited: bytes in flight %d, window %d", - h.bytesInFlight, - h.congestion.GetCongestionWindow()) + utils.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd) } // Workaround for #555: // Always allow sending of retransmissions. This should probably be limited diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler_test.go b/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler_test.go index 6648d55..e3f22c6 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler_test.go @@ -3,60 +3,15 @@ package ackhandler import ( "time" + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/congestion" + "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -type mockCongestion struct { - argsOnPacketSent []interface{} - maybeExitSlowStart bool - onRetransmissionTimeout bool - getCongestionWindow bool - packetsAcked [][]interface{} - packetsLost [][]interface{} -} - -func (m *mockCongestion) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration { - panic("not implemented") -} - -func (m *mockCongestion) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool { - m.argsOnPacketSent = []interface{}{sentTime, bytesInFlight, packetNumber, bytes, isRetransmittable} - return false -} - -func (m *mockCongestion) GetCongestionWindow() protocol.ByteCount { - m.getCongestionWindow = true - return protocol.DefaultTCPMSS -} - -func (m *mockCongestion) MaybeExitSlowStart() { - m.maybeExitSlowStart = true -} - -func (m *mockCongestion) OnRetransmissionTimeout(packetsRetransmitted bool) { - m.onRetransmissionTimeout = true -} - -func (m *mockCongestion) RetransmissionDelay() time.Duration { - return defaultRTOTimeout -} - -func (m *mockCongestion) SetNumEmulatedConnections(n int) { panic("not implemented") } -func (m *mockCongestion) OnConnectionMigration() { panic("not implemented") } -func (m *mockCongestion) SetSlowStartLargeReduction(enabled bool) { panic("not implemented") } - -func (m *mockCongestion) OnPacketAcked(n protocol.PacketNumber, l protocol.ByteCount, bif protocol.ByteCount) { - m.packetsAcked = append(m.packetsAcked, []interface{}{n, l, bif}) -} - -func (m *mockCongestion) OnPacketLost(n protocol.PacketNumber, l protocol.ByteCount, bif protocol.ByteCount) { - m.packetsLost = append(m.packetsLost, []interface{}{n, l, bif}) -} - func retransmittablePacket(num protocol.PacketNumber) *Packet { return &Packet{ PacketNumber: num, @@ -143,7 +98,7 @@ var _ = Describe("SentPacketHandler", func() { packet := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1} err := handler.SentPacket(&packet) Expect(err).ToNot(HaveOccurred()) - Expect(handler.packetHistory.Front().Value.SendTime.Unix()).To(BeNumerically("~", time.Now().Unix(), 1)) + Expect(handler.packetHistory.Front().Value.sendTime.Unix()).To(BeNumerically("~", time.Now().Unix(), 1)) }) It("does not store non-retransmittable packets", func() { @@ -553,9 +508,9 @@ var _ = Describe("SentPacketHandler", func() { It("computes the RTT", func() { now := time.Now() // First, fake the sent times of the first, second and last packet - getPacketElement(1).Value.SendTime = now.Add(-10 * time.Minute) - getPacketElement(2).Value.SendTime = now.Add(-5 * time.Minute) - getPacketElement(6).Value.SendTime = now.Add(-1 * time.Minute) + getPacketElement(1).Value.sendTime = now.Add(-10 * time.Minute) + getPacketElement(2).Value.sendTime = now.Add(-5 * time.Minute) + getPacketElement(6).Value.sendTime = now.Add(-1 * time.Minute) // Now, check that the proper times are used when calculating the deltas err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1}, 1, protocol.EncryptionUnencrypted, time.Now()) Expect(err).NotTo(HaveOccurred()) @@ -570,12 +525,50 @@ var _ = Describe("SentPacketHandler", func() { It("uses the DelayTime in the ack frame", func() { now := time.Now() - getPacketElement(1).Value.SendTime = now.Add(-10 * time.Minute) + getPacketElement(1).Value.sendTime = now.Add(-10 * time.Minute) err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, DelayTime: 5 * time.Minute}, 1, protocol.EncryptionUnencrypted, time.Now()) Expect(err).NotTo(HaveOccurred()) Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second)) }) }) + + Context("determinining, which ACKs we have received an ACK for", func() { + BeforeEach(func() { + morePackets := []*Packet{ + &Packet{PacketNumber: 13, Frames: []wire.Frame{&wire.AckFrame{LowestAcked: 80, LargestAcked: 100}, &streamFrame}, Length: 1}, + &Packet{PacketNumber: 14, Frames: []wire.Frame{&wire.AckFrame{LowestAcked: 50, LargestAcked: 200}, &streamFrame}, Length: 1}, + &Packet{PacketNumber: 15, Frames: []wire.Frame{&streamFrame}, Length: 1}, + } + for _, packet := range morePackets { + err := handler.SentPacket(packet) + Expect(err).NotTo(HaveOccurred()) + } + }) + + It("determines which ACK we have received an ACK for", func() { + err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 15, LowestAcked: 12}, 1, protocol.EncryptionUnencrypted, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201))) + }) + + It("doesn't do anything when the acked packet didn't contain an ACK", func() { + err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 13, LowestAcked: 13}, 1, protocol.EncryptionUnencrypted, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101))) + err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 15, LowestAcked: 15}, 2, protocol.EncryptionUnencrypted, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101))) + }) + + It("doesn't decrease the value", func() { + err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 14, LowestAcked: 14}, 1, protocol.EncryptionUnencrypted, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201))) + err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 13, LowestAcked: 13}, 2, protocol.EncryptionUnencrypted, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201))) + }) + }) }) Context("Retransmission handling", func() { @@ -583,13 +576,13 @@ var _ = Describe("SentPacketHandler", func() { BeforeEach(func() { packets = []*Packet{ - {PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1}, - {PacketNumber: 2, Frames: []wire.Frame{&streamFrame}, Length: 1}, - {PacketNumber: 3, Frames: []wire.Frame{&streamFrame}, Length: 1}, - {PacketNumber: 4, Frames: []wire.Frame{&streamFrame}, Length: 1}, - {PacketNumber: 5, Frames: []wire.Frame{&streamFrame}, Length: 1}, - {PacketNumber: 6, Frames: []wire.Frame{&streamFrame}, Length: 1}, - {PacketNumber: 7, Frames: []wire.Frame{&streamFrame}, Length: 1}, + {PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionUnencrypted}, + {PacketNumber: 2, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionUnencrypted}, + {PacketNumber: 3, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionUnencrypted}, + {PacketNumber: 4, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionSecure}, + {PacketNumber: 5, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionSecure}, + {PacketNumber: 6, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionForwardSecure}, + {PacketNumber: 7, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionForwardSecure}, } for _, packet := range packets { handler.SentPacket(packet) @@ -597,7 +590,7 @@ var _ = Describe("SentPacketHandler", func() { // Increase RTT, because the tests would be flaky otherwise handler.rttStats.UpdateRTT(time.Minute, 0, time.Now()) // Ack a single packet so that we have non-RTO timings - handler.ReceivedAck(&wire.AckFrame{LargestAcked: 2, LowestAcked: 2}, 1, protocol.EncryptionUnencrypted, time.Now()) + handler.ReceivedAck(&wire.AckFrame{LargestAcked: 2, LowestAcked: 2}, 1, protocol.EncryptionForwardSecure, time.Now()) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6))) }) @@ -606,7 +599,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("dequeues a packet for retransmission", func() { - getPacketElement(1).Value.SendTime = time.Now().Add(-time.Hour) + getPacketElement(1).Value.sendTime = time.Now().Add(-time.Hour) handler.OnAlarm() Expect(getPacketElement(1)).To(BeNil()) Expect(handler.retransmissionQueue).To(HaveLen(1)) @@ -617,15 +610,33 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.DequeuePacketForRetransmission()).To(BeNil()) }) - Context("StopWaitings", func() { - It("gets a StopWaitingFrame", func() { + It("deletes non forward-secure packets when the handshake completes", func() { + for i := protocol.PacketNumber(1); i <= 7; i++ { + if i == 2 { // packet 2 was already acked in BeforeEach + continue + } + handler.queuePacketForRetransmission(getPacketElement(i)) + } + Expect(handler.retransmissionQueue).To(HaveLen(6)) + handler.SetHandshakeComplete() + packet := handler.DequeuePacketForRetransmission() + Expect(packet).ToNot(BeNil()) + Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(6))) + packet = handler.DequeuePacketForRetransmission() + Expect(packet).ToNot(BeNil()) + Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(7))) + Expect(handler.DequeuePacketForRetransmission()).To(BeNil()) + }) + + Context("STOP_WAITINGs", func() { + It("gets a STOP_WAITING frame", func() { ack := wire.AckFrame{LargestAcked: 5, LowestAcked: 5} - err := handler.ReceivedAck(&ack, 2, protocol.EncryptionUnencrypted, time.Now()) + err := handler.ReceivedAck(&ack, 2, protocol.EncryptionForwardSecure, time.Now()) Expect(err).ToNot(HaveOccurred()) Expect(handler.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 6})) }) - It("gets a StopWaitingFrame after queueing a retransmission", func() { + It("gets a STOP_WAITING frame after queueing a retransmission", func() { handler.queuePacketForRetransmission(getPacketElement(5)) Expect(handler.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 6})) }) @@ -662,7 +673,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(err).NotTo(HaveOccurred()) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2))) - handler.packetHistory.Front().Value.SendTime = time.Now().Add(-time.Hour) + handler.packetHistory.Front().Value.sendTime = time.Now().Add(-time.Hour) handler.OnAlarm() Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(0))) @@ -670,15 +681,23 @@ var _ = Describe("SentPacketHandler", func() { Context("congestion", func() { var ( - cong *mockCongestion + cong *mocks.MockSendAlgorithm ) BeforeEach(func() { - cong = &mockCongestion{} + cong = mocks.NewMockSendAlgorithm(mockCtrl) + cong.EXPECT().RetransmissionDelay().AnyTimes() handler.congestion = cong }) It("should call OnSent", func() { + cong.EXPECT().OnPacketSent( + gomock.Any(), + protocol.ByteCount(42), + protocol.PacketNumber(1), + protocol.ByteCount(42), + true, + ) p := &Packet{ PacketNumber: 1, Length: 42, @@ -686,62 +705,60 @@ var _ = Describe("SentPacketHandler", func() { } err := handler.SentPacket(p) Expect(err).NotTo(HaveOccurred()) - Expect(cong.argsOnPacketSent[1]).To(Equal(protocol.ByteCount(42))) - Expect(cong.argsOnPacketSent[2]).To(Equal(protocol.PacketNumber(1))) - Expect(cong.argsOnPacketSent[3]).To(Equal(protocol.ByteCount(42))) - Expect(cong.argsOnPacketSent[4]).To(BeTrue()) }) It("should call MaybeExitSlowStart and OnPacketAcked", func() { + cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) + cong.EXPECT().MaybeExitSlowStart() + cong.EXPECT().OnPacketAcked( + protocol.PacketNumber(1), + protocol.ByteCount(1), + protocol.ByteCount(1), + ) handler.SentPacket(retransmittablePacket(1)) handler.SentPacket(retransmittablePacket(2)) err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionForwardSecure, time.Now()) Expect(err).NotTo(HaveOccurred()) - Expect(cong.maybeExitSlowStart).To(BeTrue()) - Expect(cong.packetsAcked).To(BeEquivalentTo([][]interface{}{ - {protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(1)}, - })) - Expect(cong.packetsLost).To(BeEmpty()) }) It("should call MaybeExitSlowStart and OnPacketLost", func() { + cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3) + cong.EXPECT().OnRetransmissionTimeout(true).Times(2) + cong.EXPECT().OnPacketLost( + protocol.PacketNumber(1), + protocol.ByteCount(1), + protocol.ByteCount(2), + ) + cong.EXPECT().OnPacketLost( + protocol.PacketNumber(2), + protocol.ByteCount(1), + protocol.ByteCount(1), + ) handler.SentPacket(retransmittablePacket(1)) handler.SentPacket(retransmittablePacket(2)) handler.SentPacket(retransmittablePacket(3)) handler.OnAlarm() // RTO, meaning 2 lost packets - Expect(cong.maybeExitSlowStart).To(BeFalse()) - Expect(cong.onRetransmissionTimeout).To(BeTrue()) - Expect(cong.packetsAcked).To(BeEmpty()) - Expect(cong.packetsLost).To(BeEquivalentTo([][]interface{}{ - {protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)}, - {protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(1)}, - })) }) It("allows or denies sending based on congestion", func() { + Expect(handler.retransmissionQueue).To(BeEmpty()) + handler.bytesInFlight = 100 + cong.EXPECT().GetCongestionWindow().Return(protocol.MaxByteCount) Expect(handler.SendingAllowed()).To(BeTrue()) - err := handler.SentPacket(&Packet{ - PacketNumber: 1, - Frames: []wire.Frame{&wire.PingFrame{}}, - Length: protocol.DefaultTCPMSS + 1, - }) - Expect(err).NotTo(HaveOccurred()) + cong.EXPECT().GetCongestionWindow().Return(protocol.ByteCount(0)) Expect(handler.SendingAllowed()).To(BeFalse()) }) It("allows or denies sending based on the number of tracked packets", func() { + cong.EXPECT().GetCongestionWindow().Return(protocol.MaxByteCount).AnyTimes() Expect(handler.SendingAllowed()).To(BeTrue()) handler.retransmissionQueue = make([]*Packet, protocol.MaxTrackedSentPackets) Expect(handler.SendingAllowed()).To(BeFalse()) }) It("allows sending if there are retransmisisons outstanding", func() { - err := handler.SentPacket(&Packet{ - PacketNumber: 1, - Frames: []wire.Frame{&wire.PingFrame{}}, - Length: protocol.DefaultTCPMSS + 1, - }) - Expect(err).NotTo(HaveOccurred()) + handler.bytesInFlight = 100 + cong.EXPECT().GetCongestionWindow().Return(protocol.ByteCount(0)).AnyTimes() Expect(handler.SendingAllowed()).To(BeFalse()) handler.retransmissionQueue = []*Packet{nil} Expect(handler.SendingAllowed()).To(BeTrue()) @@ -799,7 +816,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.lossTime.Sub(time.Now())).To(BeNumerically("~", time.Hour*9/8, time.Minute)) Expect(handler.GetAlarmTimeout().Sub(time.Now())).To(BeNumerically("~", time.Hour*9/8, time.Minute)) - handler.packetHistory.Front().Value.SendTime = time.Now().Add(-2 * time.Hour) + handler.packetHistory.Front().Value.sendTime = time.Now().Add(-2 * time.Hour) handler.OnAlarm() Expect(handler.DequeuePacketForRetransmission()).NotTo(BeNil()) }) @@ -843,7 +860,7 @@ var _ = Describe("SentPacketHandler", func() { err = handler.SentPacket(handshakePacket(4)) Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionSecure, time.Now().Add(time.Hour)) + err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionSecure, time.Now()) Expect(err).NotTo(HaveOccurred()) Expect(handler.lossTime.IsZero()).To(BeTrue()) handshakeTimeout := handler.computeHandshakeTimeout() diff --git a/vendor/github.com/lucas-clemente/quic-go/client.go b/vendor/github.com/lucas-clemente/quic-go/client.go index 7101a67..2180853 100644 --- a/vendor/github.com/lucas-clemente/quic-go/client.go +++ b/vendor/github.com/lucas-clemente/quic-go/client.go @@ -60,33 +60,15 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) return Dial(udpConn, udpAddr, addr, tlsConf, config) } -// DialAddrNonFWSecure establishes a new QUIC connection to a server. -// The hostname for SNI is taken from the given address. -func DialAddrNonFWSecure( - addr string, - tlsConf *tls.Config, - config *Config, -) (NonFWSession, error) { - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - return nil, err - } - return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config) -} - -// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn. +// Dial establishes a new QUIC connection to a server using a net.PacketConn. // The host parameter is used for SNI. -func DialNonFWSecure( +func Dial( pconn net.PacketConn, remoteAddr net.Addr, host string, tlsConf *tls.Config, config *Config, -) (NonFWSession, error) { +) (Session, error) { connID, err := generateConnectionID() if err != nil { return nil, err @@ -115,31 +97,11 @@ func DialNonFWSecure( } utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version) - go c.listen() if err := c.dial(); err != nil { return nil, err } - return c.session.(NonFWSession), nil -} - -// Dial establishes a new QUIC connection to a server using a net.PacketConn. -// The host parameter is used for SNI. -func Dial( - pconn net.PacketConn, - remoteAddr net.Addr, - host string, - tlsConf *tls.Config, - config *Config, -) (Session, error) { - sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config) - if err != nil { - return nil, err - } - if err := sess.WaitUntilHandshakeComplete(); err != nil { - return nil, err - } - return sess, nil + return c.session, nil } // populateClientConfig populates fields in the quic.Config with their default values, if none are set @@ -199,6 +161,7 @@ func (c *client) dialGQUIC() error { if err := c.createNewGQUICSession(); err != nil { return err } + go c.listen() return c.establishSecureConnection() } @@ -224,6 +187,7 @@ func (c *client) dialTLS() error { if err := c.createNewTLSSession(eh.GetPeerParams(), c.version); err != nil { return err } + go c.listen() if err := c.establishSecureConnection(); err != nil { if err != handshake.ErrCloseSessionForRetry { return err @@ -267,14 +231,8 @@ func (c *client) establishSecureConnection() error { select { case <-errorChan: return runErr - case ev := <-c.session.handshakeStatus(): - if ev.err != nil { - return ev.err - } - if !c.version.UsesTLS() && ev.encLevel != protocol.EncryptionSecure { - return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel) - } - return nil + case err := <-c.session.handshakeStatus(): + return err } } diff --git a/vendor/github.com/lucas-clemente/quic-go/client_test.go b/vendor/github.com/lucas-clemente/quic-go/client_test.go index ac7ae6d..a955b5e 100644 --- a/vendor/github.com/lucas-clemente/quic-go/client_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/client_test.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "net" + "os" "sync/atomic" "time" @@ -100,57 +101,7 @@ var _ = Describe("Client", func() { generateConnectionID = origGenerateConnectionID }) - It("dials non-forward-secure", func() { - packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID) - dialed := make(chan struct{}) - go func() { - defer GinkgoRecover() - s, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config) - Expect(err).ToNot(HaveOccurred()) - Expect(s).ToNot(BeNil()) - close(dialed) - }() - Consistently(dialed).ShouldNot(BeClosed()) - sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} - Eventually(dialed).Should(BeClosed()) - }) - - It("dials a non-forward-secure address", func() { - serverAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") - Expect(err).ToNot(HaveOccurred()) - server, err := net.ListenUDP("udp", serverAddr) - Expect(err).ToNot(HaveOccurred()) - defer server.Close() - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - for { - _, clientAddr, err := server.ReadFromUDP(make([]byte, 200)) - if err != nil { - return - } - _, err = server.WriteToUDP(acceptClientVersionPacket(cl.connectionID), clientAddr) - Expect(err).ToNot(HaveOccurred()) - } - }() - - dialed := make(chan struct{}) - go func() { - defer GinkgoRecover() - s, err := DialAddrNonFWSecure(server.LocalAddr().String(), nil, config) - Expect(err).ToNot(HaveOccurred()) - Expect(s).ToNot(BeNil()) - close(dialed) - }() - Consistently(dialed).ShouldNot(BeClosed()) - sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} - Eventually(dialed).Should(BeClosed()) - server.Close() - Eventually(done).Should(BeClosed()) - }) - - It("Dial only returns after the handshake is complete", func() { + It("returns after the handshake is complete", func() { packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID) dialed := make(chan struct{}) go func() { @@ -160,13 +111,14 @@ var _ = Describe("Client", func() { Expect(s).ToNot(BeNil()) close(dialed) }() - sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} - Consistently(dialed).ShouldNot(BeClosed()) - close(sess.handshakeComplete) + close(sess.handshakeChan) Eventually(dialed).Should(BeClosed()) }) It("resolves the address", func() { + if os.Getenv("APPVEYOR") == "True" { + Skip("This test is flaky on AppVeyor.") + } closeErr := errors.New("peer doesn't reply") remoteAddrChan := make(chan string) newClientSession = func( @@ -245,22 +197,7 @@ var _ = Describe("Client", func() { Expect(err).To(MatchError(testErr)) close(done) }() - sess.handshakeChan <- handshakeEvent{err: testErr} - Eventually(done).Should(BeClosed()) - }) - - It("returns an error that occurs while waiting for the handshake to complete", func() { - testErr := errors.New("late handshake error") - packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) - Expect(err).To(MatchError(testErr)) - close(done) - }() - sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} - sess.handshakeComplete <- testErr + sess.handshakeChan <- testErr Eventually(done).Should(BeClosed()) }) @@ -305,7 +242,7 @@ var _ = Describe("Client", func() { ) (packetHandler, error) { return nil, testErr } - _, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config) + _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) Expect(err).To(MatchError(testErr)) }) @@ -331,7 +268,7 @@ var _ = Describe("Client", func() { Expect(newVersion).ToNot(Equal(cl.version)) Expect(config.Versions).To(ContainElement(newVersion)) sessionChan := make(chan *mockSession) - handshakeChan := make(chan handshakeEvent) + handshakeChan := make(chan error) newClientSession = func( _ connection, _ string, @@ -382,7 +319,7 @@ var _ = Describe("Client", func() { Expect(negotiatedVersions).To(ContainElement(newVersion)) Expect(initialVersion).To(Equal(actualInitialVersion)) - handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} + close(handshakeChan) Eventually(established).Should(BeClosed()) }) diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto_stream.go b/vendor/github.com/lucas-clemente/quic-go/crypto_stream.go new file mode 100644 index 0000000..8e96ec1 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/crypto_stream.go @@ -0,0 +1,41 @@ +package quic + +import ( + "io" + + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type cryptoStreamI interface { + StreamID() protocol.StreamID + io.Reader + io.Writer + handleStreamFrame(*wire.StreamFrame) error + popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool) + closeForShutdown(error) + setReadOffset(protocol.ByteCount) + // methods needed for flow control + getWindowUpdate() protocol.ByteCount + handleMaxStreamDataFrame(*wire.MaxStreamDataFrame) +} + +type cryptoStream struct { + *stream +} + +var _ cryptoStreamI = &cryptoStream{} + +func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI { + str := newStream(version.CryptoStreamID(), sender, flowController, version) + return &cryptoStream{str} +} + +// SetReadOffset sets the read offset. +// It is only needed for the crypto stream. +// It must not be called concurrently with any other stream methods, especially Read and Write. +func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) { + s.receiveStream.readOffset = offset + s.receiveStream.frameQueue.readPosition = offset +} diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto_stream_test.go b/vendor/github.com/lucas-clemente/quic-go/crypto_stream_test.go new file mode 100644 index 0000000..d5ec3be --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/crypto_stream_test.go @@ -0,0 +1,26 @@ +package quic + +import ( + "github.com/lucas-clemente/quic-go/internal/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Crypto Stream", func() { + var ( + str *cryptoStream + mockSender *MockStreamSender + ) + + BeforeEach(func() { + mockSender = NewMockStreamSender(mockCtrl) + str = newCryptoStream(mockSender, nil, protocol.VersionWhatever).(*cryptoStream) + }) + + It("sets the read offset", func() { + str.setReadOffset(0x42) + Expect(str.receiveStream.readOffset).To(Equal(protocol.ByteCount(0x42))) + Expect(str.receiveStream.frameQueue.readPosition).To(Equal(protocol.ByteCount(0x42))) + }) +}) diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/client.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/client.go index 9d845ec..1bcc68a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/client.go +++ b/vendor/github.com/lucas-clemente/quic-go/h2quic/client.go @@ -97,45 +97,44 @@ func (c *client) handleHeaderStream() { decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) h2framer := http2.NewFramer(nil, c.headerStream) - var lastStream protocol.StreamID + var err error + for err == nil { + err = c.readResponse(h2framer, decoder) + } + utils.Debugf("Error handling header stream: %s", err) + c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error()) + // stop all running request + close(c.headerErrored) +} - for { - frame, err := h2framer.ReadFrame() - if err != nil { - c.headerErr = qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame") - break - } - lastStream = protocol.StreamID(frame.Header().StreamID) - hframe, ok := frame.(*http2.HeadersFrame) - if !ok { - c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame") - break - } - mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe} - mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment()) - if err != nil { - c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields") - break - } - - c.mutex.RLock() - responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)] - c.mutex.RUnlock() - if !ok { - c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream)) - break - } - - rsp, err := responseFromHeaders(mhframe) - if err != nil { - c.headerErr = qerr.Error(qerr.InternalError, err.Error()) - } - responseChan <- rsp +func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error { + frame, err := h2framer.ReadFrame() + if err != nil { + return err + } + hframe, ok := frame.(*http2.HeadersFrame) + if !ok { + return errors.New("not a headers frame") + } + mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe} + mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment()) + if err != nil { + return fmt.Errorf("cannot read header fields: %s", err.Error()) } - // stop all running request - utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error()) - close(c.headerErrored) + c.mutex.RLock() + responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)] + c.mutex.RUnlock() + if !ok { + return fmt.Errorf("response channel for stream %d not found", hframe.StreamID) + } + + rsp, err := responseFromHeaders(mhframe) + if err != nil { + return err + } + responseChan <- rsp + return nil } // Roundtrip executes a request and returns a response diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/client_test.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/client_test.go index 24737e1..49d9b6d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/client_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/h2quic/client_test.go @@ -188,41 +188,31 @@ var _ = Describe("Client", func() { close(done) }) - It("closes the quic client when encountering an error on the header stream", func(done Done) { + It("closes the quic client when encountering an error on the header stream", func() { headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100)) - var doReturned bool + done := make(chan struct{}) go func() { defer GinkgoRecover() - var err error rsp, err := client.RoundTrip(request) Expect(err).To(MatchError(client.headerErr)) Expect(rsp).To(BeNil()) - doReturned = true + close(done) }() - Eventually(func() bool { return doReturned }).Should(BeTrue()) - Expect(client.headerErr).To(MatchError(qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame"))) + Eventually(done).Should(BeClosed()) + Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData)) Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr)) - close(done) - }, 2) + }) - It("returns subsequent request if there was an error on the header stream before", func(done Done) { - expectedErr := qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame") + It("returns subsequent request if there was an error on the header stream before", func() { session.streamsToOpen = []quic.Stream{headerStream, dataStream, newMockStream(7)} headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100)) - var firstReqReturned bool - go func() { - defer GinkgoRecover() - _, err := client.RoundTrip(request) - Expect(err).To(MatchError(expectedErr)) - firstReqReturned = true - }() - - Eventually(func() bool { return firstReqReturned }).Should(BeTrue()) - // now that the first request failed due to an error on the header stream, try another request _, err := client.RoundTrip(request) - Expect(err).To(MatchError(expectedErr)) - close(done) + Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{})) + Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidHeadersStreamData)) + // now that the first request failed due to an error on the header stream, try another request + _, nextErr := client.RoundTrip(request) + Expect(nextErr).To(MatchError(err)) }) It("blocks if no stream is available", func() { @@ -479,16 +469,9 @@ var _ = Describe("Client", func() { It("errors if the H2 frame is not a HeadersFrame", func() { h2framer.WritePing(true, [8]byte{0, 0, 0, 0, 0, 0, 0, 0}) - - var handlerReturned bool - go func() { - client.handleHeaderStream() - handlerReturned = true - }() - + client.handleHeaderStream() Eventually(client.headerErrored).Should(BeClosed()) Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame"))) - Eventually(func() bool { return handlerReturned }).Should(BeTrue()) }) It("errors if it can't read the HPACK encoded header fields", func() { @@ -497,16 +480,26 @@ var _ = Describe("Client", func() { EndHeaders: true, BlockFragment: []byte("invalid HPACK data"), }) - - var handlerReturned bool - go func() { - client.handleHeaderStream() - handlerReturned = true - }() - + client.handleHeaderStream() Eventually(client.headerErrored).Should(BeClosed()) - Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields"))) - Eventually(func() bool { return handlerReturned }).Should(BeTrue()) + Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData)) + Expect(client.headerErr.ErrorMessage).To(ContainSubstring("cannot read header fields")) + }) + + It("errors if the stream cannot be found", func() { + var headers bytes.Buffer + enc := hpack.NewEncoder(&headers) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + err := h2framer.WriteHeaders(http2.HeadersFrameParam{ + StreamID: 1337, + EndHeaders: true, + BlockFragment: headers.Bytes(), + }) + Expect(err).ToNot(HaveOccurred()) + client.handleHeaderStream() + Eventually(client.headerErrored).Should(BeClosed()) + Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData)) + Expect(client.headerErr.ErrorMessage).To(ContainSubstring("response channel for stream 1337 not found")) }) }) }) diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer_test.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer_test.go index e3e3e27..4ea0701 100644 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer_test.go @@ -11,6 +11,7 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" + quic "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -29,6 +30,8 @@ type mockStream struct { ctxCancel context.CancelFunc } +var _ quic.Stream = &mockStream{} + func newMockStream(id protocol.StreamID) *mockStream { s := &mockStream{ id: id, @@ -39,7 +42,8 @@ func newMockStream(id protocol.StreamID) *mockStream { } func (s *mockStream) Close() error { s.closed = true; s.ctxCancel(); return nil } -func (s *mockStream) Reset(error) { s.reset = true } +func (s *mockStream) CancelRead(quic.ErrorCode) error { s.reset = true; return nil } +func (s *mockStream) CancelWrite(quic.ErrorCode) error { panic("not implemented") } func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true; s.ctxCancel() } func (s mockStream) StreamID() protocol.StreamID { return s.id } func (s *mockStream) Context() context.Context { return s.ctx } diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/server.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/server.go index 0d0cecf..329edfd 100644 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/server.go +++ b/vendor/github.com/lucas-clemente/quic-go/h2quic/server.go @@ -50,6 +50,7 @@ type Server struct { listenerMutex sync.Mutex listener quic.Listener + closed bool supportedVersionsAsString string } @@ -88,6 +89,10 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error { return errors.New("use of h2quic.Server without http.Server") } s.listenerMutex.Lock() + if s.closed { + s.listenerMutex.Unlock() + return errors.New("Server is already closed") + } if s.listener != nil { s.listenerMutex.Unlock() return errors.New("ListenAndServe may only be called once") @@ -223,7 +228,8 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, } if responseWriter.dataStream != nil { if !streamEnded && !reqBody.requestRead { - responseWriter.dataStream.Reset(nil) + // in gQUIC, the error code doesn't matter, so just use 0 here + responseWriter.dataStream.CancelRead(0) } responseWriter.dataStream.Close() } @@ -241,6 +247,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, func (s *Server) Close() error { s.listenerMutex.Lock() defer s.listenerMutex.Unlock() + s.closed = true if s.listener != nil { err := s.listener.Close() s.listener = nil diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/server_test.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/server_test.go index 242652c..55ffd33 100644 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/server_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/h2quic/server_test.go @@ -70,6 +70,7 @@ func (s *mockSession) RemoteAddr() net.Addr { func (s *mockSession) Context() context.Context { return s.ctx } +func (s *mockSession) ConnectionState() quic.ConnectionState { panic("not implemented") } var _ = Describe("H2 server", func() { var ( @@ -410,6 +411,13 @@ var _ = Describe("H2 server", func() { Expect(err).NotTo(HaveOccurred()) }) + It("errors when ListenAndServer is called after Close", func() { + serv := &Server{Server: &http.Server{}} + Expect(serv.Close()).To(Succeed()) + err := serv.ListenAndServe() + Expect(err).To(MatchError("Server is already closed")) + }) + Context("ListenAndServe", func() { BeforeEach(func() { s.Server.Addr = "localhost:0" diff --git a/vendor/github.com/lucas-clemente/quic-go/interface.go b/vendor/github.com/lucas-clemente/quic-go/interface.go index 87bf9ea..b0a1829 100644 --- a/vendor/github.com/lucas-clemente/quic-go/interface.go +++ b/vendor/github.com/lucas-clemente/quic-go/interface.go @@ -19,20 +19,42 @@ type VersionNumber = protocol.VersionNumber // A Cookie can be used to verify the ownership of the client address. type Cookie = handshake.Cookie +// ConnectionState records basic details about the QUIC connection. +type ConnectionState = handshake.ConnectionState + +// An ErrorCode is an application-defined error code. +type ErrorCode = protocol.ApplicationErrorCode + // Stream is the interface implemented by QUIC streams type Stream interface { + // StreamID returns the stream ID. + StreamID() StreamID // Read reads data from the stream. // Read can be made to time out and return a net.Error with Timeout() == true // after a fixed time limit; see SetDeadline and SetReadDeadline. + // If the stream was canceled by the peer, the error implements the StreamError + // interface, and Canceled() == true. io.Reader // Write writes data to the stream. // Write can be made to time out and return a net.Error with Timeout() == true // after a fixed time limit; see SetDeadline and SetWriteDeadline. + // If the stream was canceled by the peer, the error implements the StreamError + // interface, and Canceled() == true. io.Writer + // Close closes the write-direction of the stream. + // Future calls to Write are not permitted after calling Close. + // It must not be called concurrently with Write. + // It must not be called after calling CancelWrite. io.Closer - StreamID() StreamID - // Reset closes the stream with an error. - Reset(error) + // CancelWrite aborts sending on this stream. + // It must not be called after Close. + // Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably. + // Write will unblock immediately, and future calls to Write will fail. + CancelWrite(ErrorCode) error + // CancelRead aborts receiving on this stream. + // It will ask the peer to stop transmitting stream data. + // Read will unblock immediately, and future Read calls will fail. + CancelRead(ErrorCode) error // The context is canceled as soon as the write-side of the stream is closed. // This happens when Close() is called, or when the stream is reset (either locally or remotely). // Warning: This API should not be considered stable and might change soon. @@ -53,6 +75,41 @@ type Stream interface { SetDeadline(t time.Time) error } +// A ReceiveStream is a unidirectional Receive Stream. +type ReceiveStream interface { + // see Stream.StreamID + StreamID() StreamID + // see Stream.Read + io.Reader + // see Stream.CancelRead + CancelRead(ErrorCode) error + // see Stream.SetReadDealine + SetReadDeadline(t time.Time) error +} + +// A SendStream is a unidirectional Send Stream. +type SendStream interface { + // see Stream.StreamID + StreamID() StreamID + // see Stream.Write + io.Writer + // see Stream.Close + io.Closer + // see Stream.CancelWrite + CancelWrite(ErrorCode) error + // see Stream.Context + Context() context.Context + // see Stream.SetWriteDeadline + SetWriteDeadline(t time.Time) error +} + +// StreamError is returned by Read and Write when the peer cancels the stream. +type StreamError interface { + error + Canceled() bool + ErrorCode() ErrorCode +} + // A Session is a QUIC connection between two peers. type Session interface { // AcceptStream returns the next stream opened by the peer, blocking until one is available. @@ -74,13 +131,9 @@ type Session interface { // The context is cancelled when the session is closed. // Warning: This API should not be considered stable and might change soon. Context() context.Context -} - -// A NonFWSession is a QUIC connection between two peers half-way through the handshake. -// The communication is encrypted, but not yet forward secure. -type NonFWSession interface { - Session - WaitUntilHandshakeComplete() error + // ConnectionState returns basic details about the QUIC connection. + // Warning: This API should not be considered stable and might change soon. + ConnectionState() ConnectionState } // Config contains all configuration data needed for a QUIC server or client. diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_manager.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_manager.go index 5aaa187..8b8c9fa 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_manager.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_manager.go @@ -18,6 +18,7 @@ type CertManager interface { GetLeafCertHash() (uint64, error) VerifyServerProof(proof, chlo, serverConfigData []byte) bool Verify(hostname string) error + GetChain() []*x509.Certificate } type certManager struct { @@ -54,6 +55,10 @@ func (c *certManager) SetData(data []byte) error { return nil } +func (c *certManager) GetChain() []*x509.Certificate { + return c.chain +} + func (c *certManager) GetCommonCertificateHashes() []byte { return getCommonCertificateHashes() } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go index e74c1d1..393f487 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go @@ -10,35 +10,30 @@ import ( ) type baseFlowController struct { - mutex sync.RWMutex - - rttStats *congestion.RTTStats - + // for sending data bytesSent protocol.ByteCount sendWindow protocol.ByteCount - lastWindowUpdateTime time.Time + // for receiving data + mutex sync.RWMutex + bytesRead protocol.ByteCount + highestReceived protocol.ByteCount + receiveWindow protocol.ByteCount + receiveWindowSize protocol.ByteCount + maxReceiveWindowSize protocol.ByteCount - bytesRead protocol.ByteCount - highestReceived protocol.ByteCount - receiveWindow protocol.ByteCount - receiveWindowIncrement protocol.ByteCount - maxReceiveWindowIncrement protocol.ByteCount + epochStartTime time.Time + epochStartOffset protocol.ByteCount + rttStats *congestion.RTTStats } func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) { - c.mutex.Lock() - defer c.mutex.Unlock() - c.bytesSent += n } // UpdateSendWindow should be called after receiving a WindowUpdateFrame // it returns true if the window was actually updated func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) { - c.mutex.Lock() - defer c.mutex.Unlock() - if offset > c.sendWindow { c.sendWindow = offset } @@ -57,52 +52,55 @@ func (c *baseFlowController) AddBytesRead(n protocol.ByteCount) { defer c.mutex.Unlock() // pretend we sent a WindowUpdate when reading the first byte - // this way auto-tuning of the window increment already works for the first WindowUpdate + // this way auto-tuning of the window size already works for the first WindowUpdate if c.bytesRead == 0 { - c.lastWindowUpdateTime = time.Now() + c.startNewAutoTuningEpoch() } c.bytesRead += n } +func (c *baseFlowController) hasWindowUpdate() bool { + bytesRemaining := c.receiveWindow - c.bytesRead + // update the window when more than the threshold was consumed + return bytesRemaining <= protocol.ByteCount((float64(c.receiveWindowSize) * float64((1 - protocol.WindowUpdateThreshold)))) +} + // getWindowUpdate updates the receive window, if necessary // it returns the new offset func (c *baseFlowController) getWindowUpdate() protocol.ByteCount { - diff := c.receiveWindow - c.bytesRead - // update the window when more than half of it was already consumed - if diff >= (c.receiveWindowIncrement / 2) { + if !c.hasWindowUpdate() { return 0 } - c.maybeAdjustWindowIncrement() - c.receiveWindow = c.bytesRead + c.receiveWindowIncrement - c.lastWindowUpdateTime = time.Now() + c.maybeAdjustWindowSize() + c.receiveWindow = c.bytesRead + c.receiveWindowSize return c.receiveWindow } -func (c *baseFlowController) IsBlocked() bool { - c.mutex.RLock() - defer c.mutex.RUnlock() - - return c.sendWindowSize() == 0 -} - -// maybeAdjustWindowIncrement increases the receiveWindowIncrement if we're sending WindowUpdates too often -func (c *baseFlowController) maybeAdjustWindowIncrement() { - if c.lastWindowUpdateTime.IsZero() { +// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often. +// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing. +func (c *baseFlowController) maybeAdjustWindowSize() { + bytesReadInEpoch := c.bytesRead - c.epochStartOffset + // don't do anything if less than half the window has been consumed + if bytesReadInEpoch <= c.receiveWindowSize/2 { return } - rtt := c.rttStats.SmoothedRTT() if rtt == 0 { return } - timeSinceLastWindowUpdate := time.Since(c.lastWindowUpdateTime) - // interval between the window updates is sufficiently large, no need to increase the increment - if timeSinceLastWindowUpdate >= 2*rtt { - return + fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize) + if time.Since(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) { + // window is consumed too fast, try to increase the window size + c.receiveWindowSize = utils.MinByteCount(2*c.receiveWindowSize, c.maxReceiveWindowSize) } - c.receiveWindowIncrement = utils.MinByteCount(2*c.receiveWindowIncrement, c.maxReceiveWindowIncrement) + c.startNewAutoTuningEpoch() +} + +func (c *baseFlowController) startNewAutoTuningEpoch() { + c.epochStartTime = time.Now() + c.epochStartOffset = c.bytesRead } func (c *baseFlowController) checkFlowControlViolation() bool { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller_test.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller_test.go index 0ac218b..f996a28 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller_test.go @@ -1,6 +1,8 @@ package flowcontrol import ( + "os" + "strconv" "time" "github.com/lucas-clemente/quic-go/congestion" @@ -9,6 +11,16 @@ import ( . "github.com/onsi/gomega" ) +// on the CIs, the timing is a lot less precise, so scale every duration by this factor +func scaleDuration(t time.Duration) time.Duration { + scaleFactor := 1 + if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set + scaleFactor = f + } + Expect(scaleFactor).ToNot(BeZero()) + return time.Duration(scaleFactor) * t +} + var _ = Describe("Base Flow controller", func() { var controller *baseFlowController @@ -49,22 +61,18 @@ var _ = Describe("Base Flow controller", func() { controller.UpdateSendWindow(10) Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(20))) }) - - It("says when it's blocked", func() { - controller.UpdateSendWindow(100) - Expect(controller.IsBlocked()).To(BeFalse()) - controller.AddBytesSent(100) - Expect(controller.IsBlocked()).To(BeTrue()) - }) }) Context("receive flow control", func() { - var receiveWindow protocol.ByteCount = 10000 - var receiveWindowIncrement protocol.ByteCount = 600 + var ( + receiveWindow protocol.ByteCount = 10000 + receiveWindowSize protocol.ByteCount = 1000 + ) BeforeEach(func() { + controller.bytesRead = receiveWindow - receiveWindowSize controller.receiveWindow = receiveWindow - controller.receiveWindowIncrement = receiveWindowIncrement + controller.receiveWindowSize = receiveWindowSize }) It("adds bytes read", func() { @@ -74,31 +82,30 @@ var _ = Describe("Base Flow controller", func() { }) It("triggers a window update when necessary", func() { - controller.lastWindowUpdateTime = time.Now().Add(-time.Hour) - readPosition := receiveWindow - receiveWindowIncrement/2 + 1 + bytesConsumed := float64(receiveWindowSize)*protocol.WindowUpdateThreshold + 1 // consumed 1 byte more than the threshold + bytesRemaining := receiveWindowSize - protocol.ByteCount(bytesConsumed) + readPosition := receiveWindow - bytesRemaining controller.bytesRead = readPosition offset := controller.getWindowUpdate() - Expect(offset).To(Equal(readPosition + receiveWindowIncrement)) - Expect(controller.receiveWindow).To(Equal(readPosition + receiveWindowIncrement)) - Expect(controller.lastWindowUpdateTime).To(BeTemporally("~", time.Now(), 20*time.Millisecond)) + Expect(offset).To(Equal(readPosition + receiveWindowSize)) + Expect(controller.receiveWindow).To(Equal(readPosition + receiveWindowSize)) }) It("doesn't trigger a window update when not necessary", func() { - lastWindowUpdateTime := time.Now().Add(-time.Hour) - controller.lastWindowUpdateTime = lastWindowUpdateTime - readPosition := receiveWindow - receiveWindow/2 - 1 + bytesConsumed := float64(receiveWindowSize)*protocol.WindowUpdateThreshold - 1 // consumed 1 byte less than the threshold + bytesRemaining := receiveWindowSize - protocol.ByteCount(bytesConsumed) + readPosition := receiveWindow - bytesRemaining controller.bytesRead = readPosition offset := controller.getWindowUpdate() Expect(offset).To(BeZero()) - Expect(controller.lastWindowUpdateTime).To(Equal(lastWindowUpdateTime)) }) - Context("receive window increment auto-tuning", func() { - var oldIncrement protocol.ByteCount + Context("receive window size auto-tuning", func() { + var oldWindowSize protocol.ByteCount BeforeEach(func() { - oldIncrement = controller.receiveWindowIncrement - controller.maxReceiveWindowIncrement = 3000 + oldWindowSize = controller.receiveWindowSize + controller.maxReceiveWindowSize = 5000 }) // update the congestion such that it returns a given value for the smoothed RTT @@ -107,72 +114,98 @@ var _ = Describe("Base Flow controller", func() { Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked } - It("doesn't increase the increment for a new stream", func() { - controller.maybeAdjustWindowIncrement() - Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement)) + It("doesn't increase the window size for a new stream", func() { + controller.maybeAdjustWindowSize() + Expect(controller.receiveWindowSize).To(Equal(oldWindowSize)) }) - It("doesn't increase the increment when no RTT estimate is available", func() { + It("doesn't increase the window size when no RTT estimate is available", func() { setRtt(0) - controller.lastWindowUpdateTime = time.Now() - controller.maybeAdjustWindowIncrement() - Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement)) + controller.startNewAutoTuningEpoch() + controller.AddBytesRead(400) + offset := controller.getWindowUpdate() + Expect(offset).ToNot(BeZero()) // make sure a window update is sent + Expect(controller.receiveWindowSize).To(Equal(oldWindowSize)) }) - It("increases the increment when the last WindowUpdate was sent less than two RTTs ago", func() { - setRtt(20 * time.Millisecond) - controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond) - controller.maybeAdjustWindowIncrement() - Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement)) - }) - - It("doesn't increase the increase increment when the last WindowUpdate was sent more than two RTTs ago", func() { - setRtt(20 * time.Millisecond) - controller.lastWindowUpdateTime = time.Now().Add(-45 * time.Millisecond) - controller.maybeAdjustWindowIncrement() - Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement)) - }) - - It("doesn't increase the increment to a value higher than the maxReceiveWindowIncrement", func() { - setRtt(20 * time.Millisecond) - controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond) - controller.maybeAdjustWindowIncrement() - Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement)) // 1200 - // because the lastWindowUpdateTime is updated by MaybeTriggerWindowUpdate(), we can just call maybeAdjustWindowIncrement() multiple times and get an increase of the increment every time - controller.maybeAdjustWindowIncrement() - Expect(controller.receiveWindowIncrement).To(Equal(2 * 2 * oldIncrement)) // 2400 - controller.maybeAdjustWindowIncrement() - Expect(controller.receiveWindowIncrement).To(Equal(controller.maxReceiveWindowIncrement)) // 3000 - controller.maybeAdjustWindowIncrement() - Expect(controller.receiveWindowIncrement).To(Equal(controller.maxReceiveWindowIncrement)) // 3000 - }) - - It("returns the new increment when updating the window", func() { - setRtt(20 * time.Millisecond) - controller.AddBytesRead(9900) // receive window is 10000 - controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond) + It("increases the window size if read so fast that the window would be consumed in less than 4 RTTs", func() { + bytesRead := controller.bytesRead + rtt := scaleDuration(20 * time.Millisecond) + setRtt(rtt) + // consume more than 2/3 of the window... + dataRead := receiveWindowSize*2/3 + 1 + // ... in 4*2/3 of the RTT + controller.epochStartOffset = controller.bytesRead + controller.epochStartTime = time.Now().Add(-rtt * 4 * 2 / 3) + controller.AddBytesRead(dataRead) offset := controller.getWindowUpdate() Expect(offset).ToNot(BeZero()) - newIncrement := controller.receiveWindowIncrement - Expect(newIncrement).To(Equal(2 * oldIncrement)) - Expect(offset).To(Equal(protocol.ByteCount(9900 + newIncrement))) + // check that the window size was increased + newWindowSize := controller.receiveWindowSize + Expect(newWindowSize).To(Equal(2 * oldWindowSize)) + // check that the new window size was used to increase the offset + Expect(offset).To(Equal(protocol.ByteCount(bytesRead + dataRead + newWindowSize))) }) - It("increases the increment sent in the first WindowUpdate, if data is read fast enough", func() { - setRtt(20 * time.Millisecond) - controller.AddBytesRead(9900) + It("doesn't increase the window size if data is read so fast that the window would be consumed in less than 4 RTTs, but less than half the window has been read", func() { + // this test only makes sense if a window update is triggered before half of the window has been consumed + Expect(protocol.WindowUpdateThreshold).To(BeNumerically(">", 1/3)) + bytesRead := controller.bytesRead + rtt := scaleDuration(20 * time.Millisecond) + setRtt(rtt) + // consume more than 2/3 of the window... + dataRead := receiveWindowSize*1/3 + 1 + // ... in 4*2/3 of the RTT + controller.epochStartOffset = controller.bytesRead + controller.epochStartTime = time.Now().Add(-rtt * 4 * 1 / 3) + controller.AddBytesRead(dataRead) offset := controller.getWindowUpdate() Expect(offset).ToNot(BeZero()) - Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement)) + // check that the window size was not increased + newWindowSize := controller.receiveWindowSize + Expect(newWindowSize).To(Equal(oldWindowSize)) + // check that the new window size was used to increase the offset + Expect(offset).To(Equal(protocol.ByteCount(bytesRead + dataRead + newWindowSize))) }) - It("doesn't increamse the increment sent in the first WindowUpdate, if data is read slowly", func() { - setRtt(5 * time.Millisecond) - controller.AddBytesRead(9900) - time.Sleep(15 * time.Millisecond) // more than 2x RTT + It("doesn't increase the window size if read too slowly", func() { + bytesRead := controller.bytesRead + rtt := scaleDuration(20 * time.Millisecond) + setRtt(rtt) + // consume less than 2/3 of the window... + dataRead := receiveWindowSize*2/3 - 1 + // ... in 4*2/3 of the RTT + controller.epochStartOffset = controller.bytesRead + controller.epochStartTime = time.Now().Add(-rtt * 4 * 2 / 3) + controller.AddBytesRead(dataRead) offset := controller.getWindowUpdate() Expect(offset).ToNot(BeZero()) - Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement)) + // check that the window size was not increased + Expect(controller.receiveWindowSize).To(Equal(oldWindowSize)) + // check that the new window size was used to increase the offset + Expect(offset).To(Equal(protocol.ByteCount(bytesRead + dataRead + oldWindowSize))) + }) + + It("doesn't increase the window size to a value higher than the maxReceiveWindowSize", func() { + resetEpoch := func() { + // make sure the next call to maybeAdjustWindowSize will increase the window + controller.epochStartTime = time.Now().Add(-time.Millisecond) + controller.epochStartOffset = controller.bytesRead + controller.AddBytesRead(controller.receiveWindowSize/2 + 1) + } + setRtt(scaleDuration(20 * time.Millisecond)) + resetEpoch() + controller.maybeAdjustWindowSize() + Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize)) // 2000 + // because the lastWindowUpdateTime is updated by MaybeTriggerWindowUpdate(), we can just call maybeAdjustWindowSize() multiple times and get an increase of the window size every time + resetEpoch() + controller.maybeAdjustWindowSize() + Expect(controller.receiveWindowSize).To(Equal(2 * 2 * oldWindowSize)) // 4000 + resetEpoch() + controller.maybeAdjustWindowSize() + Expect(controller.receiveWindowSize).To(Equal(controller.maxReceiveWindowSize)) // 5000 + controller.maybeAdjustWindowSize() + Expect(controller.receiveWindowSize).To(Equal(controller.maxReceiveWindowSize)) // 5000 }) }) }) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go index 934d646..ff9c7f2 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go @@ -2,7 +2,6 @@ package flowcontrol import ( "fmt" - "time" "github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -11,6 +10,7 @@ import ( ) type connectionFlowController struct { + lastBlockedAt protocol.ByteCount baseFlowController } @@ -25,21 +25,29 @@ func NewConnectionFlowController( ) ConnectionFlowController { return &connectionFlowController{ baseFlowController: baseFlowController{ - rttStats: rttStats, - receiveWindow: receiveWindow, - receiveWindowIncrement: receiveWindow, - maxReceiveWindowIncrement: maxReceiveWindow, + rttStats: rttStats, + receiveWindow: receiveWindow, + receiveWindowSize: receiveWindow, + maxReceiveWindowSize: maxReceiveWindow, }, } } func (c *connectionFlowController) SendWindowSize() protocol.ByteCount { - c.mutex.RLock() - defer c.mutex.RUnlock() - return c.baseFlowController.sendWindowSize() } +// IsNewlyBlocked says if it is newly blocked by flow control. +// For every offset, it only returns true once. +// If it is blocked, the offset is returned. +func (c *connectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { + if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt { + return false, 0 + } + c.lastBlockedAt = c.sendWindow + return true, c.sendWindow +} + // IncrementHighestReceived adds an increment to the highestReceived value func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error { c.mutex.Lock() @@ -54,24 +62,22 @@ func (c *connectionFlowController) IncrementHighestReceived(increment protocol.B func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount { c.mutex.Lock() - defer c.mutex.Unlock() - - oldWindowIncrement := c.receiveWindowIncrement + oldWindowSize := c.receiveWindowSize offset := c.baseFlowController.getWindowUpdate() - if oldWindowIncrement < c.receiveWindowIncrement { - utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10)) + if oldWindowSize < c.receiveWindowSize { + utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10)) } + c.mutex.Unlock() return offset } -// EnsureMinimumWindowIncrement sets a minimum window increment +// EnsureMinimumWindowSize sets a minimum window size // it should make sure that the connection-level window is increased when a stream-level window grows -func (c *connectionFlowController) EnsureMinimumWindowIncrement(inc protocol.ByteCount) { +func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) { c.mutex.Lock() - defer c.mutex.Unlock() - - if inc > c.receiveWindowIncrement { - c.receiveWindowIncrement = utils.MinByteCount(inc, c.maxReceiveWindowIncrement) - c.lastWindowUpdateTime = time.Time{} // disables autotuning for the next window update + if inc > c.receiveWindowSize { + c.receiveWindowSize = utils.MinByteCount(inc, c.maxReceiveWindowSize) + c.startNewAutoTuningEpoch() } + c.mutex.Unlock() } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller_test.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller_test.go index dc400e1..056daf3 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller_test.go @@ -32,12 +32,12 @@ var _ = Describe("Connection Flow controller", func() { fc := NewConnectionFlowController(receiveWindow, maxReceiveWindow, rttStats).(*connectionFlowController) Expect(fc.receiveWindow).To(Equal(receiveWindow)) - Expect(fc.maxReceiveWindowIncrement).To(Equal(maxReceiveWindow)) + Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow)) }) }) Context("receive flow control", func() { - It("increases the highestReceived by a given increment", func() { + It("increases the highestReceived by a given window size", func() { controller.highestReceived = 1337 controller.IncrementHighestReceived(123) Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1337 + 123))) @@ -46,64 +46,96 @@ var _ = Describe("Connection Flow controller", func() { Context("getting window updates", func() { BeforeEach(func() { controller.receiveWindow = 100 - controller.receiveWindowIncrement = 60 - controller.maxReceiveWindowIncrement = 1000 + controller.receiveWindowSize = 60 + controller.maxReceiveWindowSize = 1000 + controller.bytesRead = 100 - 60 }) It("gets a window update", func() { - controller.AddBytesRead(80) + windowSize := controller.receiveWindowSize + oldOffset := controller.bytesRead + dataRead := windowSize/2 - 1 // make sure not to trigger auto-tuning + controller.AddBytesRead(dataRead) offset := controller.GetWindowUpdate() - Expect(offset).To(Equal(protocol.ByteCount(80 + 60))) + Expect(offset).To(Equal(protocol.ByteCount(oldOffset + dataRead + 60))) }) It("autotunes the window", func() { - controller.AddBytesRead(80) - setRtt(20 * time.Millisecond) - controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond) + oldOffset := controller.bytesRead + oldWindowSize := controller.receiveWindowSize + rtt := scaleDuration(20 * time.Millisecond) + setRtt(rtt) + controller.epochStartTime = time.Now().Add(-time.Millisecond) + controller.epochStartOffset = oldOffset + dataRead := oldWindowSize/2 + 1 + controller.AddBytesRead(dataRead) offset := controller.GetWindowUpdate() - Expect(offset).To(Equal(protocol.ByteCount(80 + 2*60))) + newWindowSize := controller.receiveWindowSize + Expect(newWindowSize).To(Equal(2 * oldWindowSize)) + Expect(offset).To(Equal(protocol.ByteCount(oldOffset + dataRead + newWindowSize))) }) }) }) - Context("setting the minimum increment", func() { + Context("send flow control", func() { + It("says when it's blocked", func() { + controller.UpdateSendWindow(100) + Expect(controller.IsNewlyBlocked()).To(BeFalse()) + controller.AddBytesSent(100) + blocked, offset := controller.IsNewlyBlocked() + Expect(blocked).To(BeTrue()) + Expect(offset).To(Equal(protocol.ByteCount(100))) + }) + + It("doesn't say that it's newly blocked multiple times for the same offset", func() { + controller.UpdateSendWindow(100) + controller.AddBytesSent(100) + newlyBlocked, offset := controller.IsNewlyBlocked() + Expect(newlyBlocked).To(BeTrue()) + Expect(offset).To(Equal(protocol.ByteCount(100))) + newlyBlocked, _ = controller.IsNewlyBlocked() + Expect(newlyBlocked).To(BeFalse()) + controller.UpdateSendWindow(150) + controller.AddBytesSent(150) + newlyBlocked, offset = controller.IsNewlyBlocked() + Expect(newlyBlocked).To(BeTrue()) + }) + }) + + Context("setting the minimum window size", func() { var ( - oldIncrement protocol.ByteCount - receiveWindow protocol.ByteCount = 10000 - receiveWindowIncrement protocol.ByteCount = 600 + oldWindowSize protocol.ByteCount + receiveWindow protocol.ByteCount = 10000 + receiveWindowSize protocol.ByteCount = 1000 ) BeforeEach(func() { + controller.bytesRead = receiveWindowSize - receiveWindowSize controller.receiveWindow = receiveWindow - controller.receiveWindowIncrement = receiveWindowIncrement - oldIncrement = controller.receiveWindowIncrement - controller.maxReceiveWindowIncrement = 3000 + controller.receiveWindowSize = receiveWindowSize + oldWindowSize = controller.receiveWindowSize + controller.maxReceiveWindowSize = 3000 }) - It("sets the minimum window increment", func() { - controller.EnsureMinimumWindowIncrement(1000) - Expect(controller.receiveWindowIncrement).To(Equal(protocol.ByteCount(1000))) + It("sets the minimum window window size", func() { + controller.EnsureMinimumWindowSize(1800) + Expect(controller.receiveWindowSize).To(Equal(protocol.ByteCount(1800))) }) - It("doesn't reduce the window increment", func() { - controller.EnsureMinimumWindowIncrement(1) - Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement)) + It("doesn't reduce the window window size", func() { + controller.EnsureMinimumWindowSize(1) + Expect(controller.receiveWindowSize).To(Equal(oldWindowSize)) }) - It("doens't increase the increment beyond the maxReceiveWindowIncrement", func() { - max := controller.maxReceiveWindowIncrement - controller.EnsureMinimumWindowIncrement(2 * max) - Expect(controller.receiveWindowIncrement).To(Equal(max)) + It("doens't increase the window size beyond the maxReceiveWindowSize", func() { + max := controller.maxReceiveWindowSize + controller.EnsureMinimumWindowSize(2 * max) + Expect(controller.receiveWindowSize).To(Equal(max)) }) - It("doesn't auto-tune the window after the increment was increased", func() { - setRtt(20 * time.Millisecond) - controller.bytesRead = 9900 // receive window is 10000 - controller.lastWindowUpdateTime = time.Now().Add(-20 * time.Millisecond) - controller.EnsureMinimumWindowIncrement(912) - offset := controller.getWindowUpdate() - Expect(controller.receiveWindowIncrement).To(Equal(protocol.ByteCount(912))) // no auto-tuning - Expect(offset).To(Equal(protocol.ByteCount(9900 + 912))) + It("starts a new epoch after the window size was increased", func() { + controller.EnsureMinimumWindowSize(1912) + Expect(controller.epochStartTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) }) }) }) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go index 75ec6fa..61d57e3 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go @@ -5,7 +5,6 @@ import "github.com/lucas-clemente/quic-go/internal/protocol" type flowController interface { // for sending SendWindowSize() protocol.ByteCount - IsBlocked() bool UpdateSendWindow(protocol.ByteCount) AddBytesSent(protocol.ByteCount) // for receiving @@ -16,22 +15,28 @@ type flowController interface { // A StreamFlowController is a flow controller for a QUIC stream. type StreamFlowController interface { flowController + // for sending + IsBlocked() (bool, protocol.ByteCount) // for receiving // UpdateHighestReceived should be called when a new highest offset is received // final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame UpdateHighestReceived(offset protocol.ByteCount, final bool) error + // HasWindowUpdate says if it is necessary to update the window + HasWindowUpdate() bool } // The ConnectionFlowController is the flow controller for the connection. type ConnectionFlowController interface { flowController + // for sending + IsNewlyBlocked() (bool, protocol.ByteCount) } type connectionFlowControllerI interface { ConnectionFlowController // The following two methods are not supposed to be called from outside this packet, but are needed internally // for sending - EnsureMinimumWindowIncrement(protocol.ByteCount) + EnsureMinimumWindowSize(protocol.ByteCount) // for receiving IncrementHighestReceived(protocol.ByteCount) error } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go index dadba72..824139f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go @@ -37,11 +37,11 @@ func NewStreamFlowController( contributesToConnection: contributesToConnection, connection: cfc.(connectionFlowControllerI), baseFlowController: baseFlowController{ - rttStats: rttStats, - receiveWindow: receiveWindow, - receiveWindowIncrement: receiveWindow, - maxReceiveWindowIncrement: maxReceiveWindow, - sendWindow: initialSendWindow, + rttStats: rttStats, + receiveWindow: receiveWindow, + receiveWindowSize: receiveWindow, + maxReceiveWindowSize: maxReceiveWindow, + sendWindow: initialSendWindow, }, } } @@ -102,9 +102,6 @@ func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) { } func (c *streamFlowController) SendWindowSize() protocol.ByteCount { - c.mutex.Lock() - defer c.mutex.Unlock() - window := c.baseFlowController.sendWindowSize() if c.contributesToConnection { window = utils.MinByteCount(window, c.connection.SendWindowSize()) @@ -112,22 +109,39 @@ func (c *streamFlowController) SendWindowSize() protocol.ByteCount { return window } -func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { - c.mutex.Lock() - defer c.mutex.Unlock() +// IsBlocked says if it is blocked by stream-level flow control. +// If it is blocked, the offset is returned. +func (c *streamFlowController) IsBlocked() (bool, protocol.ByteCount) { + if c.sendWindowSize() != 0 { + return false, 0 + } + return true, c.sendWindow +} +func (c *streamFlowController) HasWindowUpdate() bool { + c.mutex.Lock() + hasWindowUpdate := !c.receivedFinalOffset && c.hasWindowUpdate() + c.mutex.Unlock() + return hasWindowUpdate +} + +func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { + // don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler + c.mutex.Lock() // if we already received the final offset for this stream, the peer won't need any additional flow control credit if c.receivedFinalOffset { + c.mutex.Unlock() return 0 } - oldWindowIncrement := c.receiveWindowIncrement + oldWindowSize := c.receiveWindowSize offset := c.baseFlowController.getWindowUpdate() - if c.receiveWindowIncrement > oldWindowIncrement { // auto-tuning enlarged the window increment - utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10)) + if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size + utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10)) if c.contributesToConnection { - c.connection.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(c.receiveWindowIncrement) * protocol.ConnectionFlowControlMultiplier)) + c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier)) } } + c.mutex.Unlock() return offset } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller_test.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller_test.go index 76c1e9d..a3ef9dc 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller_test.go @@ -19,7 +19,7 @@ var _ = Describe("Stream Flow controller", func() { streamID: 10, connection: NewConnectionFlowController(1000, 1000, rttStats).(*connectionFlowController), } - controller.maxReceiveWindowIncrement = 10000 + controller.maxReceiveWindowSize = 10000 controller.rttStats = rttStats }) @@ -35,7 +35,7 @@ var _ = Describe("Stream Flow controller", func() { fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, rttStats).(*streamFlowController) Expect(fc.streamID).To(Equal(protocol.StreamID(5))) Expect(fc.receiveWindow).To(Equal(receiveWindow)) - Expect(fc.maxReceiveWindowIncrement).To(Equal(maxReceiveWindow)) + Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow)) Expect(fc.sendWindow).To(Equal(sendWindow)) Expect(fc.contributesToConnection).To(BeTrue()) }) @@ -44,11 +44,11 @@ var _ = Describe("Stream Flow controller", func() { Context("receiving data", func() { Context("registering received offsets", func() { var receiveWindow protocol.ByteCount = 10000 - var receiveWindowIncrement protocol.ByteCount = 600 + var receiveWindowSize protocol.ByteCount = 600 BeforeEach(func() { controller.receiveWindow = receiveWindow - controller.receiveWindowIncrement = receiveWindowIncrement + controller.receiveWindowSize = receiveWindowSize }) It("updates the highestReceived", func() { @@ -157,7 +157,7 @@ var _ = Describe("Stream Flow controller", func() { }) Context("generating window updates", func() { - var oldIncrement protocol.ByteCount + var oldWindowSize protocol.ByteCount // update the congestion such that it returns a given value for the smoothed RTT setRtt := func(t time.Duration) { @@ -167,37 +167,51 @@ var _ = Describe("Stream Flow controller", func() { BeforeEach(func() { controller.receiveWindow = 100 - controller.receiveWindowIncrement = 60 - controller.connection.(*connectionFlowController).receiveWindowIncrement = 120 - oldIncrement = controller.receiveWindowIncrement + controller.receiveWindowSize = 60 + controller.bytesRead = 100 - 60 + controller.connection.(*connectionFlowController).receiveWindowSize = 120 + oldWindowSize = controller.receiveWindowSize + }) + + It("tells if it has window updates", func() { + Expect(controller.HasWindowUpdate()).To(BeFalse()) + controller.AddBytesRead(30) + Expect(controller.HasWindowUpdate()).To(BeTrue()) + Expect(controller.GetWindowUpdate()).ToNot(BeZero()) + Expect(controller.HasWindowUpdate()).To(BeFalse()) }) It("tells the connection flow controller when the window was autotuned", func() { + oldOffset := controller.bytesRead controller.contributesToConnection = true - controller.AddBytesRead(75) - setRtt(20 * time.Millisecond) - controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond) + setRtt(scaleDuration(20 * time.Millisecond)) + controller.epochStartOffset = oldOffset + controller.epochStartTime = time.Now().Add(-time.Millisecond) + controller.AddBytesRead(55) offset := controller.GetWindowUpdate() - Expect(offset).To(Equal(protocol.ByteCount(75 + 2*60))) - Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement)) - Expect(controller.connection.(*connectionFlowController).receiveWindowIncrement).To(Equal(protocol.ByteCount(float64(controller.receiveWindowIncrement) * protocol.ConnectionFlowControlMultiplier))) + Expect(offset).To(Equal(protocol.ByteCount(oldOffset + 55 + 2*oldWindowSize))) + Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize)) + Expect(controller.connection.(*connectionFlowController).receiveWindowSize).To(Equal(protocol.ByteCount(float64(controller.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))) }) It("doesn't tell the connection flow controller if it doesn't contribute", func() { + oldOffset := controller.bytesRead controller.contributesToConnection = false - controller.AddBytesRead(75) - setRtt(20 * time.Millisecond) - controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond) + setRtt(scaleDuration(20 * time.Millisecond)) + controller.epochStartOffset = oldOffset + controller.epochStartTime = time.Now().Add(-time.Millisecond) + controller.AddBytesRead(55) offset := controller.GetWindowUpdate() Expect(offset).ToNot(BeZero()) - Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement)) - Expect(controller.connection.(*connectionFlowController).receiveWindowIncrement).To(Equal(protocol.ByteCount(120))) // unchanged + Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize)) + Expect(controller.connection.(*connectionFlowController).receiveWindowSize).To(Equal(protocol.ByteCount(2 * oldWindowSize))) // unchanged }) It("doesn't increase the window after a final offset was already received", func() { - controller.AddBytesRead(80) + controller.AddBytesRead(30) err := controller.UpdateHighestReceived(90, true) Expect(err).ToNot(HaveOccurred()) + Expect(controller.HasWindowUpdate()).To(BeFalse()) offset := controller.GetWindowUpdate() Expect(offset).To(BeZero()) }) @@ -231,7 +245,8 @@ var _ = Describe("Stream Flow controller", func() { controller.connection.UpdateSendWindow(50) controller.UpdateSendWindow(100) controller.AddBytesSent(50) - Expect(controller.connection.IsBlocked()).To(BeTrue()) + blocked, _ := controller.connection.IsNewlyBlocked() + Expect(blocked).To(BeTrue()) Expect(controller.IsBlocked()).To(BeFalse()) }) }) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go index 2df6d6b..cb500b5 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go @@ -51,8 +51,8 @@ type cryptoSetupClient struct { secureAEAD crypto.AEAD forwardSecureAEAD crypto.AEAD - paramsChan chan<- TransportParameters - aeadChanged chan<- protocol.EncryptionLevel + paramsChan chan<- TransportParameters + handshakeEvent chan<- struct{} params *TransportParameters } @@ -74,7 +74,7 @@ func NewCryptoSetupClient( tlsConfig *tls.Config, params *TransportParameters, paramsChan chan<- TransportParameters, - aeadChanged chan<- protocol.EncryptionLevel, + handshakeEvent chan<- struct{}, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, ) (CryptoSetup, error) { @@ -93,7 +93,7 @@ func NewCryptoSetupClient( keyExchange: getEphermalKEX, nullAEAD: nullAEAD, paramsChan: paramsChan, - aeadChanged: aeadChanged, + handshakeEvent: handshakeEvent, initialVersion: initialVersion, negotiatedVersions: negotiatedVersions, divNonceChan: make(chan []byte), @@ -159,8 +159,8 @@ func (h *cryptoSetupClient) HandleCryptoStream() error { } // blocks until the session has received the parameters h.paramsChan <- *params - h.aeadChanged <- protocol.EncryptionForwardSecure - close(h.aeadChanged) + h.handshakeEvent <- struct{}{} + close(h.handshakeEvent) default: return qerr.InvalidCryptoMessageType } @@ -381,6 +381,15 @@ func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) { h.divNonceChan <- data } +func (h *cryptoSetupClient) ConnectionState() ConnectionState { + h.mutex.Lock() + defer h.mutex.Unlock() + return ConnectionState{ + HandshakeComplete: h.forwardSecureAEAD != nil, + PeerCertificates: h.certManager.GetChain(), + } +} + func (h *cryptoSetupClient) sendCHLO() error { h.clientHelloCounter++ if h.clientHelloCounter > protocol.MaxClientHellos { @@ -496,10 +505,8 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error { if err != nil { return err } - - h.aeadChanged <- protocol.EncryptionSecure + h.handshakeEvent <- struct{}{} } - return nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client_test.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client_test.go index 709acb6..695fef6 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client_test.go @@ -2,6 +2,7 @@ package handshake import ( "bytes" + "crypto/x509" "encoding/binary" "errors" "fmt" @@ -10,6 +11,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/mocks/crypto" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/qerr" . "github.com/onsi/ginkgo" @@ -34,6 +36,8 @@ type mockCertManager struct { commonCertificateHashes []byte + chain []*x509.Certificate + leafCert []byte leafCertHash uint64 leafCertHashError error @@ -45,6 +49,8 @@ type mockCertManager struct { verifyCalled bool } +var _ crypto.CertManager = &mockCertManager{} + func (m *mockCertManager) SetData(data []byte) error { m.setDataCalledWith = data return m.setDataError @@ -72,6 +78,10 @@ func (m *mockCertManager) Verify(hostname string) error { return m.verifyError } +func (m *mockCertManager) GetChain() []*x509.Certificate { + return m.chain +} + var _ = Describe("Client Crypto Setup", func() { var ( cs *cryptoSetupClient @@ -79,7 +89,7 @@ var _ = Describe("Client Crypto Setup", func() { stream *mockStream keyDerivationCalledWith *keyDerivationValues shloMap map[Tag][]byte - aeadChanged chan protocol.EncryptionLevel + handshakeEvent chan struct{} paramsChan chan TransportParameters ) @@ -108,7 +118,7 @@ var _ = Describe("Client Crypto Setup", func() { version := protocol.Version39 // use a buffered channel here, so that we can parse a SHLO without having to receive the TransportParameters to avoid blocking paramsChan = make(chan TransportParameters, 1) - aeadChanged = make(chan protocol.EncryptionLevel, 2) + handshakeEvent = make(chan struct{}, 2) csInt, err := NewCryptoSetupClient( stream, "hostname", @@ -117,7 +127,7 @@ var _ = Describe("Client Crypto Setup", func() { nil, &TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout}, paramsChan, - aeadChanged, + handshakeEvent, protocol.Version39, nil, ) @@ -130,10 +140,6 @@ var _ = Describe("Client Crypto Setup", func() { cs.cryptoStream = stream }) - AfterEach(func() { - close(stream.unblockRead) - }) - Context("Reading REJ", func() { var tagMap map[Tag][]byte @@ -158,8 +164,17 @@ var _ = Describe("Client Crypto Setup", func() { stk := []byte("foobar") tagMap[TagSTK] = stk HandshakeMessage{Tag: TagREJ, Data: tagMap}.Write(&stream.dataToRead) - go cs.HandleCryptoStream() + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := cs.HandleCryptoStream() + Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) + close(done) + }() Eventually(func() []byte { return cs.stk }).Should(Equal(stk)) + // make the go routine return + stream.close() + Eventually(done).Should(BeClosed()) }) It("saves the proof", func() { @@ -380,22 +395,22 @@ var _ = Describe("Client Crypto Setup", func() { cs.receivedSecurePacket = false _, err := cs.handleSHLOMessage(shloMap) Expect(err).To(MatchError(qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message"))) - Expect(aeadChanged).ToNot(Receive()) - Expect(aeadChanged).ToNot(BeClosed()) + Expect(handshakeEvent).ToNot(Receive()) + Expect(handshakeEvent).ToNot(BeClosed()) }) It("rejects SHLOs without a PUBS", func() { delete(shloMap, TagPUBS) _, err := cs.handleSHLOMessage(shloMap) Expect(err).To(MatchError(qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS"))) - Expect(aeadChanged).ToNot(BeClosed()) + Expect(handshakeEvent).ToNot(BeClosed()) }) It("rejects SHLOs without a version list", func() { delete(shloMap, TagVER) _, err := cs.handleSHLOMessage(shloMap) Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list"))) - Expect(aeadChanged).ToNot(BeClosed()) + Expect(handshakeEvent).ToNot(BeClosed()) }) It("accepts a SHLO after a version negotiation", func() { @@ -430,28 +445,38 @@ var _ = Describe("Client Crypto Setup", func() { Expect(params.IdleTimeout).To(Equal(13 * time.Second)) }) - It("closes the aeadChanged when receiving an SHLO", func() { + It("closes the handshakeEvent chan when receiving an SHLO", func() { HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead) + done := make(chan struct{}) go func() { defer GinkgoRecover() err := cs.HandleCryptoStream() - Expect(err).ToNot(HaveOccurred()) + Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) + close(done) }() - Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionForwardSecure))) - Eventually(aeadChanged).Should(BeClosed()) + Eventually(handshakeEvent).Should(Receive()) + Eventually(handshakeEvent).Should(BeClosed()) + // make the go routine return + stream.close() + Eventually(done).Should(BeClosed()) }) It("passes the transport parameters on the channel", func() { shloMap[TagSFCW] = []byte{0x0d, 0x00, 0xdf, 0xba} HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead) + done := make(chan struct{}) go func() { defer GinkgoRecover() err := cs.HandleCryptoStream() - Expect(err).ToNot(HaveOccurred()) + Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) + close(done) }() var params TransportParameters Eventually(paramsChan).Should(Receive(¶ms)) Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0xbadf000d))) + // make the go routine return + stream.close() + Eventually(done).Should(BeClosed()) }) It("errors if it can't read a connection parameter", func() { @@ -637,9 +662,9 @@ var _ = Describe("Client Crypto Setup", func() { Expect(keyDerivationCalledWith.cert).To(Equal(certManager.leafCert)) Expect(keyDerivationCalledWith.divNonce).To(Equal(cs.diversificationNonce)) Expect(keyDerivationCalledWith.pers).To(Equal(protocol.PerspectiveClient)) - Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) - Expect(aeadChanged).ToNot(Receive()) - Expect(aeadChanged).ToNot(BeClosed()) + Expect(handshakeEvent).To(Receive()) + Expect(handshakeEvent).ToNot(Receive()) + Expect(handshakeEvent).ToNot(BeClosed()) }) It("uses the server nonce, if the server sent one", func() { @@ -649,51 +674,64 @@ var _ = Describe("Client Crypto Setup", func() { Expect(err).ToNot(HaveOccurred()) Expect(cs.secureAEAD).ToNot(BeNil()) Expect(keyDerivationCalledWith.nonces).To(Equal(append(cs.nonc, cs.sno...))) - Expect(aeadChanged).To(Receive()) - Expect(aeadChanged).ToNot(Receive()) - Expect(aeadChanged).ToNot(BeClosed()) + Expect(handshakeEvent).To(Receive()) + Expect(handshakeEvent).ToNot(Receive()) + Expect(handshakeEvent).ToNot(BeClosed()) }) It("doesn't create a secureAEAD if the certificate is not yet verified, even if it has all necessary values", func() { err := cs.maybeUpgradeCrypto() Expect(err).ToNot(HaveOccurred()) Expect(cs.secureAEAD).To(BeNil()) - Expect(aeadChanged).ToNot(Receive()) + Expect(handshakeEvent).ToNot(Receive()) cs.serverVerified = true // make sure we really had all necessary values before, and only serverVerified was missing err = cs.maybeUpgradeCrypto() Expect(err).ToNot(HaveOccurred()) Expect(cs.secureAEAD).ToNot(BeNil()) - Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) - Expect(aeadChanged).ToNot(Receive()) - Expect(aeadChanged).ToNot(BeClosed()) + Expect(handshakeEvent).To(Receive()) + Expect(handshakeEvent).ToNot(Receive()) + Expect(handshakeEvent).ToNot(BeClosed()) }) It("tries to escalate before reading a handshake message", func() { Expect(cs.secureAEAD).To(BeNil()) cs.serverVerified = true - go cs.HandleCryptoStream() - Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure))) - Expect(cs.secureAEAD).ToNot(BeNil()) - Expect(aeadChanged).ToNot(Receive()) - Expect(aeadChanged).ToNot(BeClosed()) - }) - - It("tries to escalate the crypto after receiving a diversification nonce", func(done Done) { + done := make(chan struct{}) go func() { defer GinkgoRecover() - cs.HandleCryptoStream() - Fail("HandleCryptoStream should not have returned") + err := cs.HandleCryptoStream() + Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) + close(done) + }() + Eventually(handshakeEvent).Should(Receive()) + Expect(cs.secureAEAD).ToNot(BeNil()) + Expect(handshakeEvent).ToNot(Receive()) + Expect(handshakeEvent).ToNot(BeClosed()) + // make the go routine return + stream.close() + Eventually(done).Should(BeClosed()) + }) + + It("tries to escalate the crypto after receiving a diversification nonce", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := cs.HandleCryptoStream() + Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) + close(done) }() cs.diversificationNonce = nil cs.serverVerified = true Expect(cs.secureAEAD).To(BeNil()) cs.SetDiversificationNonce([]byte("div")) - Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure))) + Eventually(handshakeEvent).Should(Receive()) Expect(cs.secureAEAD).ToNot(BeNil()) - Expect(aeadChanged).ToNot(Receive()) - Expect(aeadChanged).ToNot(BeClosed()) - close(done) + Expect(handshakeEvent).ToNot(Receive()) + Expect(handshakeEvent).ToNot(BeClosed()) + // make the go routine return + stream.close() + Eventually(done).Should(BeClosed()) }) Context("null encryption", func() { @@ -813,6 +851,22 @@ var _ = Describe("Client Crypto Setup", func() { }) }) + Context("reporting the connection state", func() { + It("reports the connection state before the handshake completes", func() { + chain := []*x509.Certificate{testdata.GetCertificate().Leaf} + certManager.chain = chain + state := cs.ConnectionState() + Expect(state.HandshakeComplete).To(BeFalse()) + Expect(state.PeerCertificates).To(Equal(chain)) + }) + + It("reports the connection state after the handshake completes", func() { + doSHLO() + state := cs.ConnectionState() + Expect(state.HandshakeComplete).To(BeTrue()) + }) + }) + Context("forcing encryption levels", func() { It("forces null encryption", func() { cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(4), []byte{}).Return([]byte("foobar unencrypted")) @@ -862,32 +916,51 @@ var _ = Describe("Client Crypto Setup", func() { Context("Diversification Nonces", func() { It("sets a diversification nonce", func() { - go cs.HandleCryptoStream() + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := cs.HandleCryptoStream() + Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) + close(done) + }() nonce := []byte("foobar") cs.SetDiversificationNonce(nonce) Eventually(func() []byte { return cs.diversificationNonce }).Should(Equal(nonce)) + // make the go routine return + stream.close() + Eventually(done).Should(BeClosed()) }) - It("doesn't do anything when called multiple times with the same nonce", func(done Done) { - go cs.HandleCryptoStream() + It("doesn't do anything when called multiple times with the same nonce", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := cs.HandleCryptoStream() + Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) + close(done) + }() nonce := []byte("foobar") cs.SetDiversificationNonce(nonce) cs.SetDiversificationNonce(nonce) Eventually(func() []byte { return cs.diversificationNonce }).Should(Equal(nonce)) - close(done) + // make the go routine return + stream.close() + Eventually(done).Should(BeClosed()) }) It("rejects a different diversification nonce", func() { - var err error + done := make(chan struct{}) go func() { - err = cs.HandleCryptoStream() + defer GinkgoRecover() + err := cs.HandleCryptoStream() + Expect(err).To(MatchError(errConflictingDiversificationNonces)) + close(done) }() - nonce1 := []byte("foobar") nonce2 := []byte("raboof") cs.SetDiversificationNonce(nonce1) cs.SetDiversificationNonce(nonce2) - Eventually(func() error { return err }).Should(MatchError(errConflictingDiversificationNonces)) + Eventually(done).Should(BeClosed()) }) }) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go index 6ff11ab..7d5f32e 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go @@ -23,6 +23,8 @@ type KeyExchangeFunction func() crypto.KeyExchange // The CryptoSetupServer handles all things crypto for the Session type cryptoSetupServer struct { + mutex sync.RWMutex + connID protocol.ConnectionID remoteAddr net.Addr scfg *ServerConfig @@ -42,7 +44,7 @@ type cryptoSetupServer struct { receivedParams bool paramsChan chan<- TransportParameters - aeadChanged chan<- protocol.EncryptionLevel + handshakeEvent chan<- struct{} keyDerivation QuicCryptoKeyDerivationFunction keyExchange KeyExchangeFunction @@ -51,7 +53,7 @@ type cryptoSetupServer struct { params *TransportParameters - mutex sync.RWMutex + sni string // need to fill out the ConnectionState } var _ CryptoSetup = &cryptoSetupServer{} @@ -76,7 +78,7 @@ func NewCryptoSetup( supportedVersions []protocol.VersionNumber, acceptSTK func(net.Addr, *Cookie) bool, paramsChan chan<- TransportParameters, - aeadChanged chan<- protocol.EncryptionLevel, + handshakeEvent chan<- struct{}, ) (CryptoSetup, error) { nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version) if err != nil { @@ -96,7 +98,7 @@ func NewCryptoSetup( acceptSTKCallback: acceptSTK, sentSHLO: make(chan struct{}), paramsChan: paramsChan, - aeadChanged: aeadChanged, + handshakeEvent: handshakeEvent, }, nil } @@ -139,6 +141,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][] if sni == "" { return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required") } + h.sni = sni // prevent version downgrade attacks // see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/N-de9j63tCk for a discussion and examples @@ -182,7 +185,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][] if _, err := h.cryptoStream.Write(reply); err != nil { return false, err } - h.aeadChanged <- protocol.EncryptionForwardSecure + h.handshakeEvent <- struct{}{} close(h.sentSHLO) return true, nil } @@ -206,9 +209,9 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu if err == nil { if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client h.receivedForwardSecurePacket = true - // wait until protocol.EncryptionForwardSecure was sent on the aeadChan + // wait for the send on the handshakeEvent chan <-h.sentSHLO - close(h.aeadChanged) + close(h.handshakeEvent) } return res, protocol.EncryptionForwardSecure, nil } @@ -396,8 +399,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T if err != nil { return nil, err } - - h.aeadChanged <- protocol.EncryptionSecure + h.handshakeEvent <- struct{}{} // Generate a new curve instance to derive the forward secure key var fsNonce bytes.Buffer @@ -454,6 +456,15 @@ func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) { panic("not needed for cryptoSetupServer") } +func (h *cryptoSetupServer) ConnectionState() ConnectionState { + h.mutex.Lock() + defer h.mutex.Unlock() + return ConnectionState{ + ServerName: h.sni, + HandshakeComplete: h.receivedForwardSecurePacket, + } +} + func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error { if len(nonce) != 32 { return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length") diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server_test.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server_test.go index 99caded..9c855c0 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "errors" + "io" "net" "time" @@ -63,35 +64,36 @@ func mockQuicCryptoKeyDerivation(forwardSecure bool, sharedSecret, nonces []byte } type mockStream struct { - unblockRead chan struct{} // close this chan to unblock Read + unblockRead chan struct{} dataToRead bytes.Buffer dataWritten bytes.Buffer } +var _ io.ReadWriter = &mockStream{} + +var errMockStreamClosing = errors.New("mock stream closing") + func newMockStream() *mockStream { return &mockStream{unblockRead: make(chan struct{})} } +// call Close to make Read return func (s *mockStream) Read(p []byte) (int, error) { n, _ := s.dataToRead.Read(p) if n == 0 { // block if there's no data <-s.unblockRead + return 0, errMockStreamClosing } return n, nil // never return an EOF } -func (s *mockStream) ReadByte() (byte, error) { - return s.dataToRead.ReadByte() -} - func (s *mockStream) Write(p []byte) (int, error) { return s.dataWritten.Write(p) } -func (s *mockStream) Close() error { panic("not implemented") } -func (s *mockStream) Reset(error) { panic("not implemented") } -func (mockStream) CloseRemote(offset protocol.ByteCount) { panic("not implemented") } -func (s mockStream) StreamID() protocol.StreamID { panic("not implemented") } +func (s *mockStream) close() { + close(s.unblockRead) +} type mockCookieProtector struct { data []byte @@ -122,7 +124,7 @@ var _ = Describe("Server Crypto Setup", func() { cs *cryptoSetupServer stream *mockStream paramsChan chan TransportParameters - aeadChanged chan protocol.EncryptionLevel + handshakeEvent chan struct{} nonce32 []byte versionTag []byte validSTK []byte @@ -144,7 +146,7 @@ var _ = Describe("Server Crypto Setup", func() { // use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking paramsChan = make(chan TransportParameters, 1) - aeadChanged = make(chan protocol.EncryptionLevel, 2) + handshakeEvent = make(chan struct{}, 2) stream = newMockStream() kex = &mockKEX{} signer = &mockSigner{} @@ -168,7 +170,7 @@ var _ = Describe("Server Crypto Setup", func() { supportedVersions, nil, paramsChan, - aeadChanged, + handshakeEvent, ) Expect(err).NotTo(HaveOccurred()) cs = csInt.(*cryptoSetupServer) @@ -183,10 +185,6 @@ var _ = Describe("Server Crypto Setup", func() { cs.cryptoStream = stream }) - AfterEach(func() { - close(stream.unblockRead) - }) - Context("diversification nonce", func() { BeforeEach(func() { cs.secureAEAD = mockcrypto.NewMockAEAD(mockCtrl) @@ -345,10 +343,10 @@ var _ = Describe("Server Crypto Setup", func() { err := cs.HandleCryptoStream() Expect(err).NotTo(HaveOccurred()) Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ")) - Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) + Expect(handshakeEvent).To(Receive()) // for the switch to secure Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO")) - Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) - Expect(aeadChanged).ToNot(BeClosed()) + Expect(handshakeEvent).To(Receive()) // for the switch to forward secure + Expect(handshakeEvent).ToNot(BeClosed()) }) It("rejects client nonces that have the wrong length", func() { @@ -379,9 +377,9 @@ var _ = Describe("Server Crypto Setup", func() { Expect(err).NotTo(HaveOccurred()) Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO")) Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ")) - Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) - Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) - Expect(aeadChanged).ToNot(BeClosed()) + Expect(handshakeEvent).To(Receive()) // for the switch to secure + Expect(handshakeEvent).To(Receive()) // for the switch to forward secure + Expect(handshakeEvent).ToNot(BeClosed()) }) It("recognizes inchoate CHLOs missing SCID", func() { @@ -537,7 +535,7 @@ var _ = Describe("Server Crypto Setup", func() { TagKEXS: kexs, }) Expect(err).ToNot(HaveOccurred()) - Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) + Expect(handshakeEvent).To(Receive()) // for the switch to secure close(cs.sentSHLO) } @@ -659,7 +657,26 @@ var _ = Describe("Server Crypto Setup", func() { cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(200), []byte{}) _, _, err := cs.Open(nil, []byte("forward secure encrypted"), 200, []byte{}) Expect(err).ToNot(HaveOccurred()) - Expect(aeadChanged).To(BeClosed()) + Expect(handshakeEvent).To(BeClosed()) + }) + }) + + Context("reporting the connection state", func() { + It("reports before the handshake completes", func() { + cs.sni = "server name" + state := cs.ConnectionState() + Expect(state.HandshakeComplete).To(BeFalse()) + Expect(state.ServerName).To(Equal("server name")) + }) + + It("reports after the handshake completes", func() { + doCHLO() + // receive a forward secure packet + cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(11), []byte{}) + _, _, err := cs.Open(nil, []byte("forward secure encrypted"), 11, []byte{}) + Expect(err).ToNot(HaveOccurred()) + state := cs.ConnectionState() + Expect(state.HandshakeComplete).To(BeTrue()) }) }) @@ -723,6 +740,7 @@ var _ = Describe("Server Crypto Setup", func() { Expect(err).ToNot(HaveOccurred()) Expect(done).To(BeFalse()) Expect(stream.dataWritten.Bytes()).To(ContainSubstring(string(validSTK))) + Expect(cs.sni).To(Equal("foo")) }) It("works with proper STK", func() { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go index 041c0b4..54dfe1c 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go @@ -26,9 +26,9 @@ type cryptoSetupTLS struct { nullAEAD crypto.AEAD aead crypto.AEAD - tls MintTLS - cryptoStream *CryptoStreamConn - aeadChanged chan<- protocol.EncryptionLevel + tls MintTLS + cryptoStream *CryptoStreamConn + handshakeEvent chan<- struct{} } // NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server @@ -36,16 +36,16 @@ func NewCryptoSetupTLSServer( tls MintTLS, cryptoStream *CryptoStreamConn, nullAEAD crypto.AEAD, - aeadChanged chan<- protocol.EncryptionLevel, + handshakeEvent chan<- struct{}, version protocol.VersionNumber, ) CryptoSetup { return &cryptoSetupTLS{ - tls: tls, - cryptoStream: cryptoStream, - nullAEAD: nullAEAD, - perspective: protocol.PerspectiveServer, - keyDerivation: crypto.DeriveAESKeys, - aeadChanged: aeadChanged, + tls: tls, + cryptoStream: cryptoStream, + nullAEAD: nullAEAD, + perspective: protocol.PerspectiveServer, + keyDerivation: crypto.DeriveAESKeys, + handshakeEvent: handshakeEvent, } } @@ -54,7 +54,7 @@ func NewCryptoSetupTLSClient( cryptoStream io.ReadWriter, connID protocol.ConnectionID, hostname string, - aeadChanged chan<- protocol.EncryptionLevel, + handshakeEvent chan<- struct{}, tls MintTLS, version protocol.VersionNumber, ) (CryptoSetup, error) { @@ -64,11 +64,11 @@ func NewCryptoSetupTLSClient( } return &cryptoSetupTLS{ - perspective: protocol.PerspectiveClient, - tls: tls, - nullAEAD: nullAEAD, - keyDerivation: crypto.DeriveAESKeys, - aeadChanged: aeadChanged, + perspective: protocol.PerspectiveClient, + tls: tls, + nullAEAD: nullAEAD, + keyDerivation: crypto.DeriveAESKeys, + handshakeEvent: handshakeEvent, }, nil } @@ -102,9 +102,8 @@ handshakeLoop: h.aead = aead h.mutex.Unlock() - // signal to the outside world that the handshake completed - h.aeadChanged <- protocol.EncryptionForwardSecure - close(h.aeadChanged) + h.handshakeEvent <- struct{}{} + close(h.handshakeEvent) return nil } @@ -165,3 +164,14 @@ func (h *cryptoSetupTLS) DiversificationNonce() []byte { func (h *cryptoSetupTLS) SetDiversificationNonce([]byte) { panic("diversification nonce not needed for TLS") } + +func (h *cryptoSetupTLS) ConnectionState() ConnectionState { + h.mutex.Lock() + defer h.mutex.Unlock() + mintConnState := h.tls.ConnectionState() + return ConnectionState{ + // TODO: set the ServerName, once mint exports it + HandshakeComplete: h.aead != nil, + PeerCertificates: mintConnState.PeerCertificates, + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls_test.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls_test.go index 03b486e..f0293ba 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls_test.go @@ -20,17 +20,17 @@ func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, e var _ = Describe("TLS Crypto Setup", func() { var ( - cs *cryptoSetupTLS - aeadChanged chan protocol.EncryptionLevel + cs *cryptoSetupTLS + handshakeEvent chan struct{} ) BeforeEach(func() { - aeadChanged = make(chan protocol.EncryptionLevel, 2) + handshakeEvent = make(chan struct{}, 2) cs = NewCryptoSetupTLSServer( nil, NewCryptoStreamConn(nil), nil, // AEAD - aeadChanged, + handshakeEvent, protocol.VersionTLS, ).(*cryptoSetupTLS) cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl) @@ -51,8 +51,8 @@ var _ = Describe("TLS Crypto Setup", func() { cs.keyDerivation = mockKeyDerivation err := cs.HandleCryptoStream() Expect(err).ToNot(HaveOccurred()) - Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) - Expect(aeadChanged).To(BeClosed()) + Expect(handshakeEvent).To(Receive()) + Expect(handshakeEvent).To(BeClosed()) }) It("handshakes until it is connected", func() { @@ -63,7 +63,30 @@ var _ = Describe("TLS Crypto Setup", func() { cs.keyDerivation = mockKeyDerivation err := cs.HandleCryptoStream() Expect(err).ToNot(HaveOccurred()) - Expect(aeadChanged).To(Receive()) + Expect(handshakeEvent).To(Receive()) + }) + + Context("reporting the handshake state", func() { + It("reports before the handshake compeletes", func() { + cs.tls = mockhandshake.NewMockMintTLS(mockCtrl) + cs.tls.(*mockhandshake.MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{}) + state := cs.ConnectionState() + Expect(state.HandshakeComplete).To(BeFalse()) + Expect(state.PeerCertificates).To(BeNil()) + }) + + It("reports after the handshake completes", func() { + cs.tls = mockhandshake.NewMockMintTLS(mockCtrl) + cs.tls.(*mockhandshake.MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{}) + cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert) + cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected) + cs.keyDerivation = mockKeyDerivation + err := cs.HandleCryptoStream() + Expect(err).ToNot(HaveOccurred()) + state := cs.ConnectionState() + Expect(state.HandshakeComplete).To(BeTrue()) + Expect(state.PeerCertificates).To(BeNil()) + }) }) Context("escalating crypto", func() { @@ -180,17 +203,17 @@ var _ = Describe("TLS Crypto Setup", func() { var _ = Describe("TLS Crypto Setup, for the client", func() { var ( - cs *cryptoSetupTLS - aeadChanged chan protocol.EncryptionLevel + cs *cryptoSetupTLS + handshakeEvent chan struct{} ) BeforeEach(func() { - aeadChanged = make(chan protocol.EncryptionLevel, 2) + handshakeEvent = make(chan struct{}) csInt, err := NewCryptoSetupTLSClient( nil, 0, "quic.clemente.io", - aeadChanged, + handshakeEvent, nil, // mintTLS protocol.VersionTLS, ) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go index fbb7006..34b9553 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go @@ -1,6 +1,7 @@ package handshake import ( + "crypto/x509" "io" "github.com/bifurcation/mint" @@ -29,6 +30,7 @@ type MintTLS interface { // additional methods Handshake() mint.Alert State() mint.State + ConnectionState() mint.ConnectionState SetCryptoStream(io.ReadWriter) SetExtensionHandler(mint.AppExtensionHandler) error @@ -41,8 +43,17 @@ type CryptoSetup interface { // TODO: clean up this interface DiversificationNonce() []byte // only needed for cryptoSetupServer SetDiversificationNonce([]byte) // only needed for cryptoSetupClient + ConnectionState() ConnectionState GetSealer() (protocol.EncryptionLevel, Sealer) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) } + +// ConnectionState records basic details about the QUIC connection. +// Warning: This API should not be considered stable and might change soon. +type ConnectionState struct { + HandshakeComplete bool // handshake is complete + ServerName string // server name requested by client, if any (server side only) + PeerCertificates []*x509.Certificate // certificate chain presented by remote peer +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go index 6d64d39..20d2d06 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go @@ -24,6 +24,7 @@ type extensionHandlerClient struct { var _ mint.AppExtensionHandler = &extensionHandlerClient{} var _ TLSExtensionHandler = &extensionHandlerClient{} +// NewExtensionHandlerClient creates a new extension handler for the client. func NewExtensionHandlerClient( params *TransportParameters, initialVersion protocol.VersionNumber, @@ -57,7 +58,10 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error { ext := &tlsExtensionBody{} - found, _ := el.Find(ext) + found, err := el.Find(ext) + if err != nil { + return err + } if hType != mint.HandshakeTypeEncryptedExtensions && hType != mint.HandshakeTypeNewSessionTicket { if found { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client_test.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client_test.go index 52822b8..05cfae5 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client_test.go @@ -39,7 +39,8 @@ var _ = Describe("TLS Extension Handler, for the client", func() { Expect(err).ToNot(HaveOccurred()) Expect(el).To(HaveLen(1)) ext := &tlsExtensionBody{} - found := el.Find(ext) + found, err := el.Find(ext) + Expect(err).ToNot(HaveOccurred()) Expect(found).To(BeTrue()) chtp := &clientHelloTransportParameters{} _, err = syntax.Unmarshal(ext.data, chtp) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go index b1e157a..313751c 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go @@ -24,6 +24,7 @@ type extensionHandlerServer struct { var _ mint.AppExtensionHandler = &extensionHandlerServer{} var _ TLSExtensionHandler = &extensionHandlerServer{} +// NewExtensionHandlerServer creates a new extension handler for the server func NewExtensionHandlerServer( params *TransportParameters, supportedVersions []protocol.VersionNumber, @@ -66,7 +67,10 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error { ext := &tlsExtensionBody{} - found, _ := el.Find(ext) + found, err := el.Find(ext) + if err != nil { + return err + } if hType != mint.HandshakeTypeClientHello { if found { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server_test.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server_test.go index ceab29b..8bb8d10 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server_test.go @@ -48,7 +48,8 @@ var _ = Describe("TLS Extension Handler, for the server", func() { Expect(err).ToNot(HaveOccurred()) Expect(el).To(HaveLen(1)) ext := &tlsExtensionBody{} - found := el.Find(ext) + found, err := el.Find(ext) + Expect(err).ToNot(HaveOccurred()) Expect(found).To(BeTrue()) eetp := &encryptedExtensionsTransportParameters{} _, err = syntax.Unmarshal(ext.data, eetp) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go index 4701d7d..1622983 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go @@ -64,6 +64,9 @@ type ByteCount uint64 // MaxByteCount is the maximum value of a ByteCount const MaxByteCount = ByteCount(1<<62 - 1) +// An ApplicationErrorCode is an application-defined error code. +type ApplicationErrorCode uint16 + // MaxReceivePacketSize maximum packet size of any QUIC packet, based on // ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header, // UDP adds an additional 8 bytes. This is a total overhead of 48 bytes. diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go index 2846566..7886482 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go @@ -56,6 +56,9 @@ const DefaultMaxReceiveConnectionFlowControlWindowClient = 15 * (1 << 20) // 15 // This is the value that Chromium is using const ConnectionFlowControlMultiplier = 1.5 +// WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client +const WindowUpdateThreshold = 0.25 + // MaxIncomingStreams is the maximum number of streams that a peer may open const MaxIncomingStreams = 100 @@ -122,3 +125,9 @@ const ClosedSessionDeleteTimeout = time.Minute // NumCachedCertificates is the number of cached compressed certificate chains, each taking ~1K space const NumCachedCertificates = 128 + +// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame. +// This avoids splitting up STREAM frames into small pieces, which has 2 advantages: +// 1. it reduces the framing overhead +// 2. it reduces the head-of-line blocking, when a packet is lost +const MinStreamFrameSize ByteCount = 128 diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go index 5f0bc97..4f37b0a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go @@ -139,7 +139,7 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error } // MinLength of a written frame -func (f *AckFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { +func (f *AckFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount { if !version.UsesIETFFrameFormat() { return f.minLengthLegacy(version) } @@ -157,7 +157,7 @@ func (f *AckFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount length += utils.VarIntLen(uint64(f.LargestAcked - lowestInFirstRange)) if !f.HasMissingRanges() { - return length, nil + return length } var lowest protocol.PacketNumber for i, ackRange := range f.AckRanges { @@ -169,7 +169,7 @@ func (f *AckFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount length += utils.VarIntLen(uint64(ackRange.Last - ackRange.First)) lowest = ackRange.First } - return length, nil + return length } // HasMissingRanges returns if this frame reports any missing packets diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame_legacy.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame_legacy.go index 3bef540..42eaf24 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame_legacy.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame_legacy.go @@ -308,7 +308,7 @@ func (f *AckFrame) writeLegacy(b *bytes.Buffer, _ protocol.VersionNumber) error return nil } -func (f *AckFrame) minLengthLegacy(_ protocol.VersionNumber) (protocol.ByteCount, error) { +func (f *AckFrame) minLengthLegacy(_ protocol.VersionNumber) protocol.ByteCount { length := protocol.ByteCount(1 + 2 + 1) // 1 TypeByte, 2 ACK delay time, 1 Num Timestamp length += protocol.ByteCount(protocol.GetPacketNumberLength(f.LargestAcked)) @@ -320,7 +320,7 @@ func (f *AckFrame) minLengthLegacy(_ protocol.VersionNumber) (protocol.ByteCount length += missingSequenceNumberDeltaLen } // we don't write - return length, nil + return length } // numWritableNackRanges calculates the number of ACK blocks that are about to be written diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame.go index cc6a016..04dd29d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame.go @@ -4,17 +4,26 @@ import ( "bytes" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" ) // A BlockedFrame is a BLOCKED frame -type BlockedFrame struct{} +type BlockedFrame struct { + Offset protocol.ByteCount +} // ParseBlockedFrame parses a BLOCKED frame -func ParseBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*BlockedFrame, error) { +func ParseBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*BlockedFrame, error) { if _, err := r.ReadByte(); err != nil { return nil, err } - return &BlockedFrame{}, nil + offset, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + return &BlockedFrame{ + Offset: protocol.ByteCount(offset), + }, nil } func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { @@ -23,13 +32,14 @@ func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) er } typeByte := uint8(0x08) b.WriteByte(typeByte) + utils.WriteVarInt(b, uint64(f.Offset)) return nil } // MinLength of a written frame -func (f *BlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - if !version.UsesIETFFrameFormat() { // writing this frame would result in a legacy BLOCKED being written, which is longer - return 1 + 4, nil +func (f *BlockedFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount { + if !version.UsesIETFFrameFormat() { + return 1 + 4 } - return 1, nil + return 1 + utils.VarIntLen(uint64(f.Offset)) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame_test.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame_test.go index 9a3e2dd..ce58820 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame_test.go @@ -2,8 +2,10 @@ package wire import ( "bytes" + "io" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -12,30 +14,41 @@ import ( var _ = Describe("BLOCKED frame", func() { Context("when parsing", func() { It("accepts sample frame", func() { - b := bytes.NewReader([]byte{0x08}) - _, err := ParseBlockedFrame(b, protocol.VersionWhatever) + data := []byte{0x08} + data = append(data, encodeVarInt(0x12345678)...) + b := bytes.NewReader(data) + frame, err := ParseBlockedFrame(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) + Expect(frame.Offset).To(Equal(protocol.ByteCount(0x12345678))) Expect(b.Len()).To(BeZero()) }) It("errors on EOFs", func() { - _, err := ParseBlockedFrame(bytes.NewReader(nil), protocol.VersionWhatever) - Expect(err).To(HaveOccurred()) + data := []byte{0x08} + data = append(data, encodeVarInt(0x12345678)...) + _, err := ParseBlockedFrame(bytes.NewReader(data), versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + for i := range data { + _, err := ParseBlockedFrame(bytes.NewReader(data[:i]), versionIETFFrames) + Expect(err).To(MatchError(io.EOF)) + } }) }) Context("when writing", func() { It("writes a sample frame", func() { b := &bytes.Buffer{} - frame := BlockedFrame{} + frame := BlockedFrame{Offset: 0xdeadbeef} err := frame.Write(b, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()).To(Equal([]byte{0x08})) + expected := []byte{0x08} + expected = append(expected, encodeVarInt(0xdeadbeef)...) + Expect(b.Bytes()).To(Equal(expected)) }) It("has the correct min length", func() { - frame := BlockedFrame{} - Expect(frame.MinLength(versionIETFFrames)).To(Equal(protocol.ByteCount(1))) + frame := BlockedFrame{Offset: 0x12345} + Expect(frame.MinLength(versionIETFFrames)).To(Equal(1 + utils.VarIntLen(0x12345))) }) }) }) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go index 2cad865..ccc3a71 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go @@ -68,11 +68,11 @@ func ParseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber) } // MinLength of a written frame -func (f *ConnectionCloseFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { +func (f *ConnectionCloseFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount { if version.UsesIETFFrameFormat() { - return 1 + 2 + utils.VarIntLen(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase)), nil + return 1 + 2 + utils.VarIntLen(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase)) } - return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase)), nil + return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase)) } // Write writes an CONNECTION_CLOSE frame. diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame.go index f31f5bf..d9f0cea 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame.go @@ -9,5 +9,5 @@ import ( // A Frame in QUIC type Frame interface { Write(b *bytes.Buffer, version protocol.VersionNumber) error - MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) + MinLength(version protocol.VersionNumber) protocol.ByteCount } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/goaway_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/goaway_frame.go index 44a613c..fa5585a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/goaway_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/goaway_frame.go @@ -63,6 +63,6 @@ func (f *GoawayFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { } // MinLength of a written frame -func (f *GoawayFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - return protocol.ByteCount(1 + 4 + 4 + 2 + len(f.ReasonPhrase)), nil +func (f *GoawayFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount { + return protocol.ByteCount(1 + 4 + 4 + 2 + len(f.ReasonPhrase)) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go index 19585bc..945d11a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go @@ -43,9 +43,9 @@ func (f *MaxDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) er } // MinLength of a written frame -func (f *MaxDataFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { +func (f *MaxDataFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount { if !version.UsesIETFFrameFormat() { // writing this frame would result in a gQUIC WINDOW_UPDATE being written, which is longer - return 1 + 4 + 8, nil + return 1 + 4 + 8 } - return 1 + utils.VarIntLen(uint64(f.ByteOffset)), nil + return 1 + utils.VarIntLen(uint64(f.ByteOffset)) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go index 810dc92..5488876 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go @@ -51,10 +51,10 @@ func (f *MaxStreamDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumb } // MinLength of a written frame -func (f *MaxStreamDataFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { +func (f *MaxStreamDataFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount { // writing this frame would result in a gQUIC WINDOW_UPDATE being written, which has a different length if !version.UsesIETFFrameFormat() { - return 1 + 4 + 8, nil + return 1 + 4 + 8 } - return 1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.ByteOffset)), nil + return 1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.ByteOffset)) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_id_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_id_frame.go new file mode 100644 index 0000000..6d1aeae --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_id_frame.go @@ -0,0 +1,37 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A MaxStreamIDFrame is a MAX_STREAM_ID frame +type MaxStreamIDFrame struct { + StreamID protocol.StreamID +} + +// ParseMaxStreamIDFrame parses a MAX_STREAM_ID frame +func ParseMaxStreamIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamIDFrame, error) { + // read the Type byte + if _, err := r.ReadByte(); err != nil { + return nil, err + } + streamID, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + return &MaxStreamIDFrame{StreamID: protocol.StreamID(streamID)}, nil +} + +func (f *MaxStreamIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x6) + utils.WriteVarInt(b, uint64(f.StreamID)) + return nil +} + +// MinLength of a written frame +func (f *MaxStreamIDFrame) MinLength(protocol.VersionNumber) protocol.ByteCount { + return 1 + utils.VarIntLen(uint64(f.StreamID)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_id_frame_test.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_id_frame_test.go new file mode 100644 index 0000000..33a70bd --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_id_frame_test.go @@ -0,0 +1,51 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("MAX_STREAM_ID frame", func() { + Context("parsing", func() { + It("accepts sample frame", func() { + data := []byte{0x6} + data = append(data, encodeVarInt(0xdecafbad)...) + b := bytes.NewReader(data) + f, err := ParseMaxStreamIDFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(f.StreamID).To(Equal(protocol.StreamID(0xdecafbad))) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x06} + data = append(data, encodeVarInt(0xdeadbeefcafe13)...) + _, err := ParseMaxStreamIDFrame(bytes.NewReader(data), protocol.VersionWhatever) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := ParseMaxStreamIDFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("writing", func() { + It("writes a sample frame", func() { + b := &bytes.Buffer{} + frame := MaxStreamIDFrame{StreamID: 0x12345678} + frame.Write(b, protocol.VersionWhatever) + expected := []byte{0x6} + expected = append(expected, encodeVarInt(0x12345678)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has the correct min length", func() { + frame := MaxStreamIDFrame{StreamID: 0x1337} + Expect(frame.MinLength(protocol.VersionWhatever)).To(Equal(1 + utils.VarIntLen(0x1337))) + }) + }) +}) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go index 2a09c33..c7fdc40 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go @@ -28,6 +28,6 @@ func (f *PingFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error } // MinLength of a written frame -func (f *PingFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - return 1, nil +func (f *PingFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount { + return 1 } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame.go index 05a4cad..3f65a63 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame.go @@ -7,10 +7,12 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" ) -// A RstStreamFrame in QUIC +// A RstStreamFrame is a RST_STREAM frame in QUIC type RstStreamFrame struct { - StreamID protocol.StreamID - ErrorCode uint32 + StreamID protocol.StreamID + // The error code is a uint32 in gQUIC, but a uint16 in IETF QUIC. + // protocol.ApplicaitonErrorCode is a uint16, so larger values in gQUIC frames will be truncated. + ErrorCode protocol.ApplicationErrorCode ByteOffset protocol.ByteCount } @@ -21,7 +23,7 @@ func ParseRstStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*RstS } var streamID protocol.StreamID - var errorCode uint32 + var errorCode uint16 var byteOffset protocol.ByteCount if version.UsesIETFFrameFormat() { sid, err := utils.ReadVarInt(r) @@ -29,11 +31,10 @@ func ParseRstStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*RstS return nil, err } streamID = protocol.StreamID(sid) - ec, err := utils.BigEndian.ReadUint16(r) + errorCode, err = utils.BigEndian.ReadUint16(r) if err != nil { return nil, err } - errorCode = uint32(ec) bo, err := utils.ReadVarInt(r) if err != nil { return nil, err @@ -54,12 +55,12 @@ func ParseRstStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*RstS if err != nil { return nil, err } - errorCode = uint32(ec) + errorCode = uint16(ec) } return &RstStreamFrame{ StreamID: streamID, - ErrorCode: errorCode, + ErrorCode: protocol.ApplicationErrorCode(errorCode), ByteOffset: byteOffset, }, nil } @@ -74,15 +75,15 @@ func (f *RstStreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) } else { utils.BigEndian.WriteUint32(b, uint32(f.StreamID)) utils.BigEndian.WriteUint64(b, uint64(f.ByteOffset)) - utils.BigEndian.WriteUint32(b, f.ErrorCode) + utils.BigEndian.WriteUint32(b, uint32(f.ErrorCode)) } return nil } // MinLength of a written frame -func (f *RstStreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { +func (f *RstStreamFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount { if version.UsesIETFFrameFormat() { - return 1 + utils.VarIntLen(uint64(f.StreamID)) + 2 + utils.VarIntLen(uint64(f.ByteOffset)), nil + return 1 + utils.VarIntLen(uint64(f.StreamID)) + 2 + utils.VarIntLen(uint64(f.ByteOffset)) } - return 1 + 4 + 8 + 4, nil + return 1 + 4 + 8 + 4 } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame_test.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame_test.go index 380c055..207e337 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame_test.go @@ -22,7 +22,7 @@ var _ = Describe("RST_STREAM frame", func() { Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) Expect(frame.ByteOffset).To(Equal(protocol.ByteCount(0x987654321))) - Expect(frame.ErrorCode).To(Equal(uint32(0x1337))) + Expect(frame.ErrorCode).To(Equal(protocol.ApplicationErrorCode(0x1337))) }) It("errors on EOFs", func() { @@ -44,13 +44,13 @@ var _ = Describe("RST_STREAM frame", func() { b := bytes.NewReader([]byte{0x1, 0xde, 0xad, 0xbe, 0xef, // stream id 0x88, 0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11, // byte offset - 0x34, 0x12, 0x37, 0x13, // error code + 0x0, 0x0, 0xca, 0xfe, // error code }) frame, err := ParseRstStreamFrame(b, versionBigEndian) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) Expect(frame.ByteOffset).To(Equal(protocol.ByteCount(0x8877665544332211))) - Expect(frame.ErrorCode).To(Equal(uint32(0x34123713))) + Expect(frame.ErrorCode).To(Equal(protocol.ApplicationErrorCode(0xcafe))) }) It("errors on EOFs", func() { @@ -103,7 +103,7 @@ var _ = Describe("RST_STREAM frame", func() { frame := RstStreamFrame{ StreamID: 0x1337, ByteOffset: 0x11223344decafbad, - ErrorCode: 0xdeadbeef, + ErrorCode: 0xcafe, } b := &bytes.Buffer{} err := frame.Write(b, versionBigEndian) @@ -111,7 +111,7 @@ var _ = Describe("RST_STREAM frame", func() { Expect(b.Bytes()).To(Equal([]byte{0x01, 0x0, 0x0, 0x13, 0x37, // stream id 0x11, 0x22, 0x33, 0x44, 0xde, 0xca, 0xfb, 0xad, // byte offset - 0xde, 0xad, 0xbe, 0xef, // error code + 0x0, 0x0, 0xca, 0xfe, // error code })) }) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame.go new file mode 100644 index 0000000..4cbbce9 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame.go @@ -0,0 +1,47 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A StopSendingFrame is a STOP_SENDING frame +type StopSendingFrame struct { + StreamID protocol.StreamID + ErrorCode protocol.ApplicationErrorCode +} + +// ParseStopSendingFrame parses a STOP_SENDING frame +func ParseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSendingFrame, error) { + if _, err := r.ReadByte(); err != nil { // read the TypeByte + return nil, err + } + + streamID, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + errorCode, err := utils.BigEndian.ReadUint16(r) + if err != nil { + return nil, err + } + + return &StopSendingFrame{ + StreamID: protocol.StreamID(streamID), + ErrorCode: protocol.ApplicationErrorCode(errorCode), + }, nil +} + +// MinLength of a written frame +func (f *StopSendingFrame) MinLength(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + utils.VarIntLen(uint64(f.StreamID)) + 2 +} + +func (f *StopSendingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x0c) + utils.WriteVarInt(b, uint64(f.StreamID)) + utils.BigEndian.WriteUint16(b, uint16(f.ErrorCode)) + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame_test.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame_test.go new file mode 100644 index 0000000..ab942a0 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame_test.go @@ -0,0 +1,63 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("STOP_SENDING frame", func() { + Context("when parsing", func() { + It("parses a sample frame", func() { + data := []byte{0x0c} + data = append(data, encodeVarInt(0xdecafbad)...) // stream ID + data = append(data, []byte{0x13, 0x37}...) // error code + b := bytes.NewReader(data) + frame, err := ParseStopSendingFrame(b, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdecafbad))) + Expect(frame.ErrorCode).To(Equal(protocol.ApplicationErrorCode(0x1337))) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x0c} + data = append(data, encodeVarInt(0xdecafbad)...) // stream ID + data = append(data, []byte{0x13, 0x37}...) // error code + _, err := ParseStopSendingFrame(bytes.NewReader(data), versionIETFFrames) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := ParseStopSendingFrame(bytes.NewReader(data[:i]), versionIETFFrames) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("when writing", func() { + It("writes", func() { + frame := &StopSendingFrame{ + StreamID: 0xdeadbeefcafe, + ErrorCode: 0x10, + } + buf := &bytes.Buffer{} + err := frame.Write(buf, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x0c} + expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) + expected = append(expected, []byte{0x0, 0x10}...) + Expect(buf.Bytes()).To(Equal(expected)) + }) + + It("has the correct min length", func() { + frame := &StopSendingFrame{ + StreamID: 0xdeadbeef, + ErrorCode: 0x10, + } + Expect(frame.MinLength(versionIETFFrames)).To(Equal(1 + 2 + utils.VarIntLen(0xdeadbeef))) + }) + }) +}) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_waiting_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_waiting_frame.go index 1f46688..48fbd44 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_waiting_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_waiting_frame.go @@ -22,7 +22,10 @@ var ( errPacketNumberLenNotSet = errors.New("StopWaitingFrame: PacketNumberLen not set") ) -func (f *StopWaitingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { +func (f *StopWaitingFrame) Write(b *bytes.Buffer, v protocol.VersionNumber) error { + if v.UsesIETFFrameFormat() { + return errors.New("STOP_WAITING not defined in IETF QUIC") + } // make sure the PacketNumber was set if f.PacketNumber == protocol.PacketNumber(0) { return errPacketNumberNotSet @@ -49,14 +52,8 @@ func (f *StopWaitingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) erro } // MinLength of a written frame -func (f *StopWaitingFrame) MinLength(_ protocol.VersionNumber) (protocol.ByteCount, error) { - minLength := protocol.ByteCount(1) // typeByte - - if f.PacketNumberLen == protocol.PacketNumberLenInvalid { - return 0, errPacketNumberLenNotSet - } - minLength += protocol.ByteCount(f.PacketNumberLen) - return minLength, nil +func (f *StopWaitingFrame) MinLength(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + protocol.ByteCount(f.PacketNumberLen) } // ParseStopWaitingFrame parses a StopWaiting frame diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_waiting_frame_test.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_waiting_frame_test.go index ec22c30..a46ddd9 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_waiting_frame_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_waiting_frame_test.go @@ -84,7 +84,7 @@ var _ = Describe("StopWaitingFrame", func() { LeastUnacked: 10, PacketNumberLen: protocol.PacketNumberLen1, } - err := frame.Write(b, protocol.VersionWhatever) + err := frame.Write(b, versionBigEndian) Expect(err).To(MatchError(errPacketNumberNotSet)) }) @@ -94,7 +94,7 @@ var _ = Describe("StopWaitingFrame", func() { LeastUnacked: 10, PacketNumber: 13, } - err := frame.Write(b, protocol.VersionWhatever) + err := frame.Write(b, versionBigEndian) Expect(err).To(MatchError(errPacketNumberLenNotSet)) }) @@ -105,10 +105,21 @@ var _ = Describe("StopWaitingFrame", func() { PacketNumber: 5, PacketNumberLen: protocol.PacketNumberLen1, } - err := frame.Write(b, protocol.VersionWhatever) + err := frame.Write(b, versionBigEndian) Expect(err).To(MatchError(errLeastUnackedHigherThanPacketNumber)) }) + It("refuses to write for IETF QUIC", func() { + b := &bytes.Buffer{} + frame := &StopWaitingFrame{ + LeastUnacked: 10, + PacketNumber: 13, + PacketNumberLen: protocol.PacketNumberLen6, + } + err := frame.Write(b, versionIETFFrames) + Expect(err).To(MatchError("STOP_WAITING not defined in IETF QUIC")) + }) + Context("LeastUnackedDelta length", func() { Context("in big endian", func() { It("writes a 1-byte LeastUnackedDelta", func() { @@ -176,18 +187,10 @@ var _ = Describe("StopWaitingFrame", func() { Expect(frame.MinLength(protocol.VersionWhatever)).To(Equal(protocol.ByteCount(length + 1))) } }) - - It("errors when packetNumberLen is not set", func() { - frame := &StopWaitingFrame{ - LeastUnacked: 10, - } - _, err := frame.MinLength(0) - Expect(err).To(MatchError(errPacketNumberLenNotSet)) - }) }) Context("self consistency", func() { - It("reads a stop waiting frame that it wrote", func() { + It("reads a STOP_WAITING frame that it wrote", func() { packetNumber := protocol.PacketNumber(13) frame := &StopWaitingFrame{ LeastUnacked: 10, @@ -195,9 +198,9 @@ var _ = Describe("StopWaitingFrame", func() { PacketNumberLen: protocol.PacketNumberLen4, } b := &bytes.Buffer{} - err := frame.Write(b, protocol.VersionWhatever) + err := frame.Write(b, versionBigEndian) Expect(err).ToNot(HaveOccurred()) - readframe, err := ParseStopWaitingFrame(bytes.NewReader(b.Bytes()), packetNumber, protocol.PacketNumberLen4, protocol.VersionWhatever) + readframe, err := ParseStopWaitingFrame(bytes.NewReader(b.Bytes()), packetNumber, protocol.PacketNumberLen4, versionBigEndian) Expect(err).ToNot(HaveOccurred()) Expect(readframe.LeastUnacked).To(Equal(frame.LeastUnacked)) }) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame.go index 510de50..b67bd24 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame.go @@ -10,10 +10,11 @@ import ( // A StreamBlockedFrame in QUIC type StreamBlockedFrame struct { StreamID protocol.StreamID + Offset protocol.ByteCount } // ParseStreamBlockedFrame parses a STREAM_BLOCKED frame -func ParseStreamBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamBlockedFrame, error) { +func ParseStreamBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamBlockedFrame, error) { if _, err := r.ReadByte(); err != nil { // read the TypeByte return nil, err } @@ -21,7 +22,14 @@ func ParseStreamBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (* if err != nil { return nil, err } - return &StreamBlockedFrame{StreamID: protocol.StreamID(sid)}, nil + offset, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + return &StreamBlockedFrame{ + StreamID: protocol.StreamID(sid), + Offset: protocol.ByteCount(offset), + }, nil } // Write writes a STREAM_BLOCKED frame @@ -31,13 +39,14 @@ func (f *StreamBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumb } b.WriteByte(0x09) utils.WriteVarInt(b, uint64(f.StreamID)) + utils.WriteVarInt(b, uint64(f.Offset)) return nil } // MinLength of a written frame -func (f *StreamBlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { +func (f *StreamBlockedFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount { if !version.UsesIETFFrameFormat() { - return 1 + 4, nil + return 1 + 4 } - return 1 + utils.VarIntLen(uint64(f.StreamID)), nil + return 1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.Offset)) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame_test.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame_test.go index d31ce78..42b2046 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame_test.go @@ -14,17 +14,20 @@ var _ = Describe("STREAM_BLOCKED frame", func() { Context("parsing", func() { It("accepts sample frame", func() { data := []byte{0x9} - data = append(data, encodeVarInt(0xdeadbeef)...) + data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID + data = append(data, encodeVarInt(0xdecafbad)...) // offset b := bytes.NewReader(data) frame, err := ParseStreamBlockedFrame(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) + Expect(frame.Offset).To(Equal(protocol.ByteCount(0xdecafbad))) Expect(b.Len()).To(BeZero()) }) It("errors on EOFs", func() { data := []byte{0x9} data = append(data, encodeVarInt(0xdeadbeef)...) + data = append(data, encodeVarInt(0xc0010ff)...) _, err := ParseStreamBlockedFrame(bytes.NewReader(data), versionIETFFrames) Expect(err).NotTo(HaveOccurred()) for i := range data { @@ -38,19 +41,22 @@ var _ = Describe("STREAM_BLOCKED frame", func() { It("has proper min length", func() { f := &StreamBlockedFrame{ StreamID: 0x1337, + Offset: 0xdeadbeef, } - Expect(f.MinLength(0)).To(Equal(1 + utils.VarIntLen(0x1337))) + Expect(f.MinLength(0)).To(Equal(1 + utils.VarIntLen(0x1337) + utils.VarIntLen(0xdeadbeef))) }) It("writes a sample frame", func() { b := &bytes.Buffer{} f := &StreamBlockedFrame{ StreamID: 0xdecafbad, + Offset: 0x1337, } err := f.Write(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x9} expected = append(expected, encodeVarInt(uint64(f.StreamID))...) + expected = append(expected, encodeVarInt(uint64(f.Offset))...) Expect(b.Bytes()).To(Equal(expected)) }) }) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go index fc38acd..6be0065 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go @@ -117,7 +117,7 @@ func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) err // MinLength returns the length of the header of a StreamFrame // the total length of the frame is frame.MinLength() + frame.DataLen() -func (f *StreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { +func (f *StreamFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount { if !version.UsesIETFFrameFormat() { return f.minLengthLegacy(version) } @@ -128,5 +128,5 @@ func (f *StreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCo if f.DataLenPresent { length += utils.VarIntLen(uint64(f.DataLen())) } - return length, nil + return length } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame_legacy.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame_legacy.go index e3687cb..c44c255 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame_legacy.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame_legacy.go @@ -183,12 +183,12 @@ func (f *StreamFrame) getOffsetLength() protocol.ByteCount { return 8 } -func (f *StreamFrame) minLengthLegacy(_ protocol.VersionNumber) (protocol.ByteCount, error) { +func (f *StreamFrame) minLengthLegacy(_ protocol.VersionNumber) protocol.ByteCount { length := protocol.ByteCount(1) + protocol.ByteCount(f.calculateStreamIDLength()) + f.getOffsetLength() if f.DataLenPresent { length += 2 } - return length, nil + return length } // DataLen gives the length of data in bytes diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame_legacy_test.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame_legacy_test.go index b7b8d25..b179a0d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame_legacy_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame_legacy_test.go @@ -210,7 +210,7 @@ var _ = Describe("STREAM frame (for gQUIC)", func() { } err := f.Write(b, versionBigEndian) Expect(err).ToNot(HaveOccurred()) - minLength, _ := f.MinLength(0) + minLength := f.MinLength(0) Expect(b.Bytes()[0] & 0x20).To(Equal(uint8(0x20))) Expect(b.Bytes()[minLength-2 : minLength]).To(Equal([]byte{0x13, 0x37})) }) @@ -229,9 +229,9 @@ var _ = Describe("STREAM frame (for gQUIC)", func() { Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()[0] & 0x20).To(Equal(uint8(0))) Expect(b.Bytes()[1 : b.Len()-dataLen]).ToNot(ContainSubstring(string([]byte{0x37, 0x13}))) - minLength, _ := f.MinLength(versionBigEndian) + minLength := f.MinLength(versionBigEndian) f.DataLenPresent = true - minLengthWithoutDataLen, _ := f.MinLength(versionBigEndian) + minLengthWithoutDataLen := f.MinLength(versionBigEndian) Expect(minLength).To(Equal(minLengthWithoutDataLen - 2)) }) @@ -242,7 +242,7 @@ var _ = Describe("STREAM frame (for gQUIC)", func() { DataLenPresent: false, Offset: 0xdeadbeef, } - minLengthWithoutDataLen, _ := f.MinLength(versionBigEndian) + minLengthWithoutDataLen := f.MinLength(versionBigEndian) f.DataLenPresent = true Expect(f.MinLength(versionBigEndian)).To(Equal(minLengthWithoutDataLen + 2)) }) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_id_blocked_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_id_blocked_frame.go new file mode 100644 index 0000000..06e6743 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_id_blocked_frame.go @@ -0,0 +1,37 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A StreamIDBlockedFrame is a STREAM_ID_BLOCKED frame +type StreamIDBlockedFrame struct { + StreamID protocol.StreamID +} + +// ParseStreamIDBlockedFrame parses a STREAM_ID_BLOCKED frame +func ParseStreamIDBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamIDBlockedFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + streamID, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + return &StreamIDBlockedFrame{StreamID: protocol.StreamID(streamID)}, nil +} + +func (f *StreamIDBlockedFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + typeByte := uint8(0x0a) + b.WriteByte(typeByte) + utils.WriteVarInt(b, uint64(f.StreamID)) + return nil +} + +// MinLength of a written frame +func (f *StreamIDBlockedFrame) MinLength(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + utils.VarIntLen(uint64(f.StreamID)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_id_blocked_frame_test.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_id_blocked_frame_test.go new file mode 100644 index 0000000..9057d05 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_id_blocked_frame_test.go @@ -0,0 +1,53 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("STREAM_ID_BLOCKED frame", func() { + Context("parsing", func() { + It("accepts sample frame", func() { + expected := []byte{0xa} + expected = append(expected, encodeVarInt(0xdecafbad)...) + b := bytes.NewReader(expected) + frame, err := ParseStreamIDBlockedFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdecafbad))) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0xa} + data = append(data, encodeVarInt(0x12345678)...) + _, err := ParseStreamIDBlockedFrame(bytes.NewReader(data), versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + for i := range data { + _, err := ParseStreamIDBlockedFrame(bytes.NewReader(data[:i]), versionIETFFrames) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("writing", func() { + It("writes a sample frame", func() { + b := &bytes.Buffer{} + frame := StreamIDBlockedFrame{StreamID: 0xdeadbeefcafe} + err := frame.Write(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0xa} + expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has the correct min length", func() { + frame := StreamIDBlockedFrame{StreamID: 0x123456} + Expect(frame.MinLength(0)).To(Equal(protocol.ByteCount(1) + utils.VarIntLen(0x123456))) + }) + }) +}) diff --git a/vendor/github.com/lucas-clemente/quic-go/mint_utils.go b/vendor/github.com/lucas-clemente/quic-go/mint_utils.go index 02bb3ae..1ddba7f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/mint_utils.go +++ b/vendor/github.com/lucas-clemente/quic-go/mint_utils.go @@ -56,6 +56,10 @@ func (mc *mintController) State() mint.State { return mc.conn.State().HandshakeState } +func (mc *mintController) ConnectionState() mint.ConnectionState { + return mc.conn.State() +} + func (mc *mintController) SetCryptoStream(stream io.ReadWriter) { mc.csc.SetStream(stream) } @@ -73,6 +77,7 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf }, } if tlsConf != nil { + mconf.ServerName = tlsConf.ServerName mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates)) for i, certChain := range tlsConf.Certificates { mconf.Certificates[i] = &mint.Certificate{ @@ -87,6 +92,13 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf mconf.Certificates[i].Chain[j] = c } } + switch tlsConf.ClientAuth { + case tls.NoClientCert: + case tls.RequireAnyClientCert: + mconf.RequireClientAuth = true + default: + return nil, errors.New("mint currently only support ClientAuthType RequireAnyClientCert") + } } if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil { return nil, err @@ -128,14 +140,14 @@ func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, versio // packUnencryptedPacket provides a low-overhead way to pack a packet. // It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available. -func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, sf *wire.StreamFrame, pers protocol.Perspective) ([]byte, error) { +func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, f wire.Frame, pers protocol.Perspective) ([]byte, error) { raw := getPacketBuffer() buffer := bytes.NewBuffer(raw) if err := hdr.Write(buffer, pers, hdr.Version); err != nil { return nil, err } payloadStartIndex := buffer.Len() - if err := sf.Write(buffer, hdr.Version); err != nil { + if err := f.Write(buffer, hdr.Version); err != nil { return nil, err } raw = raw[0:buffer.Len()] @@ -144,7 +156,7 @@ func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, sf *wire.StreamFr if utils.Debug() { utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(raw), hdr.ConnectionID, protocol.EncryptionUnencrypted) hdr.Log() - wire.LogFrame(sf, true) + wire.LogFrame(f, true) } return raw, nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/mint_utils_test.go b/vendor/github.com/lucas-clemente/quic-go/mint_utils_test.go index e538cad..3398299 100644 --- a/vendor/github.com/lucas-clemente/quic-go/mint_utils_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/mint_utils_test.go @@ -2,9 +2,11 @@ package quic import ( "bytes" + "crypto/tls" "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -33,6 +35,45 @@ var _ = Describe("Packing and unpacking Initial packets", func() { hdr.Raw = buf.Bytes() }) + Context("generating a mint.Config", func() { + It("sets non-blocking mode", func() { + mintConf, err := tlsToMintConfig(nil, protocol.PerspectiveClient) + Expect(err).ToNot(HaveOccurred()) + Expect(mintConf.NonBlocking).To(BeTrue()) + }) + + It("sets the server name", func() { + conf := &tls.Config{ServerName: "www.example.com"} + mintConf, err := tlsToMintConfig(conf, protocol.PerspectiveClient) + Expect(err).ToNot(HaveOccurred()) + Expect(mintConf.ServerName).To(Equal("www.example.com")) + }) + + It("sets the certificate chain", func() { + tlsConf := testdata.GetTLSConfig() + mintConf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveClient) + Expect(err).ToNot(HaveOccurred()) + Expect(mintConf.Certificates).ToNot(BeEmpty()) + Expect(mintConf.Certificates).To(HaveLen(len(tlsConf.Certificates))) + }) + + It("requires client authentication", func() { + mintConf, err := tlsToMintConfig(nil, protocol.PerspectiveClient) + Expect(err).ToNot(HaveOccurred()) + Expect(mintConf.RequireClientAuth).To(BeFalse()) + conf := &tls.Config{ClientAuth: tls.RequireAnyClientCert} + mintConf, err = tlsToMintConfig(conf, protocol.PerspectiveClient) + Expect(err).ToNot(HaveOccurred()) + Expect(mintConf.RequireClientAuth).To(BeTrue()) + }) + + It("rejects unsupported client auth types", func() { + conf := &tls.Config{ClientAuth: tls.RequireAndVerifyClientCert} + _, err := tlsToMintConfig(conf, protocol.PerspectiveClient) + Expect(err).To(MatchError("mint currently only support ClientAuthType RequireAnyClientCert")) + }) + }) + Context("unpacking", func() { packPacket := func(frames []wire.Frame) []byte { buf := &bytes.Buffer{} diff --git a/vendor/github.com/lucas-clemente/quic-go/mock_crypto_stream_test.go b/vendor/github.com/lucas-clemente/quic-go/mock_crypto_stream_test.go new file mode 100644 index 0000000..68e47c1 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/mock_crypto_stream_test.go @@ -0,0 +1,141 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: CryptoStream) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" + wire "github.com/lucas-clemente/quic-go/internal/wire" +) + +// MockCryptoStream is a mock of CryptoStream interface +type MockCryptoStream struct { + ctrl *gomock.Controller + recorder *MockCryptoStreamMockRecorder +} + +// MockCryptoStreamMockRecorder is the mock recorder for MockCryptoStream +type MockCryptoStreamMockRecorder struct { + mock *MockCryptoStream +} + +// NewMockCryptoStream creates a new mock instance +func NewMockCryptoStream(ctrl *gomock.Controller) *MockCryptoStream { + mock := &MockCryptoStream{ctrl: ctrl} + mock.recorder = &MockCryptoStreamMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockCryptoStream) EXPECT() *MockCryptoStreamMockRecorder { + return m.recorder +} + +// Read mocks base method +func (m *MockCryptoStream) Read(arg0 []byte) (int, error) { + ret := m.ctrl.Call(m, "Read", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read +func (mr *MockCryptoStreamMockRecorder) Read(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockCryptoStream)(nil).Read), arg0) +} + +// StreamID mocks base method +func (m *MockCryptoStream) StreamID() protocol.StreamID { + ret := m.ctrl.Call(m, "StreamID") + ret0, _ := ret[0].(protocol.StreamID) + return ret0 +} + +// StreamID indicates an expected call of StreamID +func (mr *MockCryptoStreamMockRecorder) StreamID() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockCryptoStream)(nil).StreamID)) +} + +// Write mocks base method +func (m *MockCryptoStream) Write(arg0 []byte) (int, error) { + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write +func (mr *MockCryptoStreamMockRecorder) Write(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockCryptoStream)(nil).Write), arg0) +} + +// closeForShutdown mocks base method +func (m *MockCryptoStream) closeForShutdown(arg0 error) { + m.ctrl.Call(m, "closeForShutdown", arg0) +} + +// closeForShutdown indicates an expected call of closeForShutdown +func (mr *MockCryptoStreamMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockCryptoStream)(nil).closeForShutdown), arg0) +} + +// getWindowUpdate mocks base method +func (m *MockCryptoStream) getWindowUpdate() protocol.ByteCount { + ret := m.ctrl.Call(m, "getWindowUpdate") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// getWindowUpdate indicates an expected call of getWindowUpdate +func (mr *MockCryptoStreamMockRecorder) getWindowUpdate() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockCryptoStream)(nil).getWindowUpdate)) +} + +// handleMaxStreamDataFrame mocks base method +func (m *MockCryptoStream) handleMaxStreamDataFrame(arg0 *wire.MaxStreamDataFrame) { + m.ctrl.Call(m, "handleMaxStreamDataFrame", arg0) +} + +// handleMaxStreamDataFrame indicates an expected call of handleMaxStreamDataFrame +func (mr *MockCryptoStreamMockRecorder) handleMaxStreamDataFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleMaxStreamDataFrame", reflect.TypeOf((*MockCryptoStream)(nil).handleMaxStreamDataFrame), arg0) +} + +// handleStreamFrame mocks base method +func (m *MockCryptoStream) handleStreamFrame(arg0 *wire.StreamFrame) error { + ret := m.ctrl.Call(m, "handleStreamFrame", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// handleStreamFrame indicates an expected call of handleStreamFrame +func (mr *MockCryptoStreamMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockCryptoStream)(nil).handleStreamFrame), arg0) +} + +// popStreamFrame mocks base method +func (m *MockCryptoStream) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) { + ret := m.ctrl.Call(m, "popStreamFrame", arg0) + ret0, _ := ret[0].(*wire.StreamFrame) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// popStreamFrame indicates an expected call of popStreamFrame +func (mr *MockCryptoStreamMockRecorder) popStreamFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockCryptoStream)(nil).popStreamFrame), arg0) +} + +// setReadOffset mocks base method +func (m *MockCryptoStream) setReadOffset(arg0 protocol.ByteCount) { + m.ctrl.Call(m, "setReadOffset", arg0) +} + +// setReadOffset indicates an expected call of setReadOffset +func (mr *MockCryptoStreamMockRecorder) setReadOffset(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setReadOffset", reflect.TypeOf((*MockCryptoStream)(nil).setReadOffset), arg0) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/mock_receive_stream_internal_test.go b/vendor/github.com/lucas-clemente/quic-go/mock_receive_stream_internal_test.go new file mode 100644 index 0000000..c41bfa7 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/mock_receive_stream_internal_test.go @@ -0,0 +1,132 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: ReceiveStreamI) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" + wire "github.com/lucas-clemente/quic-go/internal/wire" +) + +// MockReceiveStreamI is a mock of ReceiveStreamI interface +type MockReceiveStreamI struct { + ctrl *gomock.Controller + recorder *MockReceiveStreamIMockRecorder +} + +// MockReceiveStreamIMockRecorder is the mock recorder for MockReceiveStreamI +type MockReceiveStreamIMockRecorder struct { + mock *MockReceiveStreamI +} + +// NewMockReceiveStreamI creates a new mock instance +func NewMockReceiveStreamI(ctrl *gomock.Controller) *MockReceiveStreamI { + mock := &MockReceiveStreamI{ctrl: ctrl} + mock.recorder = &MockReceiveStreamIMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockReceiveStreamI) EXPECT() *MockReceiveStreamIMockRecorder { + return m.recorder +} + +// CancelRead mocks base method +func (m *MockReceiveStreamI) CancelRead(arg0 protocol.ApplicationErrorCode) error { + ret := m.ctrl.Call(m, "CancelRead", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CancelRead indicates an expected call of CancelRead +func (mr *MockReceiveStreamIMockRecorder) CancelRead(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockReceiveStreamI)(nil).CancelRead), arg0) +} + +// Read mocks base method +func (m *MockReceiveStreamI) Read(arg0 []byte) (int, error) { + ret := m.ctrl.Call(m, "Read", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read +func (mr *MockReceiveStreamIMockRecorder) Read(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockReceiveStreamI)(nil).Read), arg0) +} + +// SetReadDeadline mocks base method +func (m *MockReceiveStreamI) SetReadDeadline(arg0 time.Time) error { + ret := m.ctrl.Call(m, "SetReadDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline +func (mr *MockReceiveStreamIMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockReceiveStreamI)(nil).SetReadDeadline), arg0) +} + +// StreamID mocks base method +func (m *MockReceiveStreamI) StreamID() protocol.StreamID { + ret := m.ctrl.Call(m, "StreamID") + ret0, _ := ret[0].(protocol.StreamID) + return ret0 +} + +// StreamID indicates an expected call of StreamID +func (mr *MockReceiveStreamIMockRecorder) StreamID() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockReceiveStreamI)(nil).StreamID)) +} + +// closeForShutdown mocks base method +func (m *MockReceiveStreamI) closeForShutdown(arg0 error) { + m.ctrl.Call(m, "closeForShutdown", arg0) +} + +// closeForShutdown indicates an expected call of closeForShutdown +func (mr *MockReceiveStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockReceiveStreamI)(nil).closeForShutdown), arg0) +} + +// getWindowUpdate mocks base method +func (m *MockReceiveStreamI) getWindowUpdate() protocol.ByteCount { + ret := m.ctrl.Call(m, "getWindowUpdate") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// getWindowUpdate indicates an expected call of getWindowUpdate +func (mr *MockReceiveStreamIMockRecorder) getWindowUpdate() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockReceiveStreamI)(nil).getWindowUpdate)) +} + +// handleRstStreamFrame mocks base method +func (m *MockReceiveStreamI) handleRstStreamFrame(arg0 *wire.RstStreamFrame) error { + ret := m.ctrl.Call(m, "handleRstStreamFrame", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// handleRstStreamFrame indicates an expected call of handleRstStreamFrame +func (mr *MockReceiveStreamIMockRecorder) handleRstStreamFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleRstStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleRstStreamFrame), arg0) +} + +// handleStreamFrame mocks base method +func (m *MockReceiveStreamI) handleStreamFrame(arg0 *wire.StreamFrame) error { + ret := m.ctrl.Call(m, "handleStreamFrame", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// handleStreamFrame indicates an expected call of handleStreamFrame +func (mr *MockReceiveStreamIMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleStreamFrame), arg0) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/mock_send_stream_internal_test.go b/vendor/github.com/lucas-clemente/quic-go/mock_send_stream_internal_test.go new file mode 100644 index 0000000..f1e68a0 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/mock_send_stream_internal_test.go @@ -0,0 +1,154 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: SendStreamI) + +// Package quic is a generated GoMock package. +package quic + +import ( + context "context" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" + wire "github.com/lucas-clemente/quic-go/internal/wire" +) + +// MockSendStreamI is a mock of SendStreamI interface +type MockSendStreamI struct { + ctrl *gomock.Controller + recorder *MockSendStreamIMockRecorder +} + +// MockSendStreamIMockRecorder is the mock recorder for MockSendStreamI +type MockSendStreamIMockRecorder struct { + mock *MockSendStreamI +} + +// NewMockSendStreamI creates a new mock instance +func NewMockSendStreamI(ctrl *gomock.Controller) *MockSendStreamI { + mock := &MockSendStreamI{ctrl: ctrl} + mock.recorder = &MockSendStreamIMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockSendStreamI) EXPECT() *MockSendStreamIMockRecorder { + return m.recorder +} + +// CancelWrite mocks base method +func (m *MockSendStreamI) CancelWrite(arg0 protocol.ApplicationErrorCode) error { + ret := m.ctrl.Call(m, "CancelWrite", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CancelWrite indicates an expected call of CancelWrite +func (mr *MockSendStreamIMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockSendStreamI)(nil).CancelWrite), arg0) +} + +// Close mocks base method +func (m *MockSendStreamI) Close() error { + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockSendStreamIMockRecorder) Close() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSendStreamI)(nil).Close)) +} + +// Context mocks base method +func (m *MockSendStreamI) Context() context.Context { + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context +func (mr *MockSendStreamIMockRecorder) Context() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockSendStreamI)(nil).Context)) +} + +// SetWriteDeadline mocks base method +func (m *MockSendStreamI) SetWriteDeadline(arg0 time.Time) error { + ret := m.ctrl.Call(m, "SetWriteDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetWriteDeadline indicates an expected call of SetWriteDeadline +func (mr *MockSendStreamIMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockSendStreamI)(nil).SetWriteDeadline), arg0) +} + +// StreamID mocks base method +func (m *MockSendStreamI) StreamID() protocol.StreamID { + ret := m.ctrl.Call(m, "StreamID") + ret0, _ := ret[0].(protocol.StreamID) + return ret0 +} + +// StreamID indicates an expected call of StreamID +func (mr *MockSendStreamIMockRecorder) StreamID() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockSendStreamI)(nil).StreamID)) +} + +// Write mocks base method +func (m *MockSendStreamI) Write(arg0 []byte) (int, error) { + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write +func (mr *MockSendStreamIMockRecorder) Write(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendStreamI)(nil).Write), arg0) +} + +// closeForShutdown mocks base method +func (m *MockSendStreamI) closeForShutdown(arg0 error) { + m.ctrl.Call(m, "closeForShutdown", arg0) +} + +// closeForShutdown indicates an expected call of closeForShutdown +func (mr *MockSendStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockSendStreamI)(nil).closeForShutdown), arg0) +} + +// handleMaxStreamDataFrame mocks base method +func (m *MockSendStreamI) handleMaxStreamDataFrame(arg0 *wire.MaxStreamDataFrame) { + m.ctrl.Call(m, "handleMaxStreamDataFrame", arg0) +} + +// handleMaxStreamDataFrame indicates an expected call of handleMaxStreamDataFrame +func (mr *MockSendStreamIMockRecorder) handleMaxStreamDataFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleMaxStreamDataFrame", reflect.TypeOf((*MockSendStreamI)(nil).handleMaxStreamDataFrame), arg0) +} + +// handleStopSendingFrame mocks base method +func (m *MockSendStreamI) handleStopSendingFrame(arg0 *wire.StopSendingFrame) { + m.ctrl.Call(m, "handleStopSendingFrame", arg0) +} + +// handleStopSendingFrame indicates an expected call of handleStopSendingFrame +func (mr *MockSendStreamIMockRecorder) handleStopSendingFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockSendStreamI)(nil).handleStopSendingFrame), arg0) +} + +// popStreamFrame mocks base method +func (m *MockSendStreamI) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) { + ret := m.ctrl.Call(m, "popStreamFrame", arg0) + ret0, _ := ret[0].(*wire.StreamFrame) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// popStreamFrame indicates an expected call of popStreamFrame +func (mr *MockSendStreamIMockRecorder) popStreamFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockSendStreamI)(nil).popStreamFrame), arg0) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/mock_stream_frame_source_test.go b/vendor/github.com/lucas-clemente/quic-go/mock_stream_frame_source_test.go new file mode 100644 index 0000000..9b36580 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/mock_stream_frame_source_test.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: StreamFrameSource) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" + wire "github.com/lucas-clemente/quic-go/internal/wire" +) + +// MockStreamFrameSource is a mock of StreamFrameSource interface +type MockStreamFrameSource struct { + ctrl *gomock.Controller + recorder *MockStreamFrameSourceMockRecorder +} + +// MockStreamFrameSourceMockRecorder is the mock recorder for MockStreamFrameSource +type MockStreamFrameSourceMockRecorder struct { + mock *MockStreamFrameSource +} + +// NewMockStreamFrameSource creates a new mock instance +func NewMockStreamFrameSource(ctrl *gomock.Controller) *MockStreamFrameSource { + mock := &MockStreamFrameSource{ctrl: ctrl} + mock.recorder = &MockStreamFrameSourceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockStreamFrameSource) EXPECT() *MockStreamFrameSourceMockRecorder { + return m.recorder +} + +// HasCryptoStreamData mocks base method +func (m *MockStreamFrameSource) HasCryptoStreamData() bool { + ret := m.ctrl.Call(m, "HasCryptoStreamData") + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasCryptoStreamData indicates an expected call of HasCryptoStreamData +func (mr *MockStreamFrameSourceMockRecorder) HasCryptoStreamData() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasCryptoStreamData", reflect.TypeOf((*MockStreamFrameSource)(nil).HasCryptoStreamData)) +} + +// PopCryptoStreamFrame mocks base method +func (m *MockStreamFrameSource) PopCryptoStreamFrame(arg0 protocol.ByteCount) *wire.StreamFrame { + ret := m.ctrl.Call(m, "PopCryptoStreamFrame", arg0) + ret0, _ := ret[0].(*wire.StreamFrame) + return ret0 +} + +// PopCryptoStreamFrame indicates an expected call of PopCryptoStreamFrame +func (mr *MockStreamFrameSourceMockRecorder) PopCryptoStreamFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopCryptoStreamFrame", reflect.TypeOf((*MockStreamFrameSource)(nil).PopCryptoStreamFrame), arg0) +} + +// PopStreamFrames mocks base method +func (m *MockStreamFrameSource) PopStreamFrames(arg0 protocol.ByteCount) []*wire.StreamFrame { + ret := m.ctrl.Call(m, "PopStreamFrames", arg0) + ret0, _ := ret[0].([]*wire.StreamFrame) + return ret0 +} + +// PopStreamFrames indicates an expected call of PopStreamFrames +func (mr *MockStreamFrameSourceMockRecorder) PopStreamFrames(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopStreamFrames", reflect.TypeOf((*MockStreamFrameSource)(nil).PopStreamFrames), arg0) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/mock_stream_getter_test.go b/vendor/github.com/lucas-clemente/quic-go/mock_stream_getter_test.go new file mode 100644 index 0000000..8dfa2d8 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/mock_stream_getter_test.go @@ -0,0 +1,61 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: StreamGetter) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockStreamGetter is a mock of StreamGetter interface +type MockStreamGetter struct { + ctrl *gomock.Controller + recorder *MockStreamGetterMockRecorder +} + +// MockStreamGetterMockRecorder is the mock recorder for MockStreamGetter +type MockStreamGetterMockRecorder struct { + mock *MockStreamGetter +} + +// NewMockStreamGetter creates a new mock instance +func NewMockStreamGetter(ctrl *gomock.Controller) *MockStreamGetter { + mock := &MockStreamGetter{ctrl: ctrl} + mock.recorder = &MockStreamGetterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockStreamGetter) EXPECT() *MockStreamGetterMockRecorder { + return m.recorder +} + +// GetOrOpenReceiveStream mocks base method +func (m *MockStreamGetter) GetOrOpenReceiveStream(arg0 protocol.StreamID) (receiveStreamI, error) { + ret := m.ctrl.Call(m, "GetOrOpenReceiveStream", arg0) + ret0, _ := ret[0].(receiveStreamI) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream +func (mr *MockStreamGetterMockRecorder) GetOrOpenReceiveStream(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenReceiveStream), arg0) +} + +// GetOrOpenSendStream mocks base method +func (m *MockStreamGetter) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStreamI, error) { + ret := m.ctrl.Call(m, "GetOrOpenSendStream", arg0) + ret0, _ := ret[0].(sendStreamI) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream +func (mr *MockStreamGetterMockRecorder) GetOrOpenSendStream(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenSendStream), arg0) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/mock_stream_internal_test.go b/vendor/github.com/lucas-clemente/quic-go/mock_stream_internal_test.go new file mode 100644 index 0000000..6cbc8a9 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/mock_stream_internal_test.go @@ -0,0 +1,239 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: StreamI) + +// Package quic is a generated GoMock package. +package quic + +import ( + context "context" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" + wire "github.com/lucas-clemente/quic-go/internal/wire" +) + +// MockStreamI is a mock of StreamI interface +type MockStreamI struct { + ctrl *gomock.Controller + recorder *MockStreamIMockRecorder +} + +// MockStreamIMockRecorder is the mock recorder for MockStreamI +type MockStreamIMockRecorder struct { + mock *MockStreamI +} + +// NewMockStreamI creates a new mock instance +func NewMockStreamI(ctrl *gomock.Controller) *MockStreamI { + mock := &MockStreamI{ctrl: ctrl} + mock.recorder = &MockStreamIMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockStreamI) EXPECT() *MockStreamIMockRecorder { + return m.recorder +} + +// CancelRead mocks base method +func (m *MockStreamI) CancelRead(arg0 protocol.ApplicationErrorCode) error { + ret := m.ctrl.Call(m, "CancelRead", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CancelRead indicates an expected call of CancelRead +func (mr *MockStreamIMockRecorder) CancelRead(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockStreamI)(nil).CancelRead), arg0) +} + +// CancelWrite mocks base method +func (m *MockStreamI) CancelWrite(arg0 protocol.ApplicationErrorCode) error { + ret := m.ctrl.Call(m, "CancelWrite", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CancelWrite indicates an expected call of CancelWrite +func (mr *MockStreamIMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockStreamI)(nil).CancelWrite), arg0) +} + +// Close mocks base method +func (m *MockStreamI) Close() error { + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockStreamIMockRecorder) Close() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStreamI)(nil).Close)) +} + +// Context mocks base method +func (m *MockStreamI) Context() context.Context { + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context +func (mr *MockStreamIMockRecorder) Context() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockStreamI)(nil).Context)) +} + +// Read mocks base method +func (m *MockStreamI) Read(arg0 []byte) (int, error) { + ret := m.ctrl.Call(m, "Read", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read +func (mr *MockStreamIMockRecorder) Read(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStreamI)(nil).Read), arg0) +} + +// SetDeadline mocks base method +func (m *MockStreamI) SetDeadline(arg0 time.Time) error { + ret := m.ctrl.Call(m, "SetDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDeadline indicates an expected call of SetDeadline +func (mr *MockStreamIMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockStreamI)(nil).SetDeadline), arg0) +} + +// SetReadDeadline mocks base method +func (m *MockStreamI) SetReadDeadline(arg0 time.Time) error { + ret := m.ctrl.Call(m, "SetReadDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline +func (mr *MockStreamIMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockStreamI)(nil).SetReadDeadline), arg0) +} + +// SetWriteDeadline mocks base method +func (m *MockStreamI) SetWriteDeadline(arg0 time.Time) error { + ret := m.ctrl.Call(m, "SetWriteDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetWriteDeadline indicates an expected call of SetWriteDeadline +func (mr *MockStreamIMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStreamI)(nil).SetWriteDeadline), arg0) +} + +// StreamID mocks base method +func (m *MockStreamI) StreamID() protocol.StreamID { + ret := m.ctrl.Call(m, "StreamID") + ret0, _ := ret[0].(protocol.StreamID) + return ret0 +} + +// StreamID indicates an expected call of StreamID +func (mr *MockStreamIMockRecorder) StreamID() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockStreamI)(nil).StreamID)) +} + +// Write mocks base method +func (m *MockStreamI) Write(arg0 []byte) (int, error) { + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write +func (mr *MockStreamIMockRecorder) Write(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStreamI)(nil).Write), arg0) +} + +// closeForShutdown mocks base method +func (m *MockStreamI) closeForShutdown(arg0 error) { + m.ctrl.Call(m, "closeForShutdown", arg0) +} + +// closeForShutdown indicates an expected call of closeForShutdown +func (mr *MockStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockStreamI)(nil).closeForShutdown), arg0) +} + +// getWindowUpdate mocks base method +func (m *MockStreamI) getWindowUpdate() protocol.ByteCount { + ret := m.ctrl.Call(m, "getWindowUpdate") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// getWindowUpdate indicates an expected call of getWindowUpdate +func (mr *MockStreamIMockRecorder) getWindowUpdate() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockStreamI)(nil).getWindowUpdate)) +} + +// handleMaxStreamDataFrame mocks base method +func (m *MockStreamI) handleMaxStreamDataFrame(arg0 *wire.MaxStreamDataFrame) { + m.ctrl.Call(m, "handleMaxStreamDataFrame", arg0) +} + +// handleMaxStreamDataFrame indicates an expected call of handleMaxStreamDataFrame +func (mr *MockStreamIMockRecorder) handleMaxStreamDataFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleMaxStreamDataFrame", reflect.TypeOf((*MockStreamI)(nil).handleMaxStreamDataFrame), arg0) +} + +// handleRstStreamFrame mocks base method +func (m *MockStreamI) handleRstStreamFrame(arg0 *wire.RstStreamFrame) error { + ret := m.ctrl.Call(m, "handleRstStreamFrame", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// handleRstStreamFrame indicates an expected call of handleRstStreamFrame +func (mr *MockStreamIMockRecorder) handleRstStreamFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleRstStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleRstStreamFrame), arg0) +} + +// handleStopSendingFrame mocks base method +func (m *MockStreamI) handleStopSendingFrame(arg0 *wire.StopSendingFrame) { + m.ctrl.Call(m, "handleStopSendingFrame", arg0) +} + +// handleStopSendingFrame indicates an expected call of handleStopSendingFrame +func (mr *MockStreamIMockRecorder) handleStopSendingFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockStreamI)(nil).handleStopSendingFrame), arg0) +} + +// handleStreamFrame mocks base method +func (m *MockStreamI) handleStreamFrame(arg0 *wire.StreamFrame) error { + ret := m.ctrl.Call(m, "handleStreamFrame", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// handleStreamFrame indicates an expected call of handleStreamFrame +func (mr *MockStreamIMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleStreamFrame), arg0) +} + +// popStreamFrame mocks base method +func (m *MockStreamI) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) { + ret := m.ctrl.Call(m, "popStreamFrame", arg0) + ret0, _ := ret[0].(*wire.StreamFrame) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// popStreamFrame indicates an expected call of popStreamFrame +func (mr *MockStreamIMockRecorder) popStreamFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockStreamI)(nil).popStreamFrame), arg0) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/mock_stream_manager_test.go b/vendor/github.com/lucas-clemente/quic-go/mock_stream_manager_test.go new file mode 100644 index 0000000..05993e9 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/mock_stream_manager_test.go @@ -0,0 +1,146 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: StreamManager) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + handshake "github.com/lucas-clemente/quic-go/internal/handshake" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockStreamManager is a mock of StreamManager interface +type MockStreamManager struct { + ctrl *gomock.Controller + recorder *MockStreamManagerMockRecorder +} + +// MockStreamManagerMockRecorder is the mock recorder for MockStreamManager +type MockStreamManagerMockRecorder struct { + mock *MockStreamManager +} + +// NewMockStreamManager creates a new mock instance +func NewMockStreamManager(ctrl *gomock.Controller) *MockStreamManager { + mock := &MockStreamManager{ctrl: ctrl} + mock.recorder = &MockStreamManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockStreamManager) EXPECT() *MockStreamManagerMockRecorder { + return m.recorder +} + +// AcceptStream mocks base method +func (m *MockStreamManager) AcceptStream() (Stream, error) { + ret := m.ctrl.Call(m, "AcceptStream") + ret0, _ := ret[0].(Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AcceptStream indicates an expected call of AcceptStream +func (mr *MockStreamManagerMockRecorder) AcceptStream() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream)) +} + +// CloseWithError mocks base method +func (m *MockStreamManager) CloseWithError(arg0 error) { + m.ctrl.Call(m, "CloseWithError", arg0) +} + +// CloseWithError indicates an expected call of CloseWithError +func (mr *MockStreamManagerMockRecorder) CloseWithError(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockStreamManager)(nil).CloseWithError), arg0) +} + +// DeleteStream mocks base method +func (m *MockStreamManager) DeleteStream(arg0 protocol.StreamID) error { + ret := m.ctrl.Call(m, "DeleteStream", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteStream indicates an expected call of DeleteStream +func (mr *MockStreamManagerMockRecorder) DeleteStream(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStream", reflect.TypeOf((*MockStreamManager)(nil).DeleteStream), arg0) +} + +// GetOrOpenReceiveStream mocks base method +func (m *MockStreamManager) GetOrOpenReceiveStream(arg0 protocol.StreamID) (receiveStreamI, error) { + ret := m.ctrl.Call(m, "GetOrOpenReceiveStream", arg0) + ret0, _ := ret[0].(receiveStreamI) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream +func (mr *MockStreamManagerMockRecorder) GetOrOpenReceiveStream(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenReceiveStream), arg0) +} + +// GetOrOpenSendStream mocks base method +func (m *MockStreamManager) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStreamI, error) { + ret := m.ctrl.Call(m, "GetOrOpenSendStream", arg0) + ret0, _ := ret[0].(sendStreamI) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream +func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0) +} + +// GetOrOpenStream mocks base method +func (m *MockStreamManager) GetOrOpenStream(arg0 protocol.StreamID) (streamI, error) { + ret := m.ctrl.Call(m, "GetOrOpenStream", arg0) + ret0, _ := ret[0].(streamI) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOrOpenStream indicates an expected call of GetOrOpenStream +func (mr *MockStreamManagerMockRecorder) GetOrOpenStream(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenStream), arg0) +} + +// OpenStream mocks base method +func (m *MockStreamManager) OpenStream() (Stream, error) { + ret := m.ctrl.Call(m, "OpenStream") + ret0, _ := ret[0].(Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenStream indicates an expected call of OpenStream +func (mr *MockStreamManagerMockRecorder) OpenStream() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockStreamManager)(nil).OpenStream)) +} + +// OpenStreamSync mocks base method +func (m *MockStreamManager) OpenStreamSync() (Stream, error) { + ret := m.ctrl.Call(m, "OpenStreamSync") + ret0, _ := ret[0].(Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenStreamSync indicates an expected call of OpenStreamSync +func (mr *MockStreamManagerMockRecorder) OpenStreamSync() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync)) +} + +// UpdateLimits mocks base method +func (m *MockStreamManager) UpdateLimits(arg0 *handshake.TransportParameters) { + m.ctrl.Call(m, "UpdateLimits", arg0) +} + +// UpdateLimits indicates an expected call of UpdateLimits +func (mr *MockStreamManagerMockRecorder) UpdateLimits(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLimits", reflect.TypeOf((*MockStreamManager)(nil).UpdateLimits), arg0) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/mock_stream_sender_test.go b/vendor/github.com/lucas-clemente/quic-go/mock_stream_sender_test.go new file mode 100644 index 0000000..da3ad8d --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/mock_stream_sender_test.go @@ -0,0 +1,76 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: StreamSender) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" + wire "github.com/lucas-clemente/quic-go/internal/wire" +) + +// MockStreamSender is a mock of StreamSender interface +type MockStreamSender struct { + ctrl *gomock.Controller + recorder *MockStreamSenderMockRecorder +} + +// MockStreamSenderMockRecorder is the mock recorder for MockStreamSender +type MockStreamSenderMockRecorder struct { + mock *MockStreamSender +} + +// NewMockStreamSender creates a new mock instance +func NewMockStreamSender(ctrl *gomock.Controller) *MockStreamSender { + mock := &MockStreamSender{ctrl: ctrl} + mock.recorder = &MockStreamSenderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockStreamSender) EXPECT() *MockStreamSenderMockRecorder { + return m.recorder +} + +// onHasStreamData mocks base method +func (m *MockStreamSender) onHasStreamData(arg0 protocol.StreamID) { + m.ctrl.Call(m, "onHasStreamData", arg0) +} + +// onHasStreamData indicates an expected call of onHasStreamData +func (mr *MockStreamSenderMockRecorder) onHasStreamData(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamData", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamData), arg0) +} + +// onHasWindowUpdate mocks base method +func (m *MockStreamSender) onHasWindowUpdate(arg0 protocol.StreamID) { + m.ctrl.Call(m, "onHasWindowUpdate", arg0) +} + +// onHasWindowUpdate indicates an expected call of onHasWindowUpdate +func (mr *MockStreamSenderMockRecorder) onHasWindowUpdate(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasWindowUpdate", reflect.TypeOf((*MockStreamSender)(nil).onHasWindowUpdate), arg0) +} + +// onStreamCompleted mocks base method +func (m *MockStreamSender) onStreamCompleted(arg0 protocol.StreamID) { + m.ctrl.Call(m, "onStreamCompleted", arg0) +} + +// onStreamCompleted indicates an expected call of onStreamCompleted +func (mr *MockStreamSenderMockRecorder) onStreamCompleted(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onStreamCompleted", reflect.TypeOf((*MockStreamSender)(nil).onStreamCompleted), arg0) +} + +// queueControlFrame mocks base method +func (m *MockStreamSender) queueControlFrame(arg0 wire.Frame) { + m.ctrl.Call(m, "queueControlFrame", arg0) +} + +// queueControlFrame indicates an expected call of queueControlFrame +func (mr *MockStreamSenderMockRecorder) queueControlFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "queueControlFrame", reflect.TypeOf((*MockStreamSender)(nil).queueControlFrame), arg0) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/mockgen.go b/vendor/github.com/lucas-clemente/quic-go/mockgen.go new file mode 100644 index 0000000..3802a86 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/mockgen.go @@ -0,0 +1,12 @@ +package quic + +//go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/lucas-clemente/quic-go streamI StreamI" +//go:generate sh -c "./mockgen_private.sh quic mock_receive_stream_internal_test.go github.com/lucas-clemente/quic-go receiveStreamI ReceiveStreamI" +//go:generate sh -c "./mockgen_private.sh quic mock_send_stream_internal_test.go github.com/lucas-clemente/quic-go sendStreamI SendStreamI" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender StreamSender" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter StreamGetter" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource StreamFrameSource" +//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStreamI CryptoStream" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager StreamManager" +//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_stream_getter_test.go mock_stream_manager_test.go" +//go:generate sh -c "goimports -w mock*_test.go" diff --git a/vendor/github.com/lucas-clemente/quic-go/mockgen_private.sh b/vendor/github.com/lucas-clemente/quic-go/mockgen_private.sh new file mode 100755 index 0000000..7fbe68d --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/mockgen_private.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# Mockgen refuses to generate mocks private types. +# This script copies the quic package to a temporary directory, and adds an public alias for the private type. +# It then creates a mock for this public (alias) type. + +TEMP_DIR=$(mktemp -d) +mkdir -p $TEMP_DIR/src/github.com/lucas-clemente/quic-go/ + +# copy all .go files to a temporary directory +# golang.org/x/crypto/curve25519/ uses Go compiler directives, which is confusing to mockgen +rsync -r --exclude 'vendor/golang.org/x/crypto/curve25519/' --include='*.go' --include '*/' --exclude '*' $GOPATH/src/github.com/lucas-clemente/quic-go/ $TEMP_DIR/src/github.com/lucas-clemente/quic-go/ +echo "type $5 = $4" >> $TEMP_DIR/src/github.com/lucas-clemente/quic-go/interface.go + +export GOPATH="$TEMP_DIR:$GOPATH" + +mockgen -package $1 -self_package $1 -destination $2 $3 $5 + +rm -r "$TEMP_DIR" diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_packer.go b/vendor/github.com/lucas-clemente/quic-go/packet_packer.go index aabddee..74e46e3 100644 --- a/vendor/github.com/lucas-clemente/quic-go/packet_packer.go +++ b/vendor/github.com/lucas-clemente/quic-go/packet_packer.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "fmt" + "sync" "github.com/lucas-clemente/quic-go/ackhandler" "github.com/lucas-clemente/quic-go/internal/handshake" @@ -18,6 +19,12 @@ type packedPacket struct { encryptionLevel protocol.EncryptionLevel } +type streamFrameSource interface { + HasCryptoStreamData() bool + PopCryptoStreamFrame(protocol.ByteCount) *wire.StreamFrame + PopStreamFrames(protocol.ByteCount) []*wire.StreamFrame +} + type packetPacker struct { connectionID protocol.ConnectionID perspective protocol.Perspective @@ -25,20 +32,23 @@ type packetPacker struct { cryptoSetup handshake.CryptoSetup packetNumberGenerator *packetNumberGenerator - streamFramer *streamFramer + streams streamFrameSource - controlFrames []wire.Frame - stopWaiting *wire.StopWaitingFrame - ackFrame *wire.AckFrame - leastUnacked protocol.PacketNumber - omitConnectionID bool - hasSentPacket bool // has the packetPacker already sent a packet + controlFrameMutex sync.Mutex + controlFrames []wire.Frame + + stopWaiting *wire.StopWaitingFrame + ackFrame *wire.AckFrame + leastUnacked protocol.PacketNumber + omitConnectionID bool + hasSentPacket bool // has the packetPacker already sent a packet + makeNextPacketRetransmittable bool } func newPacketPacker(connectionID protocol.ConnectionID, initialPacketNumber protocol.PacketNumber, cryptoSetup handshake.CryptoSetup, - streamFramer *streamFramer, + streamFramer streamFrameSource, perspective protocol.Perspective, version protocol.VersionNumber, ) *packetPacker { @@ -47,7 +57,7 @@ func newPacketPacker(connectionID protocol.ConnectionID, connectionID: connectionID, perspective: perspective, version: version, - streamFramer: streamFramer, + streams: streamFramer, packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength), } } @@ -73,7 +83,7 @@ func (p *packetPacker) PackAckPacket() (*packedPacket, error) { encLevel, sealer := p.cryptoSetup.GetSealer() header := p.getHeader(encLevel) frames := []wire.Frame{p.ackFrame} - if p.stopWaiting != nil { + if p.stopWaiting != nil { // a STOP_WAITING will only be queued when using gQUIC p.stopWaiting.PacketNumber = header.PacketNumber p.stopWaiting.PacketNumberLen = header.PacketNumberLen frames = append(frames, p.stopWaiting) @@ -98,14 +108,20 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (* if err != nil { return nil, err } - if p.stopWaiting == nil { - return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame") - } header := p.getHeader(packet.EncryptionLevel) - p.stopWaiting.PacketNumber = header.PacketNumber - p.stopWaiting.PacketNumberLen = header.PacketNumberLen - frames := append([]wire.Frame{p.stopWaiting}, packet.Frames...) - p.stopWaiting = nil + var frames []wire.Frame + if !p.version.UsesIETFFrameFormat() { // for gQUIC: pack a STOP_WAITING first + if p.stopWaiting == nil { + return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame") + } + swf := p.stopWaiting + swf.PacketNumber = header.PacketNumber + swf.PacketNumberLen = header.PacketNumberLen + p.stopWaiting = nil + frames = append([]wire.Frame{swf}, packet.Frames...) + } else { + frames = packet.Frames + } raw, err := p.writeAndSealPacket(header, frames, sealer) return &packedPacket{ header: header, @@ -118,7 +134,7 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (* // PackPacket packs a new packet // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise func (p *packetPacker) PackPacket() (*packedPacket, error) { - hasCryptoStreamFrame := p.streamFramer.HasCryptoStreamFrame() + hasCryptoStreamFrame := p.streams.HasCryptoStreamData() // if this is the first packet to be send, make sure it contains stream data if !p.hasSentPacket && !hasCryptoStreamFrame { return nil, nil @@ -153,6 +169,15 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { if len(payloadFrames) == 1 && p.stopWaiting != nil { return nil, nil } + // check if this packet only contains an ACK and / or STOP_WAITING + if !ackhandler.HasRetransmittableFrames(payloadFrames) { + if p.makeNextPacketRetransmittable { + payloadFrames = append(payloadFrames, &wire.PingFrame{}) + p.makeNextPacketRetransmittable = false + } + } else { // this packet already contains a retransmittable frame. No need to send a PING + p.makeNextPacketRetransmittable = false + } p.stopWaiting = nil p.ackFrame = nil @@ -176,7 +201,9 @@ func (p *packetPacker) packCryptoPacket() (*packedPacket, error) { return nil, err } maxLen := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength - frames := []wire.Frame{p.streamFramer.PopCryptoStreamFrame(maxLen)} + sf := p.streams.PopCryptoStreamFrame(maxLen) + sf.DataLenPresent = false + frames := []wire.Frame{sf} raw, err := p.writeAndSealPacket(header, frames, sealer) if err != nil { return nil, err @@ -197,29 +224,20 @@ func (p *packetPacker) composeNextPacket( var payloadFrames []wire.Frame // STOP_WAITING and ACK will always fit - if p.stopWaiting != nil { - payloadFrames = append(payloadFrames, p.stopWaiting) - l, err := p.stopWaiting.MinLength(p.version) - if err != nil { - return nil, err - } + if p.ackFrame != nil { // ACKs need to go first, so that the sentPacketHandler will recognize them + payloadFrames = append(payloadFrames, p.ackFrame) + l := p.ackFrame.MinLength(p.version) payloadLength += l } - if p.ackFrame != nil { - payloadFrames = append(payloadFrames, p.ackFrame) - l, err := p.ackFrame.MinLength(p.version) - if err != nil { - return nil, err - } - payloadLength += l + if p.stopWaiting != nil { // a STOP_WAITING will only be queued when using gQUIC + payloadFrames = append(payloadFrames, p.stopWaiting) + payloadLength += p.stopWaiting.MinLength(p.version) } + p.controlFrameMutex.Lock() for len(p.controlFrames) > 0 { frame := p.controlFrames[len(p.controlFrames)-1] - minLength, err := frame.MinLength(p.version) - if err != nil { - return nil, err - } + minLength := frame.MinLength(p.version) if payloadLength+minLength > maxFrameSize { break } @@ -227,6 +245,7 @@ func (p *packetPacker) composeNextPacket( payloadLength += minLength p.controlFrames = p.controlFrames[:len(p.controlFrames)-1] } + p.controlFrameMutex.Unlock() if payloadLength > maxFrameSize { return nil, fmt.Errorf("Packet Packer BUG: packet payload (%d) too large (%d)", payloadLength, maxFrameSize) @@ -247,20 +266,14 @@ func (p *packetPacker) composeNextPacket( maxFrameSize += 2 } - fs := p.streamFramer.PopStreamFrames(maxFrameSize - payloadLength) + fs := p.streams.PopStreamFrames(maxFrameSize - payloadLength) if len(fs) != 0 { fs[len(fs)-1].DataLenPresent = false } - // TODO: Simplify for _, f := range fs { payloadFrames = append(payloadFrames, f) } - - for b := p.streamFramer.PopBlockedFrame(); b != nil; b = p.streamFramer.PopBlockedFrame() { - p.controlFrames = append(p.controlFrames, b) - } - return payloadFrames, nil } @@ -271,7 +284,9 @@ func (p *packetPacker) QueueControlFrame(frame wire.Frame) { case *wire.AckFrame: p.ackFrame = f default: + p.controlFrameMutex.Lock() p.controlFrames = append(p.controlFrames, f) + p.controlFrameMutex.Unlock() } } @@ -377,3 +392,7 @@ func (p *packetPacker) SetLeastUnacked(leastUnacked protocol.PacketNumber) { func (p *packetPacker) SetOmitConnectionID() { p.omitConnectionID = true } + +func (p *packetPacker) MakeNextPacketRetransmittable() { + p.makeNextPacketRetransmittable = true +} diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_packer_test.go b/vendor/github.com/lucas-clemente/quic-go/packet_packer_test.go index 50cedd9..936d255 100644 --- a/vendor/github.com/lucas-clemente/quic-go/packet_packer_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/packet_packer_test.go @@ -4,6 +4,7 @@ import ( "bytes" "math" + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/ackhandler" "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/handshake" @@ -49,29 +50,32 @@ func (m *mockCryptoSetup) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) } func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.divNonce } func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce = divNonce } +func (m *mockCryptoSetup) ConnectionState() ConnectionState { panic("not implemented") } var _ = Describe("Packet packer", func() { var ( - packer *packetPacker - publicHeaderLen protocol.ByteCount - maxFrameSize protocol.ByteCount - streamFramer *streamFramer - cryptoStream *stream + packer *packetPacker + publicHeaderLen protocol.ByteCount + maxFrameSize protocol.ByteCount + cryptoStream cryptoStreamI + mockStreamFramer *MockStreamFrameSource ) BeforeEach(func() { version := versionGQUICFrames - cryptoStream = &stream{streamID: version.CryptoStreamID(), flowController: flowcontrol.NewStreamFlowController(version.CryptoStreamID(), false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil)} - streamsMap := newStreamsMap(nil, protocol.PerspectiveServer, versionGQUICFrames) - streamFramer = newStreamFramer(cryptoStream, streamsMap, nil, versionGQUICFrames) + mockSender := NewMockStreamSender(mockCtrl) + mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes() + cryptoStream = newCryptoStream(mockSender, flowcontrol.NewStreamFlowController(version.CryptoStreamID(), false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil), version) + mockStreamFramer = NewMockStreamFrameSource(mockCtrl) - packer = &packetPacker{ - cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, - connectionID: 0x1337, - packetNumberGenerator: newPacketNumberGenerator(1, protocol.SkipPacketAveragePeriodLength), - streamFramer: streamFramer, - perspective: protocol.PerspectiveServer, - } + packer = newPacketPacker( + 0x1337, + 1, + &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, + mockStreamFramer, + protocol.PerspectiveServer, + version, + ) publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number maxFrameSize = protocol.MaxPacketSize - protocol.ByteCount((&mockSealer{}).Overhead()) - publicHeaderLen packer.hasSentPacket = true @@ -79,33 +83,36 @@ var _ = Describe("Packet packer", func() { }) It("returns nil when no packet is queued", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) p, err := packer.PackPacket() Expect(p).To(BeNil()) Expect(err).ToNot(HaveOccurred()) }) It("packs single packets", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() f := &wire.StreamFrame{ StreamID: 5, Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, } - streamFramer.AddFrameForRetransmission(f) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f}) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) b := &bytes.Buffer{} f.Write(b, packer.version) - Expect(p.frames).To(HaveLen(1)) + Expect(p.frames).To(Equal([]wire.Frame{f})) Expect(p.raw).To(ContainSubstring(string(b.Bytes()))) }) It("stores the encryption level a packet was sealed with", func() { - packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionForwardSecure - f := &wire.StreamFrame{ + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{{ StreamID: 5, Data: []byte("foobar"), - } - streamFramer.AddFrameForRetransmission(f) + }}) + packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionForwardSecure p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) @@ -213,7 +220,7 @@ var _ = Describe("Packet packer", func() { }) }) - It("packs a ConnectionClose", func() { + It("packs a CONNECTION_CLOSE", func() { ccf := wire.ConnectionCloseFrame{ ErrorCode: 0x1337, ReasonPhrase: "foobar", @@ -224,23 +231,21 @@ var _ = Describe("Packet packer", func() { Expect(p.frames[0]).To(Equal(&ccf)) }) - It("doesn't send any other frames when sending a ConnectionClose", func() { - ccf := wire.ConnectionCloseFrame{ + It("doesn't send any other frames when sending a CONNECTION_CLOSE", func() { + // expect no mockStreamFramer.PopStreamFrames + ccf := &wire.ConnectionCloseFrame{ ErrorCode: 0x1337, ReasonPhrase: "foobar", } packer.controlFrames = []wire.Frame{&wire.MaxStreamDataFrame{StreamID: 37}} - streamFramer.AddFrameForRetransmission(&wire.StreamFrame{ - StreamID: 5, - Data: []byte("foobar"), - }) - p, err := packer.PackConnectionClose(&ccf) + p, err := packer.PackConnectionClose(ccf) Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0]).To(Equal(&ccf)) + Expect(p.frames).To(Equal([]wire.Frame{ccf})) }) It("packs only control frames", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) packer.QueueControlFrame(&wire.RstStreamFrame{}) packer.QueueControlFrame(&wire.MaxDataFrame{}) p, err := packer.PackPacket() @@ -251,6 +256,8 @@ var _ = Describe("Packet packer", func() { }) It("increases the packet number", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) packer.QueueControlFrame(&wire.RstStreamFrame{}) p1, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) @@ -263,6 +270,8 @@ var _ = Describe("Packet packer", func() { }) It("packs a STOP_WAITING frame first", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) packer.packetNumberGenerator.next = 15 swf := &wire.StopWaitingFrame{LeastUnacked: 10} packer.QueueControlFrame(&wire.RstStreamFrame{}) @@ -275,6 +284,8 @@ var _ = Describe("Packet packer", func() { }) It("sets the LeastUnackedDelta length of a STOP_WAITING frame", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number packer.packetNumberGenerator.next = packetNumber swf := &wire.StopWaitingFrame{LeastUnacked: packetNumber - 0x100} @@ -286,6 +297,8 @@ var _ = Describe("Packet packer", func() { }) It("does not pack a packet containing only a STOP_WAITING frame", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) swf := &wire.StopWaitingFrame{LeastUnacked: 10} packer.QueueControlFrame(swf) p, err := packer.PackPacket() @@ -294,6 +307,8 @@ var _ = Describe("Packet packer", func() { }) It("packs a packet if it has queued control frames, but no new control frames", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}} p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) @@ -301,6 +316,7 @@ var _ = Describe("Packet packer", func() { }) It("refuses to send a packet that doesn't contain crypto stream data, if it has never sent a packet before", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() packer.hasSentPacket = false packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}} p, err := packer.PackPacket() @@ -329,8 +345,7 @@ var _ = Describe("Packet packer", func() { It("packs a lot of control frames into 2 packets if they don't fit into one", func() { blockedFrame := &wire.BlockedFrame{} - minLength, _ := blockedFrame.MinLength(packer.version) - maxFramesPerPacket := int(maxFrameSize) / int(minLength) + maxFramesPerPacket := int(maxFrameSize) / int(blockedFrame.MinLength(packer.version)) var controlFrames []wire.Frame for i := 0; i < maxFramesPerPacket+10; i++ { controlFrames = append(controlFrames, blockedFrame) @@ -345,16 +360,17 @@ var _ = Describe("Packet packer", func() { }) It("only increases the packet number when there is an actual packet to send", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) packer.packetNumberGenerator.nextToSkip = 1000 p, err := packer.PackPacket() Expect(p).To(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(1))) - f := &wire.StreamFrame{ + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{{ StreamID: 5, Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, - } - streamFramer.AddFrameForRetransmission(f) + }}) p, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) @@ -362,320 +378,207 @@ var _ = Describe("Packet packer", func() { Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(2))) }) - Context("STREAM Frame handling", func() { + It("adds a PING frame when it's supposed to send a retransmittable packet", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) + packer.QueueControlFrame(&wire.AckFrame{}) + packer.QueueControlFrame(&wire.StopWaitingFrame{}) + packer.MakeNextPacketRetransmittable() + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(3)) + Expect(p.frames).To(ContainElement(&wire.PingFrame{})) + // make sure the next packet doesn't contain another PING + packer.QueueControlFrame(&wire.AckFrame{}) + p, err = packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(1)) + }) + + It("waits until there's something to send before adding a PING frame", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) + packer.MakeNextPacketRetransmittable() + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(BeNil()) + packer.QueueControlFrame(&wire.AckFrame{}) + p, err = packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(2)) + Expect(p.frames).To(ContainElement(&wire.PingFrame{})) + }) + + It("doesn't send a PING if it already sent another retransmittable frame", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) + packer.MakeNextPacketRetransmittable() + packer.QueueControlFrame(&wire.MaxDataFrame{}) + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(1)) + packer.QueueControlFrame(&wire.AckFrame{}) + p, err = packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(1)) + }) + + Context("STREAM frame handling", func() { It("does not splits a STREAM frame with maximum size, for gQUIC frames", func() { - f := &wire.StreamFrame{ - Offset: 1, - StreamID: 5, - DataLenPresent: false, - } - minLength, _ := f.MinLength(packer.version) - maxStreamFrameDataLen := maxFrameSize - minLength - f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)) - streamFramer.AddFrameForRetransmission(f) - payloadFrames, err := packer.composeNextPacket(maxFrameSize, true) - Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(HaveLen(1)) - Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - payloadFrames, err = packer.composeNextPacket(maxFrameSize, true) - Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(BeEmpty()) - }) - - It("does not splits a STREAM frame with maximum size, for IETF draft style frame", func() { - packer.version = versionIETFFrames - streamFramer.version = versionIETFFrames - f := &wire.StreamFrame{ - Offset: 1, - StreamID: 5, - DataLenPresent: true, - } - minLength, _ := f.MinLength(packer.version) - // for IETF draft style STREAM frames, we don't know the size of the DataLen, because it is a variable length integer - // in the general case, we therefore use a STREAM frame that is 1 byte smaller than the maximum size - maxStreamFrameDataLen := maxFrameSize - minLength - 1 - f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)) - streamFramer.AddFrameForRetransmission(f) - payloadFrames, err := packer.composeNextPacket(maxFrameSize, true) - Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(HaveLen(1)) - Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - payloadFrames, err = packer.composeNextPacket(maxFrameSize, true) - Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(BeEmpty()) - }) - - It("correctly handles a STREAM frame with one byte less than maximum size", func() { - maxStreamFrameDataLen := maxFrameSize - (1 + 1 + 2) - 1 - f1 := &wire.StreamFrame{ - StreamID: 5, - Offset: 1, - Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)), - } - f2 := &wire.StreamFrame{ - StreamID: 5, - Offset: 1, - Data: []byte("foobar"), - } - streamFramer.AddFrameForRetransmission(f1) - streamFramer.AddFrameForRetransmission(f2) - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize - 1))) - Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - p, err = packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - }) - - It("packs multiple small STREAM frames into single packet", func() { - f1 := &wire.StreamFrame{ - StreamID: 5, - Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, - } - f2 := &wire.StreamFrame{ - StreamID: 5, - Data: []byte{0xBE, 0xEF, 0x13, 0x37}, - } - f3 := &wire.StreamFrame{ - StreamID: 3, - Data: []byte{0xCA, 0xFE}, - } - streamFramer.AddFrameForRetransmission(f1) - streamFramer.AddFrameForRetransmission(f2) - streamFramer.AddFrameForRetransmission(f3) - p, err := packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - b := &bytes.Buffer{} - f1.Write(b, 0) - f2.Write(b, 0) - f3.Write(b, 0) - Expect(p.frames).To(HaveLen(3)) - Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue()) - Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeTrue()) - Expect(p.frames[2].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - Expect(p.raw).To(ContainSubstring(string(f1.Data))) - Expect(p.raw).To(ContainSubstring(string(f2.Data))) - Expect(p.raw).To(ContainSubstring(string(f3.Data))) - }) - - It("splits one STREAM frame larger than maximum size", func() { - f := &wire.StreamFrame{ - StreamID: 7, - Offset: 1, - } - minLength, _ := f.MinLength(packer.version) - maxStreamFrameDataLen := maxFrameSize - minLength - f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+200) - streamFramer.AddFrameForRetransmission(f) - payloadFrames, err := packer.composeNextPacket(maxFrameSize, true) - Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(HaveLen(1)) - Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - Expect(payloadFrames[0].(*wire.StreamFrame).Data).To(HaveLen(int(maxStreamFrameDataLen))) - payloadFrames, err = packer.composeNextPacket(maxFrameSize, true) - Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(HaveLen(1)) - Expect(payloadFrames[0].(*wire.StreamFrame).Data).To(HaveLen(200)) - Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - payloadFrames, err = packer.composeNextPacket(maxFrameSize, true) - Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(BeEmpty()) - }) - - It("packs 2 STREAM frames that are too big for one packet correctly", func() { - maxStreamFrameDataLen := maxFrameSize - (1 + 1 + 2) - f1 := &wire.StreamFrame{ - StreamID: 5, - Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+100), - Offset: 1, - } - f2 := &wire.StreamFrame{ - StreamID: 5, - Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+100), - Offset: 1, - } - streamFramer.AddFrameForRetransmission(f1) - streamFramer.AddFrameForRetransmission(f2) + mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).DoAndReturn(func(maxSize protocol.ByteCount) []*wire.StreamFrame { + f := &wire.StreamFrame{ + Offset: 1, + StreamID: 5, + DataLenPresent: true, + } + f.Data = bytes.Repeat([]byte{'f'}, int(maxSize-f.MinLength(packer.version))) + return []*wire.StreamFrame{f} + }) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) - p, err = packer.PackPacket() - Expect(p.frames).To(HaveLen(2)) - Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue()) - Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) - p, err = packer.PackPacket() - Expect(p.frames).To(HaveLen(1)) Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - Expect(err).ToNot(HaveOccurred()) - Expect(p).ToNot(BeNil()) p, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) }) - It("packs a packet that has the maximum packet size when given a large enough STREAM frame", func() { - f := &wire.StreamFrame{ - StreamID: 5, - Offset: 1, - } - minLength, _ := f.MinLength(packer.version) - f.Data = bytes.Repeat([]byte{'f'}, int(maxFrameSize-minLength+1)) // + 1 since MinceLength is 1 bigger than the actual StreamFrame header - streamFramer.AddFrameForRetransmission(f) + It("does not splits a STREAM frame with maximum size, for IETF draft style frame", func() { + packer.version = versionIETFFrames + mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).DoAndReturn(func(maxSize protocol.ByteCount) []*wire.StreamFrame { + f := &wire.StreamFrame{ + Offset: 1, + StreamID: 5, + DataLenPresent: true, + } + f.Data = bytes.Repeat([]byte{'f'}, int(maxSize-f.MinLength(packer.version))) + return []*wire.StreamFrame{f} + }) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p).ToNot(BeNil()) + Expect(p.frames).To(HaveLen(1)) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) + Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) + p, err = packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(BeNil()) }) - It("splits a STREAM frame larger than the maximum size", func() { - f := &wire.StreamFrame{ - StreamID: 5, - Offset: 1, + It("packs multiple small STREAM frames into single packet", func() { + f1 := &wire.StreamFrame{ + StreamID: 5, + Data: []byte("frame 1"), + DataLenPresent: true, } - minLength, _ := f.MinLength(packer.version) - f.Data = bytes.Repeat([]byte{'f'}, int(maxFrameSize-minLength+2)) // + 2 since MinceLength is 1 bigger than the actual StreamFrame header - - streamFramer.AddFrameForRetransmission(f) - payloadFrames, err := packer.composeNextPacket(maxFrameSize, true) + f2 := &wire.StreamFrame{ + StreamID: 5, + Data: []byte("frame 2"), + DataLenPresent: true, + } + f3 := &wire.StreamFrame{ + StreamID: 3, + Data: []byte("frame 3"), + DataLenPresent: true, + } + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f1, f2, f3}) + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(HaveLen(1)) - payloadFrames, err = packer.composeNextPacket(maxFrameSize, true) - Expect(err).ToNot(HaveOccurred()) - Expect(payloadFrames).To(HaveLen(1)) + Expect(p.frames).To(HaveLen(3)) + Expect(p.frames[0].(*wire.StreamFrame).Data).To(Equal([]byte("frame 1"))) + Expect(p.frames[1].(*wire.StreamFrame).Data).To(Equal([]byte("frame 2"))) + Expect(p.frames[2].(*wire.StreamFrame).Data).To(Equal([]byte("frame 3"))) + Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue()) + Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeTrue()) + Expect(p.frames[2].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) }) It("refuses to send unencrypted stream data on a data stream", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + // don't expect a call to mockStreamFramer.PopStreamFrames packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted - f := &wire.StreamFrame{ - StreamID: 3, - Data: []byte("foobar"), - } - streamFramer.AddFrameForRetransmission(f) p, err := packer.PackPacket() Expect(err).NotTo(HaveOccurred()) Expect(p).To(BeNil()) }) It("sends non forward-secure data as the client", func() { - packer.perspective = protocol.PerspectiveClient - packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure f := &wire.StreamFrame{ StreamID: 5, Data: []byte("foobar"), } - streamFramer.AddFrameForRetransmission(f) + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f}) + packer.perspective = protocol.PerspectiveClient + packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) - Expect(p.frames[0]).To(Equal(f)) + Expect(p.frames).To(Equal([]wire.Frame{f})) }) It("does not send non forward-secure data as the server", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + // don't expect a call to mockStreamFramer.PopStreamFrames packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure - f := &wire.StreamFrame{ - StreamID: 5, - Data: []byte("foobar"), - } - streamFramer.AddFrameForRetransmission(f) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) }) It("sends unencrypted stream data on the crypto stream", func() { - packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted - cryptoStream.dataForWriting = []byte("foobar") - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) - Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0]).To(Equal(&wire.StreamFrame{ + f := &wire.StreamFrame{ StreamID: packer.version.CryptoStreamID(), Data: []byte("foobar"), - })) + } + mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true) + mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f) + packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(Equal([]wire.Frame{f})) + Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) }) It("sends encrypted stream data on the crypto stream", func() { - packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionSecure - cryptoStream.dataForWriting = []byte("foobar") - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) - Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0]).To(Equal(&wire.StreamFrame{ + f := &wire.StreamFrame{ StreamID: packer.version.CryptoStreamID(), Data: []byte("foobar"), - })) - }) - - It("does not pack stream frames if not allowed", func() { - packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted - packer.QueueControlFrame(&wire.AckFrame{}) - streamFramer.AddFrameForRetransmission(&wire.StreamFrame{StreamID: 3, Data: []byte("foobar")}) + } + mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true) + mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f) + packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionSecure p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) - Expect(func() { _ = p.frames[0].(*wire.AckFrame) }).NotTo(Panic()) + Expect(p.frames).To(Equal([]wire.Frame{f})) + Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) }) - }) - Context("BLOCKED frames", func() { - It("queues a BLOCKED frame", func() { - length := 100 - streamFramer.blockedFrameQueue = []wire.Frame{&wire.StreamBlockedFrame{StreamID: 5}} - f := &wire.StreamFrame{ - StreamID: 5, - Data: bytes.Repeat([]byte{'f'}, length), - } - streamFramer.AddFrameForRetransmission(f) - _, err := packer.composeNextPacket(maxFrameSize, true) + It("does not pack STREAM frames if not allowed", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + // don't expect a call to mockStreamFramer.PopStreamFrames + packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted + ack := &wire.AckFrame{LargestAcked: 10} + packer.QueueControlFrame(ack) + p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(packer.controlFrames[0]).To(Equal(&wire.StreamBlockedFrame{StreamID: 5})) + Expect(p.frames).To(Equal([]wire.Frame{ack})) }) - - It("removes the dataLen attribute from the last StreamFrame, even if it queued a BLOCKED frame", func() { - length := 100 - streamFramer.blockedFrameQueue = []wire.Frame{&wire.StreamBlockedFrame{StreamID: 5}} - f := &wire.StreamFrame{ - StreamID: 5, - Data: bytes.Repeat([]byte{'f'}, length), - } - streamFramer.AddFrameForRetransmission(f) - p, err := packer.composeNextPacket(maxFrameSize, true) - Expect(err).ToNot(HaveOccurred()) - Expect(p).To(HaveLen(1)) - Expect(p[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - }) - - It("packs a connection-level BlockedFrame", func() { - streamFramer.blockedFrameQueue = []wire.Frame{&wire.BlockedFrame{}} - f := &wire.StreamFrame{ - StreamID: 5, - Data: []byte("foobar"), - } - streamFramer.AddFrameForRetransmission(f) - _, err := packer.composeNextPacket(maxFrameSize, true) - Expect(err).ToNot(HaveOccurred()) - Expect(packer.controlFrames[0]).To(Equal(&wire.BlockedFrame{})) - }) - }) - - It("returns nil if we only have a single STOP_WAITING", func() { - packer.QueueControlFrame(&wire.StopWaitingFrame{}) - p, err := packer.PackPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(p).To(BeNil()) }) It("packs a single ACK", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) ack := &wire.AckFrame{LargestAcked: 42} packer.QueueControlFrame(ack) p, err := packer.PackPacket() @@ -685,6 +588,8 @@ var _ = Describe("Packet packer", func() { }) It("does not return nil if we only have a single ACK but request it to be sent", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) ack := &wire.AckFrame{} packer.QueueControlFrame(ack) p, err := packer.PackPacket() @@ -692,15 +597,6 @@ var _ = Describe("Packet packer", func() { Expect(p).ToNot(BeNil()) }) - It("queues a control frame to be sent in the next packet", func() { - msd := &wire.MaxStreamDataFrame{StreamID: 5} - packer.QueueControlFrame(msd) - p, err := packer.PackPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0]).To(Equal(msd)) - }) - Context("retransmitting of handshake packets", func() { swf := &wire.StopWaitingFrame{LeastUnacked: 1} sf := &wire.StreamFrame{ @@ -719,8 +615,19 @@ var _ = Describe("Packet packer", func() { } p, err := packer.PackHandshakeRetransmission(packet) Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(ContainElement(sf)) - Expect(p.frames).To(ContainElement(swf)) + Expect(p.frames).To(Equal([]wire.Frame{swf, sf})) + Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) + }) + + It("doesn't add a STOP_WAITING frame for IETF QUIC", func() { + packer.version = versionIETFFrames + packet := &ackhandler.Packet{ + EncryptionLevel: protocol.EncryptionUnencrypted, + Frames: []wire.Frame{sf}, + } + p, err := packer.PackHandshakeRetransmission(packet) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(Equal([]wire.Frame{sf})) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) }) @@ -733,8 +640,7 @@ var _ = Describe("Packet packer", func() { } p, err := packer.PackHandshakeRetransmission(packet) Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(ContainElement(sf)) - Expect(p.frames).To(ContainElement(swf)) + Expect(p.frames).To(Equal([]wire.Frame{swf, sf})) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) // a packet sent by the server with initial encryption contains the SHLO // it needs to have a diversification nonce @@ -768,11 +674,16 @@ var _ = Describe("Packet packer", func() { }) It("pads Initial packets to the required minimum packet size", func() { + f := &wire.StreamFrame{ + StreamID: packer.version.CryptoStreamID(), + Data: []byte("foobar"), + } + mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true) + mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f) packer.version = protocol.VersionTLS packer.hasSentPacket = false packer.perspective = protocol.PerspectiveClient packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted - cryptoStream.dataForWriting = []byte("foobar") packet, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(packet.raw).To(HaveLen(protocol.MinInitialPacketSize)) @@ -795,7 +706,7 @@ var _ = Describe("Packet packer", func() { _, err := packer.PackHandshakeRetransmission(&ackhandler.Packet{ EncryptionLevel: protocol.EncryptionSecure, }) - Expect(err).To(MatchError("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame")) + Expect(err).To(MatchError("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame")) }) }) @@ -807,7 +718,7 @@ var _ = Describe("Packet packer", func() { Expect(p.frames).To(Equal([]wire.Frame{&wire.AckFrame{DelayTime: math.MaxInt64}})) }) - It("packs ACK packets with SWFs", func() { + It("packs ACK packets with STOP_WAITING frames", func() { packer.QueueControlFrame(&wire.AckFrame{}) packer.QueueControlFrame(&wire.StopWaitingFrame{}) p, err := packer.PackAckPacket() diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go b/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go index 7291dc2..45bdc0f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go +++ b/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go @@ -107,21 +107,32 @@ func (u *packetUnpacker) parseIETFFrame(r *bytes.Reader, typeByte byte, hdr *wir err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) } case 0x6: - // TODO(#964): remove STOP_WAITING frames - // TODO(#878): implement the MAX_STREAM_ID frame - frame, err = wire.ParseStopWaitingFrame(r, hdr.PacketNumber, hdr.PacketNumberLen, u.version) + frame, err = wire.ParseMaxStreamIDFrame(r, u.version) if err != nil { - err = qerr.Error(qerr.InvalidStopWaitingData, err.Error()) + err = qerr.Error(qerr.InvalidFrameData, err.Error()) } case 0x7: frame, err = wire.ParsePingFrame(r, u.version) case 0x8: frame, err = wire.ParseBlockedFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidBlockedData, err.Error()) + } case 0x9: frame, err = wire.ParseStreamBlockedFrame(r, u.version) if err != nil { err = qerr.Error(qerr.InvalidBlockedData, err.Error()) } + case 0xa: + frame, err = wire.ParseStreamIDBlockedFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidFrameData, err.Error()) + } + case 0xc: + frame, err = wire.ParseStopSendingFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidFrameData, err.Error()) + } case 0xe: frame, err = wire.ParseAckFrame(r, u.version) if err != nil { diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_unpacker_test.go b/vendor/github.com/lucas-clemente/quic-go/packet_unpacker_test.go index 91e0656..88a342d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/packet_unpacker_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/packet_unpacker_test.go @@ -102,7 +102,7 @@ var _ = Describe("Packet unpacker", func() { f := &wire.RstStreamFrame{ StreamID: 0xdeadbeef, ByteOffset: 0xdecafbad11223344, - ErrorCode: 0x13371234, + ErrorCode: 0x1337, } err := f.Write(buf, versionGQUICFrames) Expect(err).ToNot(HaveOccurred()) @@ -342,8 +342,19 @@ var _ = Describe("Packet unpacker", func() { Expect(packet.frames).To(Equal([]wire.Frame{f})) }) + It("unpacks MAX_STREAM_ID frames", func() { + f := &wire.MaxStreamIDFrame{StreamID: 0x1337} + buf := &bytes.Buffer{} + err := f.Write(buf, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + setData(buf.Bytes()) + packet, err := unpacker.Unpack(hdrBin, hdr, data) + Expect(err).ToNot(HaveOccurred()) + Expect(packet.frames).To(Equal([]wire.Frame{f})) + }) + It("unpacks connection-level BLOCKED frames", func() { - f := &wire.BlockedFrame{} + f := &wire.BlockedFrame{Offset: 0x1234} buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) @@ -354,7 +365,32 @@ var _ = Describe("Packet unpacker", func() { }) It("unpacks stream-level BLOCKED frames", func() { - f := &wire.StreamBlockedFrame{StreamID: 0xdeadbeef} + f := &wire.StreamBlockedFrame{ + StreamID: 0xdeadbeef, + Offset: 0xdead, + } + buf := &bytes.Buffer{} + err := f.Write(buf, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + setData(buf.Bytes()) + packet, err := unpacker.Unpack(hdrBin, hdr, data) + Expect(err).ToNot(HaveOccurred()) + Expect(packet.frames).To(Equal([]wire.Frame{f})) + }) + + It("unpacks STREAM_ID_BLOCKED frames", func() { + f := &wire.StreamIDBlockedFrame{StreamID: 0x1234567} + buf := &bytes.Buffer{} + err := f.Write(buf, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + setData(buf.Bytes()) + packet, err := unpacker.Unpack(hdrBin, hdr, data) + Expect(err).ToNot(HaveOccurred()) + Expect(packet.frames).To(Equal([]wire.Frame{f})) + }) + + It("unpacks STOP_SENDING frames", func() { + f := &wire.StopSendingFrame{StreamID: 0x42} buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) @@ -392,9 +428,13 @@ var _ = Describe("Packet unpacker", func() { 0x02: qerr.InvalidConnectionCloseData, 0x04: qerr.InvalidWindowUpdateData, 0x05: qerr.InvalidWindowUpdateData, + 0x06: qerr.InvalidFrameData, + 0x08: qerr.InvalidBlockedData, 0x09: qerr.InvalidBlockedData, + 0x0a: qerr.InvalidFrameData, + 0x0c: qerr.InvalidFrameData, + 0x0e: qerr.InvalidAckData, 0x10: qerr.InvalidStreamData, - 0xe: qerr.InvalidAckData, } { setData([]byte{b}) _, err := unpacker.Unpack(hdrBin, hdr, data) diff --git a/vendor/github.com/lucas-clemente/quic-go/qerr/errorcode_string.go b/vendor/github.com/lucas-clemente/quic-go/qerr/errorcode_string.go index 5a8e024..22d0c85 100644 --- a/vendor/github.com/lucas-clemente/quic-go/qerr/errorcode_string.go +++ b/vendor/github.com/lucas-clemente/quic-go/qerr/errorcode_string.go @@ -1,8 +1,8 @@ -// Code generated by "stringer -type=ErrorCode"; DO NOT EDIT +// Code generated by "stringer -type=ErrorCode"; DO NOT EDIT. package qerr -import "fmt" +import "strconv" const ( _ErrorCode_name_0 = "InternalErrorStreamDataAfterTerminationInvalidPacketHeaderInvalidFrameDataInvalidFecDataInvalidRstStreamDataInvalidConnectionCloseDataInvalidGoawayDataInvalidAckDataInvalidVersionNegotiationPacketInvalidPublicRstPacketDecryptionFailureEncryptionFailurePacketTooLarge" @@ -19,7 +19,6 @@ var ( _ErrorCode_index_2 = [...]uint16{0, 15, 37, 57, 75, 96, 112, 127, 147, 167, 191, 226, 250, 279, 309, 340, 366, 385, 410, 425, 445, 457, 475, 505, 530, 547} _ErrorCode_index_3 = [...]uint16{0, 14, 29, 50, 65, 90, 119, 158, 184, 208, 231, 249, 279, 301, 322, 340, 366, 390, 425} _ErrorCode_index_4 = [...]uint16{0, 16, 45, 78, 97, 114, 144, 169, 192, 215, 238, 256, 276, 292, 308, 346, 379, 410, 448, 459, 477, 498, 532} - _ErrorCode_index_5 = [...]uint8{0, 34} ) func (i ErrorCode) String() string { @@ -42,6 +41,6 @@ func (i ErrorCode) String() string { case i == 97: return _ErrorCode_name_5 default: - return fmt.Sprintf("ErrorCode(%d)", i) + return "ErrorCode(" + strconv.FormatInt(int64(i), 10) + ")" } } diff --git a/vendor/github.com/lucas-clemente/quic-go/receive_stream.go b/vendor/github.com/lucas-clemente/quic-go/receive_stream.go new file mode 100644 index 0000000..f793981 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/receive_stream.go @@ -0,0 +1,284 @@ +package quic + +import ( + "fmt" + "io" + "sync" + "time" + + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type receiveStreamI interface { + ReceiveStream + + handleStreamFrame(*wire.StreamFrame) error + handleRstStreamFrame(*wire.RstStreamFrame) error + closeForShutdown(error) + getWindowUpdate() protocol.ByteCount +} + +type receiveStream struct { + mutex sync.Mutex + + streamID protocol.StreamID + + sender streamSender + + frameQueue *streamFrameSorter + readPosInFrame int + readOffset protocol.ByteCount + + closeForShutdownErr error + cancelReadErr error + resetRemotelyErr StreamError + + closedForShutdown bool // set when CloseForShutdown() is called + finRead bool // set once we read a frame with a FinBit + canceledRead bool // set when CancelRead() is called + resetRemotely bool // set when HandleRstStreamFrame() is called + + readChan chan struct{} + readDeadline time.Time + + flowController flowcontrol.StreamFlowController + version protocol.VersionNumber +} + +var _ ReceiveStream = &receiveStream{} +var _ receiveStreamI = &receiveStream{} + +func newReceiveStream( + streamID protocol.StreamID, + sender streamSender, + flowController flowcontrol.StreamFlowController, +) *receiveStream { + return &receiveStream{ + streamID: streamID, + sender: sender, + flowController: flowController, + frameQueue: newStreamFrameSorter(), + readChan: make(chan struct{}, 1), + } +} + +func (s *receiveStream) StreamID() protocol.StreamID { + return s.streamID +} + +// Read implements io.Reader. It is not thread safe! +func (s *receiveStream) Read(p []byte) (int, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.finRead { + return 0, io.EOF + } + if s.canceledRead { + return 0, s.cancelReadErr + } + if s.resetRemotely { + return 0, s.resetRemotelyErr + } + if s.closedForShutdown { + return 0, s.closeForShutdownErr + } + + bytesRead := 0 + for bytesRead < len(p) { + frame := s.frameQueue.Head() + if frame == nil && bytesRead > 0 { + return bytesRead, s.closeForShutdownErr + } + + for { + // Stop waiting on errors + if s.closedForShutdown { + return bytesRead, s.closeForShutdownErr + } + if s.canceledRead { + return bytesRead, s.cancelReadErr + } + if s.resetRemotely { + return bytesRead, s.resetRemotelyErr + } + + deadline := s.readDeadline + if !deadline.IsZero() && !time.Now().Before(deadline) { + return bytesRead, errDeadline + } + + if frame != nil { + s.readPosInFrame = int(s.readOffset - frame.Offset) + break + } + + s.mutex.Unlock() + if deadline.IsZero() { + <-s.readChan + } else { + select { + case <-s.readChan: + case <-time.After(deadline.Sub(time.Now())): + } + } + s.mutex.Lock() + frame = s.frameQueue.Head() + } + + if bytesRead > len(p) { + return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) + } + if s.readPosInFrame > int(frame.DataLen()) { + return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, frame.DataLen()) + } + + s.mutex.Unlock() + + copy(p[bytesRead:], frame.Data[s.readPosInFrame:]) + m := utils.Min(len(p)-bytesRead, int(frame.DataLen())-s.readPosInFrame) + s.readPosInFrame += m + bytesRead += m + s.readOffset += protocol.ByteCount(m) + + s.mutex.Lock() + // when a RST_STREAM was received, the was already informed about the final byteOffset for this stream + if !s.resetRemotely { + s.flowController.AddBytesRead(protocol.ByteCount(m)) + } + // this call triggers the flow controller to increase the flow control window, if necessary + if s.flowController.HasWindowUpdate() { + s.sender.onHasWindowUpdate(s.streamID) + } + + if s.readPosInFrame >= int(frame.DataLen()) { + s.frameQueue.Pop() + s.finRead = frame.FinBit + if frame.FinBit { + s.sender.onStreamCompleted(s.streamID) + return bytesRead, io.EOF + } + } + } + return bytesRead, nil +} + +func (s *receiveStream) CancelRead(errorCode protocol.ApplicationErrorCode) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.finRead { + return nil + } + if s.canceledRead { + return nil + } + s.canceledRead = true + s.cancelReadErr = fmt.Errorf("Read on stream %d canceled with error code %d", s.streamID, errorCode) + s.signalRead() + if s.version.UsesIETFFrameFormat() { + s.sender.queueControlFrame(&wire.StopSendingFrame{ + StreamID: s.streamID, + ErrorCode: errorCode, + }) + } + return nil +} + +func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error { + maxOffset := frame.Offset + frame.DataLen() + if err := s.flowController.UpdateHighestReceived(maxOffset, frame.FinBit); err != nil { + return err + } + + s.mutex.Lock() + defer s.mutex.Unlock() + if err := s.frameQueue.Push(frame); err != nil && err != errDuplicateStreamData { + return err + } + s.signalRead() + return nil +} + +func (s *receiveStream) handleRstStreamFrame(frame *wire.RstStreamFrame) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.closedForShutdown { + return nil + } + if err := s.flowController.UpdateHighestReceived(frame.ByteOffset, true); err != nil { + return err + } + // In gQUIC, error code 0 has a special meaning. + // The peer will reliably continue transmitting, but is not interested in reading from the stream. + // We should therefore just continue reading from the stream, until we encounter the FIN bit. + if !s.version.UsesIETFFrameFormat() && frame.ErrorCode == 0 { + return nil + } + + // ignore duplicate RST_STREAM frames for this stream (after checking their final offset) + if s.resetRemotely { + return nil + } + s.resetRemotely = true + s.resetRemotelyErr = streamCanceledError{ + errorCode: frame.ErrorCode, + error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode), + } + s.signalRead() + s.sender.onStreamCompleted(s.streamID) + return nil +} + +func (s *receiveStream) CloseRemote(offset protocol.ByteCount) { + s.handleStreamFrame(&wire.StreamFrame{FinBit: true, Offset: offset}) +} + +func (s *receiveStream) onClose(offset protocol.ByteCount) { + if s.canceledRead && !s.version.UsesIETFFrameFormat() { + s.sender.queueControlFrame(&wire.RstStreamFrame{ + StreamID: s.streamID, + ByteOffset: offset, + ErrorCode: 0, + }) + } +} + +func (s *receiveStream) SetReadDeadline(t time.Time) error { + s.mutex.Lock() + oldDeadline := s.readDeadline + s.readDeadline = t + s.mutex.Unlock() + // if the new deadline is before the currently set deadline, wake up Read() + if t.Before(oldDeadline) { + s.signalRead() + } + return nil +} + +// CloseForShutdown closes a stream abruptly. +// It makes Read unblock (and return the error) immediately. +// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. +func (s *receiveStream) closeForShutdown(err error) { + s.mutex.Lock() + s.closedForShutdown = true + s.closeForShutdownErr = err + s.mutex.Unlock() + s.signalRead() +} + +func (s *receiveStream) getWindowUpdate() protocol.ByteCount { + return s.flowController.GetWindowUpdate() +} + +// signalRead performs a non-blocking send on the readChan +func (s *receiveStream) signalRead() { + select { + case s.readChan <- struct{}{}: + default: + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/receive_stream_test.go b/vendor/github.com/lucas-clemente/quic-go/receive_stream_test.go new file mode 100644 index 0000000..a6ac9f8 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/receive_stream_test.go @@ -0,0 +1,648 @@ +package quic + +import ( + "errors" + "io" + "runtime" + "time" + + "github.com/golang/mock/gomock" + "github.com/lucas-clemente/quic-go/internal/mocks" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gbytes" +) + +var _ = Describe("Receive Stream", func() { + const streamID protocol.StreamID = 1337 + + var ( + str *receiveStream + strWithTimeout io.Reader // str wrapped with gbytes.TimeoutReader + mockFC *mocks.MockStreamFlowController + mockSender *MockStreamSender + ) + + BeforeEach(func() { + mockSender = NewMockStreamSender(mockCtrl) + mockFC = mocks.NewMockStreamFlowController(mockCtrl) + str = newReceiveStream(streamID, mockSender, mockFC) + + timeout := scaleDuration(250 * time.Millisecond) + strWithTimeout = gbytes.TimeoutReader(str, timeout) + }) + + It("gets stream id", func() { + Expect(str.StreamID()).To(Equal(protocol.StreamID(1337))) + }) + + Context("reading", func() { + It("reads a single STREAM frame", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) + mockFC.EXPECT().HasWindowUpdate() + frame := wire.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, + } + err := str.handleStreamFrame(&frame) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 4) + n, err := strWithTimeout.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) + }) + + It("reads a single STREAM frame in multiple goes", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + mockFC.EXPECT().HasWindowUpdate().Times(2) + frame := wire.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, + } + err := str.handleStreamFrame(&frame) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 2) + n, err := strWithTimeout.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(2)) + Expect(b).To(Equal([]byte{0xDE, 0xAD})) + n, err = strWithTimeout.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(2)) + Expect(b).To(Equal([]byte{0xBE, 0xEF})) + }) + + It("reads all data available", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) + mockFC.EXPECT().HasWindowUpdate().Times(2) + frame1 := wire.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD}, + } + frame2 := wire.StreamFrame{ + Offset: 2, + Data: []byte{0xBE, 0xEF}, + } + err := str.handleStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + err = str.handleStreamFrame(&frame2) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 6) + n, err := strWithTimeout.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF, 0x00, 0x00})) + }) + + It("assembles multiple STREAM frames", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) + mockFC.EXPECT().HasWindowUpdate().Times(2) + frame1 := wire.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD}, + } + frame2 := wire.StreamFrame{ + Offset: 2, + Data: []byte{0xBE, 0xEF}, + } + err := str.handleStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + err = str.handleStreamFrame(&frame2) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 4) + n, err := strWithTimeout.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) + }) + + It("waits until data is available", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + mockFC.EXPECT().HasWindowUpdate() + go func() { + defer GinkgoRecover() + frame := wire.StreamFrame{Data: []byte{0xDE, 0xAD}} + time.Sleep(10 * time.Millisecond) + err := str.handleStreamFrame(&frame) + Expect(err).ToNot(HaveOccurred()) + }() + b := make([]byte, 2) + n, err := strWithTimeout.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(2)) + }) + + It("handles STREAM frames in wrong order", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) + mockFC.EXPECT().HasWindowUpdate().Times(2) + frame1 := wire.StreamFrame{ + Offset: 2, + Data: []byte{0xBE, 0xEF}, + } + frame2 := wire.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD}, + } + err := str.handleStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + err = str.handleStreamFrame(&frame2) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 4) + n, err := strWithTimeout.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) + }) + + It("ignores duplicate STREAM frames", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) + mockFC.EXPECT().HasWindowUpdate().Times(2) + frame1 := wire.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD}, + } + frame2 := wire.StreamFrame{ + Offset: 0, + Data: []byte{0x13, 0x37}, + } + frame3 := wire.StreamFrame{ + Offset: 2, + Data: []byte{0xBE, 0xEF}, + } + err := str.handleStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + err = str.handleStreamFrame(&frame2) + Expect(err).ToNot(HaveOccurred()) + err = str.handleStreamFrame(&frame3) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 4) + n, err := strWithTimeout.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) + }) + + It("doesn't rejects a STREAM frames with an overlapping data range", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) + mockFC.EXPECT().HasWindowUpdate().Times(2) + frame1 := wire.StreamFrame{ + Offset: 0, + Data: []byte("foob"), + } + frame2 := wire.StreamFrame{ + Offset: 2, + Data: []byte("obar"), + } + err := str.handleStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + err = str.handleStreamFrame(&frame2) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 6) + n, err := strWithTimeout.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(6)) + Expect(b).To(Equal([]byte("foobar"))) + }) + + It("passes on errors from the streamFrameSorter", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), false) + err := str.handleStreamFrame(&wire.StreamFrame{StreamID: streamID}) // STREAM frame without data + Expect(err).To(MatchError(errEmptyStreamData)) + }) + + It("calls the onHasWindowUpdate callback, when the a MAX_STREAM_DATA should be sent", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) + mockFC.EXPECT().HasWindowUpdate().Return(true) + mockSender.EXPECT().onHasWindowUpdate(streamID) + frame1 := wire.StreamFrame{ + Offset: 0, + Data: []byte("foobar"), + } + err := str.handleStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 6) + _, err = strWithTimeout.Read(b) + Expect(err).ToNot(HaveOccurred()) + }) + + Context("deadlines", func() { + It("the deadline error has the right net.Error properties", func() { + Expect(errDeadline.Temporary()).To(BeTrue()) + Expect(errDeadline.Timeout()).To(BeTrue()) + Expect(errDeadline).To(MatchError("deadline exceeded")) + }) + + It("returns an error when Read is called after the deadline", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false).AnyTimes() + f := &wire.StreamFrame{Data: []byte("foobar")} + err := str.handleStreamFrame(f) + Expect(err).ToNot(HaveOccurred()) + str.SetReadDeadline(time.Now().Add(-time.Second)) + b := make([]byte, 6) + n, err := strWithTimeout.Read(b) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + }) + + It("unblocks after the deadline", func() { + deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) + str.SetReadDeadline(deadline) + b := make([]byte, 6) + n, err := strWithTimeout.Read(b) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(10*time.Millisecond))) + }) + + It("doesn't unblock if the deadline is changed before the first one expires", func() { + deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond)) + deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond)) + str.SetReadDeadline(deadline1) + go func() { + defer GinkgoRecover() + time.Sleep(scaleDuration(20 * time.Millisecond)) + str.SetReadDeadline(deadline2) + // make sure that this was actually execute before the deadline expires + Expect(time.Now()).To(BeTemporally("<", deadline1)) + }() + runtime.Gosched() + b := make([]byte, 10) + n, err := strWithTimeout.Read(b) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) + }) + + It("unblocks earlier, when a new deadline is set", func() { + deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond)) + deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond)) + go func() { + defer GinkgoRecover() + time.Sleep(scaleDuration(10 * time.Millisecond)) + str.SetReadDeadline(deadline2) + // make sure that this was actually execute before the deadline expires + Expect(time.Now()).To(BeTemporally("<", deadline2)) + }() + str.SetReadDeadline(deadline1) + runtime.Gosched() + b := make([]byte, 10) + _, err := strWithTimeout.Read(b) + Expect(err).To(MatchError(errDeadline)) + Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(25*time.Millisecond))) + }) + }) + + Context("closing", func() { + Context("with FIN bit", func() { + It("returns EOFs", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) + mockFC.EXPECT().HasWindowUpdate() + str.handleStreamFrame(&wire.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, + FinBit: true, + }) + mockSender.EXPECT().onStreamCompleted(streamID) + b := make([]byte, 4) + n, err := strWithTimeout.Read(b) + Expect(err).To(MatchError(io.EOF)) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) + n, err = strWithTimeout.Read(b) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(io.EOF)) + }) + + It("handles out-of-order frames", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) + mockFC.EXPECT().HasWindowUpdate().Times(2) + frame1 := wire.StreamFrame{ + Offset: 2, + Data: []byte{0xBE, 0xEF}, + FinBit: true, + } + frame2 := wire.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD}, + } + err := str.handleStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + err = str.handleStreamFrame(&frame2) + Expect(err).ToNot(HaveOccurred()) + mockSender.EXPECT().onStreamCompleted(streamID) + b := make([]byte, 4) + n, err := strWithTimeout.Read(b) + Expect(err).To(MatchError(io.EOF)) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) + n, err = strWithTimeout.Read(b) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(io.EOF)) + }) + + It("returns EOFs with partial read", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), true) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + mockFC.EXPECT().HasWindowUpdate() + err := str.handleStreamFrame(&wire.StreamFrame{ + Offset: 0, + Data: []byte{0xde, 0xad}, + FinBit: true, + }) + Expect(err).ToNot(HaveOccurred()) + mockSender.EXPECT().onStreamCompleted(streamID) + b := make([]byte, 4) + n, err := strWithTimeout.Read(b) + Expect(err).To(MatchError(io.EOF)) + Expect(n).To(Equal(2)) + Expect(b[:n]).To(Equal([]byte{0xde, 0xad})) + }) + + It("handles immediate FINs", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) + mockFC.EXPECT().HasWindowUpdate() + err := str.handleStreamFrame(&wire.StreamFrame{ + Offset: 0, + FinBit: true, + }) + Expect(err).ToNot(HaveOccurred()) + mockSender.EXPECT().onStreamCompleted(streamID) + b := make([]byte, 4) + n, err := strWithTimeout.Read(b) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(io.EOF)) + }) + }) + + It("closes when CloseRemote is called", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) + mockFC.EXPECT().HasWindowUpdate() + str.CloseRemote(0) + mockSender.EXPECT().onStreamCompleted(streamID) + b := make([]byte, 8) + n, err := strWithTimeout.Read(b) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(io.EOF)) + }) + }) + + Context("closing for shutdown", func() { + testErr := errors.New("test error") + + It("immediately returns all reads", func() { + done := make(chan struct{}) + b := make([]byte, 4) + go func() { + defer GinkgoRecover() + n, err := strWithTimeout.Read(b) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(testErr)) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + str.closeForShutdown(testErr) + Eventually(done).Should(BeClosed()) + }) + + It("errors for all following reads", func() { + str.closeForShutdown(testErr) + b := make([]byte, 1) + n, err := strWithTimeout.Read(b) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(testErr)) + }) + }) + }) + + Context("stream cancelations", func() { + Context("canceling read", func() { + It("unblocks Read", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Read([]byte{0}) + Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234")) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + err := str.CancelRead(1234) + Expect(err).ToNot(HaveOccurred()) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't allow further calls to Read", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + err := str.CancelRead(1234) + Expect(err).ToNot(HaveOccurred()) + _, err = strWithTimeout.Read([]byte{0}) + Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234")) + }) + + It("does nothing when CancelRead is called twice", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + err := str.CancelRead(1234) + Expect(err).ToNot(HaveOccurred()) + err = str.CancelRead(2345) + Expect(err).ToNot(HaveOccurred()) + _, err = strWithTimeout.Read([]byte{0}) + Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234")) + }) + + It("doesn't send a RST_STREAM frame, if the FIN was already read", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) + mockFC.EXPECT().HasWindowUpdate() + // no calls to mockSender.queueControlFrame + err := str.handleStreamFrame(&wire.StreamFrame{ + StreamID: streamID, + Data: []byte("foobar"), + FinBit: true, + }) + Expect(err).ToNot(HaveOccurred()) + mockSender.EXPECT().onStreamCompleted(streamID) + _, err = strWithTimeout.Read(make([]byte, 100)) + Expect(err).To(MatchError(io.EOF)) + err = str.CancelRead(1234) + Expect(err).ToNot(HaveOccurred()) + }) + + It("queues a STOP_SENDING frame, for IETF QUIC", func() { + mockSender.EXPECT().queueControlFrame(&wire.StopSendingFrame{ + StreamID: streamID, + ErrorCode: 1234, + }) + err := str.CancelRead(1234) + Expect(err).ToNot(HaveOccurred()) + }) + + It("doesn't queue a STOP_SENDING frame, for gQUIC", func() { + + }) + }) + + Context("receiving RST_STREAM frames", func() { + rst := &wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 42, + ErrorCode: 1234, + } + + It("unblocks Read", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Read([]byte{0}) + Expect(err).To(MatchError("Stream 1337 was reset with error code 1234")) + Expect(err).To(BeAssignableToTypeOf(streamCanceledError{})) + Expect(err.(streamCanceledError).Canceled()).To(BeTrue()) + Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(1234))) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + mockSender.EXPECT().onStreamCompleted(streamID) + str.handleRstStreamFrame(rst) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't allow further calls to Read", func() { + mockSender.EXPECT().onStreamCompleted(streamID) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true) + err := str.handleRstStreamFrame(rst) + Expect(err).ToNot(HaveOccurred()) + _, err = strWithTimeout.Read([]byte{0}) + Expect(err).To(MatchError("Stream 1337 was reset with error code 1234")) + Expect(err).To(BeAssignableToTypeOf(streamCanceledError{})) + Expect(err.(streamCanceledError).Canceled()).To(BeTrue()) + Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(1234))) + }) + + It("errors when receiving a RST_STREAM with an inconsistent offset", func() { + testErr := errors.New("already received a different final offset before") + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Return(testErr) + err := str.handleRstStreamFrame(rst) + Expect(err).To(MatchError(testErr)) + }) + + It("ignores duplicate RST_STREAM frames", func() { + mockSender.EXPECT().onStreamCompleted(streamID) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2) + err := str.handleRstStreamFrame(rst) + Expect(err).ToNot(HaveOccurred()) + err = str.handleRstStreamFrame(rst) + Expect(err).ToNot(HaveOccurred()) + }) + + It("doesn't do anyting when it was closed for shutdown", func() { + str.closeForShutdown(nil) + err := str.handleRstStreamFrame(rst) + Expect(err).ToNot(HaveOccurred()) + }) + + Context("for gQUIC", func() { + BeforeEach(func() { + str.version = versionGQUICFrames + }) + + It("unblocks Read when receiving a RST_STREAM frame with non-zero error code", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true) + readReturned := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Read([]byte{0}) + Expect(err).To(MatchError("Stream 1337 was reset with error code 1234")) + Expect(err).To(BeAssignableToTypeOf(streamCanceledError{})) + Expect(err.(streamCanceledError).Canceled()).To(BeTrue()) + Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(1234))) + close(readReturned) + }() + Consistently(readReturned).ShouldNot(BeClosed()) + mockSender.EXPECT().onStreamCompleted(streamID) + err := str.handleRstStreamFrame(rst) + Expect(err).ToNot(HaveOccurred()) + Eventually(readReturned).Should(BeClosed()) + }) + + It("continues reading until the end when receiving a RST_STREAM frame with error code 0", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true).Times(2) + gomock.InOrder( + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)), + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)), + mockSender.EXPECT().onStreamCompleted(streamID), + ) + mockFC.EXPECT().HasWindowUpdate().Times(2) + readReturned := make(chan struct{}) + go func() { + defer GinkgoRecover() + n, err := strWithTimeout.Read(make([]byte, 4)) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(4)) + n, err = strWithTimeout.Read(make([]byte, 4)) + Expect(err).To(MatchError(io.EOF)) + Expect(n).To(Equal(2)) + close(readReturned) + }() + Consistently(readReturned).ShouldNot(BeClosed()) + err := str.handleStreamFrame(&wire.StreamFrame{ + StreamID: streamID, + Data: []byte("foobar"), + FinBit: true, + }) + Expect(err).ToNot(HaveOccurred()) + err = str.handleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 6, + ErrorCode: 0, + }) + Expect(err).ToNot(HaveOccurred()) + Eventually(readReturned).Should(BeClosed()) + }) + }) + }) + }) + + Context("flow control", func() { + It("errors when a STREAM frame causes a flow control violation", func() { + testErr := errors.New("flow control violation") + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), false).Return(testErr) + frame := wire.StreamFrame{ + Offset: 2, + Data: []byte("foobar"), + } + err := str.handleStreamFrame(&frame) + Expect(err).To(MatchError(testErr)) + }) + + It("gets a window update", func() { + mockFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x100)) + Expect(str.getWindowUpdate()).To(Equal(protocol.ByteCount(0x100))) + }) + }) +}) diff --git a/vendor/github.com/lucas-clemente/quic-go/send_stream.go b/vendor/github.com/lucas-clemente/quic-go/send_stream.go new file mode 100644 index 0000000..075d1bd --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/send_stream.go @@ -0,0 +1,313 @@ +package quic + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type sendStreamI interface { + SendStream + handleStopSendingFrame(*wire.StopSendingFrame) + popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool) + closeForShutdown(error) + handleMaxStreamDataFrame(*wire.MaxStreamDataFrame) +} + +type sendStream struct { + mutex sync.Mutex + + ctx context.Context + ctxCancel context.CancelFunc + + streamID protocol.StreamID + sender streamSender + + writeOffset protocol.ByteCount + + cancelWriteErr error + closeForShutdownErr error + + closedForShutdown bool // set when CloseForShutdown() is called + finishedWriting bool // set once Close() is called + canceledWrite bool // set when CancelWrite() is called, or a STOP_SENDING frame is received + finSent bool // set when a STREAM_FRAME with FIN bit has b + + dataForWriting []byte + writeChan chan struct{} + writeDeadline time.Time + + flowController flowcontrol.StreamFlowController + + version protocol.VersionNumber +} + +var _ SendStream = &sendStream{} +var _ sendStreamI = &sendStream{} + +func newSendStream( + streamID protocol.StreamID, + sender streamSender, + flowController flowcontrol.StreamFlowController, + version protocol.VersionNumber, +) *sendStream { + s := &sendStream{ + streamID: streamID, + sender: sender, + flowController: flowController, + writeChan: make(chan struct{}, 1), + version: version, + } + s.ctx, s.ctxCancel = context.WithCancel(context.Background()) + return s +} + +func (s *sendStream) StreamID() protocol.StreamID { + return s.streamID // same for receiveStream and sendStream +} + +func (s *sendStream) Write(p []byte) (int, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.finishedWriting { + return 0, fmt.Errorf("write on closed stream %d", s.streamID) + } + if s.canceledWrite { + return 0, s.cancelWriteErr + } + if s.closeForShutdownErr != nil { + return 0, s.closeForShutdownErr + } + if !s.writeDeadline.IsZero() && !time.Now().Before(s.writeDeadline) { + return 0, errDeadline + } + if len(p) == 0 { + return 0, nil + } + + s.dataForWriting = make([]byte, len(p)) + copy(s.dataForWriting, p) + s.sender.onHasStreamData(s.streamID) + + var bytesWritten int + var err error + for { + bytesWritten = len(p) - len(s.dataForWriting) + deadline := s.writeDeadline + if !deadline.IsZero() && !time.Now().Before(deadline) { + s.dataForWriting = nil + err = errDeadline + break + } + if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown { + break + } + + s.mutex.Unlock() + if deadline.IsZero() { + <-s.writeChan + } else { + select { + case <-s.writeChan: + case <-time.After(deadline.Sub(time.Now())): + } + } + s.mutex.Lock() + } + + if s.closeForShutdownErr != nil { + err = s.closeForShutdownErr + } else if s.cancelWriteErr != nil { + err = s.cancelWriteErr + } + return bytesWritten, err +} + +// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream +// maxBytes is the maximum length this frame (including frame header) will have. +func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.closeForShutdownErr != nil { + return nil, false + } + + frame := &wire.StreamFrame{ + StreamID: s.streamID, + Offset: s.writeOffset, + DataLenPresent: true, + } + frameLen := frame.MinLength(s.version) + if frameLen >= maxBytes { // a STREAM frame must have at least one byte of data + return nil, s.dataForWriting != nil + } + frame.Data, frame.FinBit = s.getDataForWriting(maxBytes - frameLen) + if len(frame.Data) == 0 && !frame.FinBit { + // this can happen if: + // - popStreamFrame is called but there's no data for writing + // - there's data for writing, but the stream is stream-level flow control blocked + // - there's data for writing, but the stream is connection-level flow control blocked + if s.dataForWriting == nil { + return nil, false + } + isBlocked, _ := s.flowController.IsBlocked() + return nil, !isBlocked + } + if frame.FinBit { + s.finSent = true + s.sender.onStreamCompleted(s.streamID) + } else if s.streamID != s.version.CryptoStreamID() { // TODO(#657): Flow control for the crypto stream + if isBlocked, offset := s.flowController.IsBlocked(); isBlocked { + s.sender.queueControlFrame(&wire.StreamBlockedFrame{ + StreamID: s.streamID, + Offset: offset, + }) + return frame, false + } + } + return frame, s.dataForWriting != nil +} + +func (s *sendStream) getDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) { + if s.dataForWriting == nil { + return nil, s.finishedWriting && !s.finSent + } + + // TODO(#657): Flow control for the crypto stream + if s.streamID != s.version.CryptoStreamID() { + maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize()) + } + if maxBytes == 0 { + return nil, false + } + + var ret []byte + if protocol.ByteCount(len(s.dataForWriting)) > maxBytes { + ret = s.dataForWriting[:maxBytes] + s.dataForWriting = s.dataForWriting[maxBytes:] + } else { + ret = s.dataForWriting + s.dataForWriting = nil + s.signalWrite() + } + s.writeOffset += protocol.ByteCount(len(ret)) + s.flowController.AddBytesSent(protocol.ByteCount(len(ret))) + return ret, s.finishedWriting && s.dataForWriting == nil && !s.finSent +} + +func (s *sendStream) Close() error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.canceledWrite { + return fmt.Errorf("Close called for canceled stream %d", s.streamID) + } + s.finishedWriting = true + s.sender.onHasStreamData(s.streamID) // need to send the FIN + s.ctxCancel() + return nil +} + +func (s *sendStream) CancelWrite(errorCode protocol.ApplicationErrorCode) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + return s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode)) +} + +// must be called after locking the mutex +func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, writeErr error) error { + if s.canceledWrite { + return nil + } + if s.finishedWriting { + return fmt.Errorf("CancelWrite for closed stream %d", s.streamID) + } + s.canceledWrite = true + s.cancelWriteErr = writeErr + s.signalWrite() + s.sender.queueControlFrame(&wire.RstStreamFrame{ + StreamID: s.streamID, + ByteOffset: s.writeOffset, + ErrorCode: errorCode, + }) + // TODO(#991): cancel retransmissions for this stream + s.ctxCancel() + s.sender.onStreamCompleted(s.streamID) + return nil +} + +func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { + s.mutex.Lock() + defer s.mutex.Unlock() + s.handleStopSendingFrameImpl(frame) +} + +func (s *sendStream) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) { + s.flowController.UpdateSendWindow(frame.ByteOffset) + s.mutex.Lock() + if s.dataForWriting != nil { + s.sender.onHasStreamData(s.streamID) + } + s.mutex.Unlock() +} + +// must be called after locking the mutex +func (s *sendStream) handleStopSendingFrameImpl(frame *wire.StopSendingFrame) { + writeErr := streamCanceledError{ + errorCode: frame.ErrorCode, + error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode), + } + errorCode := errorCodeStopping + if !s.version.UsesIETFFrameFormat() { + errorCode = errorCodeStoppingGQUIC + } + s.cancelWriteImpl(errorCode, writeErr) +} + +func (s *sendStream) Context() context.Context { + return s.ctx +} + +func (s *sendStream) SetWriteDeadline(t time.Time) error { + s.mutex.Lock() + oldDeadline := s.writeDeadline + s.writeDeadline = t + s.mutex.Unlock() + if t.Before(oldDeadline) { + s.signalWrite() + } + return nil +} + +// CloseForShutdown closes a stream abruptly. +// It makes Write unblock (and return the error) immediately. +// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. +func (s *sendStream) closeForShutdown(err error) { + s.mutex.Lock() + s.closedForShutdown = true + s.closeForShutdownErr = err + s.mutex.Unlock() + s.signalWrite() + s.ctxCancel() +} + +func (s *sendStream) getWriteOffset() protocol.ByteCount { + return s.writeOffset +} + +// signalWrite performs a non-blocking send on the writeChan +func (s *sendStream) signalWrite() { + select { + case s.writeChan <- struct{}{}: + default: + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/send_stream_test.go b/vendor/github.com/lucas-clemente/quic-go/send_stream_test.go new file mode 100644 index 0000000..d2718e1 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/send_stream_test.go @@ -0,0 +1,636 @@ +package quic + +import ( + "bytes" + "errors" + "io" + "runtime" + "time" + + "github.com/golang/mock/gomock" + "github.com/lucas-clemente/quic-go/internal/mocks" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gbytes" +) + +var _ = Describe("Send Stream", func() { + const streamID protocol.StreamID = 1337 + + var ( + str *sendStream + strWithTimeout io.Writer // str wrapped with gbytes.TimeoutWriter + mockFC *mocks.MockStreamFlowController + mockSender *MockStreamSender + ) + + BeforeEach(func() { + mockSender = NewMockStreamSender(mockCtrl) + mockFC = mocks.NewMockStreamFlowController(mockCtrl) + str = newSendStream(streamID, mockSender, mockFC, protocol.VersionWhatever) + + timeout := scaleDuration(250 * time.Millisecond) + strWithTimeout = gbytes.TimeoutWriter(str, timeout) + }) + + waitForWrite := func() { + EventuallyWithOffset(0, func() []byte { + str.mutex.Lock() + data := str.dataForWriting + str.mutex.Unlock() + return data + }).ShouldNot(BeEmpty()) + } + + It("gets stream id", func() { + Expect(str.StreamID()).To(Equal(protocol.StreamID(1337))) + }) + + Context("writing", func() { + It("writes and gets all data at once", func() { + mockSender.EXPECT().onHasStreamData(streamID) + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) + mockFC.EXPECT().IsBlocked() + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + n, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(6)) + close(done) + }() + waitForWrite() + f, _ := str.popStreamFrame(1000) + Expect(f.Data).To(Equal([]byte("foobar"))) + Expect(f.FinBit).To(BeFalse()) + Expect(f.Offset).To(BeZero()) + Expect(f.DataLenPresent).To(BeTrue()) + Expect(str.writeOffset).To(Equal(protocol.ByteCount(6))) + Expect(str.dataForWriting).To(BeNil()) + Eventually(done).Should(BeClosed()) + }) + + It("writes and gets data in two turns", func() { + mockSender.EXPECT().onHasStreamData(streamID) + frameHeaderLen := protocol.ByteCount(4) + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2) + mockFC.EXPECT().AddBytesSent(gomock.Any() /* protocol.ByteCount(3)*/).Times(2) + mockFC.EXPECT().IsBlocked().Times(2) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + n, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(6)) + close(done) + }() + waitForWrite() + f, _ := str.popStreamFrame(3 + frameHeaderLen) + Expect(f.Data).To(Equal([]byte("foo"))) + Expect(f.FinBit).To(BeFalse()) + Expect(f.Offset).To(BeZero()) + Expect(f.DataLenPresent).To(BeTrue()) + f, _ = str.popStreamFrame(100) + Expect(f.Data).To(Equal([]byte("bar"))) + Expect(f.FinBit).To(BeFalse()) + Expect(f.Offset).To(Equal(protocol.ByteCount(3))) + Expect(f.DataLenPresent).To(BeTrue()) + Expect(str.popStreamFrame(1000)).To(BeNil()) + Eventually(done).Should(BeClosed()) + }) + + It("popStreamFrame returns nil if no data is available", func() { + frame, hasMoreData := str.popStreamFrame(1000) + Expect(frame).To(BeNil()) + Expect(hasMoreData).To(BeFalse()) + }) + + It("says if it has more data for writing", func() { + mockSender.EXPECT().onHasStreamData(streamID) + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2) + mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2) + mockFC.EXPECT().IsBlocked().Times(2) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + n, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 100)) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(100)) + close(done) + }() + waitForWrite() + frame, hasMoreData := str.popStreamFrame(50) + Expect(hasMoreData).To(BeTrue()) + frame, hasMoreData = str.popStreamFrame(1000) + Expect(frame).ToNot(BeNil()) + Expect(hasMoreData).To(BeFalse()) + frame, _ = str.popStreamFrame(1000) + Expect(frame).To(BeNil()) + Eventually(done).Should(BeClosed()) + }) + + It("copies the slice while writing", func() { + mockSender.EXPECT().onHasStreamData(streamID) + frameHeaderSize := protocol.ByteCount(4) + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(1)) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(2)) + mockFC.EXPECT().IsBlocked().Times(2) + s := []byte("foo") + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + n, err := strWithTimeout.Write(s) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(3)) + close(done) + }() + waitForWrite() + frame, _ := str.popStreamFrame(frameHeaderSize + 1) + Expect(frame.Data).To(Equal([]byte("f"))) + s[1] = 'e' + f, _ := str.popStreamFrame(100) + Expect(f).ToNot(BeNil()) + Expect(f.Data).To(Equal([]byte("oo"))) + Eventually(done).Should(BeClosed()) + }) + + It("returns when given a nil input", func() { + n, err := strWithTimeout.Write(nil) + Expect(n).To(BeZero()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("returns when given an empty slice", func() { + n, err := strWithTimeout.Write([]byte("")) + Expect(n).To(BeZero()) + Expect(err).ToNot(HaveOccurred()) + }) + + It("cancels the context when Close is called", func() { + mockSender.EXPECT().onHasStreamData(streamID) + Expect(str.Context().Done()).ToNot(BeClosed()) + str.Close() + Expect(str.Context().Done()).To(BeClosed()) + }) + + Context("flow control blocking", func() { + It("returns nil when it is blocked", func() { + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(0)) + mockFC.EXPECT().IsBlocked().Return(true, protocol.ByteCount(10)) + mockSender.EXPECT().onHasStreamData(streamID) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + waitForWrite() + f, hasMoreData := str.popStreamFrame(1000) + Expect(f).To(BeNil()) + Expect(hasMoreData).To(BeFalse()) + // make the Write go routine return + str.closeForShutdown(nil) + Eventually(done).Should(BeClosed()) + }) + + It("queues a BLOCKED frame if the stream is flow control blocked", func() { + mockSender.EXPECT().onHasStreamData(streamID) + mockSender.EXPECT().queueControlFrame(&wire.StreamBlockedFrame{ + StreamID: streamID, + Offset: 10, + }) + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) + // don't use offset 6 here, to make sure the BLOCKED frame contains the number returned by the flow controller + mockFC.EXPECT().IsBlocked().Return(true, protocol.ByteCount(10)) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + waitForWrite() + f, hasMoreData := str.popStreamFrame(1000) + Expect(f).ToNot(BeNil()) + Expect(hasMoreData).To(BeFalse()) + Eventually(done).Should(BeClosed()) + }) + + It("says that it doesn't have any more data, when it is flow control blocked", func() { + mockSender.EXPECT().onHasStreamData(streamID) + mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) + mockFC.EXPECT().AddBytesSent(gomock.Any()) + mockFC.EXPECT().IsBlocked().Return(true, protocol.ByteCount(10)) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := str.Write(bytes.Repeat([]byte{0}, 100)) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + waitForWrite() + f, hasMoreData := str.popStreamFrame(50) + Expect(f).ToNot(BeNil()) + Expect(hasMoreData).To(BeFalse()) + // make the Write go routine return + str.closeForShutdown(nil) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't queue a BLOCKED frame if the stream is flow control blocked, but the frame popped has the FIN bit set", func() { + mockSender.EXPECT().onHasStreamData(streamID).Times(2) // once for the Write, once for the Close + mockSender.EXPECT().onStreamCompleted(streamID) + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) + // don't EXPECT a call to mockFC.IsBlocked + // don't EXPECT a call to mockSender.queueControlFrame + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + waitForWrite() + Expect(str.Close()).To(Succeed()) + f, hasMoreData := str.popStreamFrame(1000) + Expect(hasMoreData).To(BeFalse()) + Expect(f).ToNot(BeNil()) + Expect(f.FinBit).To(BeTrue()) + Eventually(done).Should(BeClosed()) + }) + }) + + Context("deadlines", func() { + It("returns an error when Write is called after the deadline", func() { + str.SetWriteDeadline(time.Now().Add(-time.Second)) + n, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + }) + + It("unblocks after the deadline", func() { + mockSender.EXPECT().onHasStreamData(streamID) + deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) + str.SetWriteDeadline(deadline) + n, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond))) + }) + + It("returns the number of bytes written, when the deadline expires", func() { + mockSender.EXPECT().onHasStreamData(streamID) + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(10000)).AnyTimes() + mockFC.EXPECT().AddBytesSent(gomock.Any()) + mockFC.EXPECT().IsBlocked() + deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) + str.SetWriteDeadline(deadline) + var n int + writeReturned := make(chan struct{}) + go func() { + defer GinkgoRecover() + var err error + n, err = strWithTimeout.Write(bytes.Repeat([]byte{0}, 100)) + Expect(err).To(MatchError(errDeadline)) + Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond))) + close(writeReturned) + }() + waitForWrite() + frame, hasMoreData := str.popStreamFrame(50) + Expect(frame).ToNot(BeNil()) + Expect(hasMoreData).To(BeTrue()) + Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed()) + Expect(n).To(BeEquivalentTo(frame.DataLen())) + }) + + It("doesn't pop any data after the deadline expired", func() { + mockSender.EXPECT().onHasStreamData(streamID) + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(10000)).AnyTimes() + mockFC.EXPECT().AddBytesSent(gomock.Any()) + mockFC.EXPECT().IsBlocked() + deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) + str.SetWriteDeadline(deadline) + writeReturned := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 100)) + Expect(err).To(MatchError(errDeadline)) + close(writeReturned) + }() + waitForWrite() + frame, hasMoreData := str.popStreamFrame(50) + Expect(frame).ToNot(BeNil()) + Expect(hasMoreData).To(BeTrue()) + Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed()) + frame, hasMoreData = str.popStreamFrame(50) + Expect(frame).To(BeNil()) + Expect(hasMoreData).To(BeFalse()) + }) + + It("doesn't unblock if the deadline is changed before the first one expires", func() { + mockSender.EXPECT().onHasStreamData(streamID) + deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond)) + deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond)) + str.SetWriteDeadline(deadline1) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + time.Sleep(scaleDuration(20 * time.Millisecond)) + str.SetWriteDeadline(deadline2) + // make sure that this was actually execute before the deadline expires + Expect(time.Now()).To(BeTemporally("<", deadline1)) + close(done) + }() + runtime.Gosched() + n, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) + Eventually(done).Should(BeClosed()) + }) + + It("unblocks earlier, when a new deadline is set", func() { + mockSender.EXPECT().onHasStreamData(streamID) + deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond)) + deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond)) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + time.Sleep(scaleDuration(10 * time.Millisecond)) + str.SetWriteDeadline(deadline2) + // make sure that this was actually execute before the deadline expires + Expect(time.Now()).To(BeTemporally("<", deadline2)) + close(done) + }() + str.SetWriteDeadline(deadline1) + runtime.Gosched() + _, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).To(MatchError(errDeadline)) + Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) + Eventually(done).Should(BeClosed()) + }) + }) + + Context("closing", func() { + It("doesn't allow writes after it has been closed", func() { + mockSender.EXPECT().onHasStreamData(streamID) + str.Close() + _, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).To(MatchError("write on closed stream 1337")) + }) + + It("allows FIN", func() { + mockSender.EXPECT().onHasStreamData(streamID) + mockSender.EXPECT().onStreamCompleted(streamID) + str.Close() + f, hasMoreData := str.popStreamFrame(1000) + Expect(f).ToNot(BeNil()) + Expect(f.Data).To(BeEmpty()) + Expect(f.FinBit).To(BeTrue()) + Expect(hasMoreData).To(BeFalse()) + }) + + It("doesn't send a FIN when there's still data", func() { + mockSender.EXPECT().onHasStreamData(streamID) + frameHeaderLen := protocol.ByteCount(4) + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2) + mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2) + mockFC.EXPECT().IsBlocked() + str.dataForWriting = []byte("foobar") + Expect(str.Close()).To(Succeed()) + f, _ := str.popStreamFrame(3 + frameHeaderLen) + Expect(f).ToNot(BeNil()) + Expect(f.Data).To(Equal([]byte("foo"))) + Expect(f.FinBit).To(BeFalse()) + mockSender.EXPECT().onStreamCompleted(streamID) + f, _ = str.popStreamFrame(100) + Expect(f.Data).To(Equal([]byte("bar"))) + Expect(f.FinBit).To(BeTrue()) + }) + + It("doesn't allow FIN after it is closed for shutdown", func() { + str.closeForShutdown(errors.New("test")) + f, hasMoreData := str.popStreamFrame(1000) + Expect(f).To(BeNil()) + Expect(hasMoreData).To(BeFalse()) + }) + + It("doesn't allow FIN twice", func() { + mockSender.EXPECT().onHasStreamData(streamID) + mockSender.EXPECT().onStreamCompleted(streamID) + str.Close() + f, _ := str.popStreamFrame(1000) + Expect(f).ToNot(BeNil()) + Expect(f.Data).To(BeEmpty()) + Expect(f.FinBit).To(BeTrue()) + f, hasMoreData := str.popStreamFrame(1000) + Expect(f).To(BeNil()) + Expect(hasMoreData).To(BeFalse()) + }) + }) + + Context("closing for shutdown", func() { + testErr := errors.New("test") + + It("returns errors when the stream is cancelled", func() { + str.closeForShutdown(testErr) + n, err := strWithTimeout.Write([]byte("foo")) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(testErr)) + }) + + It("doesn't get data for writing if an error occurred", func() { + mockSender.EXPECT().onHasStreamData(streamID) + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) + mockFC.EXPECT().AddBytesSent(gomock.Any()) + mockFC.EXPECT().IsBlocked() + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 500)) + Expect(err).To(MatchError(testErr)) + close(done) + }() + waitForWrite() + frame, hasMoreData := str.popStreamFrame(50) // get a STREAM frame containing some data, but not all + Expect(frame).ToNot(BeNil()) + Expect(hasMoreData).To(BeTrue()) + str.closeForShutdown(testErr) + frame, hasMoreData = str.popStreamFrame(1000) + Expect(frame).To(BeNil()) + Expect(hasMoreData).To(BeFalse()) + Eventually(done).Should(BeClosed()) + }) + + It("cancels the context", func() { + Expect(str.Context().Done()).ToNot(BeClosed()) + str.closeForShutdown(testErr) + Expect(str.Context().Done()).To(BeClosed()) + }) + }) + }) + + Context("handling MAX_STREAM_DATA frames", func() { + It("informs the flow controller", func() { + mockFC.EXPECT().UpdateSendWindow(protocol.ByteCount(0x1337)) + str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: streamID, + ByteOffset: 0x1337, + }) + }) + + It("says when it has data for sending", func() { + mockFC.EXPECT().UpdateSendWindow(gomock.Any()) + mockSender.EXPECT().onHasStreamData(streamID).Times(2) // once for Write, once for the MAX_STREAM_DATA frame + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + waitForWrite() + str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: streamID, + ByteOffset: 42, + }) + // make sure the Write go routine returns + str.closeForShutdown(nil) + Eventually(done).Should(BeClosed()) + }) + }) + + Context("stream cancelations", func() { + Context("canceling writing", func() { + It("queues a RST_STREAM frame", func() { + mockSender.EXPECT().queueControlFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 1234, + ErrorCode: 9876, + }) + mockSender.EXPECT().onStreamCompleted(streamID) + str.writeOffset = 1234 + err := str.CancelWrite(9876) + Expect(err).ToNot(HaveOccurred()) + }) + + It("unblocks Write", func() { + mockSender.EXPECT().onHasStreamData(streamID) + mockSender.EXPECT().onStreamCompleted(streamID) + mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(gomock.Any()) + mockFC.EXPECT().IsBlocked() + writeReturned := make(chan struct{}) + var n int + go func() { + defer GinkgoRecover() + var err error + n, err = strWithTimeout.Write(bytes.Repeat([]byte{0}, 100)) + Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234")) + close(writeReturned) + }() + waitForWrite() + frame, _ := str.popStreamFrame(50) + Expect(frame).ToNot(BeNil()) + err := str.CancelWrite(1234) + Expect(err).ToNot(HaveOccurred()) + Eventually(writeReturned).Should(BeClosed()) + Expect(n).To(BeEquivalentTo(frame.DataLen())) + }) + + It("cancels the context", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onStreamCompleted(streamID) + Expect(str.Context().Done()).ToNot(BeClosed()) + str.CancelWrite(1234) + Expect(str.Context().Done()).To(BeClosed()) + }) + + It("doesn't allow further calls to Write", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onStreamCompleted(streamID) + err := str.CancelWrite(1234) + Expect(err).ToNot(HaveOccurred()) + _, err = strWithTimeout.Write([]byte("foobar")) + Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234")) + }) + + It("only cancels once", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onStreamCompleted(streamID) + err := str.CancelWrite(1234) + Expect(err).ToNot(HaveOccurred()) + err = str.CancelWrite(4321) + Expect(err).ToNot(HaveOccurred()) + }) + + It("doesn't cancel when the stream was already closed", func() { + mockSender.EXPECT().onHasStreamData(streamID) + err := str.Close() + Expect(err).ToNot(HaveOccurred()) + err = str.CancelWrite(123) + Expect(err).To(MatchError("CancelWrite for closed stream 1337")) + }) + }) + + Context("receiving STOP_SENDING frames", func() { + It("queues a RST_STREAM frames with error code Stopping", func() { + mockSender.EXPECT().queueControlFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ErrorCode: errorCodeStopping, + }) + mockSender.EXPECT().onStreamCompleted(streamID) + str.handleStopSendingFrame(&wire.StopSendingFrame{ + StreamID: streamID, + ErrorCode: 101, + }) + }) + + It("unblocks Write", func() { + mockSender.EXPECT().onHasStreamData(streamID) + mockSender.EXPECT().queueControlFrame(gomock.Any()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := str.Write([]byte("foobar")) + Expect(err).To(MatchError("Stream 1337 was reset with error code 123")) + Expect(err).To(BeAssignableToTypeOf(streamCanceledError{})) + Expect(err.(streamCanceledError).Canceled()).To(BeTrue()) + Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(123))) + close(done) + }() + waitForWrite() + mockSender.EXPECT().onStreamCompleted(streamID) + str.handleStopSendingFrame(&wire.StopSendingFrame{ + StreamID: streamID, + ErrorCode: 123, + }) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't allow further calls to Write", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onStreamCompleted(streamID) + str.handleStopSendingFrame(&wire.StopSendingFrame{ + StreamID: streamID, + ErrorCode: 123, + }) + _, err := str.Write([]byte("foobar")) + Expect(err).To(MatchError("Stream 1337 was reset with error code 123")) + Expect(err).To(BeAssignableToTypeOf(streamCanceledError{})) + Expect(err.(streamCanceledError).Canceled()).To(BeTrue()) + Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(123))) + }) + }) + }) +}) diff --git a/vendor/github.com/lucas-clemente/quic-go/server.go b/vendor/github.com/lucas-clemente/quic-go/server.go index 4ca25a2..33d5883 100644 --- a/vendor/github.com/lucas-clemente/quic-go/server.go +++ b/vendor/github.com/lucas-clemente/quic-go/server.go @@ -19,8 +19,8 @@ import ( // packetHandler handles packets type packetHandler interface { Session - getCryptoStream() cryptoStream - handshakeStatus() <-chan handshakeEvent + getCryptoStream() cryptoStreamI + handshakeStatus() <-chan error handlePacket(*receivedPacket) GetVersion() protocol.VersionNumber run() error @@ -40,15 +40,17 @@ type server struct { certChain crypto.CertChain scfg *handshake.ServerConfig - sessions map[protocol.ConnectionID]packetHandler - sessionsMutex sync.RWMutex - deleteClosedSessionsAfter time.Duration + sessionsMutex sync.RWMutex + sessions map[protocol.ConnectionID]packetHandler + closed bool serverError error sessionQueue chan Session errorChan chan struct{} - newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, error) + // set as members, so they can be set in the tests + newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, error) + deleteClosedSessionsAfter time.Duration } var _ Listener = &server{} @@ -240,6 +242,12 @@ func (s *server) Accept() (Session, error) { // Close the server func (s *server) Close() error { s.sessionsMutex.Lock() + if s.closed { + s.sessionsMutex.Unlock() + return nil + } + s.closed = true + var wg sync.WaitGroup for _, session := range s.sessions { if session != nil { @@ -254,10 +262,9 @@ func (s *server) Close() error { s.sessionsMutex.Unlock() wg.Wait() - if s.conn == nil { - return nil - } - return s.conn.Close() + err := s.conn.Close() + <-s.errorChan // wait for serve() to return + return err } // Addr returns the server's network address @@ -384,14 +391,8 @@ func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.C }() go func() { - for { - ev := <-session.handshakeStatus() - if ev.err != nil { - return - } - if ev.encLevel == protocol.EncryptionForwardSecure { - break - } + if err := <-session.handshakeStatus(); err != nil { + return } s.sessionQueue <- session }() diff --git a/vendor/github.com/lucas-clemente/quic-go/server_test.go b/vendor/github.com/lucas-clemente/quic-go/server_test.go index 7e95acc..b49c74a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/server_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/server_test.go @@ -22,14 +22,13 @@ import ( ) type mockSession struct { - connectionID protocol.ConnectionID - packetCount int - closed bool - closeReason error - closedRemote bool - stopRunLoop chan struct{} // run returns as soon as this channel receives a value - handshakeChan chan handshakeEvent - handshakeComplete chan error // for WaitUntilHandshakeComplete + connectionID protocol.ConnectionID + packetCount int + closed bool + closeReason error + closedRemote bool + stopRunLoop chan struct{} // run returns as soon as this channel receives a value + handshakeChan chan error } func (s *mockSession) handlePacket(*receivedPacket) { @@ -40,9 +39,6 @@ func (s *mockSession) run() error { <-s.stopRunLoop return s.closeReason } -func (s *mockSession) WaitUntilHandshakeComplete() error { - return <-s.handshakeComplete -} func (s *mockSession) Close(e error) error { if s.closed { return nil @@ -59,19 +55,19 @@ func (s *mockSession) closeRemote(e error) { close(s.stopRunLoop) } func (s *mockSession) OpenStream() (Stream, error) { - return &stream{streamID: 1337}, nil + return &stream{}, nil } -func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") } -func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") } -func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") } -func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") } -func (*mockSession) Context() context.Context { panic("not implemented") } -func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever } -func (s *mockSession) handshakeStatus() <-chan handshakeEvent { return s.handshakeChan } -func (*mockSession) getCryptoStream() cryptoStream { panic("not implemented") } +func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") } +func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") } +func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") } +func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") } +func (*mockSession) Context() context.Context { panic("not implemented") } +func (*mockSession) ConnectionState() ConnectionState { panic("not implemented") } +func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever } +func (s *mockSession) handshakeStatus() <-chan error { return s.handshakeChan } +func (*mockSession) getCryptoStream() cryptoStreamI { panic("not implemented") } var _ Session = &mockSession{} -var _ NonFWSession = &mockSession{} func newMockSession( _ connection, @@ -82,10 +78,9 @@ func newMockSession( _ *Config, ) (packetHandler, error) { s := mockSession{ - connectionID: connectionID, - handshakeChan: make(chan handshakeEvent), - handshakeComplete: make(chan error), - stopRunLoop: make(chan struct{}), + connectionID: connectionID, + handshakeChan: make(chan error), + stopRunLoop: make(chan struct{}), } return &s, nil } @@ -155,9 +150,8 @@ var _ = Describe("Server", func() { Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) sess := serv.sessions[connID].(*mockSession) - sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} Consistently(func() Session { return acceptedSess }).Should(BeNil()) - sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionForwardSecure} + close(sess.handshakeChan) Eventually(func() Session { return acceptedSess }).Should(Equal(sess)) close(done) }, 0.5) @@ -173,7 +167,7 @@ var _ = Describe("Server", func() { Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) sess := serv.sessions[connID].(*mockSession) - sess.handshakeChan <- handshakeEvent{err: errors.New("handshake failed")} + sess.handshakeChan <- errors.New("handshake failed") Consistently(func() bool { return accepted }).Should(BeFalse()) close(done) }) @@ -222,6 +216,7 @@ var _ = Describe("Server", func() { }) It("closes sessions and the connection when Close is called", func() { + go serv.serve() session, _ := newMockSession(nil, 0, 0, nil, nil, nil) serv.sessions[1] = session err := serv.Close() diff --git a/vendor/github.com/lucas-clemente/quic-go/server_tls.go b/vendor/github.com/lucas-clemente/quic-go/server_tls.go index a40a8f5..c4a5fb1 100644 --- a/vendor/github.com/lucas-clemente/quic-go/server_tls.go +++ b/vendor/github.com/lucas-clemente/quic-go/server_tls.go @@ -12,6 +12,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/qerr" ) type nullAEAD struct { @@ -98,6 +99,26 @@ func (s *serverTLS) newMintConnImpl(bc *handshake.CryptoStreamConn, v protocol.V return tls, extHandler.GetPeerParams(), nil } +func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Header, aead crypto.AEAD, closeErr error) error { + ccf := &wire.ConnectionCloseFrame{ + ErrorCode: qerr.HandshakeFailed, + ReasonPhrase: closeErr.Error(), + } + replyHdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + ConnectionID: clientHdr.ConnectionID, // echo the client's connection ID + PacketNumber: 1, // random packet number + Version: clientHdr.Version, + } + data, err := packUnencryptedPacket(aead, replyHdr, ccf, protocol.PerspectiveServer) + if err != nil { + return err + } + _, err = s.conn.WriteTo(data, remoteAddr) + return err +} + func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, error) { if len(hdr.Raw)+len(data) < protocol.MinInitialPacketSize { return nil, errors.New("dropping too small Initial packet") @@ -110,19 +131,30 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat } // unpack packet and check stream frame contents - version := hdr.Version - aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.ConnectionID, version) + aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.ConnectionID, hdr.Version) if err != nil { return nil, err } - frame, err := unpackInitialPacket(aead, hdr, data, version) + frame, err := unpackInitialPacket(aead, hdr, data, hdr.Version) if err != nil { utils.Debugf("Error unpacking initial packet: %s", err) return nil, nil } + sess, err := s.handleUnpackedInitial(remoteAddr, hdr, frame, aead) + if err != nil { + if ccerr := s.sendConnectionClose(remoteAddr, hdr, aead, err); ccerr != nil { + utils.Debugf("Error sending CONNECTION_CLOSE: ", ccerr) + } + return nil, err + } + return sess, nil +} + +func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (packetHandler, error) { + version := hdr.Version bc := handshake.NewCryptoStreamConn(remoteAddr) bc.AddDataForReading(frame.Data) - tls, paramsChan, err := s.newMintConn(bc, hdr.Version) + tls, paramsChan, err := s.newMintConn(bc, version) if err != nil { return nil, err } @@ -176,7 +208,7 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat return nil, err } cs := sess.getCryptoStream() - cs.SetReadOffset(frame.DataLen()) + cs.setReadOffset(frame.DataLen()) bc.SetStream(cs) return sess, nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/server_tls_test.go b/vendor/github.com/lucas-clemente/quic-go/server_tls_test.go index 9c8d00e..7cead5c 100644 --- a/vendor/github.com/lucas-clemente/quic-go/server_tls_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/server_tls_test.go @@ -4,15 +4,16 @@ import ( "bytes" "io" - "github.com/lucas-clemente/quic-go/internal/mocks" - "github.com/lucas-clemente/quic-go/internal/mocks/handshake" - "github.com/bifurcation/mint" "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/mocks" + "github.com/lucas-clemente/quic-go/internal/mocks/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/qerr" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -65,6 +66,18 @@ var _ = Describe("Stateless TLS handling", func() { return hdr, data } + unpackPacket := func(data []byte) (*wire.Header, []byte) { + r := bytes.NewReader(conn.dataWritten.Bytes()) + hdr, err := wire.ParseHeaderSentByServer(r, protocol.VersionTLS) + Expect(err).ToNot(HaveOccurred()) + hdr.Raw = data[:len(data)-r.Len()] + aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, hdr.ConnectionID, protocol.VersionTLS) + Expect(err).ToNot(HaveOccurred()) + payload, err := aead.Open(nil, data[len(data)-r.Len():], hdr.PacketNumber, hdr.Raw) + Expect(err).ToNot(HaveOccurred()) + return hdr, payload + } + It("sends a version negotiation packet if it doesn't support the version", func() { server.HandleInitial(nil, &wire.Header{Version: 0x1337}, bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize)) Expect(conn.dataWritten.Len()).ToNot(BeZero()) @@ -124,4 +137,20 @@ var _ = Describe("Stateless TLS handling", func() { Eventually(sessionChan).Should(Receive()) Eventually(done).Should(BeClosed()) }) + + It("sends a CONNECTION_CLOSE, if mint returns an error", func() { + mintTLS.EXPECT().Handshake().Return(mint.AlertAccessDenied) + extHandler.EXPECT().GetPeerParams() + hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")}) + server.HandleInitial(nil, hdr, data) + // the Handshake packet is written by the session + Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty()) + // unpack the packet to check that it actually contains a CONNECTION_CLOSE + hdr, data = unpackPacket(conn.dataWritten.Bytes()) + Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake)) + ccf, err := wire.ParseConnectionCloseFrame(bytes.NewReader(data), protocol.VersionTLS) + Expect(err).ToNot(HaveOccurred()) + Expect(ccf.ErrorCode).To(Equal(qerr.HandshakeFailed)) + Expect(ccf.ReasonPhrase).To(Equal(mint.AlertAccessDenied.String())) + }) }) diff --git a/vendor/github.com/lucas-clemente/quic-go/session.go b/vendor/github.com/lucas-clemente/quic-go/session.go index 467c992..992888d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/session.go +++ b/vendor/github.com/lucas-clemente/quic-go/session.go @@ -4,7 +4,6 @@ import ( "context" "crypto/tls" "errors" - "fmt" "net" "sync" "time" @@ -24,6 +23,23 @@ type unpacker interface { Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) } +type streamGetter interface { + GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error) + GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error) +} + +type streamManager interface { + GetOrOpenStream(protocol.StreamID) (streamI, error) + GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error) + GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error) + OpenStream() (Stream, error) + OpenStreamSync() (Stream, error) + AcceptStream() (Stream, error) + DeleteStream(protocol.StreamID) error + UpdateLimits(*handshake.TransportParameters) + CloseWithError(error) +} + type receivedPacket struct { remoteAddr net.Addr header *wire.Header @@ -36,11 +52,6 @@ var ( newCryptoSetupClient = handshake.NewCryptoSetupClient ) -type handshakeEvent struct { - encLevel protocol.EncryptionLevel - err error -} - type closeError struct { err error remote bool @@ -55,16 +66,16 @@ type session struct { conn connection - streamsMap *streamsMap - cryptoStream cryptoStream + streamsMap streamManager + cryptoStream cryptoStreamI rttStats *congestion.RTTStats sentPacketHandler ackhandler.SentPacketHandler receivedPacketHandler ackhandler.ReceivedPacketHandler streamFramer *streamFramer - - connFlowController flowcontrol.ConnectionFlowController + windowUpdateQueue *windowUpdateQueue + connFlowController flowcontrol.ConnectionFlowController unpacker unpacker packer *packetPacker @@ -87,17 +98,14 @@ type session struct { // this channel is passed to the CryptoSetup and receives the transport parameters, as soon as the peer sends them paramsChan <-chan handshake.TransportParameters - // this channel is passed to the CryptoSetup and receives the current encryption level - // it is closed as soon as the handshake is complete - aeadChanged <-chan protocol.EncryptionLevel + // the handshakeEvent channel is passed to the CryptoSetup. + // It receives when it makes sense to try decrypting undecryptable packets. + handshakeEvent <-chan struct{} + // handshakeChan is returned by handshakeStatus. + // It receives any error that might occur during the handshake. + // It is closed when the handshake is complete. + handshakeChan chan error handshakeComplete bool - // will be closed as soon as the handshake completes, and receive any error that might occur until then - // it is used to block WaitUntilHandshakeComplete() - handshakeCompleteChan chan error - // handshakeChan receives handshake events and is closed as soon the handshake completes - // the receiving end of this channel is passed to the creator of the session - // it receives at most 3 handshake events: 2 when the encryption level changes, and one error - handshakeChan chan handshakeEvent lastRcvdPacketNumber protocol.PacketNumber // Used to calculate the next packet number from the truncated wire @@ -116,6 +124,7 @@ type session struct { } var _ Session = &session{} +var _ streamSender = &session{} // newSession makes a new session func newSession( @@ -127,15 +136,15 @@ func newSession( config *Config, ) (packetHandler, error) { paramsChan := make(chan handshake.TransportParameters) - aeadChanged := make(chan protocol.EncryptionLevel, 2) + handshakeEvent := make(chan struct{}, 1) s := &session{ - conn: conn, - connectionID: connectionID, - perspective: protocol.PerspectiveServer, - version: v, - config: config, - aeadChanged: aeadChanged, - paramsChan: paramsChan, + conn: conn, + connectionID: connectionID, + perspective: protocol.PerspectiveServer, + version: v, + config: config, + handshakeEvent: handshakeEvent, + paramsChan: paramsChan, } s.preSetup() transportParams := &handshake.TransportParameters{ @@ -154,7 +163,7 @@ func newSession( s.config.Versions, s.config.AcceptCookie, paramsChan, - aeadChanged, + handshakeEvent, ) if err != nil { return nil, err @@ -175,15 +184,15 @@ var newClientSession = func( negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton ) (packetHandler, error) { paramsChan := make(chan handshake.TransportParameters) - aeadChanged := make(chan protocol.EncryptionLevel, 2) + handshakeEvent := make(chan struct{}, 1) s := &session{ - conn: conn, - connectionID: connectionID, - perspective: protocol.PerspectiveClient, - version: v, - config: config, - aeadChanged: aeadChanged, - paramsChan: paramsChan, + conn: conn, + connectionID: connectionID, + perspective: protocol.PerspectiveClient, + version: v, + config: config, + handshakeEvent: handshakeEvent, + paramsChan: paramsChan, } s.preSetup() transportParams := &handshake.TransportParameters{ @@ -201,7 +210,7 @@ var newClientSession = func( tlsConf, transportParams, paramsChan, - aeadChanged, + handshakeEvent, initialVersion, negotiatedVersions, ) @@ -223,21 +232,21 @@ func newTLSServerSession( peerParams *handshake.TransportParameters, v protocol.VersionNumber, ) (packetHandler, error) { - aeadChanged := make(chan protocol.EncryptionLevel, 2) + handshakeEvent := make(chan struct{}, 1) s := &session{ - conn: conn, - config: config, - connectionID: connectionID, - perspective: protocol.PerspectiveServer, - version: v, - aeadChanged: aeadChanged, + conn: conn, + config: config, + connectionID: connectionID, + perspective: protocol.PerspectiveServer, + version: v, + handshakeEvent: handshakeEvent, } s.preSetup() s.cryptoSetup = handshake.NewCryptoSetupTLSServer( tls, cryptoStreamConn, nullAEAD, - aeadChanged, + handshakeEvent, v, ) if err := s.postSetup(initialPacketNumber); err != nil { @@ -260,15 +269,15 @@ var newTLSClientSession = func( paramsChan <-chan handshake.TransportParameters, initialPacketNumber protocol.PacketNumber, ) (packetHandler, error) { - aeadChanged := make(chan protocol.EncryptionLevel, 2) + handshakeEvent := make(chan struct{}, 1) s := &session{ - conn: conn, - config: config, - connectionID: connectionID, - perspective: protocol.PerspectiveClient, - version: v, - aeadChanged: aeadChanged, - paramsChan: paramsChan, + conn: conn, + config: config, + connectionID: connectionID, + perspective: protocol.PerspectiveClient, + version: v, + handshakeEvent: handshakeEvent, + paramsChan: paramsChan, } s.preSetup() tls.SetCryptoStream(s.cryptoStream) @@ -276,7 +285,7 @@ var newTLSClientSession = func( s.cryptoStream, s.connectionID, hostname, - aeadChanged, + handshakeEvent, tls, v, ) @@ -294,12 +303,11 @@ func (s *session) preSetup() { protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), s.rttStats, ) - s.cryptoStream = s.newStream(s.version.CryptoStreamID()).(cryptoStream) + s.cryptoStream = s.newCryptoStream() } func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error { - s.handshakeChan = make(chan handshakeEvent, 3) - s.handshakeCompleteChan = make(chan error, 1) + s.handshakeChan = make(chan error, 1) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) s.closeChan = make(chan closeError, 1) s.sendingScheduled = make(chan struct{}, 1) @@ -314,9 +322,12 @@ func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error { s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version) - s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.version) - s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.connFlowController, s.version) - + if s.version.UsesTLS() { + s.streamsMap = newStreamsMap(s.newStream, s.perspective) + } else { + s.streamsMap = newStreamsMapLegacy(s.newStream, s.perspective) + } + s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) s.packer = newPacketPacker(s.connectionID, initialPacketNumber, s.cryptoSetup, @@ -324,6 +335,7 @@ func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error { s.perspective, s.version, ) + s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.cryptoStream, s.packer.QueueControlFrame) s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version} return nil } @@ -339,7 +351,7 @@ func (s *session) run() error { }() var closeErr closeError - aeadChanged := s.aeadChanged + handshakeEvent := s.handshakeEvent runLoop: for { @@ -377,16 +389,20 @@ runLoop: putPacketBuffer(p.header.Raw) case p := <-s.paramsChan: s.processTransportParameters(&p) - case l, ok := <-aeadChanged: + case _, ok := <-handshakeEvent: if !ok { // the aeadChanged chan was closed. This means that the handshake is completed. s.handshakeComplete = true - aeadChanged = nil // prevent this case from ever being selected again + handshakeEvent = nil // prevent this case from ever being selected again s.sentPacketHandler.SetHandshakeComplete() + if !s.version.UsesTLS() && s.perspective == protocol.PerspectiveClient { + // In gQUIC, there's no equivalent to the Finished message in TLS + // The server knows that the handshake is complete when it receives the first forward-secure packet sent by the client. + // We need to make sure that the client actually sends such a packet. + s.packer.QueueControlFrame(&wire.PingFrame{}) + } close(s.handshakeChan) - close(s.handshakeCompleteChan) } else { s.tryDecryptingQueuedPackets() - s.handshakeChan <- handshakeEvent{encLevel: l} } } @@ -403,9 +419,26 @@ runLoop: s.keepAlivePingSent = true } - if err := s.sendPacket(); err != nil { - s.closeLocal(err) + sendingAllowed := s.sentPacketHandler.SendingAllowed() + if !sendingAllowed { // if congestion limited, at least try sending an ACK frame + if err := s.maybeSendAckOnlyPacket(); err != nil { + s.closeLocal(err) + } + } else { + // repeatedly try sending until we don't have any more data, or run out of the congestion window + for sendingAllowed { + sentPacket, err := s.sendPacket() + if err != nil { + s.closeLocal(err) + break + } + if !sentPacket { + break + } + sendingAllowed = s.sentPacketHandler.SendingAllowed() + } } + if !s.receivedTooManyUndecrytablePacketsTime.IsZero() && s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout).Before(now) && len(s.undecryptablePackets) != 0 { s.closeLocal(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received")) } @@ -415,17 +448,12 @@ runLoop: if s.handshakeComplete && now.Sub(s.lastNetworkActivityTime) >= s.config.IdleTimeout { s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) } - - if err := s.streamsMap.DeleteClosedStreams(); err != nil { - s.closeLocal(err) - } } // only send the error the handshakeChan when the handshake is not completed yet // otherwise this chan will already be closed if !s.handshakeComplete { - s.handshakeCompleteChan <- closeErr.err - s.handshakeChan <- handshakeEvent{err: closeErr.err} + s.handshakeChan <- closeErr.err } s.handleCloseError(closeErr) return closeErr.err @@ -435,6 +463,10 @@ func (s *session) Context() context.Context { return s.ctx } +func (s *session) ConnectionState() ConnectionState { + return s.cryptoSetup.ConnectionState() +} + func (s *session) maybeResetTimer() { var deadline time.Time if s.config.KeepAlive && s.handshakeComplete && !s.keepAlivePingSent { @@ -504,7 +536,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber) isRetransmittable := ackhandler.HasRetransmittableFrames(packet.frames) - if err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, isRetransmittable); err != nil { + if err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, p.rcvTime, isRetransmittable); err != nil { return err } @@ -524,8 +556,7 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase)) case *wire.GoawayFrame: err = errors.New("unimplemented: handling GOAWAY frames") - case *wire.StopWaitingFrame: - s.receivedPacketHandler.IgnoreBelow(frame.LeastUnacked) + case *wire.StopWaitingFrame: // ignore STOP_WAITINGs case *wire.RstStreamFrame: err = s.handleRstStreamFrame(frame) case *wire.MaxDataFrame: @@ -534,6 +565,8 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve err = s.handleMaxStreamDataFrame(frame) case *wire.BlockedFrame: case *wire.StreamBlockedFrame: + case *wire.StopSendingFrame: + err = s.handleStopSendingFrame(frame) case *wire.PingFrame: default: return errors.New("Session BUG: unexpected frame type") @@ -563,9 +596,12 @@ func (s *session) handlePacket(p *receivedPacket) { func (s *session) handleStreamFrame(frame *wire.StreamFrame) error { if frame.StreamID == s.version.CryptoStreamID() { - return s.cryptoStream.AddStreamFrame(frame) + if frame.FinBit { + return errors.New("Received STREAM frame with FIN bit for the crypto stream") + } + return s.cryptoStream.handleStreamFrame(frame) } - str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) + str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) if err != nil { return err } @@ -574,7 +610,7 @@ func (s *session) handleStreamFrame(frame *wire.StreamFrame) error { // ignore this StreamFrame return nil } - return str.AddStreamFrame(frame) + return str.handleStreamFrame(frame) } func (s *session) handleMaxDataFrame(frame *wire.MaxDataFrame) { @@ -582,7 +618,11 @@ func (s *session) handleMaxDataFrame(frame *wire.MaxDataFrame) { } func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error { - str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) + if frame.StreamID == s.version.CryptoStreamID() { + s.cryptoStream.handleMaxStreamDataFrame(frame) + return nil + } + str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID) if err != nil { return err } @@ -590,12 +630,15 @@ func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error // stream is closed and already garbage collected return nil } - str.UpdateSendWindow(frame.ByteOffset) + str.handleMaxStreamDataFrame(frame) return nil } func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error { - str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) + if frame.StreamID == s.version.CryptoStreamID() { + return errors.New("Received RST_STREAM frame for the crypto stream") + } + str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) if err != nil { return err } @@ -603,11 +646,31 @@ func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error { // stream is closed and already garbage collected return nil } - return str.RegisterRemoteError(fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode), frame.ByteOffset) + return str.handleRstStreamFrame(frame) +} + +func (s *session) handleStopSendingFrame(frame *wire.StopSendingFrame) error { + if frame.StreamID == s.version.CryptoStreamID() { + return errors.New("Received a STOP_SENDING frame for the crypto stream") + } + str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID) + if err != nil { + return err + } + if str == nil { + // stream is closed and already garbage collected + return nil + } + str.handleStopSendingFrame(frame) + return nil } func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error { - return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime) + if err := s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime); err != nil { + return err + } + s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked()) + return nil } func (s *session) closeLocal(e error) { @@ -647,7 +710,7 @@ func (s *session) handleCloseError(closeErr closeError) error { utils.Errorf("Closing session with error: %s", closeErr.err.Error()) } - s.cryptoStream.Cancel(quicErr) + s.cryptoStream.closeForShutdown(quicErr) s.streamsMap.CloseWithError(quicErr) if closeErr.err == errCloseSessionForNewVersion || closeErr.err == handshake.ErrCloseSessionForRetry { @@ -669,111 +732,105 @@ func (s *session) handleCloseError(closeErr closeError) error { func (s *session) processTransportParameters(params *handshake.TransportParameters) { s.peerParams = params - s.streamsMap.UpdateMaxStreamLimit(params.MaxStreams) + s.streamsMap.UpdateLimits(params) if params.OmitConnectionID { s.packer.SetOmitConnectionID() } s.connFlowController.UpdateSendWindow(params.ConnectionFlowControlWindow) - s.streamsMap.Range(func(str streamI) { - str.UpdateSendWindow(params.StreamFlowControlWindow) - }) + // the crypto stream is the only open stream at this moment + // so we don't need to update stream flow control windows } -func (s *session) sendPacket() error { +func (s *session) maybeSendAckOnlyPacket() error { + ack := s.receivedPacketHandler.GetAckFrame() + if ack == nil { + return nil + } + s.packer.QueueControlFrame(ack) + + if !s.version.UsesIETFFrameFormat() { // for gQUIC, maybe add a STOP_WAITING + if swf := s.sentPacketHandler.GetStopWaitingFrame(false); swf != nil { + s.packer.QueueControlFrame(swf) + } + } + packet, err := s.packer.PackAckPacket() + if err != nil { + return err + } + return s.sendPackedPacket(packet) +} + +func (s *session) sendPacket() (bool, error) { s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked()) - // Get MAX_DATA and MAX_STREAM_DATA frames - // this call triggers the flow controller to increase the flow control windows, if necessary - windowUpdates := s.getWindowUpdates() - for _, f := range windowUpdates { - s.packer.QueueControlFrame(f) + if offset := s.connFlowController.GetWindowUpdate(); offset != 0 { + s.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: offset}) } + if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked { + s.packer.QueueControlFrame(&wire.BlockedFrame{Offset: offset}) + } + s.windowUpdateQueue.QueueAll() ack := s.receivedPacketHandler.GetAckFrame() if ack != nil { s.packer.QueueControlFrame(ack) } - // Repeatedly try sending until we don't have any more data, or run out of the congestion window + // check for retransmissions first for { - if !s.sentPacketHandler.SendingAllowed() { - if ack == nil { - return nil - } - // If we aren't allowed to send, at least try sending an ACK frame - swf := s.sentPacketHandler.GetStopWaitingFrame(false) - if swf != nil { - s.packer.QueueControlFrame(swf) - } - packet, err := s.packer.PackAckPacket() - if err != nil { - return err - } - return s.sendPackedPacket(packet) + retransmitPacket := s.sentPacketHandler.DequeuePacketForRetransmission() + if retransmitPacket == nil { + break } - // check for retransmissions first - for { - retransmitPacket := s.sentPacketHandler.DequeuePacketForRetransmission() - if retransmitPacket == nil { - break - } - - if retransmitPacket.EncryptionLevel != protocol.EncryptionForwardSecure { - if s.handshakeComplete { - // Don't retransmit handshake packets when the handshake is complete - continue - } - utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber) + // retransmit handshake packets + if retransmitPacket.EncryptionLevel != protocol.EncryptionForwardSecure { + utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber) + if !s.version.UsesIETFFrameFormat() { s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true)) - packet, err := s.packer.PackHandshakeRetransmission(retransmitPacket) - if err != nil { - return err - } - if err = s.sendPackedPacket(packet); err != nil { - return err - } - } else { - utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber) - // resend the frames that were in the packet - for _, frame := range retransmitPacket.GetFramesForRetransmission() { - // TODO: only retransmit WINDOW_UPDATEs if they actually enlarge the window - switch f := frame.(type) { - case *wire.StreamFrame: - s.streamFramer.AddFrameForRetransmission(f) - default: - s.packer.QueueControlFrame(frame) - } - } } + packet, err := s.packer.PackHandshakeRetransmission(retransmitPacket) + if err != nil { + return false, err + } + if err := s.sendPackedPacket(packet); err != nil { + return false, err + } + return true, nil } - hasRetransmission := s.streamFramer.HasFramesForRetransmission() - if ack != nil || hasRetransmission { - swf := s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission) - if swf != nil { - s.packer.QueueControlFrame(swf) + // queue all retransmittable frames sent in forward-secure packets + utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber) + // resend the frames that were in the packet + for _, frame := range retransmitPacket.GetFramesForRetransmission() { + // TODO: only retransmit WINDOW_UPDATEs if they actually enlarge the window + switch f := frame.(type) { + case *wire.StreamFrame: + s.streamFramer.AddFrameForRetransmission(f) + default: + s.packer.QueueControlFrame(frame) } } - // add a retransmittable frame - if s.sentPacketHandler.ShouldSendRetransmittablePacket() { - s.packer.QueueControlFrame(&wire.PingFrame{}) - } - packet, err := s.packer.PackPacket() - if err != nil || packet == nil { - return err - } - if err = s.sendPackedPacket(packet); err != nil { - return err - } - - // send every window update twice - for _, f := range windowUpdates { - s.packer.QueueControlFrame(f) - } - windowUpdates = nil - ack = nil } + + hasRetransmission := s.streamFramer.HasFramesForRetransmission() + if !s.version.UsesIETFFrameFormat() && (ack != nil || hasRetransmission) { + if swf := s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission); swf != nil { + s.packer.QueueControlFrame(swf) + } + } + // add a retransmittable frame + if s.sentPacketHandler.ShouldSendRetransmittablePacket() { + s.packer.MakeNextPacketRetransmittable() + } + packet, err := s.packer.PackPacket() + if err != nil || packet == nil { + return false, err + } + if err := s.sendPackedPacket(packet); err != nil { + return false, err + } + return true, nil } func (s *session) sendPackedPacket(packet *packedPacket) error { @@ -824,7 +881,7 @@ func (s *session) GetOrOpenStream(id protocol.StreamID) (Stream, error) { return str, err } // make sure to return an actual nil value here, not an Stream with value nil - return nil, err + return str, err } // AcceptStream returns the next stream openend by the peer @@ -841,18 +898,6 @@ func (s *session) OpenStreamSync() (Stream, error) { return s.streamsMap.OpenStreamSync() } -func (s *session) WaitUntilHandshakeComplete() error { - return <-s.handshakeCompleteChan -} - -func (s *session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) { - s.packer.QueueControlFrame(&wire.RstStreamFrame{ - StreamID: id, - ByteOffset: offset, - }) - s.scheduleSending() -} - func (s *session) newStream(id protocol.StreamID) streamI { var initialSendWindow protocol.ByteCount if s.peerParams != nil { @@ -867,7 +912,21 @@ func (s *session) newStream(id protocol.StreamID) streamI { initialSendWindow, s.rttStats, ) - return newStream(id, s.scheduleSending, s.queueResetStreamFrame, flowController, s.version) + return newStream(id, s, flowController, s.version) +} + +func (s *session) newCryptoStream() cryptoStreamI { + id := s.version.CryptoStreamID() + flowController := flowcontrol.NewStreamFlowController( + id, + s.version.StreamContributesToConnectionFlowControl(id), + s.connFlowController, + protocol.ReceiveStreamFlowControlWindow, + protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), + 0, + s.rttStats, + ) + return newCryptoStream(s, flowController, s.version) } func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error { @@ -908,22 +967,25 @@ func (s *session) tryDecryptingQueuedPackets() { s.undecryptablePackets = s.undecryptablePackets[:0] } -func (s *session) getWindowUpdates() []wire.Frame { - var res []wire.Frame - s.streamsMap.Range(func(str streamI) { - if offset := str.GetWindowUpdate(); offset != 0 { - res = append(res, &wire.MaxStreamDataFrame{ - StreamID: str.StreamID(), - ByteOffset: offset, - }) - } - }) - if offset := s.connFlowController.GetWindowUpdate(); offset != 0 { - res = append(res, &wire.MaxDataFrame{ - ByteOffset: offset, - }) +func (s *session) queueControlFrame(f wire.Frame) { + s.packer.QueueControlFrame(f) + s.scheduleSending() +} + +func (s *session) onHasWindowUpdate(id protocol.StreamID) { + s.windowUpdateQueue.Add(id) + s.scheduleSending() +} + +func (s *session) onHasStreamData(id protocol.StreamID) { + s.streamFramer.AddActiveStream(id) + s.scheduleSending() +} + +func (s *session) onStreamCompleted(id protocol.StreamID) { + if err := s.streamsMap.DeleteStream(id); err != nil { + s.Close(err) } - return res } func (s *session) LocalAddr() net.Addr { @@ -935,11 +997,11 @@ func (s *session) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } -func (s *session) handshakeStatus() <-chan handshakeEvent { +func (s *session) handshakeStatus() <-chan error { return s.handshakeChan } -func (s *session) getCryptoStream() cryptoStream { +func (s *session) getCryptoStream() cryptoStreamI { return s.cryptoStream } diff --git a/vendor/github.com/lucas-clemente/quic-go/session_test.go b/vendor/github.com/lucas-clemente/quic-go/session_test.go index 52c9ad9..47eff4d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/session_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/session_test.go @@ -19,6 +19,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/mocks" + "github.com/lucas-clemente/quic-go/internal/mocks/ackhandler" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/wire" @@ -70,73 +71,6 @@ func (m *mockUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []byte }, nil } -type mockSentPacketHandler struct { - retransmissionQueue []*ackhandler.Packet - sentPackets []*ackhandler.Packet - congestionLimited bool - requestedStopWaiting bool - shouldSendRetransmittablePacket bool -} - -func (h *mockSentPacketHandler) SentPacket(packet *ackhandler.Packet) error { - h.sentPackets = append(h.sentPackets, packet) - return nil -} - -func (h *mockSentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error { - return nil -} -func (h *mockSentPacketHandler) SetHandshakeComplete() {} -func (h *mockSentPacketHandler) GetLeastUnacked() protocol.PacketNumber { return 1 } -func (h *mockSentPacketHandler) GetAlarmTimeout() time.Time { panic("not implemented") } -func (h *mockSentPacketHandler) OnAlarm() { panic("not implemented") } -func (h *mockSentPacketHandler) SendingAllowed() bool { return !h.congestionLimited } -func (h *mockSentPacketHandler) ShouldSendRetransmittablePacket() bool { - b := h.shouldSendRetransmittablePacket - h.shouldSendRetransmittablePacket = false - return b -} - -func (h *mockSentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame { - h.requestedStopWaiting = true - return &wire.StopWaitingFrame{LeastUnacked: 0x1337} -} - -func (h *mockSentPacketHandler) DequeuePacketForRetransmission() *ackhandler.Packet { - if len(h.retransmissionQueue) > 0 { - packet := h.retransmissionQueue[0] - h.retransmissionQueue = h.retransmissionQueue[1:] - return packet - } - return nil -} - -func newMockSentPacketHandler() ackhandler.SentPacketHandler { - return &mockSentPacketHandler{} -} - -var _ ackhandler.SentPacketHandler = &mockSentPacketHandler{} - -type mockReceivedPacketHandler struct { - nextAckFrame *wire.AckFrame - ackAlarm time.Time -} - -func (m *mockReceivedPacketHandler) GetAckFrame() *wire.AckFrame { - f := m.nextAckFrame - m.nextAckFrame = nil - return f -} -func (m *mockReceivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error { - panic("not implemented") -} -func (m *mockReceivedPacketHandler) IgnoreBelow(protocol.PacketNumber) { - panic("not implemented") -} -func (m *mockReceivedPacketHandler) GetAlarmTimeout() time.Time { return m.ackAlarm } - -var _ ackhandler.ReceivedPacketHandler = &mockReceivedPacketHandler{} - func areSessionsRunning() bool { var b bytes.Buffer pprof.Lookup("goroutine").WriteTo(&b, 1) @@ -145,11 +79,12 @@ func areSessionsRunning() bool { var _ = Describe("Session", func() { var ( - sess *session - scfg *handshake.ServerConfig - mconn *mockConnection - cryptoSetup *mockCryptoSetup - aeadChanged chan<- protocol.EncryptionLevel + sess *session + scfg *handshake.ServerConfig + mconn *mockConnection + cryptoSetup *mockCryptoSetup + streamManager *MockStreamManager + handshakeChan chan<- struct{} ) BeforeEach(func() { @@ -166,9 +101,9 @@ var _ = Describe("Session", func() { _ []protocol.VersionNumber, _ func(net.Addr, *Cookie) bool, _ chan<- handshake.TransportParameters, - aeadChangedP chan<- protocol.EncryptionLevel, + handshakeChanP chan<- struct{}, ) (handshake.CryptoSetup, error) { - aeadChanged = aeadChangedP + handshakeChan = handshakeChanP return cryptoSetup, nil } @@ -189,7 +124,8 @@ var _ = Describe("Session", func() { ) Expect(err).NotTo(HaveOccurred()) sess = pSess.(*session) - Expect(sess.streamsMap.openStreams).To(BeEmpty()) + streamManager = NewMockStreamManager(mockCtrl) + sess.streamsMap = streamManager }) AfterEach(func() { @@ -216,7 +152,7 @@ var _ = Describe("Session", func() { _ []protocol.VersionNumber, cookieFunc func(net.Addr, *Cookie) bool, _ chan<- handshake.TransportParameters, - _ chan<- protocol.EncryptionLevel, + _ chan<- struct{}, ) (handshake.CryptoSetup, error) { cookieVerify = cookieFunc return cryptoSetup, nil @@ -258,131 +194,119 @@ var _ = Describe("Session", func() { }) Context("frame handling", func() { - BeforeEach(func() { - sess.streamsMap.newStream = func(id protocol.StreamID) streamI { - str := mocks.NewMockStreamI(mockCtrl) - str.EXPECT().StreamID().Return(id).AnyTimes() - if id == 1 { - str.EXPECT().Finished().AnyTimes() - } - return str - } - }) - - Context("when handling STREAM frames", func() { - BeforeEach(func() { - sess.streamsMap.UpdateMaxStreamLimit(100) - }) - - It("makes new streams", func() { + Context("handling STREAM frames", func() { + It("passes STREAM frames to the stream", func() { f := &wire.StreamFrame{ StreamID: 5, Data: []byte{0xde, 0xca, 0xfb, 0xad}, } - newStreamLambda := sess.streamsMap.newStream - sess.streamsMap.newStream = func(id protocol.StreamID) streamI { - str := newStreamLambda(id) - if id == 5 { - str.(*mocks.MockStreamI).EXPECT().AddStreamFrame(f) - } - return str - } + str := NewMockReceiveStreamI(mockCtrl) + str.EXPECT().handleStreamFrame(f) + streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(str, nil) err := sess.handleStreamFrame(f) Expect(err).ToNot(HaveOccurred()) - str, err := sess.streamsMap.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - Expect(str).ToNot(BeNil()) }) - It("handles existing streams", func() { - f1 := &wire.StreamFrame{ + It("returns errors", func() { + testErr := errors.New("test err") + f := &wire.StreamFrame{ StreamID: 5, - Data: []byte{0xde, 0xca}, + Data: []byte{0xde, 0xca, 0xfb, 0xad}, } - f2 := &wire.StreamFrame{ - StreamID: 5, - Offset: 2, - Data: []byte{0xfb, 0xad}, - } - newStreamLambda := sess.streamsMap.newStream - sess.streamsMap.newStream = func(id protocol.StreamID) streamI { - str := newStreamLambda(id) - if id == 5 { - str.(*mocks.MockStreamI).EXPECT().AddStreamFrame(f1) - str.(*mocks.MockStreamI).EXPECT().AddStreamFrame(f2) - } - return str - } - sess.handleStreamFrame(f1) - numOpenStreams := len(sess.streamsMap.openStreams) - sess.handleStreamFrame(f2) - Expect(sess.streamsMap.openStreams).To(HaveLen(numOpenStreams)) + str := NewMockReceiveStreamI(mockCtrl) + str.EXPECT().handleStreamFrame(f).Return(testErr) + streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(str, nil) + err := sess.handleStreamFrame(f) + Expect(err).To(MatchError(testErr)) }) It("ignores STREAM frames for closed streams", func() { - sess.streamsMap.streams[5] = nil - str, err := sess.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(BeNil()) // make sure the stream is gone - err = sess.handleStreamFrame(&wire.StreamFrame{ + streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(nil, nil) // for closed streams, the streamManager returns nil + err := sess.handleStreamFrame(&wire.StreamFrame{ StreamID: 5, Data: []byte("foobar"), }) Expect(err).ToNot(HaveOccurred()) }) + + It("errors on a STREAM frame that would close the crypto stream", func() { + err := sess.handleStreamFrame(&wire.StreamFrame{ + StreamID: sess.version.CryptoStreamID(), + Offset: 0x1337, + FinBit: true, + }) + Expect(err).To(MatchError("Received STREAM frame with FIN bit for the crypto stream")) + }) + }) + + Context("handling ACK frames", func() { + It("informs the SentPacketHandler about ACKs", func() { + f := &wire.AckFrame{LargestAcked: 3, LowestAcked: 2} + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().ReceivedAck(f, protocol.PacketNumber(42), protocol.EncryptionSecure, gomock.Any()) + sph.EXPECT().GetLowestPacketNotConfirmedAcked() + sess.sentPacketHandler = sph + sess.lastRcvdPacketNumber = 42 + err := sess.handleAckFrame(f, protocol.EncryptionSecure) + Expect(err).ToNot(HaveOccurred()) + }) + + It("tells the ReceivedPacketHandler to ignore low ranges", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().ReceivedAck(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().GetLowestPacketNotConfirmedAcked().Return(protocol.PacketNumber(0x42)) + sess.sentPacketHandler = sph + rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) + rph.EXPECT().IgnoreBelow(protocol.PacketNumber(0x42)) + sess.receivedPacketHandler = rph + err := sess.handleAckFrame(&wire.AckFrame{LargestAcked: 3, LowestAcked: 2}, protocol.EncryptionUnencrypted) + Expect(err).ToNot(HaveOccurred()) + }) }) Context("handling RST_STREAM frames", func() { It("closes the streams for writing", func() { - str, err := sess.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - str.(*mocks.MockStreamI).EXPECT().RegisterRemoteError( - errors.New("RST_STREAM received with code 42"), - protocol.ByteCount(0x1337), - ) - err = sess.handleRstStreamFrame(&wire.RstStreamFrame{ - StreamID: 5, + f := &wire.RstStreamFrame{ + StreamID: 555, ErrorCode: 42, ByteOffset: 0x1337, - }) + } + str := NewMockReceiveStreamI(mockCtrl) + streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(555)).Return(str, nil) + str.EXPECT().handleRstStreamFrame(f) + err := sess.handleRstStreamFrame(f) Expect(err).ToNot(HaveOccurred()) }) - It("queues a RST_STERAM frame", func() { - sess.queueResetStreamFrame(5, 0x1337) - Expect(sess.packer.controlFrames).To(HaveLen(1)) - Expect(sess.packer.controlFrames[0].(*wire.RstStreamFrame)).To(Equal(&wire.RstStreamFrame{ - StreamID: 5, - ByteOffset: 0x1337, - })) - }) - It("returns errors", func() { - testErr := errors.New("flow control violation") - str, err := sess.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - str.(*mocks.MockStreamI).EXPECT().RegisterRemoteError(gomock.Any(), gomock.Any()).Return(testErr) - err = sess.handleRstStreamFrame(&wire.RstStreamFrame{ - StreamID: 5, + f := &wire.RstStreamFrame{ + StreamID: 7, ByteOffset: 0x1337, - }) + } + testErr := errors.New("flow control violation") + str := NewMockReceiveStreamI(mockCtrl) + streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(7)).Return(str, nil) + str.EXPECT().handleRstStreamFrame(f).Return(testErr) + err := sess.handleRstStreamFrame(f) Expect(err).To(MatchError(testErr)) }) - It("ignores the error when the stream is not known", func() { - str, err := sess.GetOrOpenStream(3) - Expect(err).ToNot(HaveOccurred()) - str.(*mocks.MockStreamI).EXPECT().Finished().Return(true) - sess.streamsMap.DeleteClosedStreams() - str, err = sess.GetOrOpenStream(3) - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(BeNil()) - err = sess.handleFrames([]wire.Frame{&wire.RstStreamFrame{ + It("ignores RST_STREAM frames for closed streams", func() { + streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(3)).Return(nil, nil) + err := sess.handleFrames([]wire.Frame{&wire.RstStreamFrame{ StreamID: 3, ErrorCode: 42, }}, protocol.EncryptionUnspecified) Expect(err).NotTo(HaveOccurred()) }) + + It("erros when a RST_STREAM frame would reset the crypto stream", func() { + err := sess.handleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: sess.version.CryptoStreamID(), + ErrorCode: 123, + }) + Expect(err).To(MatchError("Received RST_STREAM frame for the crypto stream")) + }) }) Context("handling MAX_DATA and MAX_STREAM_DATA frames", func() { @@ -393,55 +317,72 @@ var _ = Describe("Session", func() { sess.connFlowController = connFC }) - It("updates the flow control window of a stream", func() { - offset := protocol.ByteCount(0x1234) - str, err := sess.GetOrOpenStream(5) - str.(*mocks.MockStreamI).EXPECT().UpdateSendWindow(offset) - Expect(err).ToNot(HaveOccurred()) - err = sess.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ - StreamID: 5, + It("updates the flow control window of the crypto stream", func() { + fc := mocks.NewMockStreamFlowController(mockCtrl) + offset := protocol.ByteCount(0x4321) + fc.EXPECT().UpdateSendWindow(offset) + sess.cryptoStream.(*cryptoStream).sendStream.flowController = fc + err := sess.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: sess.version.CryptoStreamID(), ByteOffset: offset, }) Expect(err).ToNot(HaveOccurred()) }) + It("updates the flow control window of a stream", func() { + f := &wire.MaxStreamDataFrame{ + StreamID: 12345, + ByteOffset: 0x1337, + } + str := NewMockSendStreamI(mockCtrl) + streamManager.EXPECT().GetOrOpenSendStream(protocol.StreamID(12345)).Return(str, nil) + str.EXPECT().handleMaxStreamDataFrame(f) + err := sess.handleMaxStreamDataFrame(f) + Expect(err).ToNot(HaveOccurred()) + }) + It("updates the flow control window of the connection", func() { offset := protocol.ByteCount(0x800000) connFC.EXPECT().UpdateSendWindow(offset) sess.handleMaxDataFrame(&wire.MaxDataFrame{ByteOffset: offset}) }) - It("opens a new stream when receiving a MAX_STREAM_DATA frame for an unknown stream", func() { - newStreamLambda := sess.streamsMap.newStream - sess.streamsMap.newStream = func(id protocol.StreamID) streamI { - str := newStreamLambda(id) - if id == 5 { - str.(*mocks.MockStreamI).EXPECT().UpdateSendWindow(protocol.ByteCount(0x1337)) - } - return str - } - err := sess.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ - StreamID: 5, - ByteOffset: 0x1337, - }) - Expect(err).ToNot(HaveOccurred()) - str, err := sess.streamsMap.GetOrOpenStream(5) + It("ignores MAX_STREAM_DATA frames for a closed stream", func() { + streamManager.EXPECT().GetOrOpenSendStream(protocol.StreamID(10)).Return(nil, nil) + err := sess.handleFrames([]wire.Frame{&wire.MaxStreamDataFrame{ + StreamID: 10, + ByteOffset: 1337, + }}, protocol.EncryptionUnspecified) Expect(err).NotTo(HaveOccurred()) - Expect(str).ToNot(BeNil()) + }) + }) + + Context("handling STOP_SENDING frames", func() { + It("passes the frame to the stream", func() { + f := &wire.StopSendingFrame{ + StreamID: 5, + ErrorCode: 10, + } + str := NewMockSendStreamI(mockCtrl) + streamManager.EXPECT().GetOrOpenSendStream(protocol.StreamID(5)).Return(str, nil) + str.EXPECT().handleStopSendingFrame(f) + err := sess.handleStopSendingFrame(f) + Expect(err).ToNot(HaveOccurred()) }) - It("ignores MAX_STREAM_DATA frames for a closed stream", func() { - str, err := sess.GetOrOpenStream(3) - Expect(err).ToNot(HaveOccurred()) - str.(*mocks.MockStreamI).EXPECT().Finished().Return(true) - err = sess.streamsMap.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - str, err = sess.GetOrOpenStream(3) - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(BeNil()) - err = sess.handleFrames([]wire.Frame{&wire.MaxStreamDataFrame{ - StreamID: 3, - ByteOffset: 1337, + It("errors when receiving a STOP_SENDING for the crypto stream", func() { + err := sess.handleStopSendingFrame(&wire.StopSendingFrame{ + StreamID: sess.version.CryptoStreamID(), + ErrorCode: 10, + }) + Expect(err).To(MatchError("Received a STOP_SENDING frame for the crypto stream")) + }) + + It("ignores STOP_SENDING frames for a closed stream", func() { + streamManager.EXPECT().GetOrOpenSendStream(protocol.StreamID(3)).Return(nil, nil) + err := sess.handleFrames([]wire.Frame{&wire.StopSendingFrame{ + StreamID: 3, + ErrorCode: 1337, }}, protocol.EncryptionUnspecified) Expect(err).NotTo(HaveOccurred()) }) @@ -468,19 +409,16 @@ var _ = Describe("Session", func() { }) It("handles CONNECTION_CLOSE frames", func() { + testErr := qerr.Error(qerr.ProofInvalid, "foobar") + streamManager.EXPECT().CloseWithError(testErr) done := make(chan struct{}) go func() { defer GinkgoRecover() err := sess.run() - Expect(err).To(MatchError("ProofInvalid: foobar")) + Expect(err).To(MatchError(testErr)) close(done) }() - _, err := sess.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - sess.streamsMap.Range(func(s streamI) { - s.(*mocks.MockStreamI).EXPECT().Cancel(gomock.Any()) - }) - err = sess.handleFrames([]wire.Frame{&wire.ConnectionCloseFrame{ErrorCode: qerr.ProofInvalid, ReasonPhrase: "foobar"}}, protocol.EncryptionUnspecified) + err := sess.handleFrames([]wire.Frame{&wire.ConnectionCloseFrame{ErrorCode: qerr.ProofInvalid, ReasonPhrase: "foobar"}}, protocol.EncryptionUnspecified) Expect(err).NotTo(HaveOccurred()) Eventually(sess.Context().Done()).Should(BeClosed()) Eventually(done).Should(BeClosed()) @@ -492,129 +430,26 @@ var _ = Describe("Session", func() { Expect(sess.GetVersion()).To(Equal(protocol.VersionNumber(4242))) }) - Context("waiting until the handshake completes", func() { - It("waits until the handshake is complete", func() { - go func() { - defer GinkgoRecover() - sess.run() - }() - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - err := sess.WaitUntilHandshakeComplete() - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - aeadChanged <- protocol.EncryptionForwardSecure - Consistently(done).ShouldNot(BeClosed()) - close(aeadChanged) - Eventually(done).Should(BeClosed()) - Expect(sess.Close(nil)).To(Succeed()) - }) - - It("errors if the handshake fails", func(done Done) { - testErr := errors.New("crypto error") - sess.cryptoSetup = &mockCryptoSetup{handleErr: testErr} - go sess.run() - err := sess.WaitUntilHandshakeComplete() - Expect(err).To(MatchError(testErr)) - close(done) - }, 0.5) - - It("returns when Close is called", func(done Done) { - testErr := errors.New("close error") - go sess.run() - var waitReturned bool - go func() { - defer GinkgoRecover() - err := sess.WaitUntilHandshakeComplete() - Expect(err).To(MatchError(testErr)) - waitReturned = true - }() - sess.Close(testErr) - Eventually(func() bool { return waitReturned }).Should(BeTrue()) - close(done) - }) - - It("doesn't wait if the handshake is already completed", func(done Done) { - go sess.run() - close(aeadChanged) - err := sess.WaitUntilHandshakeComplete() - Expect(err).ToNot(HaveOccurred()) - Expect(sess.Close(nil)).To(Succeed()) - close(done) - }) - }) - - Context("accepting streams", func() { - BeforeEach(func() { - // don't use the mock here - sess.streamsMap.newStream = sess.newStream - }) - - It("waits for new streams", func() { - strChan := make(chan Stream) - // accept two streams - go func() { - defer GinkgoRecover() - for i := 0; i < 2; i++ { - str, err := sess.AcceptStream() - Expect(err).ToNot(HaveOccurred()) - strChan <- str - } - }() - Consistently(strChan).ShouldNot(Receive()) - // this could happen e.g. by receiving a STREAM frame - _, err := sess.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - var str Stream - Eventually(strChan).Should(Receive(&str)) - Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) - Eventually(strChan).Should(Receive(&str)) - Expect(str.StreamID()).To(Equal(protocol.StreamID(5))) - }) - - It("stops accepting when the session is closed", func() { - testErr := errors.New("testErr") - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := sess.AcceptStream() - Expect(err).To(MatchError(qerr.ToQuicError(testErr))) - close(done) - }() - go sess.run() - Consistently(done).ShouldNot(BeClosed()) - sess.Close(testErr) - Eventually(done).Should(BeClosed()) - }) - - It("stops accepting when the session is closed after version negotiation", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := sess.AcceptStream() - Expect(err).To(MatchError(qerr.Error(qerr.InternalError, errCloseSessionForNewVersion.Error()))) - close(done) - }() - go sess.run() - Consistently(done).ShouldNot(BeClosed()) - Expect(sess.Context().Done()).ToNot(BeClosed()) - sess.Close(errCloseSessionForNewVersion) - Eventually(done).Should(BeClosed()) - Eventually(sess.Context().Done()).Should(BeClosed()) - }) + It("accepts new streams", func() { + mstr := NewMockStreamI(mockCtrl) + streamManager.EXPECT().AcceptStream().Return(mstr, nil) + str, err := sess.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal(mstr)) }) Context("closing", func() { BeforeEach(func() { Eventually(areSessionsRunning).Should(BeFalse()) - go sess.run() + go func() { + defer GinkgoRecover() + sess.run() + }() Eventually(areSessionsRunning).Should(BeTrue()) }) It("shuts down without error", func() { + streamManager.EXPECT().CloseWithError(qerr.Error(qerr.PeerGoingAway, "")) sess.Close(nil) Eventually(areSessionsRunning).Should(BeFalse()) Expect(mconn.written).To(HaveLen(1)) @@ -626,6 +461,7 @@ var _ = Describe("Session", func() { }) It("only closes once", func() { + streamManager.EXPECT().CloseWithError(qerr.Error(qerr.PeerGoingAway, "")) sess.Close(nil) sess.Close(nil) Eventually(areSessionsRunning).Should(BeFalse()) @@ -635,26 +471,21 @@ var _ = Describe("Session", func() { It("closes streams with proper error", func() { testErr := errors.New("test error") - s, err := sess.GetOrOpenStream(5) - Expect(err).NotTo(HaveOccurred()) + streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error())) sess.Close(testErr) Eventually(areSessionsRunning).Should(BeFalse()) - n, err := s.Read([]byte{0}) - Expect(n).To(BeZero()) - Expect(err.Error()).To(ContainSubstring(testErr.Error())) - n, err = s.Write([]byte{0}) - Expect(n).To(BeZero()) - Expect(err.Error()).To(ContainSubstring(testErr.Error())) Expect(sess.Context().Done()).To(BeClosed()) }) It("closes the session in order to replace it with another QUIC version", func() { + streamManager.EXPECT().CloseWithError(gomock.Any()) sess.Close(errCloseSessionForNewVersion) Eventually(areSessionsRunning).Should(BeFalse()) Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent }) It("sends a Public Reset if the client is initiating the head-of-line blocking experiment", func() { + streamManager.EXPECT().CloseWithError(gomock.Any()) sess.Close(handshake.ErrHOLExperiment) Expect(mconn.written).To(HaveLen(1)) Expect((<-mconn.written)[0] & 0x02).ToNot(BeZero()) // Public Reset @@ -662,6 +493,7 @@ var _ = Describe("Session", func() { }) It("sends a Public Reset if the client is initiating the no STOP_WAITING experiment", func() { + streamManager.EXPECT().CloseWithError(gomock.Any()) sess.Close(handshake.ErrHOLExperiment) Expect(mconn.written).To(HaveLen(1)) Expect((<-mconn.written)[0] & 0x02).ToNot(BeZero()) // Public Reset @@ -669,6 +501,7 @@ var _ = Describe("Session", func() { }) It("cancels the context when the run loop exists", func() { + streamManager.EXPECT().CloseWithError(gomock.Any()) returned := make(chan struct{}) go func() { defer GinkgoRecover() @@ -699,11 +532,23 @@ var _ = Describe("Session", func() { Expect(sess.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) }) + It("informs the ReceivedPacketHandler", func() { + now := time.Now().Add(time.Hour) + rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) + rph.EXPECT().ReceivedPacket(protocol.PacketNumber(5), now, false) + sess.receivedPacketHandler = rph + hdr.PacketNumber = 5 + err := sess.handlePacketImpl(&receivedPacket{header: hdr, rcvTime: now}) + Expect(err).ToNot(HaveOccurred()) + }) + It("closes when handling a packet fails", func(done Done) { + streamManager.EXPECT().CloseWithError(gomock.Any()) testErr := errors.New("unpack error") hdr.PacketNumber = 5 var runErr error go func() { + defer GinkgoRecover() runErr = sess.run() }() sess.unpacker.(*mockUnpacker).unpackErr = testErr @@ -774,54 +619,117 @@ var _ = Describe("Session", func() { It("sends ACK frames", func() { packetNumber := protocol.PacketNumber(0x035e) - err := sess.receivedPacketHandler.ReceivedPacket(packetNumber, true) + err := sess.receivedPacketHandler.ReceivedPacket(packetNumber, time.Now(), true) Expect(err).ToNot(HaveOccurred()) - err = sess.sendPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(mconn.written).To(HaveLen(1)) - Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x03, 0x5e})))) - }) - - It("sends ACK frames when congestion limited", func() { - sess.sentPacketHandler = &mockSentPacketHandler{congestionLimited: true} - sess.packer.packetNumberGenerator.next = 0x1338 - packetNumber := protocol.PacketNumber(0x035e) - sess.receivedPacketHandler.ReceivedPacket(packetNumber, true) - err := sess.sendPacket() + sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) Expect(mconn.written).To(HaveLen(1)) Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x03, 0x5e})))) }) It("sends a retransmittable packet when required by the SentPacketHandler", func() { - sess.sentPacketHandler = &mockSentPacketHandler{shouldSendRetransmittablePacket: true} - err := sess.sendPacket() - Expect(err).ToNot(HaveOccurred()) + ack := &wire.AckFrame{LargestAcked: 1000} + sess.packer.QueueControlFrame(ack) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLeastUnacked().AnyTimes() + sph.EXPECT().DequeuePacketForRetransmission() + sph.EXPECT().ShouldSendRetransmittablePacket().Return(true) + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.Frames).To(HaveLen(2)) + Expect(p.Frames).To(ContainElement(ack)) + }) + sess.sentPacketHandler = sph + sent, err := sess.sendPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) Expect(mconn.written).To(HaveLen(1)) - Expect(sess.sentPacketHandler.(*mockSentPacketHandler).sentPackets[0].Frames).To(ContainElement(&wire.PingFrame{})) }) - It("sends two MAX_STREAM_DATA frames", func() { - mockFC := mocks.NewMockStreamFlowController(mockCtrl) - mockFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x1000)) - mockFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0)).Times(2) - str, err := sess.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - str.(*stream).flowController = mockFC - err = sess.sendPacket() + It("adds a MAX_DATA frames", func() { + fc := mocks.NewMockConnectionFlowController(mockCtrl) + fc.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x1337)) + fc.EXPECT().IsNewlyBlocked() + sess.connFlowController = fc + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLeastUnacked().AnyTimes() + sph.EXPECT().DequeuePacketForRetransmission() + sph.EXPECT().ShouldSendRetransmittablePacket() + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.Frames).To(Equal([]wire.Frame{ + &wire.MaxDataFrame{ByteOffset: 0x1337}, + })) + }) + sess.sentPacketHandler = sph + sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) - err = sess.sendPacket() + Expect(sent).To(BeTrue()) + }) + + It("adds MAX_STREAM_DATA frames", func() { + sess.windowUpdateQueue.callback(&wire.MaxStreamDataFrame{ + StreamID: 2, + ByteOffset: 20, + }) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLeastUnacked().AnyTimes() + sph.EXPECT().DequeuePacketForRetransmission() + sph.EXPECT().ShouldSendRetransmittablePacket() + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.Frames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 2, ByteOffset: 20})) + }) + sess.sentPacketHandler = sph + sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) - err = sess.sendPacket() + Expect(sent).To(BeTrue()) + }) + + It("adds a BLOCKED frame when it is connection-level flow control blocked", func() { + fc := mocks.NewMockConnectionFlowController(mockCtrl) + fc.EXPECT().GetWindowUpdate() + fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) + sess.connFlowController = fc + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLeastUnacked().AnyTimes() + sph.EXPECT().DequeuePacketForRetransmission() + sph.EXPECT().ShouldSendRetransmittablePacket() + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.Frames).To(Equal([]wire.Frame{ + &wire.BlockedFrame{Offset: 1337}, + })) + }) + sess.sentPacketHandler = sph + sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) - buf := &bytes.Buffer{} - (&wire.MaxStreamDataFrame{ - StreamID: 5, - ByteOffset: 0x1000, - }).Write(buf, sess.version) - Expect(mconn.written).To(HaveLen(2)) - Expect(mconn.written).To(Receive(ContainSubstring(string(buf.Bytes())))) - Expect(mconn.written).To(Receive(ContainSubstring(string(buf.Bytes())))) + Expect(sent).To(BeTrue()) + }) + + It("sends multiple packets", func() { + sess.queueControlFrame(&wire.MaxDataFrame{ByteOffset: 1}) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().DequeuePacketForRetransmission().Times(2) + sph.EXPECT().GetAlarmTimeout().AnyTimes() + sph.EXPECT().GetLeastUnacked().AnyTimes() + sph.EXPECT().ShouldSendRetransmittablePacket().Times(2) + sph.EXPECT().SentPacket(gomock.Any()).Times(2) + sph.EXPECT().SendingAllowed().Do(func() { // after sending the first packet + // make sure there's something to send + sess.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: 2}) + }).Return(true).Times(2) // allow 2 packets... + sph.EXPECT().SendingAllowed() // ...then report that we're congestion limited + sess.sentPacketHandler = sph + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + sess.run() + close(done) + }() + sess.scheduleSending() + Eventually(mconn.written).Should(HaveLen(2)) + // make the go routine return + streamManager.EXPECT().CloseWithError(gomock.Any()) + sess.Close(nil) + Eventually(done).Should(BeClosed()) }) It("sends public reset", func() { @@ -832,109 +740,200 @@ var _ = Describe("Session", func() { }) It("informs the SentPacketHandler about sent packets", func() { - sess.sentPacketHandler = newMockSentPacketHandler() - sess.packer.packetNumberGenerator.next = 0x1337 + 9 - sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure} - f := &wire.StreamFrame{ StreamID: 5, Data: []byte("foobar"), } + var sentPacket *ackhandler.Packet + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLeastUnacked().AnyTimes() + sph.EXPECT().GetStopWaitingFrame(gomock.Any()) + sph.EXPECT().DequeuePacketForRetransmission() + sph.EXPECT().ShouldSendRetransmittablePacket() + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + sentPacket = p + }) + sess.sentPacketHandler = sph + sess.packer.packetNumberGenerator.next = 0x1337 + 9 + sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure} + sess.streamFramer.AddFrameForRetransmission(f) - _, err := sess.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - err = sess.sendPacket() + sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) Expect(mconn.written).To(HaveLen(1)) - sentPackets := sess.sentPacketHandler.(*mockSentPacketHandler).sentPackets - Expect(sentPackets).To(HaveLen(1)) - Expect(sentPackets[0].Frames).To(ContainElement(f)) - Expect(sentPackets[0].EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) - Expect(mconn.written).To(HaveLen(1)) - Expect(sentPackets[0].Length).To(BeEquivalentTo(len(<-mconn.written))) + Expect(sentPacket.PacketNumber).To(Equal(protocol.PacketNumber(0x1337 + 9))) + Expect(sentPacket.Frames).To(ContainElement(f)) + Expect(sentPacket.EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) + Expect(sentPacket.Length).To(BeEquivalentTo(len(<-mconn.written))) + }) + }) + + Context("sending ACK only packets", func() { + It("doesn't do anything if there's no ACK to be sent", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sess.sentPacketHandler = sph + err := sess.maybeSendAckOnlyPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(mconn.written).To(BeEmpty()) + }) + + It("sends ACK only packets", func() { + swf := &wire.StopWaitingFrame{LeastUnacked: 10} + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLeastUnacked() + sph.EXPECT().GetAlarmTimeout().AnyTimes() + sph.EXPECT().SendingAllowed() + sph.EXPECT().GetStopWaitingFrame(false).Return(swf) + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.Frames).To(HaveLen(2)) + Expect(p.Frames[0]).To(BeAssignableToTypeOf(&wire.AckFrame{})) + Expect(p.Frames[1]).To(Equal(swf)) + }) + sess.sentPacketHandler = sph + sess.packer.packetNumberGenerator.next = 0x1338 + sess.receivedPacketHandler.ReceivedPacket(1, time.Now(), true) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + sess.run() + close(done) + }() + sess.scheduleSending() + Eventually(mconn.written).Should(HaveLen(1)) + // make sure that the go routine returns + streamManager.EXPECT().CloseWithError(gomock.Any()) + sess.Close(nil) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't include a STOP_WAITING for an ACK-only packet for IETF QUIC", func() { + sess.version = versionIETFFrames + done := make(chan struct{}) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLeastUnacked() + sph.EXPECT().GetAlarmTimeout().AnyTimes() + sph.EXPECT().SendingAllowed() + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.Frames).To(HaveLen(1)) + Expect(p.Frames[0]).To(BeAssignableToTypeOf(&wire.AckFrame{})) + }) + sess.sentPacketHandler = sph + sess.packer.packetNumberGenerator.next = 0x1338 + sess.receivedPacketHandler.ReceivedPacket(1, time.Now(), true) + go func() { + defer GinkgoRecover() + sess.run() + close(done) + }() + sess.scheduleSending() + Eventually(mconn.written).Should(HaveLen(1)) + // make sure that the go routine returns + streamManager.EXPECT().CloseWithError(gomock.Any()) + sess.Close(nil) + Eventually(done).Should(BeClosed()) }) }) Context("retransmissions", func() { - var sph *mockSentPacketHandler + var sph *mockackhandler.MockSentPacketHandler BeforeEach(func() { - // a StopWaitingFrame is added, so make sure the packet number of the new package is higher than the packet number of the retransmitted packet + // a STOP_WAITING frame is added, so make sure the packet number of the new package is higher than the packet number of the retransmitted packet sess.packer.packetNumberGenerator.next = 0x1337 + 10 sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends - sph = newMockSentPacketHandler().(*mockSentPacketHandler) + sph = mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLeastUnacked().AnyTimes() sess.sentPacketHandler = sph sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure} }) Context("for handshake packets", func() { - It("retransmits an unencrypted packet", func() { + It("retransmits an unencrypted packet, and adds a STOP_WAITING frame (for gQUIC)", func() { sf := &wire.StreamFrame{StreamID: 1, Data: []byte("foobar")} - sph.retransmissionQueue = []*ackhandler.Packet{{ + swf := &wire.StopWaitingFrame{LeastUnacked: 0x1337} + sph.EXPECT().GetStopWaitingFrame(true).Return(swf) + sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ Frames: []wire.Frame{sf}, EncryptionLevel: protocol.EncryptionUnencrypted, - }} - err := sess.sendPacket() - Expect(err).ToNot(HaveOccurred()) + }) + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) + Expect(p.Frames).To(Equal([]wire.Frame{swf, sf})) + }) + sent, err := sess.sendPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) Expect(mconn.written).To(HaveLen(1)) - sentPackets := sph.sentPackets - Expect(sentPackets).To(HaveLen(1)) - Expect(sentPackets[0].EncryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) - Expect(sentPackets[0].Frames).To(HaveLen(2)) - Expect(sentPackets[0].Frames[1]).To(Equal(sf)) - swf := sentPackets[0].Frames[0].(*wire.StopWaitingFrame) - Expect(swf.LeastUnacked).To(Equal(protocol.PacketNumber(0x1337))) }) - It("retransmit a packet encrypted with the initial encryption", func() { + It("retransmits an unencrypted packet, and doesn't add a STOP_WAITING frame (for IETF QUIC)", func() { + sess.version = versionIETFFrames + sess.packer.version = versionIETFFrames sf := &wire.StreamFrame{StreamID: 1, Data: []byte("foobar")} - sph.retransmissionQueue = []*ackhandler.Packet{{ + sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ Frames: []wire.Frame{sf}, - EncryptionLevel: protocol.EncryptionSecure, - }} - err := sess.sendPacket() - Expect(err).ToNot(HaveOccurred()) + EncryptionLevel: protocol.EncryptionUnencrypted, + }) + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) + Expect(p.Frames).To(Equal([]wire.Frame{sf})) + }) + sent, err := sess.sendPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) Expect(mconn.written).To(HaveLen(1)) - sentPackets := sph.sentPackets - Expect(sentPackets).To(HaveLen(1)) - Expect(sentPackets[0].EncryptionLevel).To(Equal(protocol.EncryptionSecure)) - Expect(sentPackets[0].Frames).To(HaveLen(2)) - Expect(sentPackets[0].Frames).To(ContainElement(sf)) - }) - - It("doesn't retransmit handshake packets when the handshake is complete", func() { - sess.handshakeComplete = true - sf := &wire.StreamFrame{StreamID: 1, Data: []byte("foobar")} - sph.retransmissionQueue = []*ackhandler.Packet{{ - Frames: []wire.Frame{sf}, - EncryptionLevel: protocol.EncryptionSecure, - }} - err := sess.sendPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(mconn.written).To(BeEmpty()) }) }) Context("for packets after the handshake", func() { - It("sends a StreamFrame from a packet queued for retransmission", func() { - f := wire.StreamFrame{ + It("sends a STREAM frame from a packet queued for retransmission, and adds a STOP_WAITING (for gQUIC)", func() { + f := &wire.StreamFrame{ StreamID: 0x5, - Data: []byte("foobar1234567"), + Data: []byte("foobar"), } - p := ackhandler.Packet{ + swf := &wire.StopWaitingFrame{LeastUnacked: 10} + sph.EXPECT().GetStopWaitingFrame(true).Return(swf) + sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ PacketNumber: 0x1337, - Frames: []wire.Frame{&f}, + Frames: []wire.Frame{f}, EncryptionLevel: protocol.EncryptionForwardSecure, - } - sph.retransmissionQueue = []*ackhandler.Packet{&p} - - err := sess.sendPacket() + }) + sph.EXPECT().DequeuePacketForRetransmission() + sph.EXPECT().ShouldSendRetransmittablePacket() + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.Frames).To(Equal([]wire.Frame{swf, f})) + Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) + }) + sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) Expect(mconn.written).To(HaveLen(1)) - Expect(sph.requestedStopWaiting).To(BeTrue()) - Expect(mconn.written).To(Receive(ContainSubstring("foobar1234567"))) }) - It("sends a StreamFrame from a packet queued for retransmission", func() { + It("sends a STREAM frame from a packet queued for retransmission, and doesn't add a STOP_WAITING (for IETF QUIC)", func() { + sess.version = versionIETFFrames + sess.packer.version = versionIETFFrames + f := &wire.StreamFrame{ + StreamID: 0x5, + Data: []byte("foobar"), + } + sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ + Frames: []wire.Frame{f}, + EncryptionLevel: protocol.EncryptionForwardSecure, + }) + sph.EXPECT().DequeuePacketForRetransmission() + sph.EXPECT().ShouldSendRetransmittablePacket() + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.Frames).To(Equal([]wire.Frame{f})) + Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) + }) + sent, err := sess.sendPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) + Expect(mconn.written).To(HaveLen(1)) + }) + + It("sends a STREAM frame from a packet queued for retransmission", func() { f1 := wire.StreamFrame{ StreamID: 0x5, Data: []byte("foobar"), @@ -943,43 +942,32 @@ var _ = Describe("Session", func() { StreamID: 0x7, Data: []byte("loremipsum"), } - p1 := ackhandler.Packet{ + p1 := &ackhandler.Packet{ PacketNumber: 0x1337, Frames: []wire.Frame{&f1}, EncryptionLevel: protocol.EncryptionForwardSecure, } - p2 := ackhandler.Packet{ + p2 := &ackhandler.Packet{ PacketNumber: 0x1338, Frames: []wire.Frame{&f2}, EncryptionLevel: protocol.EncryptionForwardSecure, } - sph.retransmissionQueue = []*ackhandler.Packet{&p1, &p2} - - err := sess.sendPacket() + sph.EXPECT().DequeuePacketForRetransmission().Return(p1) + sph.EXPECT().DequeuePacketForRetransmission().Return(p2) + sph.EXPECT().DequeuePacketForRetransmission() + sph.EXPECT().GetStopWaitingFrame(true).Return(&wire.StopWaitingFrame{}) + sph.EXPECT().ShouldSendRetransmittablePacket() + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.Frames).To(HaveLen(3)) + }) + sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) Expect(mconn.written).To(HaveLen(1)) packet := <-mconn.written Expect(packet).To(ContainSubstring("foobar")) Expect(packet).To(ContainSubstring("loremipsum")) }) - - It("always attaches a StopWaiting to a packet that contains a retransmission", func() { - f := &wire.StreamFrame{ - StreamID: 0x5, - Data: bytes.Repeat([]byte{'f'}, int(1.5*float32(protocol.MaxPacketSize))), - } - sess.streamFramer.AddFrameForRetransmission(f) - - err := sess.sendPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(mconn.written).To(HaveLen(2)) - sentPackets := sph.sentPackets - Expect(sentPackets).To(HaveLen(2)) - _, ok := sentPackets[0].Frames[0].(*wire.StopWaitingFrame) - Expect(ok).To(BeTrue()) - _, ok = sentPackets[1].Frames[0].(*wire.StopWaitingFrame) - Expect(ok).To(BeTrue()) - }) }) }) @@ -1008,122 +996,67 @@ var _ = Describe("Session", func() { sess.scheduleSending() Eventually(func() int { return len(mconn.written) }).ShouldNot(BeZero()) Expect(mconn.written).To(Receive(ContainSubstring("foobar"))) + streamManager.EXPECT().CloseWithError(gomock.Any()) }) Context("scheduling sending", func() { BeforeEach(func() { sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends - sess.processTransportParameters(&handshake.TransportParameters{ - StreamFlowControlWindow: protocol.MaxByteCount, - ConnectionFlowControlWindow: protocol.MaxByteCount, - MaxStreams: 1000, - }) sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure} }) - It("sends after writing to a stream", func(done Done) { - Expect(sess.sendingScheduled).NotTo(Receive()) - s, err := sess.GetOrOpenStream(3) - Expect(err).NotTo(HaveOccurred()) + It("sends when scheduleSending is called", func() { + done := make(chan struct{}) go func() { - s.Write([]byte("foobar")) + defer GinkgoRecover() + sess.run() close(done) }() - Eventually(sess.sendingScheduled).Should(Receive()) - s.(*stream).GetDataForWriting(1000) // unblock + sess.streamFramer.AddFrameForRetransmission(&wire.StreamFrame{ + StreamID: 5, + Data: []byte("foobar"), + }) + Consistently(mconn.written).ShouldNot(Receive()) + sess.scheduleSending() + Eventually(mconn.written).Should(Receive()) + // make the go routine return + streamManager.EXPECT().CloseWithError(gomock.Any()) + sess.Close(nil) + Eventually(done).Should(BeClosed()) }) It("sets the timer to the ack timer", func() { - rph := &mockReceivedPacketHandler{ackAlarm: time.Now().Add(10 * time.Millisecond)} - rph.nextAckFrame = &wire.AckFrame{LargestAcked: 0x1337} + rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) + rph.EXPECT().GetAckFrame().Return(&wire.AckFrame{LargestAcked: 0x1337}) + rph.EXPECT().GetAckFrame() + rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond)).MinTimes(1) sess.receivedPacketHandler = rph - go sess.run() - defer sess.Close(nil) - time.Sleep(10 * time.Millisecond) - Eventually(func() int { return len(mconn.written) }).ShouldNot(BeZero()) - Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x13, 0x37})))) - }) - - Context("bundling of small packets", func() { - It("bundles two small frames of different streams into one packet", func() { - s1, err := sess.GetOrOpenStream(5) - Expect(err).NotTo(HaveOccurred()) - s2, err := sess.GetOrOpenStream(7) - Expect(err).NotTo(HaveOccurred()) - - // Put data directly into the streams - s1.(*stream).dataForWriting = []byte("foobar1") - s2.(*stream).dataForWriting = []byte("foobar2") - - sess.scheduleSending() - go sess.run() - defer sess.Close(nil) - - Eventually(mconn.written).Should(HaveLen(1)) - packet := <-mconn.written - Expect(packet).To(ContainSubstring("foobar1")) - Expect(packet).To(ContainSubstring("foobar2")) - }) - - It("sends out two big frames in two packets", func() { - s1, err := sess.GetOrOpenStream(5) - Expect(err).NotTo(HaveOccurred()) - s2, err := sess.GetOrOpenStream(7) - Expect(err).NotTo(HaveOccurred()) - go sess.run() - defer sess.Close(nil) - go func() { - defer GinkgoRecover() - s1.Write(bytes.Repeat([]byte{'e'}, 1000)) - }() - _, err = s2.Write(bytes.Repeat([]byte{'e'}, 1000)) - Expect(err).ToNot(HaveOccurred()) - Eventually(mconn.written).Should(HaveLen(2)) - }) - - It("sends out two small frames that are written to long after one another into two packets", func() { - s, err := sess.GetOrOpenStream(5) - Expect(err).NotTo(HaveOccurred()) - go sess.run() - defer sess.Close(nil) - _, err = s.Write([]byte("foobar1")) - Expect(err).NotTo(HaveOccurred()) - Eventually(mconn.written).Should(HaveLen(1)) - _, err = s.Write([]byte("foobar2")) - Expect(err).NotTo(HaveOccurred()) - Eventually(mconn.written).Should(HaveLen(2)) - }) - - It("sends a queued ACK frame only once", func() { - packetNumber := protocol.PacketNumber(0x1337) - sess.receivedPacketHandler.ReceivedPacket(packetNumber, true) - - s, err := sess.GetOrOpenStream(5) - Expect(err).NotTo(HaveOccurred()) - go sess.run() - defer sess.Close(nil) - _, err = s.Write([]byte("foobar1")) - Expect(err).NotTo(HaveOccurred()) - Eventually(mconn.written).Should(HaveLen(1)) - _, err = s.Write([]byte("foobar2")) - Expect(err).NotTo(HaveOccurred()) - - Eventually(mconn.written).Should(HaveLen(2)) - Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x13, 0x37})))) - Expect(mconn.written).ToNot(Receive(ContainSubstring(string([]byte{0x13, 0x37})))) - }) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + sess.run() + close(done) + }() + Eventually(mconn.written).Should(Receive(ContainSubstring(string([]byte{0x13, 0x37})))) + // make the go routine return + streamManager.EXPECT().CloseWithError(gomock.Any()) + sess.Close(nil) + Eventually(done).Should(BeClosed()) }) }) It("closes when crypto stream errors", func() { testErr := errors.New("crypto setup error") + streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error())) cryptoSetup.handleErr = testErr - var runErr error + done := make(chan struct{}) go func() { - runErr = sess.run() + defer GinkgoRecover() + err := sess.run() + Expect(err).To(MatchError(testErr)) + close(done) }() - Eventually(func() error { return runErr }).Should(HaveOccurred()) - Expect(runErr).To(MatchError(testErr)) + Eventually(done).Should(BeClosed()) }) Context("sending a Public Reset when receiving undecryptable packets during the handshake", func() { @@ -1145,10 +1078,14 @@ var _ = Describe("Session", func() { BeforeEach(func() { sess.unpacker = &mockUnpacker{unpackErr: qerr.Error(qerr.DecryptionFailure, "")} sess.cryptoSetup = &mockCryptoSetup{} + streamManager.EXPECT().CloseWithError(gomock.Any()).MaxTimes(1) }) It("doesn't immediately send a Public Reset after receiving too many undecryptable packets", func() { - go sess.run() + go func() { + defer GinkgoRecover() + sess.run() + }() sendUndecryptablePackets() sess.scheduleSending() Consistently(mconn.written).Should(HaveLen(0)) @@ -1157,7 +1094,10 @@ var _ = Describe("Session", func() { }) It("sets a deadline to send a Public Reset after receiving too many undecryptable packets", func() { - go sess.run() + go func() { + defer GinkgoRecover() + sess.run() + }() sendUndecryptablePackets() Eventually(func() time.Time { return sess.receivedTooManyUndecrytablePacketsTime }).Should(BeTemporally("~", time.Now(), 20*time.Millisecond)) sess.Close(nil) @@ -1165,7 +1105,10 @@ var _ = Describe("Session", func() { }) It("drops undecryptable packets when the undecrytable packet queue is full", func() { - go sess.run() + go func() { + defer GinkgoRecover() + sess.run() + }() sendUndecryptablePackets() Eventually(func() []*receivedPacket { return sess.undecryptablePackets }).Should(HaveLen(protocol.MaxUndecryptablePackets)) // check that old packets are kept, and the new packets are dropped @@ -1176,7 +1119,10 @@ var _ = Describe("Session", func() { It("sends a Public Reset after a timeout", func() { Expect(sess.receivedTooManyUndecrytablePacketsTime).To(BeZero()) - go sess.run() + go func() { + defer GinkgoRecover() + sess.run() + }() sendUndecryptablePackets() Eventually(func() time.Time { return sess.receivedTooManyUndecrytablePacketsTime }).Should(BeTemporally("~", time.Now(), time.Second)) // speed up this test by manually setting back the time when too many packets were received @@ -1189,7 +1135,10 @@ var _ = Describe("Session", func() { }) It("doesn't send a Public Reset if decrypting them suceeded during the timeout", func() { - go sess.run() + go func() { + defer GinkgoRecover() + sess.run() + }() sess.receivedTooManyUndecrytablePacketsTime = time.Now().Add(-protocol.PublicResetTimeout).Add(-time.Millisecond) sess.scheduleSending() // wake up the run loop // there are no packets in the undecryptable packet queue @@ -1202,7 +1151,10 @@ var _ = Describe("Session", func() { It("ignores undecryptable packets after the handshake is complete", func() { sess.handshakeComplete = true - go sess.run() + go func() { + defer GinkgoRecover() + sess.run() + }() sendUndecryptablePackets() Consistently(sess.undecryptablePackets).Should(BeEmpty()) Expect(sess.Close(nil)).To(Succeed()) @@ -1220,54 +1172,62 @@ var _ = Describe("Session", func() { }) }) - It("send a handshake event on the handshakeChan when the AEAD changes to secure", func(done Done) { - go sess.run() - aeadChanged <- protocol.EncryptionSecure - Eventually(sess.handshakeStatus()).Should(Receive(&handshakeEvent{encLevel: protocol.EncryptionSecure})) + It("doesn't do anything when the crypto setup says to decrypt undecryptable packets", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := sess.run() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + handshakeChan <- struct{}{} + Consistently(sess.handshakeStatus()).ShouldNot(Receive()) + // make sure the go routine returns + streamManager.EXPECT().CloseWithError(gomock.Any()) Expect(sess.Close(nil)).To(Succeed()) - close(done) + Eventually(done).Should(BeClosed()) }) - It("send a handshake event on the handshakeChan when the AEAD changes to forward-secure", func(done Done) { - go sess.run() - aeadChanged <- protocol.EncryptionForwardSecure - Eventually(sess.handshakeStatus()).Should(Receive(&handshakeEvent{encLevel: protocol.EncryptionForwardSecure})) - Expect(sess.Close(nil)).To(Succeed()) - close(done) - }) - - It("closes the handshakeChan when the handshake completes", func(done Done) { - go sess.run() - close(aeadChanged) + It("closes the handshakeChan when the handshake completes", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := sess.run() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + close(handshakeChan) Eventually(sess.handshakeStatus()).Should(BeClosed()) + // make sure the go routine returns + streamManager.EXPECT().CloseWithError(gomock.Any()) Expect(sess.Close(nil)).To(Succeed()) - close(done) + Eventually(done).Should(BeClosed()) }) - It("passes errors to the handshakeChan", func(done Done) { + It("passes errors to the handshakeChan", func() { testErr := errors.New("handshake error") - go sess.run() - Expect(sess.Close(nil)).To(Succeed()) - Expect(sess.handshakeStatus()).To(Receive(&handshakeEvent{err: testErr})) - close(done) - }) - - It("does not block if an error occurs", func(done Done) { - // this test basically tests that the handshakeChan has a capacity of 3 - // The session needs to run (and close) properly, even if no one is receiving from the handshakeChan - go sess.run() - aeadChanged <- protocol.EncryptionSecure - aeadChanged <- protocol.EncryptionForwardSecure - Expect(sess.Close(nil)).To(Succeed()) - close(done) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := sess.run() + Expect(err).To(MatchError(testErr)) + close(done) + }() + streamManager.EXPECT().CloseWithError(gomock.Any()) + sess.Close(testErr) + Expect(sess.handshakeStatus()).To(Receive(Equal(testErr))) + Eventually(done).Should(BeClosed()) }) It("process transport parameters received from the peer", func() { paramsChan := make(chan handshake.TransportParameters) sess.paramsChan = paramsChan - _, err := sess.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - go sess.run() + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + sess.run() + close(done) + }() params := handshake.TransportParameters{ MaxStreams: 123, IdleTimeout: 90 * time.Second, @@ -1275,12 +1235,14 @@ var _ = Describe("Session", func() { ConnectionFlowControlWindow: 0x5000, OmitConnectionID: true, } + streamManager.EXPECT().UpdateLimits(¶ms) paramsChan <- params Eventually(func() *handshake.TransportParameters { return sess.peerParams }).Should(Equal(¶ms)) - Eventually(func() uint32 { return sess.streamsMap.maxOutgoingStreams }).Should(Equal(uint32(123))) - // Eventually(func() (protocol.ByteCount, error) { return sess.flowControlManager.SendWindowSize(5) }).Should(Equal(protocol.ByteCount(0x5000))) Eventually(func() bool { return sess.packer.omitConnectionID }).Should(BeTrue()) + // make the go routine return + streamManager.EXPECT().CloseWithError(gomock.Any()) Expect(sess.Close(nil)).To(Succeed()) + Eventually(done).Should(BeClosed()) }) Context("keep-alives", func() { @@ -1297,34 +1259,62 @@ var _ = Describe("Session", func() { sess.config.KeepAlive = true sess.lastNetworkActivityTime = time.Now().Add(-remoteIdleTimeout / 2) sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends - go sess.run() - defer sess.Close(nil) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + sess.run() + close(done) + }() var data []byte Eventually(mconn.written).Should(Receive(&data)) // -12 because of the crypto tag. This should be 7 (the frame id for a ping frame). Expect(data[len(data)-12-1 : len(data)-12]).To(Equal([]byte{0x07})) + // make the go routine return + streamManager.EXPECT().CloseWithError(gomock.Any()) + sess.Close(nil) + Eventually(done).Should(BeClosed()) }) It("doesn't send a PING packet if keep-alive is disabled", func() { sess.handshakeComplete = true sess.config.KeepAlive = false sess.lastNetworkActivityTime = time.Now().Add(-remoteIdleTimeout / 2) - go sess.run() - defer sess.Close(nil) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + sess.run() + close(done) + }() Consistently(mconn.written).ShouldNot(Receive()) + // make the go routine return + streamManager.EXPECT().CloseWithError(gomock.Any()) + sess.Close(nil) + Eventually(done).Should(BeClosed()) }) It("doesn't send a PING if the handshake isn't completed yet", func() { sess.handshakeComplete = false sess.config.KeepAlive = true sess.lastNetworkActivityTime = time.Now().Add(-remoteIdleTimeout / 2) - go sess.run() - defer sess.Close(nil) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + sess.run() + close(done) + }() Consistently(mconn.written).ShouldNot(Receive()) + // make the go routine return + streamManager.EXPECT().CloseWithError(gomock.Any()) + sess.Close(nil) + Eventually(done).Should(BeClosed()) }) }) Context("timeouts", func() { + BeforeEach(func() { + streamManager.EXPECT().CloseWithError(gomock.Any()) + }) + It("times out due to no network activity", func(done Done) { sess.handshakeComplete = true sess.lastNetworkActivityTime = time.Now().Add(-time.Hour) @@ -1361,7 +1351,7 @@ var _ = Describe("Session", func() { It("closes the session due to the idle timeout after handshake", func() { sess.config.IdleTimeout = 0 - close(aeadChanged) + close(handshakeChan) errChan := make(chan error) go func() { defer GinkgoRecover() @@ -1384,29 +1374,19 @@ var _ = Describe("Session", func() { }, 0.5) Context("getting streams", func() { - BeforeEach(func() { - sess.processTransportParameters(&handshake.TransportParameters{MaxStreams: 1000}) - }) - It("returns a new stream", func() { + mstr := NewMockStreamI(mockCtrl) + streamManager.EXPECT().GetOrOpenStream(protocol.StreamID(11)).Return(mstr, nil) str, err := sess.GetOrOpenStream(11) Expect(err).ToNot(HaveOccurred()) - Expect(str).ToNot(BeNil()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(11))) + Expect(str).To(Equal(mstr)) }) It("returns a nil-value (not an interface with value nil) for closed streams", func() { - str, err := sess.GetOrOpenStream(9) + strI := Stream(nil) + streamManager.EXPECT().GetOrOpenStream(protocol.StreamID(1337)).Return(strI, nil) + str, err := sess.GetOrOpenStream(1337) Expect(err).ToNot(HaveOccurred()) - str.Close() - str.(*stream).Cancel(nil) - Expect(str.(*stream).Finished()).To(BeTrue()) - err = sess.streamsMap.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(sess.streamsMap.GetOrOpenStream(9)).To(BeNil()) - str, err = sess.GetOrOpenStream(9) - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(BeNil()) // make sure that the returned value is a plain nil, not an Stream with value nil _, ok := str.(Stream) Expect(ok).To(BeFalse()) @@ -1414,34 +1394,11 @@ var _ = Describe("Session", func() { // all relevant tests for this are in the streamsMap It("opens streams synchronously", func() { + mstr := NewMockStreamI(mockCtrl) + streamManager.EXPECT().OpenStreamSync().Return(mstr, nil) str, err := sess.OpenStreamSync() Expect(err).ToNot(HaveOccurred()) - Expect(str).ToNot(BeNil()) - }) - }) - - Context("counting streams", func() { - It("errors when too many streams are opened", func() { - for i := 0; i < protocol.MaxIncomingStreams; i++ { - _, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1)) - Expect(err).NotTo(HaveOccurred()) - } - _, err := sess.GetOrOpenStream(protocol.StreamID(301)) - Expect(err).To(MatchError(qerr.TooManyOpenStreams)) - }) - - It("does not error when many streams are opened and closed", func() { - for i := 2; i <= 1000; i++ { - s, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1)) - Expect(err).NotTo(HaveOccurred()) - Expect(s.Close()).To(Succeed()) - _, sentFin := s.(*stream).GetDataForWriting(1000) // trigger "sending" of the FIN bit - Expect(sentFin).To(BeTrue()) - s.(*stream).CloseRemote(0) - _, err = s.Read([]byte("a")) - Expect(err).To(MatchError(io.EOF)) - sess.streamsMap.DeleteClosedStreams() - } + Expect(str).To(Equal(mstr)) }) }) @@ -1462,30 +1419,6 @@ var _ = Describe("Session", func() { }) }) - // Context("window updates", func() { - // It("gets stream level window updates", func() { - // _, err := sess.GetOrOpenStream(3) - // Expect(err).ToNot(HaveOccurred()) - // err = sess.flowControlManager.AddBytesRead(3, protocol.ReceiveStreamFlowControlWindow) - // Expect(err).NotTo(HaveOccurred()) - // frames := sess.getWindowUpdateFrames() - // Expect(frames).To(HaveLen(1)) - // Expect(frames[0].StreamID).To(Equal(protocol.StreamID(3))) - // Expect(frames[0].ByteOffset).To(BeEquivalentTo(protocol.ReceiveStreamFlowControlWindow * 2)) - // }) - - // It("gets connection level window updates", func() { - // _, err := sess.GetOrOpenStream(5) - // Expect(err).NotTo(HaveOccurred()) - // err = sess.flowControlManager.AddBytesRead(5, protocol.ReceiveConnectionFlowControlWindow) - // Expect(err).NotTo(HaveOccurred()) - // frames := sess.getWindowUpdateFrames() - // Expect(frames).To(HaveLen(1)) - // Expect(frames[0].StreamID).To(Equal(protocol.StreamID(0))) - // Expect(frames[0].ByteOffset).To(BeEquivalentTo(protocol.ReceiveConnectionFlowControlWindow * 2)) - // }) - // }) - It("returns the local address", func() { addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} mconn.localAddr = addr @@ -1501,9 +1434,9 @@ var _ = Describe("Session", func() { var _ = Describe("Client Session", func() { var ( - sess *session - mconn *mockConnection - aeadChanged chan<- protocol.EncryptionLevel + sess *session + mconn *mockConnection + handshakeChan chan<- struct{} cryptoSetup *mockCryptoSetup ) @@ -1520,11 +1453,11 @@ var _ = Describe("Client Session", func() { _ *tls.Config, _ *handshake.TransportParameters, _ chan<- handshake.TransportParameters, - aeadChangedP chan<- protocol.EncryptionLevel, + handshakeChanP chan<- struct{}, _ protocol.VersionNumber, _ []protocol.VersionNumber, ) (handshake.CryptoSetup, error) { - aeadChanged = aeadChangedP + handshakeChan = handshakeChanP return cryptoSetup, nil } @@ -1541,13 +1474,28 @@ var _ = Describe("Client Session", func() { ) sess = sessP.(*session) Expect(err).ToNot(HaveOccurred()) - Expect(sess.streamsMap.openStreams).To(BeEmpty()) }) AfterEach(func() { newCryptoSetupClient = handshake.NewCryptoSetupClient }) + It("sends a forward-secure packet when the handshake completes", func() { + sess.packer.hasSentPacket = true + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := sess.run() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + close(handshakeChan) + Eventually(mconn.written).Should(Receive()) + //make sure the go routine returns + Expect(sess.Close(nil)).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + Context("receiving packets", func() { var hdr *wire.Header @@ -1556,10 +1504,13 @@ var _ = Describe("Client Session", func() { sess.unpacker = &mockUnpacker{} }) - It("passes the diversification nonce to the cryptoSetup", func() { + It("passes the diversification nonce to the crypto setup", func() { + done := make(chan struct{}) go func() { defer GinkgoRecover() - sess.run() + err := sess.run() + Expect(err).ToNot(HaveOccurred()) + close(done) }() hdr.PacketNumber = 5 hdr.DiversificationNonce = []byte("foobar") @@ -1567,16 +1518,7 @@ var _ = Describe("Client Session", func() { Expect(err).ToNot(HaveOccurred()) Eventually(func() []byte { return cryptoSetup.divNonce }).Should(Equal(hdr.DiversificationNonce)) Expect(sess.Close(nil)).To(Succeed()) + Eventually(done).Should(BeClosed()) }) }) - - It("does not block if an error occurs", func(done Done) { - // this test basically tests that the handshakeChan has a capacity of 3 - // The session needs to run (and close) properly, even if no one is receiving from the handshakeChan - go sess.run() - aeadChanged <- protocol.EncryptionSecure - aeadChanged <- protocol.EncryptionForwardSecure - Expect(sess.Close(nil)).To(Succeed()) - close(done) - }) }) diff --git a/vendor/github.com/lucas-clemente/quic-go/stream.go b/vendor/github.com/lucas-clemente/quic-go/stream.go index 0e4f34e..6a4c6ce 100644 --- a/vendor/github.com/lucas-clemente/quic-go/stream.go +++ b/vendor/github.com/lucas-clemente/quic-go/stream.go @@ -1,88 +1,85 @@ package quic import ( - "context" - "fmt" - "io" "net" "sync" "time" "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" ) +const ( + errorCodeStopping protocol.ApplicationErrorCode = 0 + errorCodeStoppingGQUIC protocol.ApplicationErrorCode = 7 +) + +// The streamSender is notified by the stream about various events. +type streamSender interface { + queueControlFrame(wire.Frame) + onHasWindowUpdate(protocol.StreamID) + onHasStreamData(protocol.StreamID) + onStreamCompleted(protocol.StreamID) +} + +// Each of the both stream halves gets its own uniStreamSender. +// This is necessary in order to keep track when both halves have been completed. +type uniStreamSender struct { + streamSender + onStreamCompletedImpl func() +} + +func (s *uniStreamSender) queueControlFrame(f wire.Frame) { + s.streamSender.queueControlFrame(f) +} + +func (s *uniStreamSender) onHasWindowUpdate(id protocol.StreamID) { + s.streamSender.onHasWindowUpdate(id) +} + +func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) { + s.streamSender.onHasStreamData(id) +} + +func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) { + s.onStreamCompletedImpl() +} + +var _ streamSender = &uniStreamSender{} + type streamI interface { Stream - - AddStreamFrame(*wire.StreamFrame) error - RegisterRemoteError(error, protocol.ByteCount) error - HasDataForWriting() bool - GetDataForWriting(maxBytes protocol.ByteCount) (data []byte, shouldSendFin bool) - GetWriteOffset() protocol.ByteCount - Finished() bool - Cancel(error) - // methods needed for flow control - GetWindowUpdate() protocol.ByteCount - UpdateSendWindow(protocol.ByteCount) - IsFlowControlBlocked() bool + closeForShutdown(error) + // for receiving + handleStreamFrame(*wire.StreamFrame) error + handleRstStreamFrame(*wire.RstStreamFrame) error + getWindowUpdate() protocol.ByteCount + // for sending + handleStopSendingFrame(*wire.StopSendingFrame) + popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool) + handleMaxStreamDataFrame(*wire.MaxStreamDataFrame) } -type cryptoStream interface { - streamI - SetReadOffset(protocol.ByteCount) -} +var _ receiveStreamI = (streamI)(nil) +var _ sendStreamI = (streamI)(nil) // A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface // // Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually. type stream struct { - mutex sync.Mutex + receiveStream + sendStream - ctx context.Context - ctxCancel context.CancelFunc + completedMutex sync.Mutex + sender streamSender + receiveStreamCompleted bool + sendStreamCompleted bool - streamID protocol.StreamID - onData func() - // onReset is a callback that should send a RST_STREAM - onReset func(protocol.StreamID, protocol.ByteCount) - - readPosInFrame int - writeOffset protocol.ByteCount - readOffset protocol.ByteCount - - // Once set, the errors must not be changed! - err error - - // cancelled is set when Cancel() is called - cancelled utils.AtomicBool - // finishedReading is set once we read a frame with a FinBit - finishedReading utils.AtomicBool - // finisedWriting is set once Close() is called - finishedWriting utils.AtomicBool - // resetLocally is set if Reset() is called - resetLocally utils.AtomicBool - // resetRemotely is set if RegisterRemoteError() is called - resetRemotely utils.AtomicBool - - frameQueue *streamFrameSorter - readChan chan struct{} - readDeadline time.Time - - dataForWriting []byte - finSent utils.AtomicBool - rstSent utils.AtomicBool - writeChan chan struct{} - writeDeadline time.Time - - flowController flowcontrol.StreamFlowController - version protocol.VersionNumber + version protocol.VersionNumber } var _ Stream = &stream{} -var _ streamI = &stream{} type deadlineError struct{} @@ -92,290 +89,58 @@ func (deadlineError) Timeout() bool { return true } var errDeadline net.Error = &deadlineError{} +type streamCanceledError struct { + error + errorCode protocol.ApplicationErrorCode +} + +func (streamCanceledError) Canceled() bool { return true } +func (e streamCanceledError) ErrorCode() protocol.ApplicationErrorCode { return e.errorCode } + +var _ StreamError = &streamCanceledError{} + // newStream creates a new Stream -func newStream(StreamID protocol.StreamID, - onData func(), - onReset func(protocol.StreamID, protocol.ByteCount), +func newStream(streamID protocol.StreamID, + sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber, ) *stream { - s := &stream{ - onData: onData, - onReset: onReset, - streamID: StreamID, - flowController: flowController, - frameQueue: newStreamFrameSorter(), - readChan: make(chan struct{}, 1), - writeChan: make(chan struct{}, 1), - version: version, + s := &stream{sender: sender} + senderForSendStream := &uniStreamSender{ + streamSender: sender, + onStreamCompletedImpl: func() { + s.completedMutex.Lock() + s.sendStreamCompleted = true + s.checkIfCompleted() + s.completedMutex.Unlock() + }, } - s.ctx, s.ctxCancel = context.WithCancel(context.Background()) + s.sendStream = *newSendStream(streamID, senderForSendStream, flowController, version) + senderForReceiveStream := &uniStreamSender{ + streamSender: sender, + onStreamCompletedImpl: func() { + s.completedMutex.Lock() + s.receiveStreamCompleted = true + s.checkIfCompleted() + s.completedMutex.Unlock() + }, + } + s.receiveStream = *newReceiveStream(streamID, senderForReceiveStream, flowController) return s } -// Read implements io.Reader. It is not thread safe! -func (s *stream) Read(p []byte) (int, error) { - s.mutex.Lock() - err := s.err - s.mutex.Unlock() - if s.cancelled.Get() || s.resetLocally.Get() { - return 0, err - } - if s.finishedReading.Get() { - return 0, io.EOF - } - - bytesRead := 0 - for bytesRead < len(p) { - s.mutex.Lock() - frame := s.frameQueue.Head() - if frame == nil && bytesRead > 0 { - err = s.err - s.mutex.Unlock() - return bytesRead, err - } - - var err error - for { - // Stop waiting on errors - if s.resetLocally.Get() || s.cancelled.Get() { - err = s.err - break - } - - deadline := s.readDeadline - if !deadline.IsZero() && !time.Now().Before(deadline) { - err = errDeadline - break - } - - if frame != nil { - s.readPosInFrame = int(s.readOffset - frame.Offset) - break - } - - s.mutex.Unlock() - if deadline.IsZero() { - <-s.readChan - } else { - select { - case <-s.readChan: - case <-time.After(deadline.Sub(time.Now())): - } - } - s.mutex.Lock() - frame = s.frameQueue.Head() - } - s.mutex.Unlock() - - if err != nil { - return bytesRead, err - } - - m := utils.Min(len(p)-bytesRead, int(frame.DataLen())-s.readPosInFrame) - - if bytesRead > len(p) { - return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) - } - if s.readPosInFrame > int(frame.DataLen()) { - return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, frame.DataLen()) - } - copy(p[bytesRead:], frame.Data[s.readPosInFrame:]) - - s.readPosInFrame += m - bytesRead += m - s.readOffset += protocol.ByteCount(m) - - // when a RST_STREAM was received, the was already informed about the final byteOffset for this stream - if !s.resetRemotely.Get() { - s.flowController.AddBytesRead(protocol.ByteCount(m)) - } - s.onData() // so that a possible WINDOW_UPDATE is sent - - if s.readPosInFrame >= int(frame.DataLen()) { - fin := frame.FinBit - s.mutex.Lock() - s.frameQueue.Pop() - s.mutex.Unlock() - if fin { - s.finishedReading.Set(true) - return bytesRead, io.EOF - } - } - } - - return bytesRead, nil +// need to define StreamID() here, since both receiveStream and readStream have a StreamID() +func (s *stream) StreamID() protocol.StreamID { + // the result is same for receiveStream and sendStream + return s.sendStream.StreamID() } -func (s *stream) Write(p []byte) (int, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.resetLocally.Get() || s.err != nil { - return 0, s.err - } - if s.finishedWriting.Get() { - return 0, fmt.Errorf("write on closed stream %d", s.streamID) - } - if len(p) == 0 { - return 0, nil - } - - s.dataForWriting = make([]byte, len(p)) - copy(s.dataForWriting, p) - s.onData() - - var err error - for { - deadline := s.writeDeadline - if !deadline.IsZero() && !time.Now().Before(deadline) { - err = errDeadline - break - } - if s.dataForWriting == nil || s.err != nil { - break - } - - s.mutex.Unlock() - if deadline.IsZero() { - <-s.writeChan - } else { - select { - case <-s.writeChan: - case <-time.After(deadline.Sub(time.Now())): - } - } - s.mutex.Lock() - } - - if err != nil { - return 0, err - } - if s.err != nil { - return len(p) - len(s.dataForWriting), s.err - } - return len(p), nil -} - -func (s *stream) GetWriteOffset() protocol.ByteCount { - return s.writeOffset -} - -// HasDataForWriting says if there's stream available to be dequeued for writing -func (s *stream) HasDataForWriting() bool { - s.mutex.Lock() - hasData := s.err == nil && // nothing should be sent if an error occurred - (len(s.dataForWriting) > 0 || // there is data queued for sending - s.finishedWriting.Get() && !s.finSent.Get()) // if there is no data, but writing finished and the FIN hasn't been sent yet - s.mutex.Unlock() - return hasData -} - -func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) { - data, shouldSendFin := s.getDataForWritingImpl(maxBytes) - if shouldSendFin { - s.finSent.Set(true) - } - return data, shouldSendFin -} - -func (s *stream) getDataForWritingImpl(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.err != nil || s.dataForWriting == nil { - return nil, s.finishedWriting.Get() && !s.finSent.Get() - } - - // TODO(#657): Flow control for the crypto stream - if s.streamID != s.version.CryptoStreamID() { - maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize()) - } - if maxBytes == 0 { - return nil, false - } - - var ret []byte - if protocol.ByteCount(len(s.dataForWriting)) > maxBytes { - ret = s.dataForWriting[:maxBytes] - s.dataForWriting = s.dataForWriting[maxBytes:] - } else { - ret = s.dataForWriting - s.dataForWriting = nil - s.signalWrite() - } - s.writeOffset += protocol.ByteCount(len(ret)) - s.flowController.AddBytesSent(protocol.ByteCount(len(ret))) - return ret, s.finishedWriting.Get() && s.dataForWriting == nil && !s.finSent.Get() -} - -// Close implements io.Closer func (s *stream) Close() error { - s.finishedWriting.Set(true) - s.ctxCancel() - s.onData() - return nil -} - -func (s *stream) shouldSendReset() bool { - if s.rstSent.Get() { - return false - } - return (s.resetLocally.Get() || s.resetRemotely.Get()) && !s.finishedWriteAndSentFin() -} - -// AddStreamFrame adds a new stream frame -func (s *stream) AddStreamFrame(frame *wire.StreamFrame) error { - maxOffset := frame.Offset + frame.DataLen() - if err := s.flowController.UpdateHighestReceived(maxOffset, frame.FinBit); err != nil { + if err := s.sendStream.Close(); err != nil { return err } - - s.mutex.Lock() - defer s.mutex.Unlock() - if err := s.frameQueue.Push(frame); err != nil && err != errDuplicateStreamData { - return err - } - s.signalRead() - return nil -} - -// signalRead performs a non-blocking send on the readChan -func (s *stream) signalRead() { - select { - case s.readChan <- struct{}{}: - default: - } -} - -// signalRead performs a non-blocking send on the writeChan -func (s *stream) signalWrite() { - select { - case s.writeChan <- struct{}{}: - default: - } -} - -func (s *stream) SetReadDeadline(t time.Time) error { - s.mutex.Lock() - oldDeadline := s.readDeadline - s.readDeadline = t - s.mutex.Unlock() - // if the new deadline is before the currently set deadline, wake up Read() - if t.Before(oldDeadline) { - s.signalRead() - } - return nil -} - -func (s *stream) SetWriteDeadline(t time.Time) error { - s.mutex.Lock() - oldDeadline := s.writeDeadline - s.writeDeadline = t - s.mutex.Unlock() - if t.Before(oldDeadline) { - s.signalWrite() - } + // in gQUIC, we need to send a RST_STREAM with the final offset if CancelRead() was called + s.receiveStream.onClose(s.sendStream.getWriteOffset()) return nil } @@ -385,107 +150,31 @@ func (s *stream) SetDeadline(t time.Time) error { return nil } -// CloseRemote makes the stream receive a "virtual" FIN stream frame at a given offset -func (s *stream) CloseRemote(offset protocol.ByteCount) { - s.AddStreamFrame(&wire.StreamFrame{FinBit: true, Offset: offset}) +// CloseForShutdown closes a stream abruptly. +// It makes Read and Write unblock (and return the error) immediately. +// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. +func (s *stream) closeForShutdown(err error) { + s.sendStream.closeForShutdown(err) + s.receiveStream.closeForShutdown(err) } -// Cancel is called by session to indicate that an error occurred -// The stream should will be closed immediately -func (s *stream) Cancel(err error) { - s.mutex.Lock() - s.cancelled.Set(true) - s.ctxCancel() - // errors must not be changed! - if s.err == nil { - s.err = err - s.signalRead() - s.signalWrite() - } - s.mutex.Unlock() -} - -// resets the stream locally -func (s *stream) Reset(err error) { - if s.resetLocally.Get() { - return - } - s.mutex.Lock() - s.resetLocally.Set(true) - s.ctxCancel() - // errors must not be changed! - if s.err == nil { - s.err = err - s.signalRead() - s.signalWrite() - } - if s.shouldSendReset() { - s.onReset(s.streamID, s.writeOffset) - s.rstSent.Set(true) - } - s.mutex.Unlock() -} - -// resets the stream remotely -func (s *stream) RegisterRemoteError(err error, offset protocol.ByteCount) error { - if s.resetRemotely.Get() { - return nil - } - s.mutex.Lock() - s.resetRemotely.Set(true) - s.ctxCancel() - // errors must not be changed! - if s.err == nil { - s.err = err - s.signalWrite() - } - if err := s.flowController.UpdateHighestReceived(offset, true); err != nil { +func (s *stream) handleRstStreamFrame(frame *wire.RstStreamFrame) error { + if err := s.receiveStream.handleRstStreamFrame(frame); err != nil { return err } - if s.shouldSendReset() { - s.onReset(s.streamID, s.writeOffset) - s.rstSent.Set(true) + if !s.version.UsesIETFFrameFormat() { + s.handleStopSendingFrame(&wire.StopSendingFrame{ + StreamID: s.StreamID(), + ErrorCode: frame.ErrorCode, + }) } - s.mutex.Unlock() return nil } -func (s *stream) finishedWriteAndSentFin() bool { - return s.finishedWriting.Get() && s.finSent.Get() -} - -func (s *stream) Finished() bool { - return s.cancelled.Get() || - (s.finishedReading.Get() && s.finishedWriteAndSentFin()) || - (s.resetRemotely.Get() && s.rstSent.Get()) || - (s.finishedReading.Get() && s.rstSent.Get()) || - (s.finishedWriteAndSentFin() && s.resetRemotely.Get()) -} - -func (s *stream) Context() context.Context { - return s.ctx -} - -func (s *stream) StreamID() protocol.StreamID { - return s.streamID -} - -func (s *stream) UpdateSendWindow(n protocol.ByteCount) { - s.flowController.UpdateSendWindow(n) -} - -func (s *stream) IsFlowControlBlocked() bool { - return s.flowController.IsBlocked() -} - -func (s *stream) GetWindowUpdate() protocol.ByteCount { - return s.flowController.GetWindowUpdate() -} - -// SetReadOffset sets the read offset. -// It is only needed for the crypto stream. -// It must not be called concurrently with any other stream methods, especially Read and Write. -func (s *stream) SetReadOffset(offset protocol.ByteCount) { - s.readOffset = offset - s.frameQueue.readPosition = offset +// checkIfCompleted is called from the uniStreamSender, when one of the stream halves is completed. +// It makes sure that the onStreamCompleted callback is only called if both receive and send side have completed. +func (s *stream) checkIfCompleted() { + if s.sendStreamCompleted && s.receiveStreamCompleted { + s.sender.onStreamCompleted(s.StreamID()) + } } diff --git a/vendor/github.com/lucas-clemente/quic-go/stream_framer.go b/vendor/github.com/lucas-clemente/quic-go/stream_framer.go index e275fcc..c66815c 100644 --- a/vendor/github.com/lucas-clemente/quic-go/stream_framer.go +++ b/vendor/github.com/lucas-clemente/quic-go/stream_framer.go @@ -1,33 +1,35 @@ package quic import ( - "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "sync" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" ) type streamFramer struct { - streamsMap *streamsMap - cryptoStream streamI + streamGetter streamGetter + cryptoStream cryptoStreamI version protocol.VersionNumber - connFlowController flowcontrol.ConnectionFlowController - retransmissionQueue []*wire.StreamFrame - blockedFrameQueue []wire.Frame + + streamQueueMutex sync.Mutex + activeStreams map[protocol.StreamID]struct{} + streamQueue []protocol.StreamID + hasCryptoStreamData bool } func newStreamFramer( - cryptoStream streamI, - streamsMap *streamsMap, - cfc flowcontrol.ConnectionFlowController, + cryptoStream cryptoStreamI, + streamGetter streamGetter, v protocol.VersionNumber, ) *streamFramer { return &streamFramer{ - streamsMap: streamsMap, - cryptoStream: cryptoStream, - connFlowController: cfc, - version: v, + streamGetter: streamGetter, + cryptoStream: cryptoStream, + activeStreams: make(map[protocol.StreamID]struct{}), + version: v, } } @@ -35,114 +37,101 @@ func (f *streamFramer) AddFrameForRetransmission(frame *wire.StreamFrame) { f.retransmissionQueue = append(f.retransmissionQueue, frame) } +func (f *streamFramer) AddActiveStream(id protocol.StreamID) { + if id == f.version.CryptoStreamID() { // the crypto stream is handled separately + f.streamQueueMutex.Lock() + f.hasCryptoStreamData = true + f.streamQueueMutex.Unlock() + return + } + f.streamQueueMutex.Lock() + if _, ok := f.activeStreams[id]; !ok { + f.streamQueue = append(f.streamQueue, id) + f.activeStreams[id] = struct{}{} + } + f.streamQueueMutex.Unlock() +} + func (f *streamFramer) PopStreamFrames(maxLen protocol.ByteCount) []*wire.StreamFrame { fs, currentLen := f.maybePopFramesForRetransmission(maxLen) return append(fs, f.maybePopNormalFrames(maxLen-currentLen)...) } -func (f *streamFramer) PopBlockedFrame() wire.Frame { - if len(f.blockedFrameQueue) == 0 { - return nil - } - frame := f.blockedFrameQueue[0] - f.blockedFrameQueue = f.blockedFrameQueue[1:] - return frame -} - func (f *streamFramer) HasFramesForRetransmission() bool { return len(f.retransmissionQueue) > 0 } -func (f *streamFramer) HasCryptoStreamFrame() bool { - return f.cryptoStream.HasDataForWriting() +func (f *streamFramer) HasCryptoStreamData() bool { + f.streamQueueMutex.Lock() + hasCryptoStreamData := f.hasCryptoStreamData + f.streamQueueMutex.Unlock() + return hasCryptoStreamData } -// TODO(lclemente): This is somewhat duplicate with the normal path for generating frames. func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.StreamFrame { - if !f.HasCryptoStreamFrame() { - return nil - } - frame := &wire.StreamFrame{ - StreamID: f.cryptoStream.StreamID(), - Offset: f.cryptoStream.GetWriteOffset(), - } - frameHeaderBytes, _ := frame.MinLength(f.version) // can never error - frame.Data, frame.FinBit = f.cryptoStream.GetDataForWriting(maxLen - frameHeaderBytes) + f.streamQueueMutex.Lock() + frame, hasMoreData := f.cryptoStream.popStreamFrame(maxLen) + f.hasCryptoStreamData = hasMoreData + f.streamQueueMutex.Unlock() return frame } -func (f *streamFramer) maybePopFramesForRetransmission(maxLen protocol.ByteCount) (res []*wire.StreamFrame, currentLen protocol.ByteCount) { +func (f *streamFramer) maybePopFramesForRetransmission(maxTotalLen protocol.ByteCount) (res []*wire.StreamFrame, currentLen protocol.ByteCount) { for len(f.retransmissionQueue) > 0 { frame := f.retransmissionQueue[0] frame.DataLenPresent = true - frameHeaderLen, _ := frame.MinLength(f.version) // can never error - if currentLen+frameHeaderLen >= maxLen { + frameHeaderLen := frame.MinLength(f.version) // can never error + maxLen := maxTotalLen - currentLen + if frameHeaderLen+frame.DataLen() > maxLen && maxLen < protocol.MinStreamFrameSize { break } - currentLen += frameHeaderLen - - splitFrame := maybeSplitOffFrame(frame, maxLen-currentLen) + splitFrame := maybeSplitOffFrame(frame, maxLen-frameHeaderLen) if splitFrame != nil { // StreamFrame was split res = append(res, splitFrame) - currentLen += splitFrame.DataLen() + currentLen += frameHeaderLen + splitFrame.DataLen() break } f.retransmissionQueue = f.retransmissionQueue[1:] res = append(res, frame) - currentLen += frame.DataLen() + currentLen += frameHeaderLen + frame.DataLen() } return } -func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []*wire.StreamFrame) { - frame := &wire.StreamFrame{DataLenPresent: true} +func (f *streamFramer) maybePopNormalFrames(maxTotalLen protocol.ByteCount) []*wire.StreamFrame { var currentLen protocol.ByteCount - - fn := func(s streamI) (bool, error) { - if s == nil { - return true, nil + var frames []*wire.StreamFrame + f.streamQueueMutex.Lock() + // pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet + numActiveStreams := len(f.streamQueue) + for i := 0; i < numActiveStreams; i++ { + if maxTotalLen-currentLen < protocol.MinStreamFrameSize { + break } - - frame.StreamID = s.StreamID() - frame.Offset = s.GetWriteOffset() - // not perfect, but thread-safe since writeOffset is only written when getting data - frameHeaderBytes, _ := frame.MinLength(f.version) // can never error - if currentLen+frameHeaderBytes > maxBytes { - return false, nil // theoretically, we could find another stream that fits, but this is quite unlikely, so we stop here + id := f.streamQueue[0] + f.streamQueue = f.streamQueue[1:] + str, err := f.streamGetter.GetOrOpenSendStream(id) + if err != nil { // can happen if the stream completed after it said it had data + delete(f.activeStreams, id) + continue } - maxLen := maxBytes - currentLen - frameHeaderBytes - - if s.HasDataForWriting() { - frame.Data, frame.FinBit = s.GetDataForWriting(maxLen) + frame, hasMoreData := str.popStreamFrame(maxTotalLen - currentLen) + if hasMoreData { // put the stream back in the queue (at the end) + f.streamQueue = append(f.streamQueue, id) + } else { // no more data to send. Stream is not active any more + delete(f.activeStreams, id) } - if len(frame.Data) == 0 && !frame.FinBit { - return true, nil + if frame == nil { // can happen if the receiveStream was canceled after it said it had data + continue } - - // Finally, check if we are now FC blocked and should queue a BLOCKED frame - if !frame.FinBit && s.IsFlowControlBlocked() { - f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.StreamBlockedFrame{StreamID: s.StreamID()}) - } - if f.connFlowController.IsBlocked() { - f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.BlockedFrame{}) - } - - res = append(res, frame) - currentLen += frameHeaderBytes + frame.DataLen() - - if currentLen == maxBytes { - return false, nil - } - - frame = &wire.StreamFrame{DataLenPresent: true} - return true, nil + frames = append(frames, frame) + currentLen += frame.MinLength(f.version) + frame.DataLen() } - - f.streamsMap.RoundRobinIterate(fn) - return + f.streamQueueMutex.Unlock() + return frames } // maybeSplitOffFrame removes the first n bytes and returns them as a separate frame. If n >= len(frame), nil is returned and nothing is modified. diff --git a/vendor/github.com/lucas-clemente/quic-go/stream_framer_test.go b/vendor/github.com/lucas-clemente/quic-go/stream_framer_test.go index 0dabce5..1a4002f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/stream_framer_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/stream_framer_test.go @@ -2,9 +2,9 @@ package quic import ( "bytes" + "errors" "github.com/golang/mock/gomock" - "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" @@ -21,12 +21,13 @@ var _ = Describe("Stream Framer", func() { var ( retransmittedFrame1, retransmittedFrame2 *wire.StreamFrame framer *streamFramer - streamsMap *streamsMap - stream1, stream2 *mocks.MockStreamI - connFC *mocks.MockConnectionFlowController + cryptoStream *MockCryptoStream + stream1, stream2 *MockSendStreamI + streamGetter *MockStreamGetter ) BeforeEach(func() { + streamGetter = NewMockStreamGetter(mockCtrl) retransmittedFrame1 = &wire.StreamFrame{ StreamID: 5, Data: []byte{0x13, 0x37}, @@ -36,25 +37,14 @@ var _ = Describe("Stream Framer", func() { Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, } - stream1 = mocks.NewMockStreamI(mockCtrl) + stream1 = NewMockSendStreamI(mockCtrl) stream1.EXPECT().StreamID().Return(protocol.StreamID(5)).AnyTimes() - stream2 = mocks.NewMockStreamI(mockCtrl) + stream2 = NewMockSendStreamI(mockCtrl) stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes() - - streamsMap = newStreamsMap(nil, protocol.PerspectiveServer, versionGQUICFrames) - streamsMap.putStream(stream1) - streamsMap.putStream(stream2) - - connFC = mocks.NewMockConnectionFlowController(mockCtrl) - framer = newStreamFramer(nil, streamsMap, connFC, versionGQUICFrames) + cryptoStream = NewMockCryptoStream(mockCtrl) + framer = newStreamFramer(cryptoStream, streamGetter, versionGQUICFrames) }) - setNoData := func(str *mocks.MockStreamI) { - str.EXPECT().HasDataForWriting().Return(false).AnyTimes() - str.EXPECT().GetDataForWriting(gomock.Any()).Return(nil, false).AnyTimes() - str.EXPECT().GetWriteOffset().AnyTimes() - } - It("says if it has retransmissions", func() { Expect(framer.HasFramesForRetransmission()).To(BeFalse()) framer.AddFrameForRetransmission(retransmittedFrame1) @@ -62,119 +52,220 @@ var _ = Describe("Stream Framer", func() { }) It("sets the DataLenPresent for dequeued retransmitted frames", func() { - setNoData(stream1) - setNoData(stream2) framer.AddFrameForRetransmission(retransmittedFrame1) fs := framer.PopStreamFrames(protocol.MaxByteCount) Expect(fs).To(HaveLen(1)) Expect(fs[0].DataLenPresent).To(BeTrue()) }) - It("sets the DataLenPresent for dequeued normal frames", func() { - connFC.EXPECT().IsBlocked() - setNoData(stream2) - stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) - stream1.EXPECT().IsFlowControlBlocked() - fs := framer.PopStreamFrames(protocol.MaxByteCount) - Expect(fs).To(HaveLen(1)) - Expect(fs[0].DataLenPresent).To(BeTrue()) + Context("handling the crypto stream", func() { + It("says if it has crypto stream data", func() { + Expect(framer.HasCryptoStreamData()).To(BeFalse()) + framer.AddActiveStream(framer.version.CryptoStreamID()) + Expect(framer.HasCryptoStreamData()).To(BeTrue()) + }) + + It("says that it doesn't have crypto stream data after popping all data", func() { + streamID := framer.version.CryptoStreamID() + f := &wire.StreamFrame{ + StreamID: streamID, + Data: []byte("foobar"), + } + cryptoStream.EXPECT().popStreamFrame(protocol.ByteCount(1000)).Return(f, false) + framer.AddActiveStream(streamID) + Expect(framer.PopCryptoStreamFrame(1000)).To(Equal(f)) + Expect(framer.HasCryptoStreamData()).To(BeFalse()) + }) + + It("says that it has more crypto stream data if not all data was popped", func() { + streamID := framer.version.CryptoStreamID() + f := &wire.StreamFrame{ + StreamID: streamID, + Data: []byte("foobar"), + } + cryptoStream.EXPECT().popStreamFrame(protocol.ByteCount(1000)).Return(f, true) + framer.AddActiveStream(streamID) + Expect(framer.PopCryptoStreamFrame(1000)).To(Equal(f)) + Expect(framer.HasCryptoStreamData()).To(BeTrue()) + }) }) Context("Popping", func() { - BeforeEach(func() { - // nothing is blocked here - connFC.EXPECT().IsBlocked().AnyTimes() - stream1.EXPECT().IsFlowControlBlocked().Return(false).AnyTimes() - stream2.EXPECT().IsFlowControlBlocked().Return(false).AnyTimes() - }) - It("returns nil when popping an empty framer", func() { - setNoData(stream1) - setNoData(stream2) Expect(framer.PopStreamFrames(1000)).To(BeEmpty()) }) It("pops frames for retransmission", func() { - setNoData(stream1) - setNoData(stream2) framer.AddFrameForRetransmission(retransmittedFrame1) framer.AddFrameForRetransmission(retransmittedFrame2) fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(2)) - Expect(fs[0]).To(Equal(retransmittedFrame1)) - Expect(fs[1]).To(Equal(retransmittedFrame2)) + Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1, retransmittedFrame2})) + // make sure the frames are actually removed, and not returned a second time Expect(framer.PopStreamFrames(1000)).To(BeEmpty()) }) - It("returns normal frames", func() { - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetWriteOffset() - setNoData(stream2) - fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(1)) - Expect(fs[0].StreamID).To(Equal(stream1.StreamID())) - Expect(fs[0].Data).To(Equal([]byte("foobar"))) - Expect(fs[0].FinBit).To(BeFalse()) - }) - - It("returns multiple normal frames", func() { - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetWriteOffset() - stream2.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobaz"), false) - stream2.EXPECT().HasDataForWriting().Return(true) - stream2.EXPECT().GetWriteOffset() - fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(2)) - // Swap if we dequeued in other order - if fs[0].StreamID != stream1.StreamID() { - fs[0], fs[1] = fs[1], fs[0] - } - Expect(fs[0].StreamID).To(Equal(stream1.StreamID())) - Expect(fs[0].Data).To(Equal([]byte("foobar"))) - Expect(fs[1].StreamID).To(Equal(stream2.StreamID())) - Expect(fs[1].Data).To(Equal([]byte("foobaz"))) - }) - - It("returns retransmission frames before normal frames", func() { - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetWriteOffset() - setNoData(stream2) - framer.AddFrameForRetransmission(retransmittedFrame1) - fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(2)) - Expect(fs[0]).To(Equal(retransmittedFrame1)) - Expect(fs[1].StreamID).To(Equal(stream1.StreamID())) - }) - - It("does not pop empty frames", func() { - stream1.EXPECT().HasDataForWriting().Return(false) - stream1.EXPECT().GetWriteOffset() - setNoData(stream2) - fs := framer.PopStreamFrames(5) + It("doesn't pop frames for retransmission, if the size would be smaller than the minimum STREAM frame size", func() { + framer.AddFrameForRetransmission(&wire.StreamFrame{ + StreamID: id1, + Data: bytes.Repeat([]byte{'a'}, int(protocol.MinStreamFrameSize)), + }) + fs := framer.PopStreamFrames(protocol.MinStreamFrameSize - 1) Expect(fs).To(BeEmpty()) }) - It("uses the round-robin scheduling", func() { - streamFrameHeaderLen := protocol.ByteCount(4) - stream1.EXPECT().GetDataForWriting(10-streamFrameHeaderLen).Return(bytes.Repeat([]byte("f"), int(10-streamFrameHeaderLen)), false) - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetWriteOffset() - stream2.EXPECT().GetDataForWriting(protocol.ByteCount(10-streamFrameHeaderLen)).Return(bytes.Repeat([]byte("e"), int(10-streamFrameHeaderLen)), false) - stream2.EXPECT().HasDataForWriting().Return(true) - stream2.EXPECT().GetWriteOffset() - fs := framer.PopStreamFrames(10) - Expect(fs).To(HaveLen(1)) - // it doesn't matter here if this data is from stream1 or from stream2... - firstStreamID := fs[0].StreamID - fs = framer.PopStreamFrames(10) - Expect(fs).To(HaveLen(1)) - // ... but the data popped this time has to be from the other stream - Expect(fs[0].StreamID).ToNot(Equal(firstStreamID)) + It("pops frames for retransmission, even if the remaining space in the packet is too small, if the frame doesn't need to be split", func() { + framer.AddFrameForRetransmission(retransmittedFrame1) + fs := framer.PopStreamFrames(protocol.MinStreamFrameSize - 1) + Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1})) + }) + + It("pops frames for retransmission, if the remaining size is the miniumum STREAM frame size", func() { + framer.AddFrameForRetransmission(retransmittedFrame1) + fs := framer.PopStreamFrames(protocol.MinStreamFrameSize) + Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1})) + }) + + It("returns normal frames", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) + f := &wire.StreamFrame{ + StreamID: id1, + Data: []byte("foobar"), + Offset: 42, + } + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) + framer.AddActiveStream(id1) + fs := framer.PopStreamFrames(1000) + Expect(fs).To(Equal([]*wire.StreamFrame{f})) + }) + + It("skips a stream that was reported active, but was completed shortly after", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(nil, errors.New("stream was already deleted")) + streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil) + f := &wire.StreamFrame{ + StreamID: id2, + Data: []byte("foobar"), + } + stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) + framer.AddActiveStream(id1) + framer.AddActiveStream(id2) + Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f})) + }) + + It("skips a stream that was reported active, but doesn't have any data", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) + streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil) + f := &wire.StreamFrame{ + StreamID: id2, + Data: []byte("foobar"), + } + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(nil, false) + stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) + framer.AddActiveStream(id1) + framer.AddActiveStream(id2) + Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f})) + }) + + It("pops from a stream multiple times, if it has enough data", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2) + f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")} + f2 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")} + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, true) + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false) + framer.AddActiveStream(id1) // only add it once + Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f1})) + Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f2})) + // no further calls to popStreamFrame, after popStreamFrame said there's no more data + Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(BeNil()) + }) + + It("re-queues a stream at the end, if it has enough data", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2) + streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil) + f11 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")} + f12 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")} + f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")} + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f11, true) + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f12, false) + stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false) + framer.AddActiveStream(id1) // only add it once + framer.AddActiveStream(id2) + Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f11})) // first a frame from stream 1 + Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f2})) // then a frame from stream 2 + Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f12})) // then another frame from stream 1 + }) + + It("only dequeues data from each stream once per packet", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) + streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil) + f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")} + f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")} + // both streams have more data, and will be re-queued + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, true) + stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, true) + framer.AddActiveStream(id1) + framer.AddActiveStream(id2) + Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f1, f2})) + }) + + It("returns multiple normal frames in the order they were reported active", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) + streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil) + f1 := &wire.StreamFrame{Data: []byte("foobar")} + f2 := &wire.StreamFrame{Data: []byte("foobaz")} + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, false) + stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false) + framer.AddActiveStream(id2) + framer.AddActiveStream(id1) + Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f2, f1})) + }) + + It("only asks a stream for data once, even if it was reported active multiple times", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) + f := &wire.StreamFrame{Data: []byte("foobar")} + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) // only one call to this function + framer.AddActiveStream(id1) + framer.AddActiveStream(id1) + Expect(framer.PopStreamFrames(1000)).To(HaveLen(1)) + }) + + It("returns retransmission frames before normal frames", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) + framer.AddActiveStream(id1) + f1 := &wire.StreamFrame{Data: []byte("foobar")} + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, false) + framer.AddFrameForRetransmission(retransmittedFrame1) + fs := framer.PopStreamFrames(1000) + Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1, f1})) + }) + + It("does not pop empty frames", func() { + fs := framer.PopStreamFrames(500) + Expect(fs).To(BeEmpty()) + }) + + It("pops frames that have the minimum size", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) + stream1.EXPECT().popStreamFrame(protocol.MinStreamFrameSize).Return(&wire.StreamFrame{Data: []byte("foobar")}, false) + framer.AddActiveStream(id1) + framer.PopStreamFrames(protocol.MinStreamFrameSize) + }) + + It("does not pop frames smaller than the mimimum size", func() { + // don't expect a call to PopStreamFrame() + framer.PopStreamFrames(protocol.MinStreamFrameSize - 1) + }) + + It("stops iterating when the remaining size is smaller than the minimum STREAM frame size", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) + // pop a frame such that the remaining size is one byte less than the minimum STREAM frame size + f := &wire.StreamFrame{ + StreamID: id1, + Data: bytes.Repeat([]byte("f"), int(500-protocol.MinStreamFrameSize)), + } + stream1.EXPECT().popStreamFrame(protocol.ByteCount(500)).Return(f, false) + framer.AddActiveStream(id1) + fs := framer.PopStreamFrames(500) + Expect(fs).To(Equal([]*wire.StreamFrame{f})) }) Context("splitting of frames", func() { @@ -212,139 +303,28 @@ var _ = Describe("Stream Framer", func() { }) It("splits a frame", func() { - setNoData(stream1) - setNoData(stream2) - framer.AddFrameForRetransmission(retransmittedFrame2) - origlen := retransmittedFrame2.DataLen() - fs := framer.PopStreamFrames(6) + frame := &wire.StreamFrame{Data: bytes.Repeat([]byte{0}, 600)} + framer.AddFrameForRetransmission(frame) + fs := framer.PopStreamFrames(500) Expect(fs).To(HaveLen(1)) - minLength, _ := fs[0].MinLength(framer.version) - Expect(minLength + fs[0].DataLen()).To(Equal(protocol.ByteCount(6))) - Expect(framer.retransmissionQueue[0].Data).To(HaveLen(int(origlen - fs[0].DataLen()))) + minLength := fs[0].MinLength(framer.version) + Expect(minLength + fs[0].DataLen()).To(Equal(protocol.ByteCount(500))) + Expect(framer.retransmissionQueue[0].Data).To(HaveLen(int(600 - fs[0].DataLen()))) Expect(framer.retransmissionQueue[0].Offset).To(Equal(fs[0].DataLen())) }) - It("never returns an empty stream frame", func() { - // this one frame will be split off from again and again in this test. Therefore, it has to be large enough (checked again at the end) - origFrame := &wire.StreamFrame{ - StreamID: 5, - Offset: 1, - FinBit: false, - Data: bytes.Repeat([]byte{'f'}, 30*30), - } - framer.AddFrameForRetransmission(origFrame) - - minFrameDataLen := protocol.MaxPacketSize - - for i := 0; i < 30; i++ { - frames, currentLen := framer.maybePopFramesForRetransmission(protocol.ByteCount(i)) - if len(frames) == 0 { - Expect(currentLen).To(BeZero()) - } else { - Expect(frames).To(HaveLen(1)) - Expect(currentLen).ToNot(BeZero()) - dataLen := frames[0].DataLen() - Expect(dataLen).ToNot(BeZero()) - if dataLen < minFrameDataLen { - minFrameDataLen = dataLen - } - } - } - Expect(framer.retransmissionQueue).To(HaveLen(1)) // check that origFrame was large enough for this test and didn't get used up completely - Expect(minFrameDataLen).To(Equal(protocol.ByteCount(1))) - }) - It("only removes a frame from the framer after returning all split parts", func() { - setNoData(stream1) - setNoData(stream2) - framer.AddFrameForRetransmission(retransmittedFrame2) - fs := framer.PopStreamFrames(6) + frameHeaderLen := protocol.ByteCount(4) + frame := &wire.StreamFrame{Data: bytes.Repeat([]byte{0}, int(501-frameHeaderLen))} + framer.AddFrameForRetransmission(frame) + fs := framer.PopStreamFrames(500) Expect(fs).To(HaveLen(1)) Expect(framer.retransmissionQueue).ToNot(BeEmpty()) - fs = framer.PopStreamFrames(1000) + fs = framer.PopStreamFrames(500) Expect(fs).To(HaveLen(1)) + Expect(fs[0].DataLen()).To(BeEquivalentTo(1)) Expect(framer.retransmissionQueue).To(BeEmpty()) }) }) - - Context("sending FINs", func() { - It("sends FINs when streams are closed", func() { - offset := protocol.ByteCount(42) - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return(nil, true) - stream1.EXPECT().GetWriteOffset().Return(offset) - setNoData(stream2) - - fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(1)) - Expect(fs[0].StreamID).To(Equal(stream1.StreamID())) - Expect(fs[0].Offset).To(Equal(offset)) - Expect(fs[0].FinBit).To(BeTrue()) - Expect(fs[0].Data).To(BeEmpty()) - }) - - It("bundles FINs with data", func() { - offset := protocol.ByteCount(42) - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), true) - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetWriteOffset().Return(offset) - setNoData(stream2) - - fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(1)) - Expect(fs[0].StreamID).To(Equal(stream1.StreamID())) - Expect(fs[0].Data).To(Equal([]byte("foobar"))) - Expect(fs[0].FinBit).To(BeTrue()) - }) - }) - }) - - Context("BLOCKED frames", func() { - It("Pop returns nil if no frame is queued", func() { - Expect(framer.PopBlockedFrame()).To(BeNil()) - }) - - It("queues and pops BLOCKED frames for individually blocked streams", func() { - connFC.EXPECT().IsBlocked() - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().IsFlowControlBlocked().Return(true) - setNoData(stream2) - frames := framer.PopStreamFrames(1000) - Expect(frames).To(HaveLen(1)) - f := framer.PopBlockedFrame() - Expect(f).To(BeAssignableToTypeOf(&wire.StreamBlockedFrame{})) - bf := f.(*wire.StreamBlockedFrame) - Expect(bf.StreamID).To(Equal(stream1.StreamID())) - Expect(framer.PopBlockedFrame()).To(BeNil()) - }) - - It("does not queue a stream-level BLOCKED frame after sending the FinBit frame", func() { - connFC.EXPECT().IsBlocked() - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo"), true) - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetWriteOffset() - setNoData(stream2) - frames := framer.PopStreamFrames(1000) - Expect(frames).To(HaveLen(1)) - Expect(frames[0].FinBit).To(BeTrue()) - Expect(frames[0].DataLen()).To(Equal(protocol.ByteCount(3))) - blockedFrame := framer.PopBlockedFrame() - Expect(blockedFrame).To(BeNil()) - }) - - It("queues and pops BLOCKED frames for connection blocked streams", func() { - connFC.EXPECT().IsBlocked().Return(true) - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo"), false) - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().IsFlowControlBlocked().Return(false) - setNoData(stream2) - framer.PopStreamFrames(1000) - f := framer.PopBlockedFrame() - Expect(f).To(BeAssignableToTypeOf(&wire.BlockedFrame{})) - Expect(framer.PopBlockedFrame()).To(BeNil()) - }) }) }) diff --git a/vendor/github.com/lucas-clemente/quic-go/stream_test.go b/vendor/github.com/lucas-clemente/quic-go/stream_test.go index ce2ab48..e35af10 100644 --- a/vendor/github.com/lucas-clemente/quic-go/stream_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/stream_test.go @@ -1,66 +1,45 @@ package quic import ( - "errors" "io" - "runtime" + "os" "strconv" "time" - "os" - - "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" - . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/onsi/gomega/gbytes" ) +// in the tests for the stream deadlines we set a deadline +// and wait to make an assertion when Read / Write was unblocked +// on the CIs, the timing is a lot less precise, so scale every duration by this factor +func scaleDuration(t time.Duration) time.Duration { + scaleFactor := 1 + if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set + scaleFactor = f + } + Expect(scaleFactor).ToNot(BeZero()) + return time.Duration(scaleFactor) * t +} + var _ = Describe("Stream", func() { const streamID protocol.StreamID = 1337 var ( str *stream strWithTimeout io.ReadWriter // str wrapped with gbytes.Timeout{Reader,Writer} - onDataCalled bool - - resetCalled bool - resetCalledForStream protocol.StreamID - resetCalledAtOffset protocol.ByteCount - - mockFC *mocks.MockStreamFlowController + mockFC *mocks.MockStreamFlowController + mockSender *MockStreamSender ) - // in the tests for the stream deadlines we set a deadline - // and wait to make an assertion when Read / Write was unblocked - // on the CIs, the timing is a lot less precise, so scale every duration by this factor - scaleDuration := func(t time.Duration) time.Duration { - scaleFactor := 1 - if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set - scaleFactor = f - } - Expect(scaleFactor).ToNot(BeZero()) - return time.Duration(scaleFactor) * t - } - - onData := func() { - onDataCalled = true - } - - onReset := func(id protocol.StreamID, offset protocol.ByteCount) { - resetCalled = true - resetCalledForStream = id - resetCalledAtOffset = offset - } - BeforeEach(func() { - onDataCalled = false - resetCalled = false + mockSender = NewMockStreamSender(mockCtrl) mockFC = mocks.NewMockStreamFlowController(mockCtrl) - str = newStream(streamID, onData, onReset, mockFC, protocol.VersionWhatever) + str = newStream(streamID, mockSender, mockFC, protocol.VersionWhatever) timeout := scaleDuration(250 * time.Millisecond) strWithTimeout = struct { @@ -76,1027 +55,155 @@ var _ = Describe("Stream", func() { Expect(str.StreamID()).To(Equal(protocol.StreamID(1337))) }) - Context("reading", func() { - It("reads a single StreamFrame", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - frame := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - } - err := str.AddStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - }) - - It("reads a single StreamFrame in multiple goes", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - frame := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - } - err := str.AddStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 2) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(2)) - Expect(b).To(Equal([]byte{0xDE, 0xAD})) - n, err = strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(2)) - Expect(b).To(Equal([]byte{0xBE, 0xEF})) - }) - - It("reads all data available", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - frame1 := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD}, - } - frame2 := wire.StreamFrame{ - Offset: 2, - Data: []byte{0xBE, 0xEF}, - } - err := str.AddStreamFrame(&frame1) - Expect(err).ToNot(HaveOccurred()) - err = str.AddStreamFrame(&frame2) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 6) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF, 0x00, 0x00})) - }) - - It("assembles multiple StreamFrames", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - frame1 := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD}, - } - frame2 := wire.StreamFrame{ - Offset: 2, - Data: []byte{0xBE, 0xEF}, - } - err := str.AddStreamFrame(&frame1) - Expect(err).ToNot(HaveOccurred()) - err = str.AddStreamFrame(&frame2) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - }) - - It("waits until data is available", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - go func() { - defer GinkgoRecover() - frame := wire.StreamFrame{Data: []byte{0xDE, 0xAD}} - time.Sleep(10 * time.Millisecond) - err := str.AddStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - }() - b := make([]byte, 2) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(2)) - }) - - It("handles StreamFrames in wrong order", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - frame1 := wire.StreamFrame{ - Offset: 2, - Data: []byte{0xBE, 0xEF}, - } - frame2 := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD}, - } - err := str.AddStreamFrame(&frame1) - Expect(err).ToNot(HaveOccurred()) - err = str.AddStreamFrame(&frame2) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - }) - - It("ignores duplicate StreamFrames", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - frame1 := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD}, - } - frame2 := wire.StreamFrame{ - Offset: 0, - Data: []byte{0x13, 0x37}, - } - frame3 := wire.StreamFrame{ - Offset: 2, - Data: []byte{0xBE, 0xEF}, - } - err := str.AddStreamFrame(&frame1) - Expect(err).ToNot(HaveOccurred()) - err = str.AddStreamFrame(&frame2) - Expect(err).ToNot(HaveOccurred()) - err = str.AddStreamFrame(&frame3) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - }) - - It("doesn't rejects a StreamFrames with an overlapping data range", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - frame1 := wire.StreamFrame{ - Offset: 0, - Data: []byte("foob"), - } - frame2 := wire.StreamFrame{ - Offset: 2, - Data: []byte("obar"), - } - err := str.AddStreamFrame(&frame1) - Expect(err).ToNot(HaveOccurred()) - err = str.AddStreamFrame(&frame2) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 6) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - Expect(b).To(Equal([]byte("foobar"))) - }) - - It("calls onData", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - frame := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - } - str.AddStreamFrame(&frame) - b := make([]byte, 4) - _, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(onDataCalled).To(BeTrue()) - }) - - It("sets the read offset", func() { - str.SetReadOffset(0x42) - Expect(str.readOffset).To(Equal(protocol.ByteCount(0x42))) - Expect(str.frameQueue.readPosition).To(Equal(protocol.ByteCount(0x42))) - }) - - Context("deadlines", func() { - It("the deadline error has the right net.Error properties", func() { - Expect(errDeadline.Temporary()).To(BeTrue()) - Expect(errDeadline.Timeout()).To(BeTrue()) + // need some stream cancelation tests here, since gQUIC doesn't cleanly separate the two stream halves + Context("stream cancelations", func() { + Context("for gQUIC", func() { + BeforeEach(func() { + str.version = versionGQUICFrames + str.receiveStream.version = versionGQUICFrames + str.sendStream.version = versionGQUICFrames }) - It("returns an error when Read is called after the deadline", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false).AnyTimes() - f := &wire.StreamFrame{Data: []byte("foobar")} - err := str.AddStreamFrame(f) - Expect(err).ToNot(HaveOccurred()) - str.SetReadDeadline(time.Now().Add(-time.Second)) - b := make([]byte, 6) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - }) - - It("unblocks after the deadline", func() { - deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) - str.SetReadDeadline(deadline) - b := make([]byte, 6) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(10*time.Millisecond))) - }) - - It("doesn't unblock if the deadline is changed before the first one expires", func() { - deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond)) - deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond)) - str.SetReadDeadline(deadline1) - go func() { - defer GinkgoRecover() - time.Sleep(scaleDuration(20 * time.Millisecond)) - str.SetReadDeadline(deadline2) - // make sure that this was actually execute before the deadline expires - Expect(time.Now()).To(BeTemporally("<", deadline1)) - }() - runtime.Gosched() - b := make([]byte, 10) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) - }) - - It("unblocks earlier, when a new deadline is set", func() { - deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond)) - deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond)) - go func() { - defer GinkgoRecover() - time.Sleep(scaleDuration(10 * time.Millisecond)) - str.SetReadDeadline(deadline2) - // make sure that this was actually execute before the deadline expires - Expect(time.Now()).To(BeTemporally("<", deadline2)) - }() - str.SetReadDeadline(deadline1) - runtime.Gosched() - b := make([]byte, 10) - _, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(errDeadline)) - Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(25*time.Millisecond))) - }) - - It("sets a read deadline, when SetDeadline is called", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false).AnyTimes() - f := &wire.StreamFrame{Data: []byte("foobar")} - err := str.AddStreamFrame(f) - Expect(err).ToNot(HaveOccurred()) - str.SetDeadline(time.Now().Add(-time.Second)) - b := make([]byte, 6) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - }) - }) - - Context("closing", func() { - Context("with FIN bit", func() { - It("returns EOFs", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - frame := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - FinBit: true, - } - str.AddStreamFrame(&frame) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(io.EOF)) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - n, err = strWithTimeout.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(io.EOF)) + It("unblocks Write when receiving a RST_STREAM frame with non-zero error code", func() { + mockSender.EXPECT().onHasStreamData(streamID) + mockSender.EXPECT().queueControlFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 1000, + ErrorCode: errorCodeStoppingGQUIC, }) - - It("handles out-of-order frames", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - frame1 := wire.StreamFrame{ - Offset: 2, - Data: []byte{0xBE, 0xEF}, - FinBit: true, - } - frame2 := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD}, - } - err := str.AddStreamFrame(&frame1) - Expect(err).ToNot(HaveOccurred()) - err = str.AddStreamFrame(&frame2) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(io.EOF)) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - n, err = strWithTimeout.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(io.EOF)) - }) - - It("returns EOFs with partial read", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), true) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - frame := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD}, - FinBit: true, - } - err := str.AddStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(io.EOF)) - Expect(n).To(Equal(2)) - Expect(b[:n]).To(Equal([]byte{0xDE, 0xAD})) - }) - - It("handles immediate FINs", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) - frame := wire.StreamFrame{ - Offset: 0, - Data: []byte{}, - FinBit: true, - } - err := str.AddStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(io.EOF)) - }) - }) - - Context("when CloseRemote is called", func() { - It("closes", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) - str.CloseRemote(0) - b := make([]byte, 8) - n, err := strWithTimeout.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(io.EOF)) - }) - - It("doesn't cancel the context", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) - str.CloseRemote(0) - Expect(str.Context().Done()).ToNot(BeClosed()) - }) - }) - }) - - Context("cancelling the stream", func() { - testErr := errors.New("test error") - - It("immediately returns all reads", func() { - done := make(chan struct{}) - b := make([]byte, 4) - go func() { - defer GinkgoRecover() - n, err := strWithTimeout.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(testErr)) - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - str.Cancel(testErr) - Eventually(done).Should(BeClosed()) - }) - - It("errors for all following reads", func() { - str.Cancel(testErr) - b := make([]byte, 1) - n, err := strWithTimeout.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(testErr)) - }) - - It("cancels the context", func() { - Expect(str.Context().Done()).ToNot(BeClosed()) - str.Cancel(testErr) - Expect(str.Context().Done()).To(BeClosed()) - }) - }) - }) - - Context("resetting", func() { - testErr := errors.New("testErr") - - Context("reset by the peer", func() { - It("continues reading after receiving a remote error", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true) - frame := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, + mockSender.EXPECT().onStreamCompleted(streamID) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true) + str.writeOffset = 1000 + f := &wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 6, + ErrorCode: 123, } - str.AddStreamFrame(&frame) - str.RegisterRemoteError(testErr, 10) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - }) - - It("reads a delayed StreamFrame that arrives after receiving a remote error", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - str.RegisterRemoteError(testErr, 4) - frame := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - } - err := str.AddStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - }) - - It("returns the error if reading past the offset of the frame received", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), true) - frame := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - } - str.AddStreamFrame(&frame) - str.RegisterRemoteError(testErr, 8) - b := make([]byte, 10) - n, err := strWithTimeout.Read(b) - Expect(b[0:4]).To(Equal(frame.Data)) - Expect(err).To(MatchError(testErr)) - Expect(n).To(Equal(4)) - }) - - It("returns an EOF when reading past the offset, if the stream received a finbit", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), true) - frame := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - FinBit: true, - } - str.AddStreamFrame(&frame) - str.RegisterRemoteError(testErr, 8) - b := make([]byte, 10) - n, err := strWithTimeout.Read(b) - Expect(b[:4]).To(Equal(frame.Data)) - Expect(err).To(MatchError(io.EOF)) - Expect(n).To(Equal(4)) - }) - - It("continues reading in small chunks after receiving a remote error", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) - frame := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - FinBit: true, - } - str.AddStreamFrame(&frame) - str.RegisterRemoteError(testErr, 4) - b := make([]byte, 3) - _, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(b).To(Equal([]byte{0xde, 0xad, 0xbe})) - b = make([]byte, 3) - n, err := strWithTimeout.Read(b) - Expect(err).To(MatchError(io.EOF)) - Expect(b[:1]).To(Equal([]byte{0xef})) - Expect(n).To(Equal(1)) - }) - - It("doesn't inform the flow controller about bytes read after receiving the remote error", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true) - // No AddBytesRead() - frame := wire.StreamFrame{ - Offset: 0, - StreamID: 5, - Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - } - str.AddStreamFrame(&frame) - str.RegisterRemoteError(testErr, 10) - b := make([]byte, 3) - _, err := strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - }) - - It("stops writing after receiving a remote error", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - n, err := strWithTimeout.Write([]byte("foobar")) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(testErr)) - close(done) - }() - str.RegisterRemoteError(testErr, 10) - Eventually(done).Should(BeClosed()) - }) - - It("returns how much was written when recieving a remote error", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true) - mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(4)) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - n, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).To(MatchError(testErr)) - Expect(n).To(Equal(4)) - close(done) - }() - - Eventually(func() []byte { data, _ := str.GetDataForWriting(4); return data }).ShouldNot(BeEmpty()) - str.RegisterRemoteError(testErr, 10) - Eventually(done).Should(BeClosed()) - }) - - It("calls onReset when receiving a remote error", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) - done := make(chan struct{}) - str.writeOffset = 0x1000 - go func() { - _, _ = strWithTimeout.Write([]byte("foobar")) - close(done) - }() - str.RegisterRemoteError(testErr, 0) - Expect(resetCalled).To(BeTrue()) - Expect(resetCalledForStream).To(Equal(protocol.StreamID(1337))) - Expect(resetCalledAtOffset).To(Equal(protocol.ByteCount(0x1000))) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't call onReset if it already sent a FIN", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) - str.Close() - _, sentFin := str.GetDataForWriting(1000) - Expect(sentFin).To(BeTrue()) - str.RegisterRemoteError(testErr, 0) - Expect(resetCalled).To(BeFalse()) - }) - - It("doesn't call onReset if the stream was reset locally before", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) - str.Reset(testErr) - Expect(resetCalled).To(BeTrue()) - resetCalled = false - str.RegisterRemoteError(testErr, 0) - Expect(resetCalled).To(BeFalse()) - }) - - It("doesn't call onReset twice, when it gets two remote errors", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) - str.RegisterRemoteError(testErr, 0) - Expect(resetCalled).To(BeTrue()) - resetCalled = false - str.RegisterRemoteError(testErr, 0) - Expect(resetCalled).To(BeFalse()) - }) - }) - - Context("reset locally", func() { - It("stops writing", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - n, err := strWithTimeout.Write([]byte("foobar")) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(testErr)) - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - str.Reset(testErr) - Expect(str.GetDataForWriting(6)).To(BeNil()) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't allow further writes", func() { - str.Reset(testErr) - n, err := strWithTimeout.Write([]byte("foobar")) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(testErr)) - Expect(str.GetDataForWriting(6)).To(BeNil()) - }) - - It("stops reading", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - b := make([]byte, 4) - n, err := strWithTimeout.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(testErr)) - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - str.Reset(testErr) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't allow further reads", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false) - str.AddStreamFrame(&wire.StreamFrame{ - Data: []byte("foobar"), - }) - str.Reset(testErr) - b := make([]byte, 6) - n, err := strWithTimeout.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(testErr)) - }) - - It("calls onReset", func() { - str.writeOffset = 0x1000 - str.Reset(testErr) - Expect(resetCalled).To(BeTrue()) - Expect(resetCalledForStream).To(Equal(protocol.StreamID(1337))) - Expect(resetCalledAtOffset).To(Equal(protocol.ByteCount(0x1000))) - }) - - It("doesn't call onReset if it already sent a FIN", func() { - str.Close() - _, sentFin := str.GetDataForWriting(1000) - Expect(sentFin).To(BeTrue()) - str.Reset(testErr) - Expect(resetCalled).To(BeFalse()) - }) - - It("doesn't call onReset if the stream was reset remotely before", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) - str.RegisterRemoteError(testErr, 0) - Expect(resetCalled).To(BeTrue()) - resetCalled = false - str.Reset(testErr) - Expect(resetCalled).To(BeFalse()) - }) - - It("doesn't call onReset twice", func() { - str.Reset(testErr) - Expect(resetCalled).To(BeTrue()) - resetCalled = false - str.Reset(testErr) - Expect(resetCalled).To(BeFalse()) - }) - - It("cancels the context", func() { - Expect(str.Context().Done()).ToNot(BeClosed()) - str.Reset(testErr) - Expect(str.Context().Done()).To(BeClosed()) - }) - }) - }) - - Context("writing", func() { - It("writes and gets all data at once", func() { - mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - n, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - close(done) - }() - Eventually(func() []byte { - str.mutex.Lock() - defer str.mutex.Unlock() - return str.dataForWriting - }).Should(Equal([]byte("foobar"))) - Consistently(done).ShouldNot(BeClosed()) - Expect(onDataCalled).To(BeTrue()) - Expect(str.HasDataForWriting()).To(BeTrue()) - data, sendFin := str.GetDataForWriting(1000) - Expect(data).To(Equal([]byte("foobar"))) - Expect(sendFin).To(BeFalse()) - Expect(str.writeOffset).To(Equal(protocol.ByteCount(6))) - Expect(str.dataForWriting).To(BeNil()) - Eventually(done).Should(BeClosed()) - }) - - It("writes and gets data in two turns", func() { - mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)).Times(2) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - n, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - close(done) - }() - Eventually(func() []byte { - str.mutex.Lock() - defer str.mutex.Unlock() - return str.dataForWriting - }).Should(Equal([]byte("foobar"))) - Consistently(done).ShouldNot(BeClosed()) - Expect(str.HasDataForWriting()).To(BeTrue()) - data, sendFin := str.GetDataForWriting(3) - Expect(data).To(Equal([]byte("foo"))) - Expect(sendFin).To(BeFalse()) - Expect(str.writeOffset).To(Equal(protocol.ByteCount(3))) - Expect(str.dataForWriting).ToNot(BeNil()) - Expect(str.HasDataForWriting()).To(BeTrue()) - data, sendFin = str.GetDataForWriting(3) - Expect(data).To(Equal([]byte("bar"))) - Expect(sendFin).To(BeFalse()) - Expect(str.writeOffset).To(Equal(protocol.ByteCount(6))) - Expect(str.dataForWriting).To(BeNil()) - Expect(str.HasDataForWriting()).To(BeFalse()) - Eventually(done).Should(BeClosed()) - }) - - It("getDataForWriting returns nil if no data is available", func() { - Expect(str.GetDataForWriting(1000)).To(BeNil()) - }) - - It("copies the slice while writing", func() { - mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) - s := []byte("foo") - go func() { - defer GinkgoRecover() - n, err := strWithTimeout.Write(s) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - }() - Eventually(func() bool { return str.HasDataForWriting() }).Should(BeTrue()) - s[0] = 'v' - Expect(str.GetDataForWriting(3)).To(Equal([]byte("foo"))) - }) - - It("returns when given a nil input", func() { - n, err := strWithTimeout.Write(nil) - Expect(n).To(BeZero()) - Expect(err).ToNot(HaveOccurred()) - }) - - It("returns when given an empty slice", func() { - n, err := strWithTimeout.Write([]byte("")) - Expect(n).To(BeZero()) - Expect(err).ToNot(HaveOccurred()) - }) - - Context("deadlines", func() { - It("returns an error when Write is called after the deadline", func() { - str.SetWriteDeadline(time.Now().Add(-time.Second)) - n, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - }) - - It("unblocks after the deadline", func() { - deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) - str.SetWriteDeadline(deadline) - n, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond))) - }) - - It("doesn't unblock if the deadline is changed before the first one expires", func() { - deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond)) - deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond)) - str.SetWriteDeadline(deadline1) - go func() { - defer GinkgoRecover() - time.Sleep(scaleDuration(20 * time.Millisecond)) - str.SetWriteDeadline(deadline2) - // make sure that this was actually execute before the deadline expires - Expect(time.Now()).To(BeTemporally("<", deadline1)) - }() - runtime.Gosched() - n, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) - }) - - It("unblocks earlier, when a new deadline is set", func() { - deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond)) - deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond)) - go func() { - defer GinkgoRecover() - time.Sleep(scaleDuration(10 * time.Millisecond)) - str.SetWriteDeadline(deadline2) - // make sure that this was actually execute before the deadline expires - Expect(time.Now()).To(BeTemporally("<", deadline2)) - }() - str.SetWriteDeadline(deadline1) - runtime.Gosched() - _, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).To(MatchError(errDeadline)) - Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) - }) - - It("sets a read deadline, when SetDeadline is called", func() { - str.SetDeadline(time.Now().Add(-time.Second)) - n, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).To(MatchError(errDeadline)) - Expect(n).To(BeZero()) - }) - }) - - Context("closing", func() { - It("doesn't allow writes after it has been closed", func() { - str.Close() - _, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).To(MatchError("write on closed stream 1337")) - }) - - It("allows FIN", func() { - str.Close() - Expect(str.HasDataForWriting()).To(BeTrue()) - data, sendFin := str.GetDataForWriting(1000) - Expect(data).To(BeEmpty()) - Expect(sendFin).To(BeTrue()) - }) - - It("does not allow FIN when there's still data", func() { - mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2) - mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2) - str.dataForWriting = []byte("foobar") - str.Close() - Expect(str.HasDataForWriting()).To(BeTrue()) - data, sendFin := str.GetDataForWriting(3) - Expect(data).To(Equal([]byte("foo"))) - Expect(sendFin).To(BeFalse()) - data, sendFin = str.GetDataForWriting(3) - Expect(data).To(Equal([]byte("bar"))) - Expect(sendFin).To(BeTrue()) - }) - - It("does not allow FIN when the stream is not closed", func() { - Expect(str.HasDataForWriting()).To(BeFalse()) - _, sendFin := str.GetDataForWriting(3) - Expect(sendFin).To(BeFalse()) - }) - - It("does not allow FIN after an error", func() { - str.Cancel(errors.New("test")) - Expect(str.HasDataForWriting()).To(BeFalse()) - data, sendFin := str.GetDataForWriting(1000) - Expect(data).To(BeEmpty()) - Expect(sendFin).To(BeFalse()) - }) - - It("does not allow FIN twice", func() { - str.Close() - Expect(str.HasDataForWriting()).To(BeTrue()) - data, sendFin := str.GetDataForWriting(1000) - Expect(data).To(BeEmpty()) - Expect(sendFin).To(BeTrue()) - Expect(str.HasDataForWriting()).To(BeFalse()) - data, sendFin = str.GetDataForWriting(1000) - Expect(data).To(BeEmpty()) - Expect(sendFin).To(BeFalse()) - }) - }) - - Context("cancelling", func() { - testErr := errors.New("test") - - It("returns errors when the stream is cancelled", func() { - str.Cancel(testErr) - n, err := strWithTimeout.Write([]byte("foo")) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(testErr)) - }) - - It("doesn't get data for writing if an error occurred", func() { + writeReturned := make(chan struct{}) go func() { defer GinkgoRecover() _, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).To(MatchError(testErr)) + Expect(err).To(MatchError("Stream 1337 was reset with error code 123")) + Expect(err).To(BeAssignableToTypeOf(streamCanceledError{})) + Expect(err.(streamCanceledError).Canceled()).To(BeTrue()) + Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(123))) + close(writeReturned) }() - Eventually(func() []byte { return str.dataForWriting }).ShouldNot(BeNil()) - Expect(str.HasDataForWriting()).To(BeTrue()) - str.Cancel(testErr) - data, sendFin := str.GetDataForWriting(6) - Expect(data).To(BeNil()) - Expect(sendFin).To(BeFalse()) - Expect(str.HasDataForWriting()).To(BeFalse()) + Consistently(writeReturned).ShouldNot(BeClosed()) + err := str.handleRstStreamFrame(f) + Expect(err).ToNot(HaveOccurred()) + Eventually(writeReturned).Should(BeClosed()) + }) + + It("unblocks Write when receiving a RST_STREAM frame with error code 0", func() { + mockSender.EXPECT().onHasStreamData(streamID) + mockSender.EXPECT().queueControlFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 1000, + ErrorCode: errorCodeStoppingGQUIC, + }) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true) + str.writeOffset = 1000 + f := &wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 6, + ErrorCode: 0, + } + writeReturned := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).To(MatchError("Stream 1337 was reset with error code 0")) + Expect(err).To(BeAssignableToTypeOf(streamCanceledError{})) + Expect(err.(streamCanceledError).Canceled()).To(BeTrue()) + Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(0))) + close(writeReturned) + }() + Consistently(writeReturned).ShouldNot(BeClosed()) + err := str.handleRstStreamFrame(f) + Expect(err).ToNot(HaveOccurred()) + Eventually(writeReturned).Should(BeClosed()) + }) + + It("sends a RST_STREAM with error code 0, after the stream is closed", func() { + str.version = versionGQUICFrames + mockSender.EXPECT().onHasStreamData(streamID).Times(2) // once for the Write, once for the Close + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).AnyTimes() + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) + mockFC.EXPECT().IsBlocked() + err := str.CancelRead(1234) + Expect(err).ToNot(HaveOccurred()) + writeReturned := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + close(writeReturned) + }() + Eventually(func() *wire.StreamFrame { + frame, _ := str.popStreamFrame(1000) + return frame + }).ShouldNot(BeNil()) + Eventually(writeReturned).Should(BeClosed()) + mockSender.EXPECT().queueControlFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 6, + ErrorCode: 0, + }) + Expect(str.Close()).To(Succeed()) + }) + }) + + Context("for IETF QUIC", func() { + It("doesn't queue a RST_STREAM after closing the stream", func() { // this is what it does for gQUIC + mockSender.EXPECT().queueControlFrame(&wire.StopSendingFrame{ + StreamID: streamID, + ErrorCode: 1234, + }) + mockSender.EXPECT().onHasStreamData(streamID) + err := str.CancelRead(1234) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) }) }) }) - It("errors when a StreamFrames causes a flow control violation", func() { - testErr := errors.New("flow control violation") - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), false).Return(testErr) - frame := wire.StreamFrame{ - Offset: 2, - Data: []byte("foobar"), - } - err := str.AddStreamFrame(&frame) - Expect(err).To(MatchError(testErr)) - }) + Context("deadlines", func() { + It("sets a write deadline, when SetDeadline is called", func() { + str.SetDeadline(time.Now().Add(-time.Second)) + n, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + }) - Context("closing", func() { - testErr := errors.New("testErr") - - finishReading := func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) - err := str.AddStreamFrame(&wire.StreamFrame{FinBit: true}) + It("sets a read deadline, when SetDeadline is called", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false).AnyTimes() + f := &wire.StreamFrame{Data: []byte("foobar")} + err := str.handleStreamFrame(f) Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 100) - _, err = strWithTimeout.Read(b) - Expect(err).To(MatchError(io.EOF)) - } - - It("is finished after it is canceled", func() { - str.Cancel(testErr) - Expect(str.Finished()).To(BeTrue()) - }) - - It("is not finished if it is only closed for writing", func() { - str.Close() - _, sentFin := str.GetDataForWriting(1000) - Expect(sentFin).To(BeTrue()) - Expect(str.Finished()).To(BeFalse()) - }) - - It("cancels the context after it is closed", func() { - Expect(str.Context().Done()).ToNot(BeClosed()) - str.Close() - Expect(str.Context().Done()).To(BeClosed()) - }) - - It("is not finished if it is only closed for reading", func() { - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) - finishReading() - Expect(str.Finished()).To(BeFalse()) - }) - - It("is finished after receiving a RST and sending one", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) - // this directly sends a rst - str.RegisterRemoteError(testErr, 0) - Expect(str.rstSent.Get()).To(BeTrue()) - Expect(str.Finished()).To(BeTrue()) - }) - - It("cancels the context after receiving a RST", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) - Expect(str.Context().Done()).ToNot(BeClosed()) - str.RegisterRemoteError(testErr, 0) - Expect(str.Context().Done()).To(BeClosed()) - }) - - It("is finished after being locally reset and receiving a RST in response", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(13), true) - str.Reset(testErr) - Expect(str.Finished()).To(BeFalse()) - str.RegisterRemoteError(testErr, 13) - Expect(str.Finished()).To(BeTrue()) - }) - - It("is finished after finishing writing and receiving a RST", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(13), true) - str.Close() - _, sentFin := str.GetDataForWriting(1000) - Expect(sentFin).To(BeTrue()) - str.RegisterRemoteError(testErr, 13) - Expect(str.Finished()).To(BeTrue()) - }) - - It("is finished after finishing reading and being locally reset", func() { - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) - finishReading() - Expect(str.Finished()).To(BeFalse()) - str.Reset(testErr) - Expect(str.Finished()).To(BeTrue()) + str.SetDeadline(time.Now().Add(-time.Second)) + b := make([]byte, 6) + n, err := strWithTimeout.Read(b) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) }) }) - Context("flow control", func() { - It("says when it's flow control blocked", func() { - mockFC.EXPECT().IsBlocked().Return(false) - Expect(str.IsFlowControlBlocked()).To(BeFalse()) - mockFC.EXPECT().IsBlocked().Return(true) - Expect(str.IsFlowControlBlocked()).To(BeTrue()) + Context("completing", func() { + It("is not completed when only the receive side is completed", func() { + // don't EXPECT a call to mockSender.onStreamCompleted() + str.receiveStream.sender.onStreamCompleted(streamID) }) - It("updates the flow control window", func() { - mockFC.EXPECT().UpdateSendWindow(protocol.ByteCount(0x42)) - str.UpdateSendWindow(0x42) + It("is not completed when only the send side is completed", func() { + // don't EXPECT a call to mockSender.onStreamCompleted() + str.sendStream.sender.onStreamCompleted(streamID) }) - It("gets a window update", func() { - mockFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x100)) - Expect(str.GetWindowUpdate()).To(Equal(protocol.ByteCount(0x100))) + It("is completed when both sides are completed", func() { + mockSender.EXPECT().onStreamCompleted(streamID) + str.sendStream.sender.onStreamCompleted(streamID) + str.receiveStream.sender.onStreamCompleted(streamID) }) }) }) diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map.go b/vendor/github.com/lucas-clemente/quic-go/streams_map.go index e162205..bf9374b 100644 --- a/vendor/github.com/lucas-clemente/quic-go/streams_map.go +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map.go @@ -5,8 +5,9 @@ import ( "fmt" "sync" + "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" ) @@ -16,11 +17,8 @@ type streamsMap struct { perspective protocol.Perspective streams map[protocol.StreamID]streamI - // needed for round-robin scheduling - openStreams []protocol.StreamID - roundRobinIndex int - nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() + nextStreamToOpen protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() highestStreamOpenedByPeer protocol.StreamID nextStreamOrErrCond sync.Cond openStreamOrErrCond sync.Cond @@ -29,47 +27,32 @@ type streamsMap struct { nextStreamToAccept protocol.StreamID newStream newStreamLambda - - numOutgoingStreams uint32 - numIncomingStreams uint32 - maxIncomingStreams uint32 - maxOutgoingStreams uint32 } -type streamLambda func(streamI) (bool, error) +var _ streamManager = &streamsMap{} + type newStreamLambda func(protocol.StreamID) streamI var errMapAccess = errors.New("streamsMap: Error accessing the streams map") -func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver protocol.VersionNumber) *streamsMap { - // add some tolerance to the maximum incoming streams value - maxStreams := uint32(protocol.MaxIncomingStreams) - maxIncomingStreams := utils.MaxUint32( - maxStreams+protocol.MaxStreamsMinimumIncrement, - uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)), - ) +func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective) streamManager { sm := streamsMap{ - perspective: pers, - streams: make(map[protocol.StreamID]streamI), - openStreams: make([]protocol.StreamID, 0), - newStream: newStream, - maxIncomingStreams: maxIncomingStreams, + perspective: pers, + streams: make(map[protocol.StreamID]streamI), + newStream: newStream, } sm.nextStreamOrErrCond.L = &sm.mutex sm.openStreamOrErrCond.L = &sm.mutex - nextOddStream := protocol.StreamID(1) - if ver.CryptoStreamID() == protocol.StreamID(1) { - nextOddStream = 3 - } - if pers == protocol.PerspectiveClient { - sm.nextStream = nextOddStream - sm.nextStreamToAccept = 2 + nextClientInitiatedStream := protocol.StreamID(1) + nextServerInitiatedStream := protocol.StreamID(2) + if pers == protocol.PerspectiveServer { + sm.nextStreamToOpen = nextServerInitiatedStream + sm.nextStreamToAccept = nextClientInitiatedStream } else { - sm.nextStream = 2 - sm.nextStreamToAccept = nextOddStream + sm.nextStreamToOpen = nextClientInitiatedStream + sm.nextStreamToAccept = nextServerInitiatedStream } - return &sm } @@ -81,6 +64,23 @@ func (m *streamsMap) streamInitiatedBy(id protocol.StreamID) protocol.Perspectiv return protocol.PerspectiveClient } +func (m *streamsMap) nextStreamID(id protocol.StreamID) protocol.StreamID { + if m.perspective == protocol.PerspectiveServer && id == 0 { + return 1 + } + return id + 2 +} + +func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { + // every bidirectional stream is also a receive stream + return m.GetOrOpenStream(id) +} + +func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { + // every bidirectional stream is also a send stream + return m.GetOrOpenStream(id) +} + // GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. // Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used. func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { @@ -88,7 +88,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { s, ok := m.streams[id] m.mutex.RUnlock() if ok { - return s, nil // s may be nil + return s, nil } // ... we don't have an existing stream @@ -101,7 +101,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { } if m.perspective == m.streamInitiatedBy(id) { - if id <= m.nextStream { // this is a stream opened by us. Must have been closed already + if id <= m.nextStreamToOpen { // this is a stream opened by us. Must have been closed already return nil, nil } return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id)) @@ -110,14 +110,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { return nil, nil } - // sid is the next stream that will be opened - sid := m.highestStreamOpenedByPeer + 2 - // if there is no stream opened yet, and this is the server, stream 1 should be openend - if sid == 2 && m.perspective == protocol.PerspectiveServer { - sid = 1 - } - - for ; sid <= id; sid += 2 { + for sid := m.nextStreamID(m.highestStreamOpenedByPeer); sid <= id; sid = m.nextStreamID(sid) { if _, err := m.openRemoteStream(sid); err != nil { return nil, err } @@ -128,38 +121,26 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { } func (m *streamsMap) openRemoteStream(id protocol.StreamID) (streamI, error) { - if m.numIncomingStreams >= m.maxIncomingStreams { - return nil, qerr.TooManyOpenStreams - } if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer { return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer)) } - - m.numIncomingStreams++ if id > m.highestStreamOpenedByPeer { m.highestStreamOpenedByPeer = id } - s := m.newStream(id) m.putStream(s) return s, nil } func (m *streamsMap) openStreamImpl() (streamI, error) { - id := m.nextStream - if m.numOutgoingStreams >= m.maxOutgoingStreams { - return nil, qerr.TooManyOpenStreams - } - - m.numOutgoingStreams++ - m.nextStream += 2 - s := m.newStream(id) + s := m.newStream(m.nextStreamToOpen) m.putStream(s) + m.nextStreamToOpen = m.nextStreamID(m.nextStreamToOpen) return s, nil } // OpenStream opens the next available stream -func (m *streamsMap) OpenStream() (streamI, error) { +func (m *streamsMap) OpenStream() (Stream, error) { m.mutex.Lock() defer m.mutex.Unlock() @@ -169,7 +150,7 @@ func (m *streamsMap) OpenStream() (streamI, error) { return m.openStreamImpl() } -func (m *streamsMap) OpenStreamSync() (streamI, error) { +func (m *streamsMap) OpenStreamSync() (Stream, error) { m.mutex.Lock() defer m.mutex.Unlock() @@ -190,7 +171,7 @@ func (m *streamsMap) OpenStreamSync() (streamI, error) { // AcceptStream returns the next stream opened by the peer // it blocks until a new stream is opened -func (m *streamsMap) AcceptStream() (streamI, error) { +func (m *streamsMap) AcceptStream() (Stream, error) { m.mutex.Lock() defer m.mutex.Unlock() var str streamI @@ -209,104 +190,24 @@ func (m *streamsMap) AcceptStream() (streamI, error) { return str, nil } -func (m *streamsMap) DeleteClosedStreams() error { +func (m *streamsMap) DeleteStream(id protocol.StreamID) error { m.mutex.Lock() defer m.mutex.Unlock() - - var numDeletedStreams int - // for every closed stream, the streamID is replaced by 0 in the openStreams slice - for i, streamID := range m.openStreams { - str, ok := m.streams[streamID] - if !ok { - return errMapAccess - } - if !str.Finished() { - continue - } - numDeletedStreams++ - m.openStreams[i] = 0 - if m.streamInitiatedBy(streamID) == m.perspective { - m.numOutgoingStreams-- - } else { - m.numIncomingStreams-- - } - delete(m.streams, streamID) + _, ok := m.streams[id] + if !ok { + return errMapAccess } - - if numDeletedStreams == 0 { - return nil - } - - // remove all 0s (representing closed streams) from the openStreams slice - // and adjust the roundRobinIndex - var j int - for i, id := range m.openStreams { - if i != j { - m.openStreams[j] = m.openStreams[i] - } - if id != 0 { - j++ - } else if j < m.roundRobinIndex { - m.roundRobinIndex-- - } - } - m.openStreams = m.openStreams[:len(m.openStreams)-numDeletedStreams] + delete(m.streams, id) m.openStreamOrErrCond.Signal() return nil } -// RoundRobinIterate executes the streamLambda for every open stream, until the streamLambda returns false -// It uses a round-robin-like scheduling to ensure that every stream is considered fairly -// It prioritizes the the header-stream (StreamID 3) -func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - numStreams := len(m.streams) - startIndex := m.roundRobinIndex - - for i := 0; i < numStreams; i++ { - streamID := m.openStreams[(i+startIndex)%numStreams] - cont, err := m.iterateFunc(streamID, fn) - if err != nil { - return err - } - m.roundRobinIndex = (m.roundRobinIndex + 1) % numStreams - if !cont { - break - } - } - return nil -} - -// Range executes a callback for all streams, in pseudo-random order -func (m *streamsMap) Range(cb func(s streamI)) { - m.mutex.RLock() - defer m.mutex.RUnlock() - - for _, s := range m.streams { - if s != nil { - cb(s) - } - } -} - -func (m *streamsMap) iterateFunc(streamID protocol.StreamID, fn streamLambda) (bool, error) { - str, ok := m.streams[streamID] - if !ok { - return true, errMapAccess - } - return fn(str) -} - func (m *streamsMap) putStream(s streamI) error { id := s.StreamID() if _, ok := m.streams[id]; ok { return fmt.Errorf("a stream with ID %d already exists", id) } - m.streams[id] = s - m.openStreams = append(m.openStreams, id) return nil } @@ -316,14 +217,20 @@ func (m *streamsMap) CloseWithError(err error) { m.closeErr = err m.nextStreamOrErrCond.Broadcast() m.openStreamOrErrCond.Broadcast() - for _, s := range m.openStreams { - m.streams[s].Cancel(err) + for _, s := range m.streams { + s.closeForShutdown(err) } } -func (m *streamsMap) UpdateMaxStreamLimit(limit uint32) { +// TODO(#952): this won't be needed when gQUIC supports stateless handshakes +func (m *streamsMap) UpdateLimits(params *handshake.TransportParameters) { m.mutex.Lock() - defer m.mutex.Unlock() - m.maxOutgoingStreams = limit + for id, str := range m.streams { + str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: id, + ByteOffset: params.StreamFlowControlWindow, + }) + } + m.mutex.Unlock() m.openStreamOrErrCond.Broadcast() } diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_legacy.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_legacy.go new file mode 100644 index 0000000..de52a8f --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_legacy.go @@ -0,0 +1,257 @@ +package quic + +import ( + "fmt" + "sync" + + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/qerr" +) + +type streamsMapLegacy struct { + mutex sync.RWMutex + + perspective protocol.Perspective + + streams map[protocol.StreamID]streamI + + nextStreamToOpen protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() + highestStreamOpenedByPeer protocol.StreamID + nextStreamOrErrCond sync.Cond + openStreamOrErrCond sync.Cond + + closeErr error + nextStreamToAccept protocol.StreamID + + newStream newStreamLambda + + numOutgoingStreams uint32 + numIncomingStreams uint32 + maxIncomingStreams uint32 + maxOutgoingStreams uint32 +} + +var _ streamManager = &streamsMapLegacy{} + +func newStreamsMapLegacy(newStream newStreamLambda, pers protocol.Perspective) streamManager { + // add some tolerance to the maximum incoming streams value + maxStreams := uint32(protocol.MaxIncomingStreams) + maxIncomingStreams := utils.MaxUint32( + maxStreams+protocol.MaxStreamsMinimumIncrement, + uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)), + ) + sm := streamsMapLegacy{ + perspective: pers, + streams: make(map[protocol.StreamID]streamI), + newStream: newStream, + maxIncomingStreams: maxIncomingStreams, + } + sm.nextStreamOrErrCond.L = &sm.mutex + sm.openStreamOrErrCond.L = &sm.mutex + + nextServerInitiatedStream := protocol.StreamID(2) + nextClientInitiatedStream := protocol.StreamID(3) + if pers == protocol.PerspectiveServer { + sm.highestStreamOpenedByPeer = 1 + } + if pers == protocol.PerspectiveServer { + sm.nextStreamToOpen = nextServerInitiatedStream + sm.nextStreamToAccept = nextClientInitiatedStream + } else { + sm.nextStreamToOpen = nextClientInitiatedStream + sm.nextStreamToAccept = nextServerInitiatedStream + } + return &sm +} + +// getStreamPerspective says which side should initiate a stream +func (m *streamsMapLegacy) streamInitiatedBy(id protocol.StreamID) protocol.Perspective { + if id%2 == 0 { + return protocol.PerspectiveServer + } + return protocol.PerspectiveClient +} + +func (m *streamsMapLegacy) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { + // every bidirectional stream is also a receive stream + return m.GetOrOpenStream(id) +} + +func (m *streamsMapLegacy) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { + // every bidirectional stream is also a send stream + return m.GetOrOpenStream(id) +} + +// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. +// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used. +func (m *streamsMapLegacy) GetOrOpenStream(id protocol.StreamID) (streamI, error) { + m.mutex.RLock() + s, ok := m.streams[id] + m.mutex.RUnlock() + if ok { + return s, nil + } + + // ... we don't have an existing stream + m.mutex.Lock() + defer m.mutex.Unlock() + // We need to check whether another invocation has already created a stream (between RUnlock() and Lock()). + s, ok = m.streams[id] + if ok { + return s, nil + } + + if m.perspective == m.streamInitiatedBy(id) { + if id <= m.nextStreamToOpen { // this is a stream opened by us. Must have been closed already + return nil, nil + } + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id)) + } + if id <= m.highestStreamOpenedByPeer { // this is a peer-initiated stream that doesn't exist anymore. Must have been closed already + return nil, nil + } + + for sid := m.highestStreamOpenedByPeer + 2; sid <= id; sid += 2 { + if _, err := m.openRemoteStream(sid); err != nil { + return nil, err + } + } + + m.nextStreamOrErrCond.Broadcast() + return m.streams[id], nil +} + +func (m *streamsMapLegacy) openRemoteStream(id protocol.StreamID) (streamI, error) { + if m.numIncomingStreams >= m.maxIncomingStreams { + return nil, qerr.TooManyOpenStreams + } + if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer { + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer)) + } + + m.numIncomingStreams++ + if id > m.highestStreamOpenedByPeer { + m.highestStreamOpenedByPeer = id + } + + s := m.newStream(id) + m.putStream(s) + return s, nil +} + +func (m *streamsMapLegacy) openStreamImpl() (streamI, error) { + if m.numOutgoingStreams >= m.maxOutgoingStreams { + return nil, qerr.TooManyOpenStreams + } + + m.numOutgoingStreams++ + s := m.newStream(m.nextStreamToOpen) + m.putStream(s) + m.nextStreamToOpen += 2 + return s, nil +} + +// OpenStream opens the next available stream +func (m *streamsMapLegacy) OpenStream() (Stream, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.closeErr != nil { + return nil, m.closeErr + } + return m.openStreamImpl() +} + +func (m *streamsMapLegacy) OpenStreamSync() (Stream, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + for { + if m.closeErr != nil { + return nil, m.closeErr + } + str, err := m.openStreamImpl() + if err == nil { + return str, err + } + if err != nil && err != qerr.TooManyOpenStreams { + return nil, err + } + m.openStreamOrErrCond.Wait() + } +} + +// AcceptStream returns the next stream opened by the peer +// it blocks until a new stream is opened +func (m *streamsMapLegacy) AcceptStream() (Stream, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + var str streamI + for { + var ok bool + if m.closeErr != nil { + return nil, m.closeErr + } + str, ok = m.streams[m.nextStreamToAccept] + if ok { + break + } + m.nextStreamOrErrCond.Wait() + } + m.nextStreamToAccept += 2 + return str, nil +} + +func (m *streamsMapLegacy) DeleteStream(id protocol.StreamID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + _, ok := m.streams[id] + if !ok { + return errMapAccess + } + delete(m.streams, id) + if m.streamInitiatedBy(id) == m.perspective { + m.numOutgoingStreams-- + } else { + m.numIncomingStreams-- + } + m.openStreamOrErrCond.Signal() + return nil +} + +func (m *streamsMapLegacy) putStream(s streamI) error { + id := s.StreamID() + if _, ok := m.streams[id]; ok { + return fmt.Errorf("a stream with ID %d already exists", id) + } + m.streams[id] = s + return nil +} + +func (m *streamsMapLegacy) CloseWithError(err error) { + m.mutex.Lock() + defer m.mutex.Unlock() + m.closeErr = err + m.nextStreamOrErrCond.Broadcast() + m.openStreamOrErrCond.Broadcast() + for _, s := range m.streams { + s.closeForShutdown(err) + } +} + +// TODO(#952): this won't be needed when gQUIC supports stateless handshakes +func (m *streamsMapLegacy) UpdateLimits(params *handshake.TransportParameters) { + m.mutex.Lock() + m.maxOutgoingStreams = params.MaxStreams + for id, str := range m.streams { + str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: id, + ByteOffset: params.StreamFlowControlWindow, + }) + } + m.mutex.Unlock() + m.openStreamOrErrCond.Broadcast() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_legacy_test.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_legacy_test.go new file mode 100644 index 0000000..7e9324f --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_legacy_test.go @@ -0,0 +1,549 @@ +package quic + +import ( + "errors" + + "github.com/golang/mock/gomock" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/qerr" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Streams Map (for gQUIC)", func() { + var m *streamsMapLegacy + + newStream := func(id protocol.StreamID) streamI { + str := NewMockStreamI(mockCtrl) + str.EXPECT().StreamID().Return(id).AnyTimes() + return str + } + + setNewStreamsMap := func(p protocol.Perspective) { + m = newStreamsMapLegacy(newStream, p).(*streamsMapLegacy) + } + + deleteStream := func(id protocol.StreamID) { + ExpectWithOffset(1, m.DeleteStream(id)).To(Succeed()) + } + + Context("getting and creating streams", func() { + Context("as a server", func() { + BeforeEach(func() { + setNewStreamsMap(protocol.PerspectiveServer) + }) + + Context("client-side streams", func() { + It("gets new streams", func() { + s, err := m.GetOrOpenStream(3) + Expect(err).NotTo(HaveOccurred()) + Expect(s).ToNot(BeNil()) + Expect(s.StreamID()).To(Equal(protocol.StreamID(3))) + Expect(m.streams).To(HaveLen(1)) + Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) + Expect(m.numOutgoingStreams).To(BeZero()) + }) + + It("rejects streams with even IDs", func() { + _, err := m.GetOrOpenStream(6) + Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 6")) + }) + + It("rejects streams with even IDs, which are lower thatn the highest client-side stream", func() { + _, err := m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + _, err = m.GetOrOpenStream(4) + Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 4")) + }) + + It("gets existing streams", func() { + s, err := m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + numStreams := m.numIncomingStreams + s, err = m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + Expect(s.StreamID()).To(Equal(protocol.StreamID(5))) + Expect(m.numIncomingStreams).To(Equal(numStreams)) + }) + + It("returns nil for closed streams", func() { + _, err := m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + deleteStream(5) + s, err := m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + Expect(s).To(BeNil()) + }) + + It("opens skipped streams", func() { + _, err := m.GetOrOpenStream(7) + Expect(err).NotTo(HaveOccurred()) + Expect(m.streams).To(HaveKey(protocol.StreamID(3))) + Expect(m.streams).To(HaveKey(protocol.StreamID(5))) + Expect(m.streams).To(HaveKey(protocol.StreamID(7))) + }) + + It("doesn't reopen an already closed stream", func() { + _, err := m.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + deleteStream(5) + Expect(err).ToNot(HaveOccurred()) + str, err := m.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) + }) + + Context("counting streams", func() { + It("errors when too many streams are opened", func() { + for i := uint32(0); i < m.maxIncomingStreams; i++ { + _, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) + Expect(err).NotTo(HaveOccurred()) + } + _, err := m.GetOrOpenStream(protocol.StreamID(2*m.maxIncomingStreams + 3)) + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + }) + + It("errors when too many streams are opened implicitely", func() { + _, err := m.GetOrOpenStream(protocol.StreamID(m.maxIncomingStreams*2 + 3)) + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + }) + + It("does not error when many streams are opened and closed", func() { + for i := uint32(2); i < 10*m.maxIncomingStreams; i++ { + str, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) + Expect(err).NotTo(HaveOccurred()) + deleteStream(str.StreamID()) + } + }) + }) + }) + + Context("server-side streams", func() { + It("doesn't allow opening streams before receiving the transport parameters", func() { + _, err := m.OpenStream() + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + }) + + It("opens a stream 2 first", func() { + m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000}) + s, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(s).ToNot(BeNil()) + Expect(s.StreamID()).To(Equal(protocol.StreamID(2))) + Expect(m.numIncomingStreams).To(BeZero()) + Expect(m.numOutgoingStreams).To(BeEquivalentTo(1)) + }) + + It("returns the error when the streamsMap was closed", func() { + testErr := errors.New("test error") + m.CloseWithError(testErr) + _, err := m.OpenStream() + Expect(err).To(MatchError(testErr)) + }) + + It("doesn't reopen an already closed stream", func() { + m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000}) + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(2))) + deleteStream(2) + Expect(err).ToNot(HaveOccurred()) + str, err = m.GetOrOpenStream(2) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) + }) + + Context("counting streams", func() { + const maxOutgoingStreams = 50 + + BeforeEach(func() { + m.UpdateLimits(&handshake.TransportParameters{MaxStreams: maxOutgoingStreams}) + }) + + It("errors when too many streams are opened", func() { + for i := 1; i <= maxOutgoingStreams; i++ { + _, err := m.OpenStream() + Expect(err).NotTo(HaveOccurred()) + } + _, err := m.OpenStream() + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + }) + + It("does not error when many streams are opened and closed", func() { + for i := 2; i < 10*maxOutgoingStreams; i++ { + str, err := m.OpenStream() + Expect(err).NotTo(HaveOccurred()) + deleteStream(str.StreamID()) + } + }) + + It("allows many server- and client-side streams at the same time", func() { + for i := 1; i < maxOutgoingStreams; i++ { + _, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + } + for i := 0; i < maxOutgoingStreams; i++ { + _, err := m.GetOrOpenStream(protocol.StreamID(2*i + 1)) + Expect(err).ToNot(HaveOccurred()) + } + }) + }) + + Context("opening streams synchronously", func() { + const maxOutgoingStreams = 10 + + BeforeEach(func() { + m.UpdateLimits(&handshake.TransportParameters{MaxStreams: maxOutgoingStreams}) + }) + + openMaxNumStreams := func() { + for i := 1; i <= maxOutgoingStreams; i++ { + _, err := m.OpenStream() + Expect(err).NotTo(HaveOccurred()) + } + _, err := m.OpenStream() + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + } + + It("waits until another stream is closed", func() { + openMaxNumStreams() + var str Stream + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + var err error + str, err = m.OpenStreamSync() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + deleteStream(6) + Eventually(done).Should(BeClosed()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(2*maxOutgoingStreams + 2))) + }) + + It("stops waiting when an error is registered", func() { + testErr := errors.New("test error") + openMaxNumStreams() + for _, str := range m.streams { + str.(*MockStreamI).EXPECT().closeForShutdown(testErr) + } + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := m.OpenStreamSync() + Expect(err).To(MatchError(testErr)) + close(done) + }() + + Consistently(done).ShouldNot(BeClosed()) + m.CloseWithError(testErr) + Eventually(done).Should(BeClosed()) + }) + + It("immediately returns when OpenStreamSync is called after an error was registered", func() { + testErr := errors.New("test error") + m.CloseWithError(testErr) + _, err := m.OpenStreamSync() + Expect(err).To(MatchError(testErr)) + }) + }) + }) + + Context("accepting streams", func() { + It("does nothing if no stream is opened", func() { + var accepted bool + go func() { + _, _ = m.AcceptStream() + accepted = true + }() + Consistently(func() bool { return accepted }).Should(BeFalse()) + }) + + It("starts with stream 3", func() { + var str Stream + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + var err error + str, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + _, err := m.GetOrOpenStream(3) + Expect(err).ToNot(HaveOccurred()) + Eventually(done).Should(BeClosed()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + }) + + It("returns an implicitly opened stream, if a stream number is skipped", func() { + var str Stream + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + var err error + str, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + _, err := m.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + Eventually(done).Should(BeClosed()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + }) + + It("returns to multiple accepts", func() { + var str1, str2 Stream + done1 := make(chan struct{}) + done2 := make(chan struct{}) + go func() { + defer GinkgoRecover() + var err error + str1, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + close(done1) + }() + go func() { + defer GinkgoRecover() + var err error + str2, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + close(done2) + }() + _, err := m.GetOrOpenStream(5) // opens stream 3 and 5 + Expect(err).ToNot(HaveOccurred()) + Eventually(done1).Should(BeClosed()) + Eventually(done2).Should(BeClosed()) + Expect(str1.StreamID()).ToNot(Equal(str2.StreamID())) + Expect(str1.StreamID() + str2.StreamID()).To(BeEquivalentTo(3 + 5)) + }) + + It("waits until a new stream is available", func() { + var str Stream + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + var err error + str, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + _, err := m.GetOrOpenStream(3) + Expect(err).ToNot(HaveOccurred()) + Eventually(done).Should(BeClosed()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + }) + + It("returns multiple streams on subsequent Accept calls, if available", func() { + var str Stream + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + var err error + str, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + _, err := m.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + Eventually(done).Should(BeClosed()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + str, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(5))) + }) + + It("blocks after accepting a stream", func() { + _, err := m.GetOrOpenStream(3) + Expect(err).ToNot(HaveOccurred()) + str, err := m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, _ = m.AcceptStream() + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + // make the go routine return + str.(*MockStreamI).EXPECT().closeForShutdown(gomock.Any()) + m.CloseWithError(errors.New("shut down")) + Eventually(done).Should(BeClosed()) + }) + + It("stops waiting when an error is registered", func() { + testErr := errors.New("testErr") + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := m.AcceptStream() + Expect(err).To(MatchError(testErr)) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + m.CloseWithError(testErr) + Eventually(done).Should(BeClosed()) + }) + It("immediately returns when Accept is called after an error was registered", func() { + testErr := errors.New("testErr") + m.CloseWithError(testErr) + _, err := m.AcceptStream() + Expect(err).To(MatchError(testErr)) + }) + }) + }) + + Context("as a client", func() { + BeforeEach(func() { + setNewStreamsMap(protocol.PerspectiveClient) + m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000}) + }) + + Context("server-side streams", func() { + It("rejects streams with odd IDs", func() { + _, err := m.GetOrOpenStream(5) + Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 5")) + }) + + It("rejects streams with odds IDs, which are lower than the highest server-side stream", func() { + _, err := m.GetOrOpenStream(6) + Expect(err).NotTo(HaveOccurred()) + _, err = m.GetOrOpenStream(5) + Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 5")) + }) + + It("gets new streams", func() { + s, err := m.GetOrOpenStream(2) + Expect(err).NotTo(HaveOccurred()) + Expect(s.StreamID()).To(Equal(protocol.StreamID(2))) + Expect(m.streams).To(HaveLen(1)) + Expect(m.numOutgoingStreams).To(BeZero()) + Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) + }) + + It("opens skipped streams", func() { + _, err := m.GetOrOpenStream(6) + Expect(err).NotTo(HaveOccurred()) + Expect(m.streams).To(HaveKey(protocol.StreamID(2))) + Expect(m.streams).To(HaveKey(protocol.StreamID(4))) + Expect(m.streams).To(HaveKey(protocol.StreamID(6))) + Expect(m.numOutgoingStreams).To(BeZero()) + Expect(m.numIncomingStreams).To(BeEquivalentTo(3)) + }) + + It("doesn't reopen an already closed stream", func() { + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + deleteStream(3) + Expect(err).ToNot(HaveOccurred()) + str, err = m.GetOrOpenStream(3) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) + }) + }) + + Context("client-side streams", func() { + It("starts with stream 3", func() { + s, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(s).ToNot(BeNil()) + Expect(s.StreamID()).To(BeEquivalentTo(3)) + Expect(m.numOutgoingStreams).To(BeEquivalentTo(1)) + Expect(m.numIncomingStreams).To(BeZero()) + }) + + It("opens multiple streams", func() { + s1, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + s2, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(s2.StreamID()).To(Equal(s1.StreamID() + 2)) + }) + + It("doesn't reopen an already closed stream", func() { + _, err := m.GetOrOpenStream(4) + Expect(err).ToNot(HaveOccurred()) + deleteStream(4) + Expect(err).ToNot(HaveOccurred()) + str, err := m.GetOrOpenStream(4) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) + }) + }) + + Context("accepting streams", func() { + It("accepts stream 2 first", func() { + var str Stream + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + var err error + str, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + _, err := m.GetOrOpenStream(2) + Expect(err).ToNot(HaveOccurred()) + Eventually(done).Should(BeClosed()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(2))) + }) + }) + }) + }) + + Context("deleting streams", func() { + BeforeEach(func() { + setNewStreamsMap(protocol.PerspectiveServer) + }) + + It("deletes an incoming stream", func() { + _, err := m.GetOrOpenStream(5) // open stream 3 and 5 + Expect(err).ToNot(HaveOccurred()) + Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) + err = m.DeleteStream(3) + Expect(err).ToNot(HaveOccurred()) + Expect(m.streams).To(HaveLen(1)) + Expect(m.streams).To(HaveKey(protocol.StreamID(5))) + Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) + }) + + It("deletes an outgoing stream", func() { + m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000}) + _, err := m.OpenStream() // open stream 2 + Expect(err).ToNot(HaveOccurred()) + _, err = m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) + err = m.DeleteStream(2) + Expect(err).ToNot(HaveOccurred()) + Expect(m.numOutgoingStreams).To(BeEquivalentTo(1)) + }) + + It("errors when the stream doesn't exist", func() { + err := m.DeleteStream(1337) + Expect(err).To(MatchError(errMapAccess)) + }) + }) + + It("sets the flow control limit", func() { + setNewStreamsMap(protocol.PerspectiveServer) + _, err := m.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + m.streams[3].(*MockStreamI).EXPECT().handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: 3, + ByteOffset: 321, + }) + m.streams[5].(*MockStreamI).EXPECT().handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: 5, + ByteOffset: 321, + }) + m.UpdateLimits(&handshake.TransportParameters{StreamFlowControlWindow: 321}) + }) +}) diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_test.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_test.go index 99936c9..9a1fa32 100644 --- a/vendor/github.com/lucas-clemente/quic-go/streams_map_test.go +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_test.go @@ -2,64 +2,46 @@ package quic import ( "errors" - "sort" - - "github.com/lucas-clemente/quic-go/internal/mocks" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/qerr" "github.com/golang/mock/gomock" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -var _ = Describe("Streams Map", func() { - var ( - m *streamsMap - finishedStreams map[protocol.StreamID]*gomock.Call - ) +var _ = Describe("Streams Map (for IETF QUIC)", func() { + var m *streamsMap newStream := func(id protocol.StreamID) streamI { - str := mocks.NewMockStreamI(mockCtrl) + str := NewMockStreamI(mockCtrl) str.EXPECT().StreamID().Return(id).AnyTimes() - c := str.EXPECT().Finished().Return(false).AnyTimes() - finishedStreams[id] = c return str } - setNewStreamsMap := func(p protocol.Perspective, v protocol.VersionNumber) { - m = newStreamsMap(newStream, p, v) + setNewStreamsMap := func(p protocol.Perspective) { + m = newStreamsMap(newStream, p).(*streamsMap) } - BeforeEach(func() { - finishedStreams = make(map[protocol.StreamID]*gomock.Call) - }) - - AfterEach(func() { - Expect(m.openStreams).To(HaveLen(len(m.streams))) - }) - deleteStream := func(id protocol.StreamID) { - str := m.streams[id] - Expect(str).ToNot(BeNil()) - finishedStreams[id].Return(true) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, m.DeleteStream(id)).To(Succeed()) } Context("getting and creating streams", func() { Context("as a server", func() { BeforeEach(func() { - setNewStreamsMap(protocol.PerspectiveServer, versionGQUICFrames) + setNewStreamsMap(protocol.PerspectiveServer) }) Context("client-side streams", func() { It("gets new streams", func() { s, err := m.GetOrOpenStream(1) Expect(err).NotTo(HaveOccurred()) + Expect(s).ToNot(BeNil()) Expect(s.StreamID()).To(Equal(protocol.StreamID(1))) - Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) - Expect(m.numOutgoingStreams).To(BeZero()) + Expect(m.streams).To(HaveLen(1)) }) It("rejects streams with even IDs", func() { @@ -77,11 +59,11 @@ var _ = Describe("Streams Map", func() { It("gets existing streams", func() { s, err := m.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) - numStreams := m.numIncomingStreams + numStreams := len(m.streams) s, err = m.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) Expect(s.StreamID()).To(Equal(protocol.StreamID(5))) - Expect(m.numIncomingStreams).To(Equal(numStreams)) + Expect(m.streams).To(HaveLen(numStreams)) }) It("returns nil for closed streams", func() { @@ -94,11 +76,11 @@ var _ = Describe("Streams Map", func() { }) It("opens skipped streams", func() { - _, err := m.GetOrOpenStream(5) + _, err := m.GetOrOpenStream(7) Expect(err).NotTo(HaveOccurred()) - Expect(m.streams).To(HaveKey(protocol.StreamID(1))) Expect(m.streams).To(HaveKey(protocol.StreamID(3))) Expect(m.streams).To(HaveKey(protocol.StreamID(5))) + Expect(m.streams).To(HaveKey(protocol.StreamID(7))) }) It("doesn't reopen an already closed stream", func() { @@ -110,46 +92,14 @@ var _ = Describe("Streams Map", func() { Expect(err).ToNot(HaveOccurred()) Expect(str).To(BeNil()) }) - - Context("counting streams", func() { - It("errors when too many streams are opened", func() { - for i := uint32(0); i < m.maxIncomingStreams; i++ { - _, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) - Expect(err).NotTo(HaveOccurred()) - } - _, err := m.GetOrOpenStream(protocol.StreamID(2*m.maxIncomingStreams + 3)) - Expect(err).To(MatchError(qerr.TooManyOpenStreams)) - }) - - It("errors when too many streams are opened implicitely", func() { - _, err := m.GetOrOpenStream(protocol.StreamID(m.maxIncomingStreams*2 + 1)) - Expect(err).To(MatchError(qerr.TooManyOpenStreams)) - }) - - It("does not error when many streams are opened and closed", func() { - for i := uint32(2); i < 10*m.maxIncomingStreams; i++ { - str, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) - Expect(err).NotTo(HaveOccurred()) - deleteStream(str.StreamID()) - } - }) - }) }) Context("server-side streams", func() { - It("doesn't allow opening streams before receiving the transport parameters", func() { - _, err := m.OpenStream() - Expect(err).To(MatchError(qerr.TooManyOpenStreams)) - }) - It("opens a stream 2 first", func() { - m.UpdateMaxStreamLimit(100) s, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) Expect(s).ToNot(BeNil()) Expect(s.StreamID()).To(Equal(protocol.StreamID(2))) - Expect(m.numIncomingStreams).To(BeZero()) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(1)) }) It("returns the error when the streamsMap was closed", func() { @@ -160,7 +110,6 @@ var _ = Describe("Streams Map", func() { }) It("doesn't reopen an already closed stream", func() { - m.UpdateMaxStreamLimit(100) str, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) Expect(str.StreamID()).To(Equal(protocol.StreamID(2))) @@ -171,96 +120,7 @@ var _ = Describe("Streams Map", func() { Expect(str).To(BeNil()) }) - Context("counting streams", func() { - const maxOutgoingStreams = 50 - - BeforeEach(func() { - m.UpdateMaxStreamLimit(maxOutgoingStreams) - }) - - It("errors when too many streams are opened", func() { - for i := 1; i <= maxOutgoingStreams; i++ { - _, err := m.OpenStream() - Expect(err).NotTo(HaveOccurred()) - } - _, err := m.OpenStream() - Expect(err).To(MatchError(qerr.TooManyOpenStreams)) - }) - - It("does not error when many streams are opened and closed", func() { - for i := 2; i < 10*maxOutgoingStreams; i++ { - str, err := m.OpenStream() - Expect(err).NotTo(HaveOccurred()) - deleteStream(str.StreamID()) - } - }) - - It("allows many server- and client-side streams at the same time", func() { - for i := 1; i < maxOutgoingStreams; i++ { - _, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - } - for i := 0; i < maxOutgoingStreams; i++ { - _, err := m.GetOrOpenStream(protocol.StreamID(2*i + 1)) - Expect(err).ToNot(HaveOccurred()) - } - }) - }) - Context("opening streams synchronously", func() { - const maxOutgoingStreams = 10 - - BeforeEach(func() { - m.UpdateMaxStreamLimit(maxOutgoingStreams) - }) - - openMaxNumStreams := func() { - for i := 1; i <= maxOutgoingStreams; i++ { - _, err := m.OpenStream() - Expect(err).NotTo(HaveOccurred()) - } - _, err := m.OpenStream() - Expect(err).To(MatchError(qerr.TooManyOpenStreams)) - } - - It("waits until another stream is closed", func() { - openMaxNumStreams() - var returned bool - var str streamI - go func() { - defer GinkgoRecover() - var err error - str, err = m.OpenStreamSync() - Expect(err).ToNot(HaveOccurred()) - returned = true - }() - - Consistently(func() bool { return returned }).Should(BeFalse()) - deleteStream(6) - Eventually(func() bool { return returned }).Should(BeTrue()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(2*maxOutgoingStreams + 2))) - }) - - It("stops waiting when an error is registered", func() { - testErr := errors.New("test error") - openMaxNumStreams() - for _, str := range m.streams { - str.(*mocks.MockStreamI).EXPECT().Cancel(testErr) - } - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := m.OpenStreamSync() - Expect(err).To(MatchError(testErr)) - close(done) - }() - - Consistently(done).ShouldNot(BeClosed()) - m.CloseWithError(testErr) - Eventually(done).Should(BeClosed()) - }) - It("immediately returns when OpenStreamSync is called after an error was registered", func() { testErr := errors.New("test error") m.CloseWithError(testErr) @@ -280,127 +140,131 @@ var _ = Describe("Streams Map", func() { Consistently(func() bool { return accepted }).Should(BeFalse()) }) - It("starts with stream 1, if the crypto stream is stream 0", func() { - setNewStreamsMap(protocol.PerspectiveServer, versionIETFFrames) - var str streamI + It("starts with stream 1", func() { + var str Stream + done := make(chan struct{}) go func() { defer GinkgoRecover() var err error str, err = m.AcceptStream() Expect(err).ToNot(HaveOccurred()) + close(done) }() _, err := m.GetOrOpenStream(1) Expect(err).ToNot(HaveOccurred()) - Eventually(func() Stream { return str }).ShouldNot(BeNil()) + Eventually(done).Should(BeClosed()) Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) }) - It("starts with stream 3, if the crypto stream is stream 1", func() { - var str streamI + It("returns an implicitly opened stream, if a stream number is skipped", func() { + var str Stream + done := make(chan struct{}) go func() { defer GinkgoRecover() var err error str, err = m.AcceptStream() Expect(err).ToNot(HaveOccurred()) + close(done) }() _, err := m.GetOrOpenStream(3) Expect(err).ToNot(HaveOccurred()) - Eventually(func() Stream { return str }).ShouldNot(BeNil()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) - }) - - It("returns an implicitly opened stream, if a stream number is skipped", func() { - var str streamI - go func() { - defer GinkgoRecover() - var err error - str, err = m.AcceptStream() - Expect(err).ToNot(HaveOccurred()) - }() - _, err := m.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - Eventually(func() Stream { return str }).ShouldNot(BeNil()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + Eventually(done).Should(BeClosed()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) }) It("returns to multiple accepts", func() { - var str1, str2 streamI + var str1, str2 Stream + done1 := make(chan struct{}) + done2 := make(chan struct{}) go func() { defer GinkgoRecover() var err error str1, err = m.AcceptStream() Expect(err).ToNot(HaveOccurred()) + close(done1) }() go func() { defer GinkgoRecover() var err error str2, err = m.AcceptStream() Expect(err).ToNot(HaveOccurred()) + close(done2) }() - _, err := m.GetOrOpenStream(5) // opens stream 3 and 5 + _, err := m.GetOrOpenStream(3) // opens stream 1 and 3 Expect(err).ToNot(HaveOccurred()) - Eventually(func() streamI { return str1 }).ShouldNot(BeNil()) - Eventually(func() streamI { return str2 }).ShouldNot(BeNil()) + Eventually(done1).Should(BeClosed()) + Eventually(done2).Should(BeClosed()) Expect(str1.StreamID()).ToNot(Equal(str2.StreamID())) - Expect(str1.StreamID() + str2.StreamID()).To(BeEquivalentTo(3 + 5)) + Expect(str1.StreamID() + str2.StreamID()).To(BeEquivalentTo(1 + 3)) }) - It("waits a new stream is available", func() { - var str streamI + It("waits until a new stream is available", func() { + var str Stream + done := make(chan struct{}) go func() { defer GinkgoRecover() var err error str, err = m.AcceptStream() Expect(err).ToNot(HaveOccurred()) + close(done) }() - Consistently(func() streamI { return str }).Should(BeNil()) - _, err := m.GetOrOpenStream(3) + Consistently(done).ShouldNot(BeClosed()) + _, err := m.GetOrOpenStream(1) Expect(err).ToNot(HaveOccurred()) - Eventually(func() streamI { return str }).ShouldNot(BeNil()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + Eventually(done).Should(BeClosed()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) }) It("returns multiple streams on subsequent Accept calls, if available", func() { - var str streamI + var str Stream + done := make(chan struct{}) go func() { defer GinkgoRecover() var err error str, err = m.AcceptStream() Expect(err).ToNot(HaveOccurred()) + close(done) }() - _, err := m.GetOrOpenStream(5) + _, err := m.GetOrOpenStream(3) Expect(err).ToNot(HaveOccurred()) - Eventually(func() streamI { return str }).ShouldNot(BeNil()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + Eventually(done).Should(BeClosed()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) str, err = m.AcceptStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(5))) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) }) It("blocks after accepting a stream", func() { - var accepted bool - _, err := m.GetOrOpenStream(3) + _, err := m.GetOrOpenStream(1) Expect(err).ToNot(HaveOccurred()) str, err := m.AcceptStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) + done := make(chan struct{}) go func() { defer GinkgoRecover() _, _ = m.AcceptStream() - accepted = true + close(done) }() - Consistently(func() bool { return accepted }).Should(BeFalse()) + Consistently(done).ShouldNot(BeClosed()) + // make the go routine return + str.(*MockStreamI).EXPECT().closeForShutdown(gomock.Any()) + m.CloseWithError(errors.New("shut down")) + Eventually(done).Should(BeClosed()) }) It("stops waiting when an error is registered", func() { testErr := errors.New("testErr") - var acceptErr error + done := make(chan struct{}) go func() { - _, acceptErr = m.AcceptStream() + defer GinkgoRecover() + _, err := m.AcceptStream() + Expect(err).To(MatchError(testErr)) + close(done) }() - Consistently(func() error { return acceptErr }).ShouldNot(HaveOccurred()) + Consistently(done).ShouldNot(BeClosed()) m.CloseWithError(testErr) - Eventually(func() error { return acceptErr }).Should(MatchError(testErr)) + Eventually(done).Should(BeClosed()) }) It("immediately returns when Accept is called after an error was registered", func() { @@ -414,8 +278,7 @@ var _ = Describe("Streams Map", func() { Context("as a client", func() { BeforeEach(func() { - setNewStreamsMap(protocol.PerspectiveClient, versionGQUICFrames) - m.UpdateMaxStreamLimit(100) + setNewStreamsMap(protocol.PerspectiveClient) }) Context("server-side streams", func() { @@ -424,7 +287,7 @@ var _ = Describe("Streams Map", func() { Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 5")) }) - It("rejects streams with odds IDs, which are lower thatn the highest server-side stream", func() { + It("rejects streams with odds IDs, which are lower than the highest server-side stream", func() { _, err := m.GetOrOpenStream(6) Expect(err).NotTo(HaveOccurred()) _, err = m.GetOrOpenStream(5) @@ -435,8 +298,7 @@ var _ = Describe("Streams Map", func() { s, err := m.GetOrOpenStream(2) Expect(err).NotTo(HaveOccurred()) Expect(s.StreamID()).To(Equal(protocol.StreamID(2))) - Expect(m.numOutgoingStreams).To(BeZero()) - Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) + Expect(m.streams).To(HaveLen(1)) }) It("opens skipped streams", func() { @@ -445,41 +307,27 @@ var _ = Describe("Streams Map", func() { Expect(m.streams).To(HaveKey(protocol.StreamID(2))) Expect(m.streams).To(HaveKey(protocol.StreamID(4))) Expect(m.streams).To(HaveKey(protocol.StreamID(6))) - Expect(m.numOutgoingStreams).To(BeZero()) - Expect(m.numIncomingStreams).To(BeEquivalentTo(3)) }) It("doesn't reopen an already closed stream", func() { str, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) - deleteStream(3) + Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) + deleteStream(1) Expect(err).ToNot(HaveOccurred()) - str, err = m.GetOrOpenStream(3) + str, err = m.GetOrOpenStream(1) Expect(err).ToNot(HaveOccurred()) Expect(str).To(BeNil()) }) }) Context("client-side streams", func() { - It("starts with stream 1, if the crypto stream is stream 0", func() { - setNewStreamsMap(protocol.PerspectiveClient, versionIETFFrames) - m.UpdateMaxStreamLimit(100) + It("starts with stream 1", func() { + setNewStreamsMap(protocol.PerspectiveClient) s, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) Expect(s).ToNot(BeNil()) Expect(s.StreamID()).To(BeEquivalentTo(1)) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(1)) - Expect(m.numIncomingStreams).To(BeZero()) - }) - - It("starts with stream 3, if the crypto stream is stream 1", func() { - s, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - Expect(s).ToNot(BeNil()) - Expect(s.StreamID()).To(BeEquivalentTo(3)) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(1)) - Expect(m.numIncomingStreams).To(BeZero()) }) It("opens multiple streams", func() { @@ -503,297 +351,65 @@ var _ = Describe("Streams Map", func() { Context("accepting streams", func() { It("accepts stream 2 first", func() { - var str streamI + var str Stream + done := make(chan struct{}) go func() { defer GinkgoRecover() var err error str, err = m.AcceptStream() Expect(err).ToNot(HaveOccurred()) + close(done) }() _, err := m.GetOrOpenStream(2) Expect(err).ToNot(HaveOccurred()) - Eventually(func() streamI { return str }).ShouldNot(BeNil()) + Eventually(done).Should(BeClosed()) Expect(str.StreamID()).To(Equal(protocol.StreamID(2))) }) }) }) }) - Context("DoS mitigation, iterating and deleting", func() { + Context("deleting streams", func() { BeforeEach(func() { - setNewStreamsMap(protocol.PerspectiveServer, versionGQUICFrames) + setNewStreamsMap(protocol.PerspectiveServer) }) - closeStream := func(id protocol.StreamID) { - str := m.streams[id] - ExpectWithOffset(1, str).ToNot(BeNil()) - finishedStreams[id].Return(true) - } - - Context("deleting streams", func() { - Context("as a server", func() { - BeforeEach(func() { - m.UpdateMaxStreamLimit(100) - for i := 1; i <= 5; i++ { - if i%2 == 1 { - _, err := m.openRemoteStream(protocol.StreamID(i)) - Expect(err).ToNot(HaveOccurred()) - } else { - _, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - } - } - Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4, 5})) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) // 2 and 4 - Expect(m.numIncomingStreams).To(BeEquivalentTo(3)) // 1, 3 and 5 - }) - - It("does not delete streams with Close()", func() { - str, err := m.GetOrOpenStream(55) - Expect(err).ToNot(HaveOccurred()) - str.(*mocks.MockStreamI).EXPECT().Close() - str.Close() - err = m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - str, err = m.GetOrOpenStream(55) - Expect(err).ToNot(HaveOccurred()) - Expect(str).ToNot(BeNil()) - }) - - It("removes the first stream", func() { - closeStream(1) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.openStreams).To(HaveLen(4)) - Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 3, 4, 5})) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) - Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) - }) - - It("removes a stream in the middle", func() { - closeStream(3) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.streams).To(HaveLen(4)) - Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 4, 5})) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) - Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) - }) - - It("removes a client-initiated stream", func() { - closeStream(2) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.streams).To(HaveLen(4)) - Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 3, 4, 5})) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(1)) - Expect(m.numIncomingStreams).To(BeEquivalentTo(3)) - }) - - It("removes a stream at the end", func() { - closeStream(5) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.openStreams).To(HaveLen(4)) - Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4})) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) - Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) - }) - - It("removes all streams", func() { - for i := 1; i <= 5; i++ { - closeStream(protocol.StreamID(i)) - } - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.streams).To(BeEmpty()) - Expect(m.openStreams).To(BeEmpty()) - Expect(m.numOutgoingStreams).To(BeZero()) - Expect(m.numIncomingStreams).To(BeZero()) - }) - }) - - Context("as a client", func() { - BeforeEach(func() { - setNewStreamsMap(protocol.PerspectiveClient, versionGQUICFrames) - m.UpdateMaxStreamLimit(100) - for i := 1; i <= 5; i++ { - if i%2 == 0 { - _, err := m.openRemoteStream(protocol.StreamID(i)) - Expect(err).ToNot(HaveOccurred()) - } else { - _, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - } - } - Expect(m.openStreams).To(Equal([]protocol.StreamID{3, 2, 5, 4, 7})) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(3)) // 3, 5 and 7 - Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) // 2 and 4 - }) - - It("removes a stream that we initiated", func() { - closeStream(3) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.streams).To(HaveLen(4)) - Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 5, 4, 7})) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) - Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) - }) - - It("removes a stream that the server initiated", func() { - closeStream(2) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.openStreams).To(HaveLen(4)) - Expect(m.openStreams).To(Equal([]protocol.StreamID{3, 5, 4, 7})) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(3)) - Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) - }) - - It("removes all streams", func() { - closeStream(3) - closeStream(2) - closeStream(5) - closeStream(4) - closeStream(7) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.streams).To(BeEmpty()) - Expect(m.openStreams).To(BeEmpty()) - Expect(m.numOutgoingStreams).To(BeZero()) - Expect(m.numIncomingStreams).To(BeZero()) - }) - }) + It("deletes an incoming stream", func() { + _, err := m.GetOrOpenStream(3) // open stream 1 and 3 + Expect(err).ToNot(HaveOccurred()) + err = m.DeleteStream(1) + Expect(err).ToNot(HaveOccurred()) + Expect(m.streams).To(HaveLen(1)) + Expect(m.streams).To(HaveKey(protocol.StreamID(3))) }) - Context("Ranging", func() { - // create 5 streams, ids 4 to 8 - var callbackCalledForStream []protocol.StreamID - callback := func(str streamI) { - callbackCalledForStream = append(callbackCalledForStream, str.StreamID()) - sort.Slice(callbackCalledForStream, func(i, j int) bool { return callbackCalledForStream[i] < callbackCalledForStream[j] }) - } - - BeforeEach(func() { - callbackCalledForStream = callbackCalledForStream[:0] - for i := 4; i <= 8; i++ { - err := m.putStream(&stream{streamID: protocol.StreamID(i)}) - Expect(err).NotTo(HaveOccurred()) - } - }) - - It("ranges over all open streams", func() { - m.Range(callback) - Expect(callbackCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8})) - }) + It("deletes an outgoing stream", func() { + _, err := m.OpenStream() // open stream 2 + Expect(err).ToNot(HaveOccurred()) + _, err = m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + err = m.DeleteStream(2) + Expect(err).ToNot(HaveOccurred()) }) - Context("RoundRobinIterate", func() { - // create 5 streams, ids 4 to 8 - var lambdaCalledForStream []protocol.StreamID - var numIterations int - - BeforeEach(func() { - lambdaCalledForStream = lambdaCalledForStream[:0] - numIterations = 0 - for i := 4; i <= 8; i++ { - err := m.putStream(newStream(protocol.StreamID(i))) - Expect(err).NotTo(HaveOccurred()) - } - }) - - It("executes the lambda exactly once for every stream", func() { - fn := func(str streamI) (bool, error) { - lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID()) - numIterations++ - return true, nil - } - err := m.RoundRobinIterate(fn) - Expect(err).ToNot(HaveOccurred()) - Expect(numIterations).To(Equal(5)) - Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8})) - Expect(m.roundRobinIndex).To(BeZero()) - }) - - It("goes around once when starting in the middle", func() { - fn := func(str streamI) (bool, error) { - lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID()) - numIterations++ - return true, nil - } - m.roundRobinIndex = 3 // pointing to stream 7 - err := m.RoundRobinIterate(fn) - Expect(err).ToNot(HaveOccurred()) - Expect(numIterations).To(Equal(5)) - Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{7, 8, 4, 5, 6})) - Expect(m.roundRobinIndex).To(BeEquivalentTo(3)) - }) - - It("picks up at the index+1 where it last stopped", func() { - fn := func(str streamI) (bool, error) { - lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID()) - numIterations++ - if str.StreamID() == 5 { - return false, nil - } - return true, nil - } - err := m.RoundRobinIterate(fn) - Expect(err).ToNot(HaveOccurred()) - Expect(numIterations).To(Equal(2)) - Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5})) - Expect(m.roundRobinIndex).To(BeEquivalentTo(2)) - numIterations = 0 - lambdaCalledForStream = lambdaCalledForStream[:0] - fn2 := func(str streamI) (bool, error) { - lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID()) - numIterations++ - if str.StreamID() == 7 { - return false, nil - } - return true, nil - } - err = m.RoundRobinIterate(fn2) - Expect(err).ToNot(HaveOccurred()) - Expect(numIterations).To(Equal(2)) - Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{6, 7})) - }) - - Context("adjusting the RoundRobinIndex when deleting streams", func() { - /* - Index: 0 1 2 3 4 - StreamID: [ 4, 5, 6, 7, 8 ] - */ - - It("adjusts when deleting an element in front", func() { - m.roundRobinIndex = 3 // stream 7 - deleteStream(5) - Expect(m.roundRobinIndex).To(BeEquivalentTo(2)) - }) - - It("doesn't adjust when deleting an element at the back", func() { - m.roundRobinIndex = 1 // stream 5 - deleteStream(7) - Expect(m.roundRobinIndex).To(BeEquivalentTo(1)) - }) - - It("doesn't adjust when deleting the element it is pointing to", func() { - m.roundRobinIndex = 3 // stream 7 - deleteStream(7) - Expect(m.roundRobinIndex).To(BeEquivalentTo(3)) - }) - - It("adjusts when deleting multiple elements", func() { - m.roundRobinIndex = 3 // stream 7 - closeStream(5) - closeStream(6) - closeStream(8) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.roundRobinIndex).To(BeEquivalentTo(1)) - }) - }) + It("errors when the stream doesn't exist", func() { + err := m.DeleteStream(1337) + Expect(err).To(MatchError(errMapAccess)) }) }) + + It("sets the flow control limit", func() { + setNewStreamsMap(protocol.PerspectiveServer) + _, err := m.GetOrOpenStream(3) + Expect(err).ToNot(HaveOccurred()) + m.streams[1].(*MockStreamI).EXPECT().handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: 1, + ByteOffset: 321, + }) + m.streams[3].(*MockStreamI).EXPECT().handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: 3, + ByteOffset: 321, + }) + m.UpdateLimits(&handshake.TransportParameters{StreamFlowControlWindow: 321}) + }) }) diff --git a/vendor/github.com/lucas-clemente/quic-go/window_update_queue.go b/vendor/github.com/lucas-clemente/quic-go/window_update_queue.go new file mode 100644 index 0000000..ed006aa --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/window_update_queue.go @@ -0,0 +1,57 @@ +package quic + +import ( + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type windowUpdateQueue struct { + mutex sync.Mutex + + queue map[protocol.StreamID]bool // used as a set + callback func(wire.Frame) + cryptoStream cryptoStreamI + streamGetter streamGetter +} + +func newWindowUpdateQueue(streamGetter streamGetter, cryptoStream cryptoStreamI, cb func(wire.Frame)) *windowUpdateQueue { + return &windowUpdateQueue{ + queue: make(map[protocol.StreamID]bool), + streamGetter: streamGetter, + cryptoStream: cryptoStream, + callback: cb, + } +} + +func (q *windowUpdateQueue) Add(id protocol.StreamID) { + q.mutex.Lock() + q.queue[id] = true + q.mutex.Unlock() +} + +func (q *windowUpdateQueue) QueueAll() { + q.mutex.Lock() + var offset protocol.ByteCount + for id := range q.queue { + if id == q.cryptoStream.StreamID() { + offset = q.cryptoStream.getWindowUpdate() + } else { + str, err := q.streamGetter.GetOrOpenReceiveStream(id) + if err != nil || str == nil { // the stream can be nil if it was completed before dequeing the window update + continue + } + offset = str.getWindowUpdate() + } + if offset == 0 { // can happen if we received a final offset, right after queueing the window update + continue + } + q.callback(&wire.MaxStreamDataFrame{ + StreamID: id, + ByteOffset: offset, + }) + delete(q.queue, id) + } + q.mutex.Unlock() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/window_update_queue_test.go b/vendor/github.com/lucas-clemente/quic-go/window_update_queue_test.go new file mode 100644 index 0000000..cf0511f --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/window_update_queue_test.go @@ -0,0 +1,90 @@ +package quic + +import ( + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Window Update Queue", func() { + var ( + q *windowUpdateQueue + streamGetter *MockStreamGetter + queuedFrames []wire.Frame + cryptoStream *MockCryptoStream + ) + + BeforeEach(func() { + streamGetter = NewMockStreamGetter(mockCtrl) + cryptoStream = NewMockCryptoStream(mockCtrl) + cryptoStream.EXPECT().StreamID().Return(protocol.StreamID(0)).AnyTimes() + queuedFrames = queuedFrames[:0] + q = newWindowUpdateQueue(streamGetter, cryptoStream, func(f wire.Frame) { + queuedFrames = append(queuedFrames, f) + }) + }) + + It("adds stream offsets and gets MAX_STREAM_DATA frames", func() { + stream1 := NewMockStreamI(mockCtrl) + stream1.EXPECT().getWindowUpdate().Return(protocol.ByteCount(10)) + stream3 := NewMockStreamI(mockCtrl) + stream3.EXPECT().getWindowUpdate().Return(protocol.ByteCount(30)) + streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(3)).Return(stream3, nil) + streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(1)).Return(stream1, nil) + q.Add(3) + q.Add(1) + q.QueueAll() + Expect(queuedFrames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 1, ByteOffset: 10})) + Expect(queuedFrames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 3, ByteOffset: 30})) + }) + + It("deletes the entry after getting the MAX_STREAM_DATA frame", func() { + stream10 := NewMockStreamI(mockCtrl) + stream10.EXPECT().getWindowUpdate().Return(protocol.ByteCount(100)) + streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(10)).Return(stream10, nil) + q.Add(10) + q.QueueAll() + Expect(queuedFrames).To(HaveLen(1)) + q.QueueAll() + Expect(queuedFrames).To(HaveLen(1)) + }) + + It("doesn't queue a MAX_STREAM_DATA for a closed stream", func() { + streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(12)).Return(nil, nil) + q.Add(12) + q.QueueAll() + Expect(queuedFrames).To(BeEmpty()) + }) + + It("doesn't queue a MAX_STREAM_DATA if the flow controller returns an offset of 0", func() { + stream5 := NewMockStreamI(mockCtrl) + stream5.EXPECT().getWindowUpdate().Return(protocol.ByteCount(0)) + streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(stream5, nil) + q.Add(5) + q.QueueAll() + Expect(queuedFrames).To(BeEmpty()) + }) + + It("adds MAX_STREAM_DATA frames for the crypto stream", func() { + cryptoStream.EXPECT().getWindowUpdate().Return(protocol.ByteCount(42)) + q.Add(0) + q.QueueAll() + Expect(queuedFrames).To(Equal([]wire.Frame{ + &wire.MaxStreamDataFrame{StreamID: 0, ByteOffset: 42}, + })) + }) + + It("deduplicates", func() { + stream10 := NewMockStreamI(mockCtrl) + stream10.EXPECT().getWindowUpdate().Return(protocol.ByteCount(200)) + streamGetter.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(10)).Return(stream10, nil) + q.Add(10) + q.Add(10) + q.QueueAll() + Expect(queuedFrames).To(Equal([]wire.Frame{ + &wire.MaxStreamDataFrame{StreamID: 10, ByteOffset: 200}, + })) + }) +})