update quic
This commit is contained in:
parent
3e711e63dc
commit
d0695adfb6
|
@ -421,6 +421,7 @@
|
|||
revision = "393af48d391698c6ae4219566bfbdfef67269997"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
name = "github.com/lucas-clemente/quic-go"
|
||||
packages = [
|
||||
".",
|
||||
|
@ -435,8 +436,7 @@
|
|||
"internal/wire",
|
||||
"qerr"
|
||||
]
|
||||
revision = "ded0eb4f6f30a8049bd334a26ff7ff728648fe13"
|
||||
version = "v0.6.0"
|
||||
revision = "15bcc2579f7dab14c84f438741f2b535cf474df9"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
|
@ -753,6 +753,6 @@
|
|||
[solve-meta]
|
||||
analyzer-name = "dep"
|
||||
analyzer-version = 1
|
||||
inputs-digest = "a11e1692755a705514dbd401ba4795821d1ac221d6f9100124c38a29db98c568"
|
||||
inputs-digest = "97c8282ef9b3abed71907d17ccf38379134714596610880b02d5ca03be634678"
|
||||
solver-name = "gps-cdcl"
|
||||
solver-version = 1
|
||||
|
|
137
Gopkg.toml
137
Gopkg.toml
|
@ -0,0 +1,137 @@
|
|||
# Gopkg.toml example
|
||||
#
|
||||
# Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md
|
||||
# for detailed Gopkg.toml documentation.
|
||||
#
|
||||
# required = ["github.com/user/thing/cmd/thing"]
|
||||
# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"]
|
||||
#
|
||||
# [[constraint]]
|
||||
# name = "github.com/user/project"
|
||||
# version = "1.0.0"
|
||||
#
|
||||
# [[constraint]]
|
||||
# name = "github.com/user/project2"
|
||||
# branch = "dev"
|
||||
# source = "github.com/myfork/project2"
|
||||
#
|
||||
# [[override]]
|
||||
# name = "github.com/x/y"
|
||||
# version = "2.4.0"
|
||||
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "github.com/Xe/gopreload"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/Xe/ln"
|
||||
version = "0.1.0"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "github.com/Xe/uuid"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "github.com/Xe/x"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/asdine/storm"
|
||||
version = "2.0.2"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "github.com/brandur/simplebox"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/caarlos0/env"
|
||||
version = "3.2.0"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "github.com/dgryski/go-failure"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "github.com/dickeyxxx/netrc"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "github.com/facebookgo/flagenv"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "github.com/golang/protobuf"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/google/gops"
|
||||
version = "0.3.2"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/hashicorp/terraform"
|
||||
version = "0.11.2"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/joho/godotenv"
|
||||
version = "1.2.0"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "github.com/jtolds/qod"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "github.com/kr/pretty"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/lucas-clemente/quic-go"
|
||||
branch = "master"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/magefile/mage"
|
||||
version = "2.0.1"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "github.com/mtneug/pkg"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "github.com/olekukonko/tablewriter"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/pkg/errors"
|
||||
version = "0.8.0"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "github.com/streamrail/concurrent-map"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/xtaci/kcp-go"
|
||||
version = "3.23.0"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/xtaci/smux"
|
||||
version = "1.0.6"
|
||||
|
||||
[[constraint]]
|
||||
name = "go.uber.org/atomic"
|
||||
version = "1.3.1"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "golang.org/x/crypto"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "golang.org/x/net"
|
||||
|
||||
[[constraint]]
|
||||
name = "google.golang.org/grpc"
|
||||
version = "1.9.2"
|
||||
|
||||
[[constraint]]
|
||||
name = "gopkg.in/alecthomas/kingpin.v2"
|
||||
version = "2.2.6"
|
|
@ -1,4 +1,5 @@
|
|||
dist: trusty
|
||||
group: travis_latest
|
||||
|
||||
addons:
|
||||
hosts:
|
||||
|
@ -8,6 +9,7 @@ language: go
|
|||
|
||||
go:
|
||||
- 1.9.2
|
||||
- 1.10beta1
|
||||
|
||||
# first part of the GOARCH workaround
|
||||
# setting the GOARCH directly doesn't work, since the value will be overwritten later
|
||||
|
@ -30,6 +32,7 @@ before_install:
|
|||
- export GOARCH=$TRAVIS_GOARCH
|
||||
- go env # for debugging
|
||||
- google-chrome --version
|
||||
- "printf \"quic.clemente.io certificate valid until: \" && openssl x509 -in example/fullchain.pem -enddate -noout | cut -d = -f 2"
|
||||
- "export DISPLAY=:99.0"
|
||||
- "Xvfb $DISPLAY &> /dev/null &"
|
||||
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
# Changelog
|
||||
|
||||
## v0.6.1 (unreleased)
|
||||
## v0.7 (unreleased)
|
||||
|
||||
- The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored.
|
||||
- Remove `DialNonFWSecure` and `DialAddrNonFWSecure`.
|
||||
- Expose the `ConnectionState` in the `Session` (experimental API).
|
||||
|
||||
## v0.6.0 (2017-12-12)
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"github.com/golang/mock/gomock"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
|
@ -11,3 +12,13 @@ func TestCrypto(t *testing.T) {
|
|||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "AckHandler Suite")
|
||||
}
|
||||
|
||||
var mockCtrl *gomock.Controller
|
||||
|
||||
var _ = BeforeEach(func() {
|
||||
mockCtrl = gomock.NewController(GinkgoT())
|
||||
})
|
||||
|
||||
var _ = AfterEach(func() {
|
||||
mockCtrl.Finish()
|
||||
})
|
||||
|
|
|
@ -16,6 +16,7 @@ type SentPacketHandler interface {
|
|||
|
||||
SendingAllowed() bool
|
||||
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
|
||||
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
|
||||
ShouldSendRetransmittablePacket() bool
|
||||
DequeuePacketForRetransmission() (packet *Packet)
|
||||
GetLeastUnacked() protocol.PacketNumber
|
||||
|
@ -26,7 +27,7 @@ type SentPacketHandler interface {
|
|||
|
||||
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
|
||||
type ReceivedPacketHandler interface {
|
||||
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
|
||||
ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error
|
||||
IgnoreBelow(protocol.PacketNumber)
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
|
|
|
@ -15,7 +15,8 @@ type Packet struct {
|
|||
Length protocol.ByteCount
|
||||
EncryptionLevel protocol.EncryptionLevel
|
||||
|
||||
SendTime time.Time
|
||||
largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK
|
||||
sendTime time.Time
|
||||
}
|
||||
|
||||
// GetFramesForRetransmission gets all the frames for retransmission
|
||||
|
|
10
vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go
generated
vendored
10
vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go
generated
vendored
|
@ -34,10 +34,10 @@ func NewReceivedPacketHandler(version protocol.VersionNumber) ReceivedPacketHand
|
|||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error {
|
||||
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
|
||||
if packetNumber > h.largestObserved {
|
||||
h.largestObserved = packetNumber
|
||||
h.largestObservedReceivedTime = time.Now()
|
||||
h.largestObservedReceivedTime = rcvTime
|
||||
}
|
||||
|
||||
if packetNumber < h.ignoreBelow {
|
||||
|
@ -47,7 +47,7 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe
|
|||
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
|
||||
return err
|
||||
}
|
||||
h.maybeQueueAck(packetNumber, shouldInstigateAck)
|
||||
h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -58,7 +58,7 @@ func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
|
|||
h.packetHistory.DeleteBelow(p)
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) {
|
||||
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) {
|
||||
h.packetsReceivedSinceLastAck++
|
||||
|
||||
if shouldInstigateAck {
|
||||
|
@ -86,7 +86,7 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
|
|||
h.ackQueued = true
|
||||
} else {
|
||||
if h.ackAlarm.IsZero() {
|
||||
h.ackAlarm = time.Now().Add(h.ackSendDelay)
|
||||
h.ackAlarm = rcvTime.Add(h.ackSendDelay)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
84
vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler_test.go
generated
vendored
84
vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler_test.go
generated
vendored
|
@ -21,34 +21,36 @@ var _ = Describe("receivedPacketHandler", func() {
|
|||
|
||||
Context("accepting packets", func() {
|
||||
It("handles a packet that arrives late", func() {
|
||||
err := handler.ReceivedPacket(protocol.PacketNumber(1), true)
|
||||
err := handler.ReceivedPacket(protocol.PacketNumber(1), time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = handler.ReceivedPacket(protocol.PacketNumber(3), true)
|
||||
err = handler.ReceivedPacket(protocol.PacketNumber(3), time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = handler.ReceivedPacket(protocol.PacketNumber(2), true)
|
||||
err = handler.ReceivedPacket(protocol.PacketNumber(2), time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("saves the time when each packet arrived", func() {
|
||||
err := handler.ReceivedPacket(protocol.PacketNumber(3), true)
|
||||
err := handler.ReceivedPacket(protocol.PacketNumber(3), time.Now(), true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond))
|
||||
})
|
||||
|
||||
It("updates the largestObserved and the largestObservedReceivedTime", func() {
|
||||
now := time.Now()
|
||||
handler.largestObserved = 3
|
||||
handler.largestObservedReceivedTime = time.Now().Add(-1 * time.Second)
|
||||
err := handler.ReceivedPacket(5, true)
|
||||
handler.largestObservedReceivedTime = now.Add(-1 * time.Second)
|
||||
err := handler.ReceivedPacket(5, now, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5)))
|
||||
Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond))
|
||||
Expect(handler.largestObservedReceivedTime).To(Equal(now))
|
||||
})
|
||||
|
||||
It("doesn't update the largestObserved and the largestObservedReceivedTime for a belated packet", func() {
|
||||
timestamp := time.Now().Add(-1 * time.Second)
|
||||
now := time.Now()
|
||||
timestamp := now.Add(-1 * time.Second)
|
||||
handler.largestObserved = 5
|
||||
handler.largestObservedReceivedTime = timestamp
|
||||
err := handler.ReceivedPacket(4, true)
|
||||
err := handler.ReceivedPacket(4, now, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5)))
|
||||
Expect(handler.largestObservedReceivedTime).To(Equal(timestamp))
|
||||
|
@ -57,7 +59,7 @@ var _ = Describe("receivedPacketHandler", func() {
|
|||
It("passes on errors from receivedPacketHistory", func() {
|
||||
var err error
|
||||
for i := protocol.PacketNumber(0); i < 5*protocol.MaxTrackedReceivedAckRanges; i++ {
|
||||
err = handler.ReceivedPacket(2*i+1, true)
|
||||
err = handler.ReceivedPacket(2*i+1, time.Time{}, true)
|
||||
// this will eventually return an error
|
||||
// details about when exactly the receivedPacketHistory errors are tested there
|
||||
if err != nil {
|
||||
|
@ -72,7 +74,7 @@ var _ = Describe("receivedPacketHandler", func() {
|
|||
Context("queueing ACKs", func() {
|
||||
receiveAndAck10Packets := func() {
|
||||
for i := 1; i <= 10; i++ {
|
||||
err := handler.ReceivedPacket(protocol.PacketNumber(i), true)
|
||||
err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
Expect(handler.GetAckFrame()).ToNot(BeNil())
|
||||
|
@ -80,14 +82,14 @@ var _ = Describe("receivedPacketHandler", func() {
|
|||
}
|
||||
|
||||
It("always queues an ACK for the first packet", func() {
|
||||
err := handler.ReceivedPacket(1, false)
|
||||
err := handler.ReceivedPacket(1, time.Time{}, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.ackQueued).To(BeTrue())
|
||||
Expect(handler.GetAlarmTimeout()).To(BeZero())
|
||||
})
|
||||
|
||||
It("works with packet number 0", func() {
|
||||
err := handler.ReceivedPacket(0, false)
|
||||
err := handler.ReceivedPacket(0, time.Time{}, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.ackQueued).To(BeTrue())
|
||||
Expect(handler.GetAlarmTimeout()).To(BeZero())
|
||||
|
@ -95,11 +97,11 @@ var _ = Describe("receivedPacketHandler", func() {
|
|||
|
||||
It("queues an ACK for every second retransmittable packet, if they are arriving fast", func() {
|
||||
receiveAndAck10Packets()
|
||||
err := handler.ReceivedPacket(11, true)
|
||||
err := handler.ReceivedPacket(11, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.ackQueued).To(BeFalse())
|
||||
Expect(handler.GetAlarmTimeout()).NotTo(BeZero())
|
||||
err = handler.ReceivedPacket(12, true)
|
||||
err = handler.ReceivedPacket(12, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.ackQueued).To(BeTrue())
|
||||
Expect(handler.GetAlarmTimeout()).To(BeZero())
|
||||
|
@ -107,11 +109,11 @@ var _ = Describe("receivedPacketHandler", func() {
|
|||
|
||||
It("only sets the timer when receiving a retransmittable packets", func() {
|
||||
receiveAndAck10Packets()
|
||||
err := handler.ReceivedPacket(11, false)
|
||||
err := handler.ReceivedPacket(11, time.Time{}, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.ackQueued).To(BeFalse())
|
||||
Expect(handler.ackAlarm).To(BeZero())
|
||||
err = handler.ReceivedPacket(12, true)
|
||||
err = handler.ReceivedPacket(12, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.ackQueued).To(BeFalse())
|
||||
Expect(handler.ackAlarm).ToNot(BeZero())
|
||||
|
@ -120,15 +122,15 @@ var _ = Describe("receivedPacketHandler", func() {
|
|||
|
||||
It("queues an ACK if it was reported missing before", func() {
|
||||
receiveAndAck10Packets()
|
||||
err := handler.ReceivedPacket(11, true)
|
||||
err := handler.ReceivedPacket(11, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = handler.ReceivedPacket(13, true)
|
||||
err = handler.ReceivedPacket(13, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ack := handler.GetAckFrame() // ACK: 1 and 3, missing: 2
|
||||
Expect(ack).ToNot(BeNil())
|
||||
Expect(ack.HasMissingRanges()).To(BeTrue())
|
||||
Expect(handler.ackQueued).To(BeFalse())
|
||||
err = handler.ReceivedPacket(12, false)
|
||||
err = handler.ReceivedPacket(12, time.Time{}, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.ackQueued).To(BeTrue())
|
||||
})
|
||||
|
@ -136,10 +138,10 @@ var _ = Describe("receivedPacketHandler", func() {
|
|||
It("queues an ACK if it creates a new missing range", func() {
|
||||
receiveAndAck10Packets()
|
||||
for i := 11; i < 16; i++ {
|
||||
err := handler.ReceivedPacket(protocol.PacketNumber(i), true)
|
||||
err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
err := handler.ReceivedPacket(20, true) // we now know that packets 16 to 19 are missing
|
||||
err := handler.ReceivedPacket(20, time.Time{}, true) // we now know that packets 16 to 19 are missing
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.ackQueued).To(BeTrue())
|
||||
ack := handler.GetAckFrame()
|
||||
|
@ -154,9 +156,9 @@ var _ = Describe("receivedPacketHandler", func() {
|
|||
})
|
||||
|
||||
It("generates a simple ACK frame", func() {
|
||||
err := handler.ReceivedPacket(1, true)
|
||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = handler.ReceivedPacket(2, true)
|
||||
err = handler.ReceivedPacket(2, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ack := handler.GetAckFrame()
|
||||
Expect(ack).ToNot(BeNil())
|
||||
|
@ -166,7 +168,7 @@ var _ = Describe("receivedPacketHandler", func() {
|
|||
})
|
||||
|
||||
It("generates an ACK for packet number 0", func() {
|
||||
err := handler.ReceivedPacket(0, true)
|
||||
err := handler.ReceivedPacket(0, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ack := handler.GetAckFrame()
|
||||
Expect(ack).ToNot(BeNil())
|
||||
|
@ -176,12 +178,12 @@ var _ = Describe("receivedPacketHandler", func() {
|
|||
})
|
||||
|
||||
It("saves the last sent ACK", func() {
|
||||
err := handler.ReceivedPacket(1, true)
|
||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ack := handler.GetAckFrame()
|
||||
Expect(ack).ToNot(BeNil())
|
||||
Expect(handler.lastAck).To(Equal(ack))
|
||||
err = handler.ReceivedPacket(2, true)
|
||||
err = handler.ReceivedPacket(2, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
handler.ackQueued = true
|
||||
ack = handler.GetAckFrame()
|
||||
|
@ -190,9 +192,9 @@ var _ = Describe("receivedPacketHandler", func() {
|
|||
})
|
||||
|
||||
It("generates an ACK frame with missing packets", func() {
|
||||
err := handler.ReceivedPacket(1, true)
|
||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = handler.ReceivedPacket(4, true)
|
||||
err = handler.ReceivedPacket(4, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ack := handler.GetAckFrame()
|
||||
Expect(ack).ToNot(BeNil())
|
||||
|
@ -205,11 +207,11 @@ var _ = Describe("receivedPacketHandler", func() {
|
|||
})
|
||||
|
||||
It("generates an ACK for packet number 0 and other packets", func() {
|
||||
err := handler.ReceivedPacket(0, true)
|
||||
err := handler.ReceivedPacket(0, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = handler.ReceivedPacket(1, true)
|
||||
err = handler.ReceivedPacket(1, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = handler.ReceivedPacket(3, true)
|
||||
err = handler.ReceivedPacket(3, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ack := handler.GetAckFrame()
|
||||
Expect(ack).ToNot(BeNil())
|
||||
|
@ -223,15 +225,15 @@ var _ = Describe("receivedPacketHandler", func() {
|
|||
|
||||
It("accepts packets below the lower limit", func() {
|
||||
handler.IgnoreBelow(6)
|
||||
err := handler.ReceivedPacket(2, true)
|
||||
err := handler.ReceivedPacket(2, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("doesn't add delayed packets to the packetHistory", func() {
|
||||
handler.IgnoreBelow(7)
|
||||
err := handler.ReceivedPacket(4, true)
|
||||
err := handler.ReceivedPacket(4, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = handler.ReceivedPacket(10, true)
|
||||
err = handler.ReceivedPacket(10, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ack := handler.GetAckFrame()
|
||||
Expect(ack).ToNot(BeNil())
|
||||
|
@ -241,7 +243,7 @@ var _ = Describe("receivedPacketHandler", func() {
|
|||
|
||||
It("deletes packets from the packetHistory when a lower limit is set", func() {
|
||||
for i := 1; i <= 12; i++ {
|
||||
err := handler.ReceivedPacket(protocol.PacketNumber(i), true)
|
||||
err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
handler.IgnoreBelow(7)
|
||||
|
@ -256,7 +258,7 @@ var _ = Describe("receivedPacketHandler", func() {
|
|||
// TODO: remove this test when dropping support for STOP_WAITINGs
|
||||
It("handles a lower limit of 0", func() {
|
||||
handler.IgnoreBelow(0)
|
||||
err := handler.ReceivedPacket(1337, true)
|
||||
err := handler.ReceivedPacket(1337, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ack := handler.GetAckFrame()
|
||||
Expect(ack).ToNot(BeNil())
|
||||
|
@ -264,7 +266,7 @@ var _ = Describe("receivedPacketHandler", func() {
|
|||
})
|
||||
|
||||
It("resets all counters needed for the ACK queueing decision when sending an ACK", func() {
|
||||
err := handler.ReceivedPacket(1, true)
|
||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
handler.ackAlarm = time.Now().Add(-time.Minute)
|
||||
Expect(handler.GetAckFrame()).ToNot(BeNil())
|
||||
|
@ -275,7 +277,7 @@ var _ = Describe("receivedPacketHandler", func() {
|
|||
})
|
||||
|
||||
It("doesn't generate an ACK when none is queued and the timer is not set", func() {
|
||||
err := handler.ReceivedPacket(1, true)
|
||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
handler.ackQueued = false
|
||||
handler.ackAlarm = time.Time{}
|
||||
|
@ -283,7 +285,7 @@ var _ = Describe("receivedPacketHandler", func() {
|
|||
})
|
||||
|
||||
It("doesn't generate an ACK when none is queued and the timer has not yet expired", func() {
|
||||
err := handler.ReceivedPacket(1, true)
|
||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
handler.ackQueued = false
|
||||
handler.ackAlarm = time.Now().Add(time.Minute)
|
||||
|
@ -291,7 +293,7 @@ var _ = Describe("receivedPacketHandler", func() {
|
|||
})
|
||||
|
||||
It("generates an ACK when the timer has expired", func() {
|
||||
err := handler.ReceivedPacket(1, true)
|
||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
handler.ackQueued = false
|
||||
handler.ackAlarm = time.Now().Add(-time.Minute)
|
||||
|
|
|
@ -40,6 +40,10 @@ type sentPacketHandler struct {
|
|||
|
||||
largestAcked protocol.PacketNumber
|
||||
largestReceivedPacketWithAck protocol.PacketNumber
|
||||
// lowestPacketNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
|
||||
// example: we send an ACK for packets 90-100 with packet number 20
|
||||
// once we receive an ACK from the peer for packet 20, the lowestPacketNotConfirmedAcked is 101
|
||||
lowestPacketNotConfirmedAcked protocol.PacketNumber
|
||||
|
||||
packetHistory *PacketList
|
||||
stopWaitingManager stopWaitingManager
|
||||
|
@ -95,6 +99,13 @@ func (h *sentPacketHandler) ShouldSendRetransmittablePacket() bool {
|
|||
}
|
||||
|
||||
func (h *sentPacketHandler) SetHandshakeComplete() {
|
||||
var queue []*Packet
|
||||
for _, packet := range h.retransmissionQueue {
|
||||
if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
|
||||
queue = append(queue, packet)
|
||||
}
|
||||
}
|
||||
h.retransmissionQueue = queue
|
||||
h.handshakeComplete = true
|
||||
}
|
||||
|
||||
|
@ -114,11 +125,19 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
|
|||
h.lastSentPacketNumber = packet.PacketNumber
|
||||
now := time.Now()
|
||||
|
||||
var largestAcked protocol.PacketNumber
|
||||
if len(packet.Frames) > 0 {
|
||||
if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok {
|
||||
largestAcked = ackFrame.LargestAcked
|
||||
}
|
||||
}
|
||||
|
||||
packet.Frames = stripNonRetransmittableFrames(packet.Frames)
|
||||
isRetransmittable := len(packet.Frames) != 0
|
||||
|
||||
if isRetransmittable {
|
||||
packet.SendTime = now
|
||||
packet.sendTime = now
|
||||
packet.largestAcked = largestAcked
|
||||
h.bytesInFlight += packet.Length
|
||||
h.packetHistory.PushBack(*packet)
|
||||
h.numNonRetransmittablePackets = 0
|
||||
|
@ -134,7 +153,7 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
|
|||
isRetransmittable,
|
||||
)
|
||||
|
||||
h.updateLossDetectionAlarm()
|
||||
h.updateLossDetectionAlarm(now)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -146,14 +165,12 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
|
|||
// duplicate or out-of-order ACK
|
||||
// if withPacketNumber <= h.largestReceivedPacketWithAck && withPacketNumber != 0 {
|
||||
if withPacketNumber <= h.largestReceivedPacketWithAck {
|
||||
utils.Debugf("ignoring ack because duplicate")
|
||||
return ErrDuplicateOrOutOfOrderAck
|
||||
}
|
||||
h.largestReceivedPacketWithAck = withPacketNumber
|
||||
|
||||
// ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK)
|
||||
if ackFrame.LargestAcked < h.lowestUnacked() {
|
||||
utils.Debugf("ignoring ack because repeated")
|
||||
return nil
|
||||
}
|
||||
h.largestAcked = ackFrame.LargestAcked
|
||||
|
@ -178,13 +195,19 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
|
|||
if encLevel < p.Value.EncryptionLevel {
|
||||
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.Value.PacketNumber, p.Value.EncryptionLevel)
|
||||
}
|
||||
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
|
||||
// It is safe to ignore the corner case of packets that just acked packet 0, because
|
||||
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
|
||||
if p.Value.largestAcked != 0 {
|
||||
h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.Value.largestAcked+1)
|
||||
}
|
||||
h.onPacketAcked(p)
|
||||
h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
|
||||
}
|
||||
}
|
||||
|
||||
h.detectLostPackets()
|
||||
h.updateLossDetectionAlarm()
|
||||
h.detectLostPackets(rcvTime)
|
||||
h.updateLossDetectionAlarm(rcvTime)
|
||||
|
||||
h.garbageCollectSkippedPackets()
|
||||
h.stopWaitingManager.ReceivedAck(ackFrame)
|
||||
|
@ -192,6 +215,10 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
|
|||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber {
|
||||
return h.lowestPacketNotConfirmedAcked
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*PacketElement, error) {
|
||||
var ackedPackets []*PacketElement
|
||||
ackRangeIndex := 0
|
||||
|
@ -233,7 +260,7 @@ func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, a
|
|||
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
||||
packet := el.Value
|
||||
if packet.PacketNumber == largestAcked {
|
||||
h.rttStats.UpdateRTT(rcvTime.Sub(packet.SendTime), ackDelay, time.Now())
|
||||
h.rttStats.UpdateRTT(rcvTime.Sub(packet.sendTime), ackDelay, rcvTime)
|
||||
return true
|
||||
}
|
||||
// Packets are sorted by number, so we can stop searching
|
||||
|
@ -244,7 +271,7 @@ func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, a
|
|||
return false
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) updateLossDetectionAlarm() {
|
||||
func (h *sentPacketHandler) updateLossDetectionAlarm(now time.Time) {
|
||||
// Cancel the alarm if no packets are outstanding
|
||||
if h.packetHistory.Len() == 0 {
|
||||
h.alarm = time.Time{}
|
||||
|
@ -253,19 +280,18 @@ func (h *sentPacketHandler) updateLossDetectionAlarm() {
|
|||
|
||||
// TODO(#497): TLP
|
||||
if !h.handshakeComplete {
|
||||
h.alarm = time.Now().Add(h.computeHandshakeTimeout())
|
||||
h.alarm = now.Add(h.computeHandshakeTimeout())
|
||||
} else if !h.lossTime.IsZero() {
|
||||
// Early retransmit timer or time loss detection.
|
||||
h.alarm = h.lossTime
|
||||
} else {
|
||||
// RTO
|
||||
h.alarm = time.Now().Add(h.computeRTOTimeout())
|
||||
h.alarm = now.Add(h.computeRTOTimeout())
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) detectLostPackets() {
|
||||
func (h *sentPacketHandler) detectLostPackets(now time.Time) {
|
||||
h.lossTime = time.Time{}
|
||||
now := time.Now()
|
||||
|
||||
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
|
||||
delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
|
||||
|
@ -278,7 +304,7 @@ func (h *sentPacketHandler) detectLostPackets() {
|
|||
break
|
||||
}
|
||||
|
||||
timeSinceSent := now.Sub(packet.SendTime)
|
||||
timeSinceSent := now.Sub(packet.sendTime)
|
||||
if timeSinceSent > delayUntilLost {
|
||||
lostPackets = append(lostPackets, el)
|
||||
} else if h.lossTime.IsZero() {
|
||||
|
@ -296,20 +322,22 @@ func (h *sentPacketHandler) detectLostPackets() {
|
|||
}
|
||||
|
||||
func (h *sentPacketHandler) OnAlarm() {
|
||||
now := time.Now()
|
||||
|
||||
// TODO(#497): TLP
|
||||
if !h.handshakeComplete {
|
||||
h.queueHandshakePacketsForRetransmission()
|
||||
h.handshakeCount++
|
||||
} else if !h.lossTime.IsZero() {
|
||||
// Early retransmit or time loss detection
|
||||
h.detectLostPackets()
|
||||
h.detectLostPackets(now)
|
||||
} else {
|
||||
// RTO
|
||||
h.retransmitOldestTwoPackets()
|
||||
h.rtoCount++
|
||||
}
|
||||
|
||||
h.updateLossDetectionAlarm()
|
||||
h.updateLossDetectionAlarm(now)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
|
||||
|
@ -345,12 +373,11 @@ func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFra
|
|||
}
|
||||
|
||||
func (h *sentPacketHandler) SendingAllowed() bool {
|
||||
congestionLimited := h.bytesInFlight > h.congestion.GetCongestionWindow()
|
||||
cwnd := h.congestion.GetCongestionWindow()
|
||||
congestionLimited := h.bytesInFlight > cwnd
|
||||
maxTrackedLimited := protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()) >= protocol.MaxTrackedSentPackets
|
||||
if congestionLimited {
|
||||
utils.Debugf("Congestion limited: bytes in flight %d, window %d",
|
||||
h.bytesInFlight,
|
||||
h.congestion.GetCongestionWindow())
|
||||
utils.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
|
||||
}
|
||||
// Workaround for #555:
|
||||
// Always allow sending of retransmissions. This should probably be limited
|
||||
|
|
213
vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler_test.go
generated
vendored
213
vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler_test.go
generated
vendored
|
@ -3,60 +3,15 @@ package ackhandler
|
|||
import (
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/mocks"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type mockCongestion struct {
|
||||
argsOnPacketSent []interface{}
|
||||
maybeExitSlowStart bool
|
||||
onRetransmissionTimeout bool
|
||||
getCongestionWindow bool
|
||||
packetsAcked [][]interface{}
|
||||
packetsLost [][]interface{}
|
||||
}
|
||||
|
||||
func (m *mockCongestion) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *mockCongestion) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool {
|
||||
m.argsOnPacketSent = []interface{}{sentTime, bytesInFlight, packetNumber, bytes, isRetransmittable}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *mockCongestion) GetCongestionWindow() protocol.ByteCount {
|
||||
m.getCongestionWindow = true
|
||||
return protocol.DefaultTCPMSS
|
||||
}
|
||||
|
||||
func (m *mockCongestion) MaybeExitSlowStart() {
|
||||
m.maybeExitSlowStart = true
|
||||
}
|
||||
|
||||
func (m *mockCongestion) OnRetransmissionTimeout(packetsRetransmitted bool) {
|
||||
m.onRetransmissionTimeout = true
|
||||
}
|
||||
|
||||
func (m *mockCongestion) RetransmissionDelay() time.Duration {
|
||||
return defaultRTOTimeout
|
||||
}
|
||||
|
||||
func (m *mockCongestion) SetNumEmulatedConnections(n int) { panic("not implemented") }
|
||||
func (m *mockCongestion) OnConnectionMigration() { panic("not implemented") }
|
||||
func (m *mockCongestion) SetSlowStartLargeReduction(enabled bool) { panic("not implemented") }
|
||||
|
||||
func (m *mockCongestion) OnPacketAcked(n protocol.PacketNumber, l protocol.ByteCount, bif protocol.ByteCount) {
|
||||
m.packetsAcked = append(m.packetsAcked, []interface{}{n, l, bif})
|
||||
}
|
||||
|
||||
func (m *mockCongestion) OnPacketLost(n protocol.PacketNumber, l protocol.ByteCount, bif protocol.ByteCount) {
|
||||
m.packetsLost = append(m.packetsLost, []interface{}{n, l, bif})
|
||||
}
|
||||
|
||||
func retransmittablePacket(num protocol.PacketNumber) *Packet {
|
||||
return &Packet{
|
||||
PacketNumber: num,
|
||||
|
@ -143,7 +98,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
packet := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1}
|
||||
err := handler.SentPacket(&packet)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.packetHistory.Front().Value.SendTime.Unix()).To(BeNumerically("~", time.Now().Unix(), 1))
|
||||
Expect(handler.packetHistory.Front().Value.sendTime.Unix()).To(BeNumerically("~", time.Now().Unix(), 1))
|
||||
})
|
||||
|
||||
It("does not store non-retransmittable packets", func() {
|
||||
|
@ -553,9 +508,9 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
It("computes the RTT", func() {
|
||||
now := time.Now()
|
||||
// First, fake the sent times of the first, second and last packet
|
||||
getPacketElement(1).Value.SendTime = now.Add(-10 * time.Minute)
|
||||
getPacketElement(2).Value.SendTime = now.Add(-5 * time.Minute)
|
||||
getPacketElement(6).Value.SendTime = now.Add(-1 * time.Minute)
|
||||
getPacketElement(1).Value.sendTime = now.Add(-10 * time.Minute)
|
||||
getPacketElement(2).Value.sendTime = now.Add(-5 * time.Minute)
|
||||
getPacketElement(6).Value.sendTime = now.Add(-1 * time.Minute)
|
||||
// Now, check that the proper times are used when calculating the deltas
|
||||
err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1}, 1, protocol.EncryptionUnencrypted, time.Now())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
@ -570,12 +525,50 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
|
||||
It("uses the DelayTime in the ack frame", func() {
|
||||
now := time.Now()
|
||||
getPacketElement(1).Value.SendTime = now.Add(-10 * time.Minute)
|
||||
getPacketElement(1).Value.sendTime = now.Add(-10 * time.Minute)
|
||||
err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, DelayTime: 5 * time.Minute}, 1, protocol.EncryptionUnencrypted, time.Now())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second))
|
||||
})
|
||||
})
|
||||
|
||||
Context("determinining, which ACKs we have received an ACK for", func() {
|
||||
BeforeEach(func() {
|
||||
morePackets := []*Packet{
|
||||
&Packet{PacketNumber: 13, Frames: []wire.Frame{&wire.AckFrame{LowestAcked: 80, LargestAcked: 100}, &streamFrame}, Length: 1},
|
||||
&Packet{PacketNumber: 14, Frames: []wire.Frame{&wire.AckFrame{LowestAcked: 50, LargestAcked: 200}, &streamFrame}, Length: 1},
|
||||
&Packet{PacketNumber: 15, Frames: []wire.Frame{&streamFrame}, Length: 1},
|
||||
}
|
||||
for _, packet := range morePackets {
|
||||
err := handler.SentPacket(packet)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
})
|
||||
|
||||
It("determines which ACK we have received an ACK for", func() {
|
||||
err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 15, LowestAcked: 12}, 1, protocol.EncryptionUnencrypted, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201)))
|
||||
})
|
||||
|
||||
It("doesn't do anything when the acked packet didn't contain an ACK", func() {
|
||||
err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 13, LowestAcked: 13}, 1, protocol.EncryptionUnencrypted, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101)))
|
||||
err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 15, LowestAcked: 15}, 2, protocol.EncryptionUnencrypted, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101)))
|
||||
})
|
||||
|
||||
It("doesn't decrease the value", func() {
|
||||
err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 14, LowestAcked: 14}, 1, protocol.EncryptionUnencrypted, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201)))
|
||||
err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 13, LowestAcked: 13}, 2, protocol.EncryptionUnencrypted, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201)))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("Retransmission handling", func() {
|
||||
|
@ -583,13 +576,13 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
|
||||
BeforeEach(func() {
|
||||
packets = []*Packet{
|
||||
{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1},
|
||||
{PacketNumber: 2, Frames: []wire.Frame{&streamFrame}, Length: 1},
|
||||
{PacketNumber: 3, Frames: []wire.Frame{&streamFrame}, Length: 1},
|
||||
{PacketNumber: 4, Frames: []wire.Frame{&streamFrame}, Length: 1},
|
||||
{PacketNumber: 5, Frames: []wire.Frame{&streamFrame}, Length: 1},
|
||||
{PacketNumber: 6, Frames: []wire.Frame{&streamFrame}, Length: 1},
|
||||
{PacketNumber: 7, Frames: []wire.Frame{&streamFrame}, Length: 1},
|
||||
{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionUnencrypted},
|
||||
{PacketNumber: 2, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionUnencrypted},
|
||||
{PacketNumber: 3, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionUnencrypted},
|
||||
{PacketNumber: 4, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionSecure},
|
||||
{PacketNumber: 5, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionSecure},
|
||||
{PacketNumber: 6, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionForwardSecure},
|
||||
{PacketNumber: 7, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionForwardSecure},
|
||||
}
|
||||
for _, packet := range packets {
|
||||
handler.SentPacket(packet)
|
||||
|
@ -597,7 +590,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
// Increase RTT, because the tests would be flaky otherwise
|
||||
handler.rttStats.UpdateRTT(time.Minute, 0, time.Now())
|
||||
// Ack a single packet so that we have non-RTO timings
|
||||
handler.ReceivedAck(&wire.AckFrame{LargestAcked: 2, LowestAcked: 2}, 1, protocol.EncryptionUnencrypted, time.Now())
|
||||
handler.ReceivedAck(&wire.AckFrame{LargestAcked: 2, LowestAcked: 2}, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6)))
|
||||
})
|
||||
|
||||
|
@ -606,7 +599,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
})
|
||||
|
||||
It("dequeues a packet for retransmission", func() {
|
||||
getPacketElement(1).Value.SendTime = time.Now().Add(-time.Hour)
|
||||
getPacketElement(1).Value.sendTime = time.Now().Add(-time.Hour)
|
||||
handler.OnAlarm()
|
||||
Expect(getPacketElement(1)).To(BeNil())
|
||||
Expect(handler.retransmissionQueue).To(HaveLen(1))
|
||||
|
@ -617,15 +610,33 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
Expect(handler.DequeuePacketForRetransmission()).To(BeNil())
|
||||
})
|
||||
|
||||
Context("StopWaitings", func() {
|
||||
It("gets a StopWaitingFrame", func() {
|
||||
It("deletes non forward-secure packets when the handshake completes", func() {
|
||||
for i := protocol.PacketNumber(1); i <= 7; i++ {
|
||||
if i == 2 { // packet 2 was already acked in BeforeEach
|
||||
continue
|
||||
}
|
||||
handler.queuePacketForRetransmission(getPacketElement(i))
|
||||
}
|
||||
Expect(handler.retransmissionQueue).To(HaveLen(6))
|
||||
handler.SetHandshakeComplete()
|
||||
packet := handler.DequeuePacketForRetransmission()
|
||||
Expect(packet).ToNot(BeNil())
|
||||
Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(6)))
|
||||
packet = handler.DequeuePacketForRetransmission()
|
||||
Expect(packet).ToNot(BeNil())
|
||||
Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(7)))
|
||||
Expect(handler.DequeuePacketForRetransmission()).To(BeNil())
|
||||
})
|
||||
|
||||
Context("STOP_WAITINGs", func() {
|
||||
It("gets a STOP_WAITING frame", func() {
|
||||
ack := wire.AckFrame{LargestAcked: 5, LowestAcked: 5}
|
||||
err := handler.ReceivedAck(&ack, 2, protocol.EncryptionUnencrypted, time.Now())
|
||||
err := handler.ReceivedAck(&ack, 2, protocol.EncryptionForwardSecure, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 6}))
|
||||
})
|
||||
|
||||
It("gets a StopWaitingFrame after queueing a retransmission", func() {
|
||||
It("gets a STOP_WAITING frame after queueing a retransmission", func() {
|
||||
handler.queuePacketForRetransmission(getPacketElement(5))
|
||||
Expect(handler.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 6}))
|
||||
})
|
||||
|
@ -662,7 +673,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2)))
|
||||
|
||||
handler.packetHistory.Front().Value.SendTime = time.Now().Add(-time.Hour)
|
||||
handler.packetHistory.Front().Value.sendTime = time.Now().Add(-time.Hour)
|
||||
handler.OnAlarm()
|
||||
|
||||
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(0)))
|
||||
|
@ -670,15 +681,23 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
|
||||
Context("congestion", func() {
|
||||
var (
|
||||
cong *mockCongestion
|
||||
cong *mocks.MockSendAlgorithm
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
cong = &mockCongestion{}
|
||||
cong = mocks.NewMockSendAlgorithm(mockCtrl)
|
||||
cong.EXPECT().RetransmissionDelay().AnyTimes()
|
||||
handler.congestion = cong
|
||||
})
|
||||
|
||||
It("should call OnSent", func() {
|
||||
cong.EXPECT().OnPacketSent(
|
||||
gomock.Any(),
|
||||
protocol.ByteCount(42),
|
||||
protocol.PacketNumber(1),
|
||||
protocol.ByteCount(42),
|
||||
true,
|
||||
)
|
||||
p := &Packet{
|
||||
PacketNumber: 1,
|
||||
Length: 42,
|
||||
|
@ -686,62 +705,60 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
}
|
||||
err := handler.SentPacket(p)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(cong.argsOnPacketSent[1]).To(Equal(protocol.ByteCount(42)))
|
||||
Expect(cong.argsOnPacketSent[2]).To(Equal(protocol.PacketNumber(1)))
|
||||
Expect(cong.argsOnPacketSent[3]).To(Equal(protocol.ByteCount(42)))
|
||||
Expect(cong.argsOnPacketSent[4]).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should call MaybeExitSlowStart and OnPacketAcked", func() {
|
||||
cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
|
||||
cong.EXPECT().MaybeExitSlowStart()
|
||||
cong.EXPECT().OnPacketAcked(
|
||||
protocol.PacketNumber(1),
|
||||
protocol.ByteCount(1),
|
||||
protocol.ByteCount(1),
|
||||
)
|
||||
handler.SentPacket(retransmittablePacket(1))
|
||||
handler.SentPacket(retransmittablePacket(2))
|
||||
err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(cong.maybeExitSlowStart).To(BeTrue())
|
||||
Expect(cong.packetsAcked).To(BeEquivalentTo([][]interface{}{
|
||||
{protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(1)},
|
||||
}))
|
||||
Expect(cong.packetsLost).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should call MaybeExitSlowStart and OnPacketLost", func() {
|
||||
cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3)
|
||||
cong.EXPECT().OnRetransmissionTimeout(true).Times(2)
|
||||
cong.EXPECT().OnPacketLost(
|
||||
protocol.PacketNumber(1),
|
||||
protocol.ByteCount(1),
|
||||
protocol.ByteCount(2),
|
||||
)
|
||||
cong.EXPECT().OnPacketLost(
|
||||
protocol.PacketNumber(2),
|
||||
protocol.ByteCount(1),
|
||||
protocol.ByteCount(1),
|
||||
)
|
||||
handler.SentPacket(retransmittablePacket(1))
|
||||
handler.SentPacket(retransmittablePacket(2))
|
||||
handler.SentPacket(retransmittablePacket(3))
|
||||
handler.OnAlarm() // RTO, meaning 2 lost packets
|
||||
Expect(cong.maybeExitSlowStart).To(BeFalse())
|
||||
Expect(cong.onRetransmissionTimeout).To(BeTrue())
|
||||
Expect(cong.packetsAcked).To(BeEmpty())
|
||||
Expect(cong.packetsLost).To(BeEquivalentTo([][]interface{}{
|
||||
{protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)},
|
||||
{protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(1)},
|
||||
}))
|
||||
})
|
||||
|
||||
It("allows or denies sending based on congestion", func() {
|
||||
Expect(handler.retransmissionQueue).To(BeEmpty())
|
||||
handler.bytesInFlight = 100
|
||||
cong.EXPECT().GetCongestionWindow().Return(protocol.MaxByteCount)
|
||||
Expect(handler.SendingAllowed()).To(BeTrue())
|
||||
err := handler.SentPacket(&Packet{
|
||||
PacketNumber: 1,
|
||||
Frames: []wire.Frame{&wire.PingFrame{}},
|
||||
Length: protocol.DefaultTCPMSS + 1,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
cong.EXPECT().GetCongestionWindow().Return(protocol.ByteCount(0))
|
||||
Expect(handler.SendingAllowed()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("allows or denies sending based on the number of tracked packets", func() {
|
||||
cong.EXPECT().GetCongestionWindow().Return(protocol.MaxByteCount).AnyTimes()
|
||||
Expect(handler.SendingAllowed()).To(BeTrue())
|
||||
handler.retransmissionQueue = make([]*Packet, protocol.MaxTrackedSentPackets)
|
||||
Expect(handler.SendingAllowed()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("allows sending if there are retransmisisons outstanding", func() {
|
||||
err := handler.SentPacket(&Packet{
|
||||
PacketNumber: 1,
|
||||
Frames: []wire.Frame{&wire.PingFrame{}},
|
||||
Length: protocol.DefaultTCPMSS + 1,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
handler.bytesInFlight = 100
|
||||
cong.EXPECT().GetCongestionWindow().Return(protocol.ByteCount(0)).AnyTimes()
|
||||
Expect(handler.SendingAllowed()).To(BeFalse())
|
||||
handler.retransmissionQueue = []*Packet{nil}
|
||||
Expect(handler.SendingAllowed()).To(BeTrue())
|
||||
|
@ -799,7 +816,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
Expect(handler.lossTime.Sub(time.Now())).To(BeNumerically("~", time.Hour*9/8, time.Minute))
|
||||
Expect(handler.GetAlarmTimeout().Sub(time.Now())).To(BeNumerically("~", time.Hour*9/8, time.Minute))
|
||||
|
||||
handler.packetHistory.Front().Value.SendTime = time.Now().Add(-2 * time.Hour)
|
||||
handler.packetHistory.Front().Value.sendTime = time.Now().Add(-2 * time.Hour)
|
||||
handler.OnAlarm()
|
||||
Expect(handler.DequeuePacketForRetransmission()).NotTo(BeNil())
|
||||
})
|
||||
|
@ -843,7 +860,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
err = handler.SentPacket(handshakePacket(4))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionSecure, time.Now().Add(time.Hour))
|
||||
err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionSecure, time.Now())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(handler.lossTime.IsZero()).To(BeTrue())
|
||||
handshakeTimeout := handler.computeHandshakeTimeout()
|
||||
|
|
|
@ -60,33 +60,15 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error)
|
|||
return Dial(udpConn, udpAddr, addr, tlsConf, config)
|
||||
}
|
||||
|
||||
// DialAddrNonFWSecure establishes a new QUIC connection to a server.
|
||||
// The hostname for SNI is taken from the given address.
|
||||
func DialAddrNonFWSecure(
|
||||
addr string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (NonFWSession, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
|
||||
}
|
||||
|
||||
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
// The host parameter is used for SNI.
|
||||
func DialNonFWSecure(
|
||||
func Dial(
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (NonFWSession, error) {
|
||||
) (Session, error) {
|
||||
connID, err := generateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -115,31 +97,11 @@ func DialNonFWSecure(
|
|||
}
|
||||
|
||||
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
||||
go c.listen()
|
||||
|
||||
if err := c.dial(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.session.(NonFWSession), nil
|
||||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
// The host parameter is used for SNI.
|
||||
func Dial(
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := sess.WaitUntilHandshakeComplete(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sess, nil
|
||||
return c.session, nil
|
||||
}
|
||||
|
||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||
|
@ -199,6 +161,7 @@ func (c *client) dialGQUIC() error {
|
|||
if err := c.createNewGQUICSession(); err != nil {
|
||||
return err
|
||||
}
|
||||
go c.listen()
|
||||
return c.establishSecureConnection()
|
||||
}
|
||||
|
||||
|
@ -224,6 +187,7 @@ func (c *client) dialTLS() error {
|
|||
if err := c.createNewTLSSession(eh.GetPeerParams(), c.version); err != nil {
|
||||
return err
|
||||
}
|
||||
go c.listen()
|
||||
if err := c.establishSecureConnection(); err != nil {
|
||||
if err != handshake.ErrCloseSessionForRetry {
|
||||
return err
|
||||
|
@ -267,14 +231,8 @@ func (c *client) establishSecureConnection() error {
|
|||
select {
|
||||
case <-errorChan:
|
||||
return runErr
|
||||
case ev := <-c.session.handshakeStatus():
|
||||
if ev.err != nil {
|
||||
return ev.err
|
||||
}
|
||||
if !c.version.UsesTLS() && ev.encLevel != protocol.EncryptionSecure {
|
||||
return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel)
|
||||
}
|
||||
return nil
|
||||
case err := <-c.session.handshakeStatus():
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
|
@ -100,57 +101,7 @@ var _ = Describe("Client", func() {
|
|||
generateConnectionID = origGenerateConnectionID
|
||||
})
|
||||
|
||||
It("dials non-forward-secure", func() {
|
||||
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
|
||||
dialed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
s, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(s).ToNot(BeNil())
|
||||
close(dialed)
|
||||
}()
|
||||
Consistently(dialed).ShouldNot(BeClosed())
|
||||
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
|
||||
Eventually(dialed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("dials a non-forward-secure address", func() {
|
||||
serverAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
server, err := net.ListenUDP("udp", serverAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
for {
|
||||
_, clientAddr, err := server.ReadFromUDP(make([]byte, 200))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = server.WriteToUDP(acceptClientVersionPacket(cl.connectionID), clientAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
}()
|
||||
|
||||
dialed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
s, err := DialAddrNonFWSecure(server.LocalAddr().String(), nil, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(s).ToNot(BeNil())
|
||||
close(dialed)
|
||||
}()
|
||||
Consistently(dialed).ShouldNot(BeClosed())
|
||||
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
|
||||
Eventually(dialed).Should(BeClosed())
|
||||
server.Close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("Dial only returns after the handshake is complete", func() {
|
||||
It("returns after the handshake is complete", func() {
|
||||
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
|
||||
dialed := make(chan struct{})
|
||||
go func() {
|
||||
|
@ -160,13 +111,14 @@ var _ = Describe("Client", func() {
|
|||
Expect(s).ToNot(BeNil())
|
||||
close(dialed)
|
||||
}()
|
||||
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
|
||||
Consistently(dialed).ShouldNot(BeClosed())
|
||||
close(sess.handshakeComplete)
|
||||
close(sess.handshakeChan)
|
||||
Eventually(dialed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("resolves the address", func() {
|
||||
if os.Getenv("APPVEYOR") == "True" {
|
||||
Skip("This test is flaky on AppVeyor.")
|
||||
}
|
||||
closeErr := errors.New("peer doesn't reply")
|
||||
remoteAddrChan := make(chan string)
|
||||
newClientSession = func(
|
||||
|
@ -245,22 +197,7 @@ var _ = Describe("Client", func() {
|
|||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
sess.handshakeChan <- handshakeEvent{err: testErr}
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("returns an error that occurs while waiting for the handshake to complete", func() {
|
||||
testErr := errors.New("late handshake error")
|
||||
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
|
||||
sess.handshakeComplete <- testErr
|
||||
sess.handshakeChan <- testErr
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
|
@ -305,7 +242,7 @@ var _ = Describe("Client", func() {
|
|||
) (packetHandler, error) {
|
||||
return nil, testErr
|
||||
}
|
||||
_, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
|
@ -331,7 +268,7 @@ var _ = Describe("Client", func() {
|
|||
Expect(newVersion).ToNot(Equal(cl.version))
|
||||
Expect(config.Versions).To(ContainElement(newVersion))
|
||||
sessionChan := make(chan *mockSession)
|
||||
handshakeChan := make(chan handshakeEvent)
|
||||
handshakeChan := make(chan error)
|
||||
newClientSession = func(
|
||||
_ connection,
|
||||
_ string,
|
||||
|
@ -382,7 +319,7 @@ var _ = Describe("Client", func() {
|
|||
Expect(negotiatedVersions).To(ContainElement(newVersion))
|
||||
Expect(initialVersion).To(Equal(actualInitialVersion))
|
||||
|
||||
handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
|
||||
close(handshakeChan)
|
||||
Eventually(established).Should(BeClosed())
|
||||
})
|
||||
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type cryptoStreamI interface {
|
||||
StreamID() protocol.StreamID
|
||||
io.Reader
|
||||
io.Writer
|
||||
handleStreamFrame(*wire.StreamFrame) error
|
||||
popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
|
||||
closeForShutdown(error)
|
||||
setReadOffset(protocol.ByteCount)
|
||||
// methods needed for flow control
|
||||
getWindowUpdate() protocol.ByteCount
|
||||
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
|
||||
}
|
||||
|
||||
type cryptoStream struct {
|
||||
*stream
|
||||
}
|
||||
|
||||
var _ cryptoStreamI = &cryptoStream{}
|
||||
|
||||
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI {
|
||||
str := newStream(version.CryptoStreamID(), sender, flowController, version)
|
||||
return &cryptoStream{str}
|
||||
}
|
||||
|
||||
// SetReadOffset sets the read offset.
|
||||
// It is only needed for the crypto stream.
|
||||
// It must not be called concurrently with any other stream methods, especially Read and Write.
|
||||
func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) {
|
||||
s.receiveStream.readOffset = offset
|
||||
s.receiveStream.frameQueue.readPosition = offset
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Crypto Stream", func() {
|
||||
var (
|
||||
str *cryptoStream
|
||||
mockSender *MockStreamSender
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
mockSender = NewMockStreamSender(mockCtrl)
|
||||
str = newCryptoStream(mockSender, nil, protocol.VersionWhatever).(*cryptoStream)
|
||||
})
|
||||
|
||||
It("sets the read offset", func() {
|
||||
str.setReadOffset(0x42)
|
||||
Expect(str.receiveStream.readOffset).To(Equal(protocol.ByteCount(0x42)))
|
||||
Expect(str.receiveStream.frameQueue.readPosition).To(Equal(protocol.ByteCount(0x42)))
|
||||
})
|
||||
})
|
|
@ -97,45 +97,44 @@ func (c *client) handleHeaderStream() {
|
|||
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
||||
h2framer := http2.NewFramer(nil, c.headerStream)
|
||||
|
||||
var lastStream protocol.StreamID
|
||||
var err error
|
||||
for err == nil {
|
||||
err = c.readResponse(h2framer, decoder)
|
||||
}
|
||||
utils.Debugf("Error handling header stream: %s", err)
|
||||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
|
||||
// stop all running request
|
||||
close(c.headerErrored)
|
||||
}
|
||||
|
||||
for {
|
||||
func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
|
||||
frame, err := h2framer.ReadFrame()
|
||||
if err != nil {
|
||||
c.headerErr = qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
|
||||
break
|
||||
return err
|
||||
}
|
||||
lastStream = protocol.StreamID(frame.Header().StreamID)
|
||||
hframe, ok := frame.(*http2.HeadersFrame)
|
||||
if !ok {
|
||||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")
|
||||
break
|
||||
return errors.New("not a headers frame")
|
||||
}
|
||||
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
|
||||
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
|
||||
if err != nil {
|
||||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields")
|
||||
break
|
||||
return fmt.Errorf("cannot read header fields: %s", err.Error())
|
||||
}
|
||||
|
||||
c.mutex.RLock()
|
||||
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
|
||||
c.mutex.RUnlock()
|
||||
if !ok {
|
||||
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream))
|
||||
break
|
||||
return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
|
||||
}
|
||||
|
||||
rsp, err := responseFromHeaders(mhframe)
|
||||
if err != nil {
|
||||
c.headerErr = qerr.Error(qerr.InternalError, err.Error())
|
||||
return err
|
||||
}
|
||||
responseChan <- rsp
|
||||
}
|
||||
|
||||
// stop all running request
|
||||
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error())
|
||||
close(c.headerErrored)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Roundtrip executes a request and returns a response
|
||||
|
|
|
@ -188,41 +188,31 @@ var _ = Describe("Client", func() {
|
|||
close(done)
|
||||
})
|
||||
|
||||
It("closes the quic client when encountering an error on the header stream", func(done Done) {
|
||||
It("closes the quic client when encountering an error on the header stream", func() {
|
||||
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
|
||||
var doReturned bool
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
rsp, err := client.RoundTrip(request)
|
||||
Expect(err).To(MatchError(client.headerErr))
|
||||
Expect(rsp).To(BeNil())
|
||||
doReturned = true
|
||||
close(done)
|
||||
}()
|
||||
|
||||
Eventually(func() bool { return doReturned }).Should(BeTrue())
|
||||
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")))
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
|
||||
Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr))
|
||||
close(done)
|
||||
}, 2)
|
||||
})
|
||||
|
||||
It("returns subsequent request if there was an error on the header stream before", func(done Done) {
|
||||
expectedErr := qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
|
||||
It("returns subsequent request if there was an error on the header stream before", func() {
|
||||
session.streamsToOpen = []quic.Stream{headerStream, dataStream, newMockStream(7)}
|
||||
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
|
||||
var firstReqReturned bool
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := client.RoundTrip(request)
|
||||
Expect(err).To(MatchError(expectedErr))
|
||||
firstReqReturned = true
|
||||
}()
|
||||
|
||||
Eventually(func() bool { return firstReqReturned }).Should(BeTrue())
|
||||
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
|
||||
// now that the first request failed due to an error on the header stream, try another request
|
||||
_, err := client.RoundTrip(request)
|
||||
Expect(err).To(MatchError(expectedErr))
|
||||
close(done)
|
||||
_, nextErr := client.RoundTrip(request)
|
||||
Expect(nextErr).To(MatchError(err))
|
||||
})
|
||||
|
||||
It("blocks if no stream is available", func() {
|
||||
|
@ -479,16 +469,9 @@ var _ = Describe("Client", func() {
|
|||
|
||||
It("errors if the H2 frame is not a HeadersFrame", func() {
|
||||
h2framer.WritePing(true, [8]byte{0, 0, 0, 0, 0, 0, 0, 0})
|
||||
|
||||
var handlerReturned bool
|
||||
go func() {
|
||||
client.handleHeaderStream()
|
||||
handlerReturned = true
|
||||
}()
|
||||
|
||||
Eventually(client.headerErrored).Should(BeClosed())
|
||||
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")))
|
||||
Eventually(func() bool { return handlerReturned }).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("errors if it can't read the HPACK encoded header fields", func() {
|
||||
|
@ -497,16 +480,26 @@ var _ = Describe("Client", func() {
|
|||
EndHeaders: true,
|
||||
BlockFragment: []byte("invalid HPACK data"),
|
||||
})
|
||||
|
||||
var handlerReturned bool
|
||||
go func() {
|
||||
client.handleHeaderStream()
|
||||
handlerReturned = true
|
||||
}()
|
||||
|
||||
Eventually(client.headerErrored).Should(BeClosed())
|
||||
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields")))
|
||||
Eventually(func() bool { return handlerReturned }).Should(BeTrue())
|
||||
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
|
||||
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("cannot read header fields"))
|
||||
})
|
||||
|
||||
It("errors if the stream cannot be found", func() {
|
||||
var headers bytes.Buffer
|
||||
enc := hpack.NewEncoder(&headers)
|
||||
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
|
||||
err := h2framer.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: 1337,
|
||||
EndHeaders: true,
|
||||
BlockFragment: headers.Bytes(),
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
client.handleHeaderStream()
|
||||
Eventually(client.headerErrored).Should(BeClosed())
|
||||
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
|
||||
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("response channel for stream 1337 not found"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
@ -29,6 +30,8 @@ type mockStream struct {
|
|||
ctxCancel context.CancelFunc
|
||||
}
|
||||
|
||||
var _ quic.Stream = &mockStream{}
|
||||
|
||||
func newMockStream(id protocol.StreamID) *mockStream {
|
||||
s := &mockStream{
|
||||
id: id,
|
||||
|
@ -39,7 +42,8 @@ func newMockStream(id protocol.StreamID) *mockStream {
|
|||
}
|
||||
|
||||
func (s *mockStream) Close() error { s.closed = true; s.ctxCancel(); return nil }
|
||||
func (s *mockStream) Reset(error) { s.reset = true }
|
||||
func (s *mockStream) CancelRead(quic.ErrorCode) error { s.reset = true; return nil }
|
||||
func (s *mockStream) CancelWrite(quic.ErrorCode) error { panic("not implemented") }
|
||||
func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true; s.ctxCancel() }
|
||||
func (s mockStream) StreamID() protocol.StreamID { return s.id }
|
||||
func (s *mockStream) Context() context.Context { return s.ctx }
|
||||
|
|
|
@ -50,6 +50,7 @@ type Server struct {
|
|||
|
||||
listenerMutex sync.Mutex
|
||||
listener quic.Listener
|
||||
closed bool
|
||||
|
||||
supportedVersionsAsString string
|
||||
}
|
||||
|
@ -88,6 +89,10 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
|
|||
return errors.New("use of h2quic.Server without http.Server")
|
||||
}
|
||||
s.listenerMutex.Lock()
|
||||
if s.closed {
|
||||
s.listenerMutex.Unlock()
|
||||
return errors.New("Server is already closed")
|
||||
}
|
||||
if s.listener != nil {
|
||||
s.listenerMutex.Unlock()
|
||||
return errors.New("ListenAndServe may only be called once")
|
||||
|
@ -223,7 +228,8 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||
}
|
||||
if responseWriter.dataStream != nil {
|
||||
if !streamEnded && !reqBody.requestRead {
|
||||
responseWriter.dataStream.Reset(nil)
|
||||
// in gQUIC, the error code doesn't matter, so just use 0 here
|
||||
responseWriter.dataStream.CancelRead(0)
|
||||
}
|
||||
responseWriter.dataStream.Close()
|
||||
}
|
||||
|
@ -241,6 +247,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||
func (s *Server) Close() error {
|
||||
s.listenerMutex.Lock()
|
||||
defer s.listenerMutex.Unlock()
|
||||
s.closed = true
|
||||
if s.listener != nil {
|
||||
err := s.listener.Close()
|
||||
s.listener = nil
|
||||
|
|
|
@ -70,6 +70,7 @@ func (s *mockSession) RemoteAddr() net.Addr {
|
|||
func (s *mockSession) Context() context.Context {
|
||||
return s.ctx
|
||||
}
|
||||
func (s *mockSession) ConnectionState() quic.ConnectionState { panic("not implemented") }
|
||||
|
||||
var _ = Describe("H2 server", func() {
|
||||
var (
|
||||
|
@ -410,6 +411,13 @@ var _ = Describe("H2 server", func() {
|
|||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors when ListenAndServer is called after Close", func() {
|
||||
serv := &Server{Server: &http.Server{}}
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
err := serv.ListenAndServe()
|
||||
Expect(err).To(MatchError("Server is already closed"))
|
||||
})
|
||||
|
||||
Context("ListenAndServe", func() {
|
||||
BeforeEach(func() {
|
||||
s.Server.Addr = "localhost:0"
|
||||
|
|
|
@ -19,20 +19,42 @@ type VersionNumber = protocol.VersionNumber
|
|||
// A Cookie can be used to verify the ownership of the client address.
|
||||
type Cookie = handshake.Cookie
|
||||
|
||||
// ConnectionState records basic details about the QUIC connection.
|
||||
type ConnectionState = handshake.ConnectionState
|
||||
|
||||
// An ErrorCode is an application-defined error code.
|
||||
type ErrorCode = protocol.ApplicationErrorCode
|
||||
|
||||
// Stream is the interface implemented by QUIC streams
|
||||
type Stream interface {
|
||||
// StreamID returns the stream ID.
|
||||
StreamID() StreamID
|
||||
// Read reads data from the stream.
|
||||
// Read can be made to time out and return a net.Error with Timeout() == true
|
||||
// after a fixed time limit; see SetDeadline and SetReadDeadline.
|
||||
// If the stream was canceled by the peer, the error implements the StreamError
|
||||
// interface, and Canceled() == true.
|
||||
io.Reader
|
||||
// Write writes data to the stream.
|
||||
// Write can be made to time out and return a net.Error with Timeout() == true
|
||||
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
|
||||
// If the stream was canceled by the peer, the error implements the StreamError
|
||||
// interface, and Canceled() == true.
|
||||
io.Writer
|
||||
// Close closes the write-direction of the stream.
|
||||
// Future calls to Write are not permitted after calling Close.
|
||||
// It must not be called concurrently with Write.
|
||||
// It must not be called after calling CancelWrite.
|
||||
io.Closer
|
||||
StreamID() StreamID
|
||||
// Reset closes the stream with an error.
|
||||
Reset(error)
|
||||
// CancelWrite aborts sending on this stream.
|
||||
// It must not be called after Close.
|
||||
// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
|
||||
// Write will unblock immediately, and future calls to Write will fail.
|
||||
CancelWrite(ErrorCode) error
|
||||
// CancelRead aborts receiving on this stream.
|
||||
// It will ask the peer to stop transmitting stream data.
|
||||
// Read will unblock immediately, and future Read calls will fail.
|
||||
CancelRead(ErrorCode) error
|
||||
// The context is canceled as soon as the write-side of the stream is closed.
|
||||
// This happens when Close() is called, or when the stream is reset (either locally or remotely).
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
|
@ -53,6 +75,41 @@ type Stream interface {
|
|||
SetDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// A ReceiveStream is a unidirectional Receive Stream.
|
||||
type ReceiveStream interface {
|
||||
// see Stream.StreamID
|
||||
StreamID() StreamID
|
||||
// see Stream.Read
|
||||
io.Reader
|
||||
// see Stream.CancelRead
|
||||
CancelRead(ErrorCode) error
|
||||
// see Stream.SetReadDealine
|
||||
SetReadDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// A SendStream is a unidirectional Send Stream.
|
||||
type SendStream interface {
|
||||
// see Stream.StreamID
|
||||
StreamID() StreamID
|
||||
// see Stream.Write
|
||||
io.Writer
|
||||
// see Stream.Close
|
||||
io.Closer
|
||||
// see Stream.CancelWrite
|
||||
CancelWrite(ErrorCode) error
|
||||
// see Stream.Context
|
||||
Context() context.Context
|
||||
// see Stream.SetWriteDeadline
|
||||
SetWriteDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// StreamError is returned by Read and Write when the peer cancels the stream.
|
||||
type StreamError interface {
|
||||
error
|
||||
Canceled() bool
|
||||
ErrorCode() ErrorCode
|
||||
}
|
||||
|
||||
// A Session is a QUIC connection between two peers.
|
||||
type Session interface {
|
||||
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
|
||||
|
@ -74,13 +131,9 @@ type Session interface {
|
|||
// The context is cancelled when the session is closed.
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
Context() context.Context
|
||||
}
|
||||
|
||||
// A NonFWSession is a QUIC connection between two peers half-way through the handshake.
|
||||
// The communication is encrypted, but not yet forward secure.
|
||||
type NonFWSession interface {
|
||||
Session
|
||||
WaitUntilHandshakeComplete() error
|
||||
// ConnectionState returns basic details about the QUIC connection.
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
ConnectionState() ConnectionState
|
||||
}
|
||||
|
||||
// Config contains all configuration data needed for a QUIC server or client.
|
||||
|
|
|
@ -18,6 +18,7 @@ type CertManager interface {
|
|||
GetLeafCertHash() (uint64, error)
|
||||
VerifyServerProof(proof, chlo, serverConfigData []byte) bool
|
||||
Verify(hostname string) error
|
||||
GetChain() []*x509.Certificate
|
||||
}
|
||||
|
||||
type certManager struct {
|
||||
|
@ -54,6 +55,10 @@ func (c *certManager) SetData(data []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *certManager) GetChain() []*x509.Certificate {
|
||||
return c.chain
|
||||
}
|
||||
|
||||
func (c *certManager) GetCommonCertificateHashes() []byte {
|
||||
return getCommonCertificateHashes()
|
||||
}
|
||||
|
|
74
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go
generated
vendored
74
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go
generated
vendored
|
@ -10,35 +10,30 @@ import (
|
|||
)
|
||||
|
||||
type baseFlowController struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
rttStats *congestion.RTTStats
|
||||
|
||||
// for sending data
|
||||
bytesSent protocol.ByteCount
|
||||
sendWindow protocol.ByteCount
|
||||
|
||||
lastWindowUpdateTime time.Time
|
||||
|
||||
// for receiving data
|
||||
mutex sync.RWMutex
|
||||
bytesRead protocol.ByteCount
|
||||
highestReceived protocol.ByteCount
|
||||
receiveWindow protocol.ByteCount
|
||||
receiveWindowIncrement protocol.ByteCount
|
||||
maxReceiveWindowIncrement protocol.ByteCount
|
||||
receiveWindowSize protocol.ByteCount
|
||||
maxReceiveWindowSize protocol.ByteCount
|
||||
|
||||
epochStartTime time.Time
|
||||
epochStartOffset protocol.ByteCount
|
||||
rttStats *congestion.RTTStats
|
||||
}
|
||||
|
||||
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.bytesSent += n
|
||||
}
|
||||
|
||||
// UpdateSendWindow should be called after receiving a WindowUpdateFrame
|
||||
// it returns true if the window was actually updated
|
||||
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
if offset > c.sendWindow {
|
||||
c.sendWindow = offset
|
||||
}
|
||||
|
@ -57,52 +52,55 @@ func (c *baseFlowController) AddBytesRead(n protocol.ByteCount) {
|
|||
defer c.mutex.Unlock()
|
||||
|
||||
// pretend we sent a WindowUpdate when reading the first byte
|
||||
// this way auto-tuning of the window increment already works for the first WindowUpdate
|
||||
// this way auto-tuning of the window size already works for the first WindowUpdate
|
||||
if c.bytesRead == 0 {
|
||||
c.lastWindowUpdateTime = time.Now()
|
||||
c.startNewAutoTuningEpoch()
|
||||
}
|
||||
c.bytesRead += n
|
||||
}
|
||||
|
||||
func (c *baseFlowController) hasWindowUpdate() bool {
|
||||
bytesRemaining := c.receiveWindow - c.bytesRead
|
||||
// update the window when more than the threshold was consumed
|
||||
return bytesRemaining <= protocol.ByteCount((float64(c.receiveWindowSize) * float64((1 - protocol.WindowUpdateThreshold))))
|
||||
}
|
||||
|
||||
// getWindowUpdate updates the receive window, if necessary
|
||||
// it returns the new offset
|
||||
func (c *baseFlowController) getWindowUpdate() protocol.ByteCount {
|
||||
diff := c.receiveWindow - c.bytesRead
|
||||
// update the window when more than half of it was already consumed
|
||||
if diff >= (c.receiveWindowIncrement / 2) {
|
||||
if !c.hasWindowUpdate() {
|
||||
return 0
|
||||
}
|
||||
|
||||
c.maybeAdjustWindowIncrement()
|
||||
c.receiveWindow = c.bytesRead + c.receiveWindowIncrement
|
||||
c.lastWindowUpdateTime = time.Now()
|
||||
c.maybeAdjustWindowSize()
|
||||
c.receiveWindow = c.bytesRead + c.receiveWindowSize
|
||||
return c.receiveWindow
|
||||
}
|
||||
|
||||
func (c *baseFlowController) IsBlocked() bool {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
|
||||
return c.sendWindowSize() == 0
|
||||
}
|
||||
|
||||
// maybeAdjustWindowIncrement increases the receiveWindowIncrement if we're sending WindowUpdates too often
|
||||
func (c *baseFlowController) maybeAdjustWindowIncrement() {
|
||||
if c.lastWindowUpdateTime.IsZero() {
|
||||
// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often.
|
||||
// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing.
|
||||
func (c *baseFlowController) maybeAdjustWindowSize() {
|
||||
bytesReadInEpoch := c.bytesRead - c.epochStartOffset
|
||||
// don't do anything if less than half the window has been consumed
|
||||
if bytesReadInEpoch <= c.receiveWindowSize/2 {
|
||||
return
|
||||
}
|
||||
|
||||
rtt := c.rttStats.SmoothedRTT()
|
||||
if rtt == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
timeSinceLastWindowUpdate := time.Since(c.lastWindowUpdateTime)
|
||||
// interval between the window updates is sufficiently large, no need to increase the increment
|
||||
if timeSinceLastWindowUpdate >= 2*rtt {
|
||||
return
|
||||
fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize)
|
||||
if time.Since(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) {
|
||||
// window is consumed too fast, try to increase the window size
|
||||
c.receiveWindowSize = utils.MinByteCount(2*c.receiveWindowSize, c.maxReceiveWindowSize)
|
||||
}
|
||||
c.receiveWindowIncrement = utils.MinByteCount(2*c.receiveWindowIncrement, c.maxReceiveWindowIncrement)
|
||||
c.startNewAutoTuningEpoch()
|
||||
}
|
||||
|
||||
func (c *baseFlowController) startNewAutoTuningEpoch() {
|
||||
c.epochStartTime = time.Now()
|
||||
c.epochStartOffset = c.bytesRead
|
||||
}
|
||||
|
||||
func (c *baseFlowController) checkFlowControlViolation() bool {
|
||||
|
|
181
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller_test.go
generated
vendored
181
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller_test.go
generated
vendored
|
@ -1,6 +1,8 @@
|
|||
package flowcontrol
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
|
@ -9,6 +11,16 @@ import (
|
|||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// on the CIs, the timing is a lot less precise, so scale every duration by this factor
|
||||
func scaleDuration(t time.Duration) time.Duration {
|
||||
scaleFactor := 1
|
||||
if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set
|
||||
scaleFactor = f
|
||||
}
|
||||
Expect(scaleFactor).ToNot(BeZero())
|
||||
return time.Duration(scaleFactor) * t
|
||||
}
|
||||
|
||||
var _ = Describe("Base Flow controller", func() {
|
||||
var controller *baseFlowController
|
||||
|
||||
|
@ -49,22 +61,18 @@ var _ = Describe("Base Flow controller", func() {
|
|||
controller.UpdateSendWindow(10)
|
||||
Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(20)))
|
||||
})
|
||||
|
||||
It("says when it's blocked", func() {
|
||||
controller.UpdateSendWindow(100)
|
||||
Expect(controller.IsBlocked()).To(BeFalse())
|
||||
controller.AddBytesSent(100)
|
||||
Expect(controller.IsBlocked()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("receive flow control", func() {
|
||||
var receiveWindow protocol.ByteCount = 10000
|
||||
var receiveWindowIncrement protocol.ByteCount = 600
|
||||
var (
|
||||
receiveWindow protocol.ByteCount = 10000
|
||||
receiveWindowSize protocol.ByteCount = 1000
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
controller.bytesRead = receiveWindow - receiveWindowSize
|
||||
controller.receiveWindow = receiveWindow
|
||||
controller.receiveWindowIncrement = receiveWindowIncrement
|
||||
controller.receiveWindowSize = receiveWindowSize
|
||||
})
|
||||
|
||||
It("adds bytes read", func() {
|
||||
|
@ -74,31 +82,30 @@ var _ = Describe("Base Flow controller", func() {
|
|||
})
|
||||
|
||||
It("triggers a window update when necessary", func() {
|
||||
controller.lastWindowUpdateTime = time.Now().Add(-time.Hour)
|
||||
readPosition := receiveWindow - receiveWindowIncrement/2 + 1
|
||||
bytesConsumed := float64(receiveWindowSize)*protocol.WindowUpdateThreshold + 1 // consumed 1 byte more than the threshold
|
||||
bytesRemaining := receiveWindowSize - protocol.ByteCount(bytesConsumed)
|
||||
readPosition := receiveWindow - bytesRemaining
|
||||
controller.bytesRead = readPosition
|
||||
offset := controller.getWindowUpdate()
|
||||
Expect(offset).To(Equal(readPosition + receiveWindowIncrement))
|
||||
Expect(controller.receiveWindow).To(Equal(readPosition + receiveWindowIncrement))
|
||||
Expect(controller.lastWindowUpdateTime).To(BeTemporally("~", time.Now(), 20*time.Millisecond))
|
||||
Expect(offset).To(Equal(readPosition + receiveWindowSize))
|
||||
Expect(controller.receiveWindow).To(Equal(readPosition + receiveWindowSize))
|
||||
})
|
||||
|
||||
It("doesn't trigger a window update when not necessary", func() {
|
||||
lastWindowUpdateTime := time.Now().Add(-time.Hour)
|
||||
controller.lastWindowUpdateTime = lastWindowUpdateTime
|
||||
readPosition := receiveWindow - receiveWindow/2 - 1
|
||||
bytesConsumed := float64(receiveWindowSize)*protocol.WindowUpdateThreshold - 1 // consumed 1 byte less than the threshold
|
||||
bytesRemaining := receiveWindowSize - protocol.ByteCount(bytesConsumed)
|
||||
readPosition := receiveWindow - bytesRemaining
|
||||
controller.bytesRead = readPosition
|
||||
offset := controller.getWindowUpdate()
|
||||
Expect(offset).To(BeZero())
|
||||
Expect(controller.lastWindowUpdateTime).To(Equal(lastWindowUpdateTime))
|
||||
})
|
||||
|
||||
Context("receive window increment auto-tuning", func() {
|
||||
var oldIncrement protocol.ByteCount
|
||||
Context("receive window size auto-tuning", func() {
|
||||
var oldWindowSize protocol.ByteCount
|
||||
|
||||
BeforeEach(func() {
|
||||
oldIncrement = controller.receiveWindowIncrement
|
||||
controller.maxReceiveWindowIncrement = 3000
|
||||
oldWindowSize = controller.receiveWindowSize
|
||||
controller.maxReceiveWindowSize = 5000
|
||||
})
|
||||
|
||||
// update the congestion such that it returns a given value for the smoothed RTT
|
||||
|
@ -107,72 +114,98 @@ var _ = Describe("Base Flow controller", func() {
|
|||
Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked
|
||||
}
|
||||
|
||||
It("doesn't increase the increment for a new stream", func() {
|
||||
controller.maybeAdjustWindowIncrement()
|
||||
Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement))
|
||||
It("doesn't increase the window size for a new stream", func() {
|
||||
controller.maybeAdjustWindowSize()
|
||||
Expect(controller.receiveWindowSize).To(Equal(oldWindowSize))
|
||||
})
|
||||
|
||||
It("doesn't increase the increment when no RTT estimate is available", func() {
|
||||
It("doesn't increase the window size when no RTT estimate is available", func() {
|
||||
setRtt(0)
|
||||
controller.lastWindowUpdateTime = time.Now()
|
||||
controller.maybeAdjustWindowIncrement()
|
||||
Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement))
|
||||
controller.startNewAutoTuningEpoch()
|
||||
controller.AddBytesRead(400)
|
||||
offset := controller.getWindowUpdate()
|
||||
Expect(offset).ToNot(BeZero()) // make sure a window update is sent
|
||||
Expect(controller.receiveWindowSize).To(Equal(oldWindowSize))
|
||||
})
|
||||
|
||||
It("increases the increment when the last WindowUpdate was sent less than two RTTs ago", func() {
|
||||
setRtt(20 * time.Millisecond)
|
||||
controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond)
|
||||
controller.maybeAdjustWindowIncrement()
|
||||
Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement))
|
||||
})
|
||||
|
||||
It("doesn't increase the increase increment when the last WindowUpdate was sent more than two RTTs ago", func() {
|
||||
setRtt(20 * time.Millisecond)
|
||||
controller.lastWindowUpdateTime = time.Now().Add(-45 * time.Millisecond)
|
||||
controller.maybeAdjustWindowIncrement()
|
||||
Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement))
|
||||
})
|
||||
|
||||
It("doesn't increase the increment to a value higher than the maxReceiveWindowIncrement", func() {
|
||||
setRtt(20 * time.Millisecond)
|
||||
controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond)
|
||||
controller.maybeAdjustWindowIncrement()
|
||||
Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement)) // 1200
|
||||
// because the lastWindowUpdateTime is updated by MaybeTriggerWindowUpdate(), we can just call maybeAdjustWindowIncrement() multiple times and get an increase of the increment every time
|
||||
controller.maybeAdjustWindowIncrement()
|
||||
Expect(controller.receiveWindowIncrement).To(Equal(2 * 2 * oldIncrement)) // 2400
|
||||
controller.maybeAdjustWindowIncrement()
|
||||
Expect(controller.receiveWindowIncrement).To(Equal(controller.maxReceiveWindowIncrement)) // 3000
|
||||
controller.maybeAdjustWindowIncrement()
|
||||
Expect(controller.receiveWindowIncrement).To(Equal(controller.maxReceiveWindowIncrement)) // 3000
|
||||
})
|
||||
|
||||
It("returns the new increment when updating the window", func() {
|
||||
setRtt(20 * time.Millisecond)
|
||||
controller.AddBytesRead(9900) // receive window is 10000
|
||||
controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond)
|
||||
It("increases the window size if read so fast that the window would be consumed in less than 4 RTTs", func() {
|
||||
bytesRead := controller.bytesRead
|
||||
rtt := scaleDuration(20 * time.Millisecond)
|
||||
setRtt(rtt)
|
||||
// consume more than 2/3 of the window...
|
||||
dataRead := receiveWindowSize*2/3 + 1
|
||||
// ... in 4*2/3 of the RTT
|
||||
controller.epochStartOffset = controller.bytesRead
|
||||
controller.epochStartTime = time.Now().Add(-rtt * 4 * 2 / 3)
|
||||
controller.AddBytesRead(dataRead)
|
||||
offset := controller.getWindowUpdate()
|
||||
Expect(offset).ToNot(BeZero())
|
||||
newIncrement := controller.receiveWindowIncrement
|
||||
Expect(newIncrement).To(Equal(2 * oldIncrement))
|
||||
Expect(offset).To(Equal(protocol.ByteCount(9900 + newIncrement)))
|
||||
// check that the window size was increased
|
||||
newWindowSize := controller.receiveWindowSize
|
||||
Expect(newWindowSize).To(Equal(2 * oldWindowSize))
|
||||
// check that the new window size was used to increase the offset
|
||||
Expect(offset).To(Equal(protocol.ByteCount(bytesRead + dataRead + newWindowSize)))
|
||||
})
|
||||
|
||||
It("increases the increment sent in the first WindowUpdate, if data is read fast enough", func() {
|
||||
setRtt(20 * time.Millisecond)
|
||||
controller.AddBytesRead(9900)
|
||||
It("doesn't increase the window size if data is read so fast that the window would be consumed in less than 4 RTTs, but less than half the window has been read", func() {
|
||||
// this test only makes sense if a window update is triggered before half of the window has been consumed
|
||||
Expect(protocol.WindowUpdateThreshold).To(BeNumerically(">", 1/3))
|
||||
bytesRead := controller.bytesRead
|
||||
rtt := scaleDuration(20 * time.Millisecond)
|
||||
setRtt(rtt)
|
||||
// consume more than 2/3 of the window...
|
||||
dataRead := receiveWindowSize*1/3 + 1
|
||||
// ... in 4*2/3 of the RTT
|
||||
controller.epochStartOffset = controller.bytesRead
|
||||
controller.epochStartTime = time.Now().Add(-rtt * 4 * 1 / 3)
|
||||
controller.AddBytesRead(dataRead)
|
||||
offset := controller.getWindowUpdate()
|
||||
Expect(offset).ToNot(BeZero())
|
||||
Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement))
|
||||
// check that the window size was not increased
|
||||
newWindowSize := controller.receiveWindowSize
|
||||
Expect(newWindowSize).To(Equal(oldWindowSize))
|
||||
// check that the new window size was used to increase the offset
|
||||
Expect(offset).To(Equal(protocol.ByteCount(bytesRead + dataRead + newWindowSize)))
|
||||
})
|
||||
|
||||
It("doesn't increamse the increment sent in the first WindowUpdate, if data is read slowly", func() {
|
||||
setRtt(5 * time.Millisecond)
|
||||
controller.AddBytesRead(9900)
|
||||
time.Sleep(15 * time.Millisecond) // more than 2x RTT
|
||||
It("doesn't increase the window size if read too slowly", func() {
|
||||
bytesRead := controller.bytesRead
|
||||
rtt := scaleDuration(20 * time.Millisecond)
|
||||
setRtt(rtt)
|
||||
// consume less than 2/3 of the window...
|
||||
dataRead := receiveWindowSize*2/3 - 1
|
||||
// ... in 4*2/3 of the RTT
|
||||
controller.epochStartOffset = controller.bytesRead
|
||||
controller.epochStartTime = time.Now().Add(-rtt * 4 * 2 / 3)
|
||||
controller.AddBytesRead(dataRead)
|
||||
offset := controller.getWindowUpdate()
|
||||
Expect(offset).ToNot(BeZero())
|
||||
Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement))
|
||||
// check that the window size was not increased
|
||||
Expect(controller.receiveWindowSize).To(Equal(oldWindowSize))
|
||||
// check that the new window size was used to increase the offset
|
||||
Expect(offset).To(Equal(protocol.ByteCount(bytesRead + dataRead + oldWindowSize)))
|
||||
})
|
||||
|
||||
It("doesn't increase the window size to a value higher than the maxReceiveWindowSize", func() {
|
||||
resetEpoch := func() {
|
||||
// make sure the next call to maybeAdjustWindowSize will increase the window
|
||||
controller.epochStartTime = time.Now().Add(-time.Millisecond)
|
||||
controller.epochStartOffset = controller.bytesRead
|
||||
controller.AddBytesRead(controller.receiveWindowSize/2 + 1)
|
||||
}
|
||||
setRtt(scaleDuration(20 * time.Millisecond))
|
||||
resetEpoch()
|
||||
controller.maybeAdjustWindowSize()
|
||||
Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize)) // 2000
|
||||
// because the lastWindowUpdateTime is updated by MaybeTriggerWindowUpdate(), we can just call maybeAdjustWindowSize() multiple times and get an increase of the window size every time
|
||||
resetEpoch()
|
||||
controller.maybeAdjustWindowSize()
|
||||
Expect(controller.receiveWindowSize).To(Equal(2 * 2 * oldWindowSize)) // 4000
|
||||
resetEpoch()
|
||||
controller.maybeAdjustWindowSize()
|
||||
Expect(controller.receiveWindowSize).To(Equal(controller.maxReceiveWindowSize)) // 5000
|
||||
controller.maybeAdjustWindowSize()
|
||||
Expect(controller.receiveWindowSize).To(Equal(controller.maxReceiveWindowSize)) // 5000
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -2,7 +2,6 @@ package flowcontrol
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
@ -11,6 +10,7 @@ import (
|
|||
)
|
||||
|
||||
type connectionFlowController struct {
|
||||
lastBlockedAt protocol.ByteCount
|
||||
baseFlowController
|
||||
}
|
||||
|
||||
|
@ -27,19 +27,27 @@ func NewConnectionFlowController(
|
|||
baseFlowController: baseFlowController{
|
||||
rttStats: rttStats,
|
||||
receiveWindow: receiveWindow,
|
||||
receiveWindowIncrement: receiveWindow,
|
||||
maxReceiveWindowIncrement: maxReceiveWindow,
|
||||
receiveWindowSize: receiveWindow,
|
||||
maxReceiveWindowSize: maxReceiveWindow,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connectionFlowController) SendWindowSize() protocol.ByteCount {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
|
||||
return c.baseFlowController.sendWindowSize()
|
||||
}
|
||||
|
||||
// IsNewlyBlocked says if it is newly blocked by flow control.
|
||||
// For every offset, it only returns true once.
|
||||
// If it is blocked, the offset is returned.
|
||||
func (c *connectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
|
||||
if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
|
||||
return false, 0
|
||||
}
|
||||
c.lastBlockedAt = c.sendWindow
|
||||
return true, c.sendWindow
|
||||
}
|
||||
|
||||
// IncrementHighestReceived adds an increment to the highestReceived value
|
||||
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error {
|
||||
c.mutex.Lock()
|
||||
|
@ -54,24 +62,22 @@ func (c *connectionFlowController) IncrementHighestReceived(increment protocol.B
|
|||
|
||||
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
oldWindowIncrement := c.receiveWindowIncrement
|
||||
oldWindowSize := c.receiveWindowSize
|
||||
offset := c.baseFlowController.getWindowUpdate()
|
||||
if oldWindowIncrement < c.receiveWindowIncrement {
|
||||
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10))
|
||||
if oldWindowSize < c.receiveWindowSize {
|
||||
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
return offset
|
||||
}
|
||||
|
||||
// EnsureMinimumWindowIncrement sets a minimum window increment
|
||||
// EnsureMinimumWindowSize sets a minimum window size
|
||||
// it should make sure that the connection-level window is increased when a stream-level window grows
|
||||
func (c *connectionFlowController) EnsureMinimumWindowIncrement(inc protocol.ByteCount) {
|
||||
func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
if inc > c.receiveWindowIncrement {
|
||||
c.receiveWindowIncrement = utils.MinByteCount(inc, c.maxReceiveWindowIncrement)
|
||||
c.lastWindowUpdateTime = time.Time{} // disables autotuning for the next window update
|
||||
if inc > c.receiveWindowSize {
|
||||
c.receiveWindowSize = utils.MinByteCount(inc, c.maxReceiveWindowSize)
|
||||
c.startNewAutoTuningEpoch()
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
|
|
@ -32,12 +32,12 @@ var _ = Describe("Connection Flow controller", func() {
|
|||
|
||||
fc := NewConnectionFlowController(receiveWindow, maxReceiveWindow, rttStats).(*connectionFlowController)
|
||||
Expect(fc.receiveWindow).To(Equal(receiveWindow))
|
||||
Expect(fc.maxReceiveWindowIncrement).To(Equal(maxReceiveWindow))
|
||||
Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow))
|
||||
})
|
||||
})
|
||||
|
||||
Context("receive flow control", func() {
|
||||
It("increases the highestReceived by a given increment", func() {
|
||||
It("increases the highestReceived by a given window size", func() {
|
||||
controller.highestReceived = 1337
|
||||
controller.IncrementHighestReceived(123)
|
||||
Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1337 + 123)))
|
||||
|
@ -46,64 +46,96 @@ var _ = Describe("Connection Flow controller", func() {
|
|||
Context("getting window updates", func() {
|
||||
BeforeEach(func() {
|
||||
controller.receiveWindow = 100
|
||||
controller.receiveWindowIncrement = 60
|
||||
controller.maxReceiveWindowIncrement = 1000
|
||||
controller.receiveWindowSize = 60
|
||||
controller.maxReceiveWindowSize = 1000
|
||||
controller.bytesRead = 100 - 60
|
||||
})
|
||||
|
||||
It("gets a window update", func() {
|
||||
controller.AddBytesRead(80)
|
||||
windowSize := controller.receiveWindowSize
|
||||
oldOffset := controller.bytesRead
|
||||
dataRead := windowSize/2 - 1 // make sure not to trigger auto-tuning
|
||||
controller.AddBytesRead(dataRead)
|
||||
offset := controller.GetWindowUpdate()
|
||||
Expect(offset).To(Equal(protocol.ByteCount(80 + 60)))
|
||||
Expect(offset).To(Equal(protocol.ByteCount(oldOffset + dataRead + 60)))
|
||||
})
|
||||
|
||||
It("autotunes the window", func() {
|
||||
controller.AddBytesRead(80)
|
||||
setRtt(20 * time.Millisecond)
|
||||
controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond)
|
||||
oldOffset := controller.bytesRead
|
||||
oldWindowSize := controller.receiveWindowSize
|
||||
rtt := scaleDuration(20 * time.Millisecond)
|
||||
setRtt(rtt)
|
||||
controller.epochStartTime = time.Now().Add(-time.Millisecond)
|
||||
controller.epochStartOffset = oldOffset
|
||||
dataRead := oldWindowSize/2 + 1
|
||||
controller.AddBytesRead(dataRead)
|
||||
offset := controller.GetWindowUpdate()
|
||||
Expect(offset).To(Equal(protocol.ByteCount(80 + 2*60)))
|
||||
newWindowSize := controller.receiveWindowSize
|
||||
Expect(newWindowSize).To(Equal(2 * oldWindowSize))
|
||||
Expect(offset).To(Equal(protocol.ByteCount(oldOffset + dataRead + newWindowSize)))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("setting the minimum increment", func() {
|
||||
Context("send flow control", func() {
|
||||
It("says when it's blocked", func() {
|
||||
controller.UpdateSendWindow(100)
|
||||
Expect(controller.IsNewlyBlocked()).To(BeFalse())
|
||||
controller.AddBytesSent(100)
|
||||
blocked, offset := controller.IsNewlyBlocked()
|
||||
Expect(blocked).To(BeTrue())
|
||||
Expect(offset).To(Equal(protocol.ByteCount(100)))
|
||||
})
|
||||
|
||||
It("doesn't say that it's newly blocked multiple times for the same offset", func() {
|
||||
controller.UpdateSendWindow(100)
|
||||
controller.AddBytesSent(100)
|
||||
newlyBlocked, offset := controller.IsNewlyBlocked()
|
||||
Expect(newlyBlocked).To(BeTrue())
|
||||
Expect(offset).To(Equal(protocol.ByteCount(100)))
|
||||
newlyBlocked, _ = controller.IsNewlyBlocked()
|
||||
Expect(newlyBlocked).To(BeFalse())
|
||||
controller.UpdateSendWindow(150)
|
||||
controller.AddBytesSent(150)
|
||||
newlyBlocked, offset = controller.IsNewlyBlocked()
|
||||
Expect(newlyBlocked).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("setting the minimum window size", func() {
|
||||
var (
|
||||
oldIncrement protocol.ByteCount
|
||||
oldWindowSize protocol.ByteCount
|
||||
receiveWindow protocol.ByteCount = 10000
|
||||
receiveWindowIncrement protocol.ByteCount = 600
|
||||
receiveWindowSize protocol.ByteCount = 1000
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
controller.bytesRead = receiveWindowSize - receiveWindowSize
|
||||
controller.receiveWindow = receiveWindow
|
||||
controller.receiveWindowIncrement = receiveWindowIncrement
|
||||
oldIncrement = controller.receiveWindowIncrement
|
||||
controller.maxReceiveWindowIncrement = 3000
|
||||
controller.receiveWindowSize = receiveWindowSize
|
||||
oldWindowSize = controller.receiveWindowSize
|
||||
controller.maxReceiveWindowSize = 3000
|
||||
})
|
||||
|
||||
It("sets the minimum window increment", func() {
|
||||
controller.EnsureMinimumWindowIncrement(1000)
|
||||
Expect(controller.receiveWindowIncrement).To(Equal(protocol.ByteCount(1000)))
|
||||
It("sets the minimum window window size", func() {
|
||||
controller.EnsureMinimumWindowSize(1800)
|
||||
Expect(controller.receiveWindowSize).To(Equal(protocol.ByteCount(1800)))
|
||||
})
|
||||
|
||||
It("doesn't reduce the window increment", func() {
|
||||
controller.EnsureMinimumWindowIncrement(1)
|
||||
Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement))
|
||||
It("doesn't reduce the window window size", func() {
|
||||
controller.EnsureMinimumWindowSize(1)
|
||||
Expect(controller.receiveWindowSize).To(Equal(oldWindowSize))
|
||||
})
|
||||
|
||||
It("doens't increase the increment beyond the maxReceiveWindowIncrement", func() {
|
||||
max := controller.maxReceiveWindowIncrement
|
||||
controller.EnsureMinimumWindowIncrement(2 * max)
|
||||
Expect(controller.receiveWindowIncrement).To(Equal(max))
|
||||
It("doens't increase the window size beyond the maxReceiveWindowSize", func() {
|
||||
max := controller.maxReceiveWindowSize
|
||||
controller.EnsureMinimumWindowSize(2 * max)
|
||||
Expect(controller.receiveWindowSize).To(Equal(max))
|
||||
})
|
||||
|
||||
It("doesn't auto-tune the window after the increment was increased", func() {
|
||||
setRtt(20 * time.Millisecond)
|
||||
controller.bytesRead = 9900 // receive window is 10000
|
||||
controller.lastWindowUpdateTime = time.Now().Add(-20 * time.Millisecond)
|
||||
controller.EnsureMinimumWindowIncrement(912)
|
||||
offset := controller.getWindowUpdate()
|
||||
Expect(controller.receiveWindowIncrement).To(Equal(protocol.ByteCount(912))) // no auto-tuning
|
||||
Expect(offset).To(Equal(protocol.ByteCount(9900 + 912)))
|
||||
It("starts a new epoch after the window size was increased", func() {
|
||||
controller.EnsureMinimumWindowSize(1912)
|
||||
Expect(controller.epochStartTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -5,7 +5,6 @@ import "github.com/lucas-clemente/quic-go/internal/protocol"
|
|||
type flowController interface {
|
||||
// for sending
|
||||
SendWindowSize() protocol.ByteCount
|
||||
IsBlocked() bool
|
||||
UpdateSendWindow(protocol.ByteCount)
|
||||
AddBytesSent(protocol.ByteCount)
|
||||
// for receiving
|
||||
|
@ -16,22 +15,28 @@ type flowController interface {
|
|||
// A StreamFlowController is a flow controller for a QUIC stream.
|
||||
type StreamFlowController interface {
|
||||
flowController
|
||||
// for sending
|
||||
IsBlocked() (bool, protocol.ByteCount)
|
||||
// for receiving
|
||||
// UpdateHighestReceived should be called when a new highest offset is received
|
||||
// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame
|
||||
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
|
||||
// HasWindowUpdate says if it is necessary to update the window
|
||||
HasWindowUpdate() bool
|
||||
}
|
||||
|
||||
// The ConnectionFlowController is the flow controller for the connection.
|
||||
type ConnectionFlowController interface {
|
||||
flowController
|
||||
// for sending
|
||||
IsNewlyBlocked() (bool, protocol.ByteCount)
|
||||
}
|
||||
|
||||
type connectionFlowControllerI interface {
|
||||
ConnectionFlowController
|
||||
// The following two methods are not supposed to be called from outside this packet, but are needed internally
|
||||
// for sending
|
||||
EnsureMinimumWindowIncrement(protocol.ByteCount)
|
||||
EnsureMinimumWindowSize(protocol.ByteCount)
|
||||
// for receiving
|
||||
IncrementHighestReceived(protocol.ByteCount) error
|
||||
}
|
||||
|
|
38
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
38
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
|
@ -39,8 +39,8 @@ func NewStreamFlowController(
|
|||
baseFlowController: baseFlowController{
|
||||
rttStats: rttStats,
|
||||
receiveWindow: receiveWindow,
|
||||
receiveWindowIncrement: receiveWindow,
|
||||
maxReceiveWindowIncrement: maxReceiveWindow,
|
||||
receiveWindowSize: receiveWindow,
|
||||
maxReceiveWindowSize: maxReceiveWindow,
|
||||
sendWindow: initialSendWindow,
|
||||
},
|
||||
}
|
||||
|
@ -102,9 +102,6 @@ func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
|
|||
}
|
||||
|
||||
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
window := c.baseFlowController.sendWindowSize()
|
||||
if c.contributesToConnection {
|
||||
window = utils.MinByteCount(window, c.connection.SendWindowSize())
|
||||
|
@ -112,22 +109,39 @@ func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
|
|||
return window
|
||||
}
|
||||
|
||||
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
// IsBlocked says if it is blocked by stream-level flow control.
|
||||
// If it is blocked, the offset is returned.
|
||||
func (c *streamFlowController) IsBlocked() (bool, protocol.ByteCount) {
|
||||
if c.sendWindowSize() != 0 {
|
||||
return false, 0
|
||||
}
|
||||
return true, c.sendWindow
|
||||
}
|
||||
|
||||
func (c *streamFlowController) HasWindowUpdate() bool {
|
||||
c.mutex.Lock()
|
||||
hasWindowUpdate := !c.receivedFinalOffset && c.hasWindowUpdate()
|
||||
c.mutex.Unlock()
|
||||
return hasWindowUpdate
|
||||
}
|
||||
|
||||
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||
// don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler
|
||||
c.mutex.Lock()
|
||||
// if we already received the final offset for this stream, the peer won't need any additional flow control credit
|
||||
if c.receivedFinalOffset {
|
||||
c.mutex.Unlock()
|
||||
return 0
|
||||
}
|
||||
|
||||
oldWindowIncrement := c.receiveWindowIncrement
|
||||
oldWindowSize := c.receiveWindowSize
|
||||
offset := c.baseFlowController.getWindowUpdate()
|
||||
if c.receiveWindowIncrement > oldWindowIncrement { // auto-tuning enlarged the window increment
|
||||
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10))
|
||||
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
|
||||
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
|
||||
if c.contributesToConnection {
|
||||
c.connection.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(c.receiveWindowIncrement) * protocol.ConnectionFlowControlMultiplier))
|
||||
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
|
||||
}
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
return offset
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ var _ = Describe("Stream Flow controller", func() {
|
|||
streamID: 10,
|
||||
connection: NewConnectionFlowController(1000, 1000, rttStats).(*connectionFlowController),
|
||||
}
|
||||
controller.maxReceiveWindowIncrement = 10000
|
||||
controller.maxReceiveWindowSize = 10000
|
||||
controller.rttStats = rttStats
|
||||
})
|
||||
|
||||
|
@ -35,7 +35,7 @@ var _ = Describe("Stream Flow controller", func() {
|
|||
fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, rttStats).(*streamFlowController)
|
||||
Expect(fc.streamID).To(Equal(protocol.StreamID(5)))
|
||||
Expect(fc.receiveWindow).To(Equal(receiveWindow))
|
||||
Expect(fc.maxReceiveWindowIncrement).To(Equal(maxReceiveWindow))
|
||||
Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow))
|
||||
Expect(fc.sendWindow).To(Equal(sendWindow))
|
||||
Expect(fc.contributesToConnection).To(BeTrue())
|
||||
})
|
||||
|
@ -44,11 +44,11 @@ var _ = Describe("Stream Flow controller", func() {
|
|||
Context("receiving data", func() {
|
||||
Context("registering received offsets", func() {
|
||||
var receiveWindow protocol.ByteCount = 10000
|
||||
var receiveWindowIncrement protocol.ByteCount = 600
|
||||
var receiveWindowSize protocol.ByteCount = 600
|
||||
|
||||
BeforeEach(func() {
|
||||
controller.receiveWindow = receiveWindow
|
||||
controller.receiveWindowIncrement = receiveWindowIncrement
|
||||
controller.receiveWindowSize = receiveWindowSize
|
||||
})
|
||||
|
||||
It("updates the highestReceived", func() {
|
||||
|
@ -157,7 +157,7 @@ var _ = Describe("Stream Flow controller", func() {
|
|||
})
|
||||
|
||||
Context("generating window updates", func() {
|
||||
var oldIncrement protocol.ByteCount
|
||||
var oldWindowSize protocol.ByteCount
|
||||
|
||||
// update the congestion such that it returns a given value for the smoothed RTT
|
||||
setRtt := func(t time.Duration) {
|
||||
|
@ -167,37 +167,51 @@ var _ = Describe("Stream Flow controller", func() {
|
|||
|
||||
BeforeEach(func() {
|
||||
controller.receiveWindow = 100
|
||||
controller.receiveWindowIncrement = 60
|
||||
controller.connection.(*connectionFlowController).receiveWindowIncrement = 120
|
||||
oldIncrement = controller.receiveWindowIncrement
|
||||
controller.receiveWindowSize = 60
|
||||
controller.bytesRead = 100 - 60
|
||||
controller.connection.(*connectionFlowController).receiveWindowSize = 120
|
||||
oldWindowSize = controller.receiveWindowSize
|
||||
})
|
||||
|
||||
It("tells if it has window updates", func() {
|
||||
Expect(controller.HasWindowUpdate()).To(BeFalse())
|
||||
controller.AddBytesRead(30)
|
||||
Expect(controller.HasWindowUpdate()).To(BeTrue())
|
||||
Expect(controller.GetWindowUpdate()).ToNot(BeZero())
|
||||
Expect(controller.HasWindowUpdate()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("tells the connection flow controller when the window was autotuned", func() {
|
||||
oldOffset := controller.bytesRead
|
||||
controller.contributesToConnection = true
|
||||
controller.AddBytesRead(75)
|
||||
setRtt(20 * time.Millisecond)
|
||||
controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond)
|
||||
setRtt(scaleDuration(20 * time.Millisecond))
|
||||
controller.epochStartOffset = oldOffset
|
||||
controller.epochStartTime = time.Now().Add(-time.Millisecond)
|
||||
controller.AddBytesRead(55)
|
||||
offset := controller.GetWindowUpdate()
|
||||
Expect(offset).To(Equal(protocol.ByteCount(75 + 2*60)))
|
||||
Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement))
|
||||
Expect(controller.connection.(*connectionFlowController).receiveWindowIncrement).To(Equal(protocol.ByteCount(float64(controller.receiveWindowIncrement) * protocol.ConnectionFlowControlMultiplier)))
|
||||
Expect(offset).To(Equal(protocol.ByteCount(oldOffset + 55 + 2*oldWindowSize)))
|
||||
Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize))
|
||||
Expect(controller.connection.(*connectionFlowController).receiveWindowSize).To(Equal(protocol.ByteCount(float64(controller.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier)))
|
||||
})
|
||||
|
||||
It("doesn't tell the connection flow controller if it doesn't contribute", func() {
|
||||
oldOffset := controller.bytesRead
|
||||
controller.contributesToConnection = false
|
||||
controller.AddBytesRead(75)
|
||||
setRtt(20 * time.Millisecond)
|
||||
controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond)
|
||||
setRtt(scaleDuration(20 * time.Millisecond))
|
||||
controller.epochStartOffset = oldOffset
|
||||
controller.epochStartTime = time.Now().Add(-time.Millisecond)
|
||||
controller.AddBytesRead(55)
|
||||
offset := controller.GetWindowUpdate()
|
||||
Expect(offset).ToNot(BeZero())
|
||||
Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement))
|
||||
Expect(controller.connection.(*connectionFlowController).receiveWindowIncrement).To(Equal(protocol.ByteCount(120))) // unchanged
|
||||
Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize))
|
||||
Expect(controller.connection.(*connectionFlowController).receiveWindowSize).To(Equal(protocol.ByteCount(2 * oldWindowSize))) // unchanged
|
||||
})
|
||||
|
||||
It("doesn't increase the window after a final offset was already received", func() {
|
||||
controller.AddBytesRead(80)
|
||||
controller.AddBytesRead(30)
|
||||
err := controller.UpdateHighestReceived(90, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(controller.HasWindowUpdate()).To(BeFalse())
|
||||
offset := controller.GetWindowUpdate()
|
||||
Expect(offset).To(BeZero())
|
||||
})
|
||||
|
@ -231,7 +245,8 @@ var _ = Describe("Stream Flow controller", func() {
|
|||
controller.connection.UpdateSendWindow(50)
|
||||
controller.UpdateSendWindow(100)
|
||||
controller.AddBytesSent(50)
|
||||
Expect(controller.connection.IsBlocked()).To(BeTrue())
|
||||
blocked, _ := controller.connection.IsNewlyBlocked()
|
||||
Expect(blocked).To(BeTrue())
|
||||
Expect(controller.IsBlocked()).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
|
23
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go
generated
vendored
23
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go
generated
vendored
|
@ -52,7 +52,7 @@ type cryptoSetupClient struct {
|
|||
forwardSecureAEAD crypto.AEAD
|
||||
|
||||
paramsChan chan<- TransportParameters
|
||||
aeadChanged chan<- protocol.EncryptionLevel
|
||||
handshakeEvent chan<- struct{}
|
||||
|
||||
params *TransportParameters
|
||||
}
|
||||
|
@ -74,7 +74,7 @@ func NewCryptoSetupClient(
|
|||
tlsConfig *tls.Config,
|
||||
params *TransportParameters,
|
||||
paramsChan chan<- TransportParameters,
|
||||
aeadChanged chan<- protocol.EncryptionLevel,
|
||||
handshakeEvent chan<- struct{},
|
||||
initialVersion protocol.VersionNumber,
|
||||
negotiatedVersions []protocol.VersionNumber,
|
||||
) (CryptoSetup, error) {
|
||||
|
@ -93,7 +93,7 @@ func NewCryptoSetupClient(
|
|||
keyExchange: getEphermalKEX,
|
||||
nullAEAD: nullAEAD,
|
||||
paramsChan: paramsChan,
|
||||
aeadChanged: aeadChanged,
|
||||
handshakeEvent: handshakeEvent,
|
||||
initialVersion: initialVersion,
|
||||
negotiatedVersions: negotiatedVersions,
|
||||
divNonceChan: make(chan []byte),
|
||||
|
@ -159,8 +159,8 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
|
|||
}
|
||||
// blocks until the session has received the parameters
|
||||
h.paramsChan <- *params
|
||||
h.aeadChanged <- protocol.EncryptionForwardSecure
|
||||
close(h.aeadChanged)
|
||||
h.handshakeEvent <- struct{}{}
|
||||
close(h.handshakeEvent)
|
||||
default:
|
||||
return qerr.InvalidCryptoMessageType
|
||||
}
|
||||
|
@ -381,6 +381,15 @@ func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) {
|
|||
h.divNonceChan <- data
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) ConnectionState() ConnectionState {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
return ConnectionState{
|
||||
HandshakeComplete: h.forwardSecureAEAD != nil,
|
||||
PeerCertificates: h.certManager.GetChain(),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) sendCHLO() error {
|
||||
h.clientHelloCounter++
|
||||
if h.clientHelloCounter > protocol.MaxClientHellos {
|
||||
|
@ -496,10 +505,8 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.aeadChanged <- protocol.EncryptionSecure
|
||||
h.handshakeEvent <- struct{}{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
171
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client_test.go
generated
vendored
171
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client_test.go
generated
vendored
|
@ -2,6 +2,7 @@ package handshake
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/x509"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -10,6 +11,7 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/mocks/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
. "github.com/onsi/ginkgo"
|
||||
|
@ -34,6 +36,8 @@ type mockCertManager struct {
|
|||
|
||||
commonCertificateHashes []byte
|
||||
|
||||
chain []*x509.Certificate
|
||||
|
||||
leafCert []byte
|
||||
leafCertHash uint64
|
||||
leafCertHashError error
|
||||
|
@ -45,6 +49,8 @@ type mockCertManager struct {
|
|||
verifyCalled bool
|
||||
}
|
||||
|
||||
var _ crypto.CertManager = &mockCertManager{}
|
||||
|
||||
func (m *mockCertManager) SetData(data []byte) error {
|
||||
m.setDataCalledWith = data
|
||||
return m.setDataError
|
||||
|
@ -72,6 +78,10 @@ func (m *mockCertManager) Verify(hostname string) error {
|
|||
return m.verifyError
|
||||
}
|
||||
|
||||
func (m *mockCertManager) GetChain() []*x509.Certificate {
|
||||
return m.chain
|
||||
}
|
||||
|
||||
var _ = Describe("Client Crypto Setup", func() {
|
||||
var (
|
||||
cs *cryptoSetupClient
|
||||
|
@ -79,7 +89,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
stream *mockStream
|
||||
keyDerivationCalledWith *keyDerivationValues
|
||||
shloMap map[Tag][]byte
|
||||
aeadChanged chan protocol.EncryptionLevel
|
||||
handshakeEvent chan struct{}
|
||||
paramsChan chan TransportParameters
|
||||
)
|
||||
|
||||
|
@ -108,7 +118,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
version := protocol.Version39
|
||||
// use a buffered channel here, so that we can parse a SHLO without having to receive the TransportParameters to avoid blocking
|
||||
paramsChan = make(chan TransportParameters, 1)
|
||||
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
||||
handshakeEvent = make(chan struct{}, 2)
|
||||
csInt, err := NewCryptoSetupClient(
|
||||
stream,
|
||||
"hostname",
|
||||
|
@ -117,7 +127,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
nil,
|
||||
&TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout},
|
||||
paramsChan,
|
||||
aeadChanged,
|
||||
handshakeEvent,
|
||||
protocol.Version39,
|
||||
nil,
|
||||
)
|
||||
|
@ -130,10 +140,6 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
cs.cryptoStream = stream
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
close(stream.unblockRead)
|
||||
})
|
||||
|
||||
Context("Reading REJ", func() {
|
||||
var tagMap map[Tag][]byte
|
||||
|
||||
|
@ -158,8 +164,17 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
stk := []byte("foobar")
|
||||
tagMap[TagSTK] = stk
|
||||
HandshakeMessage{Tag: TagREJ, Data: tagMap}.Write(&stream.dataToRead)
|
||||
go cs.HandleCryptoStream()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
|
||||
close(done)
|
||||
}()
|
||||
Eventually(func() []byte { return cs.stk }).Should(Equal(stk))
|
||||
// make the go routine return
|
||||
stream.close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("saves the proof", func() {
|
||||
|
@ -380,22 +395,22 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
cs.receivedSecurePacket = false
|
||||
_, err := cs.handleSHLOMessage(shloMap)
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")))
|
||||
Expect(aeadChanged).ToNot(Receive())
|
||||
Expect(aeadChanged).ToNot(BeClosed())
|
||||
Expect(handshakeEvent).ToNot(Receive())
|
||||
Expect(handshakeEvent).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("rejects SHLOs without a PUBS", func() {
|
||||
delete(shloMap, TagPUBS)
|
||||
_, err := cs.handleSHLOMessage(shloMap)
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")))
|
||||
Expect(aeadChanged).ToNot(BeClosed())
|
||||
Expect(handshakeEvent).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("rejects SHLOs without a version list", func() {
|
||||
delete(shloMap, TagVER)
|
||||
_, err := cs.handleSHLOMessage(shloMap)
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")))
|
||||
Expect(aeadChanged).ToNot(BeClosed())
|
||||
Expect(handshakeEvent).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("accepts a SHLO after a version negotiation", func() {
|
||||
|
@ -430,28 +445,38 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
Expect(params.IdleTimeout).To(Equal(13 * time.Second))
|
||||
})
|
||||
|
||||
It("closes the aeadChanged when receiving an SHLO", func() {
|
||||
It("closes the handshakeEvent chan when receiving an SHLO", func() {
|
||||
HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
|
||||
close(done)
|
||||
}()
|
||||
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionForwardSecure)))
|
||||
Eventually(aeadChanged).Should(BeClosed())
|
||||
Eventually(handshakeEvent).Should(Receive())
|
||||
Eventually(handshakeEvent).Should(BeClosed())
|
||||
// make the go routine return
|
||||
stream.close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("passes the transport parameters on the channel", func() {
|
||||
shloMap[TagSFCW] = []byte{0x0d, 0x00, 0xdf, 0xba}
|
||||
HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
|
||||
close(done)
|
||||
}()
|
||||
var params TransportParameters
|
||||
Eventually(paramsChan).Should(Receive(¶ms))
|
||||
Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0xbadf000d)))
|
||||
// make the go routine return
|
||||
stream.close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("errors if it can't read a connection parameter", func() {
|
||||
|
@ -637,9 +662,9 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
Expect(keyDerivationCalledWith.cert).To(Equal(certManager.leafCert))
|
||||
Expect(keyDerivationCalledWith.divNonce).To(Equal(cs.diversificationNonce))
|
||||
Expect(keyDerivationCalledWith.pers).To(Equal(protocol.PerspectiveClient))
|
||||
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
|
||||
Expect(aeadChanged).ToNot(Receive())
|
||||
Expect(aeadChanged).ToNot(BeClosed())
|
||||
Expect(handshakeEvent).To(Receive())
|
||||
Expect(handshakeEvent).ToNot(Receive())
|
||||
Expect(handshakeEvent).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("uses the server nonce, if the server sent one", func() {
|
||||
|
@ -649,51 +674,64 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cs.secureAEAD).ToNot(BeNil())
|
||||
Expect(keyDerivationCalledWith.nonces).To(Equal(append(cs.nonc, cs.sno...)))
|
||||
Expect(aeadChanged).To(Receive())
|
||||
Expect(aeadChanged).ToNot(Receive())
|
||||
Expect(aeadChanged).ToNot(BeClosed())
|
||||
Expect(handshakeEvent).To(Receive())
|
||||
Expect(handshakeEvent).ToNot(Receive())
|
||||
Expect(handshakeEvent).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't create a secureAEAD if the certificate is not yet verified, even if it has all necessary values", func() {
|
||||
err := cs.maybeUpgradeCrypto()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cs.secureAEAD).To(BeNil())
|
||||
Expect(aeadChanged).ToNot(Receive())
|
||||
Expect(handshakeEvent).ToNot(Receive())
|
||||
cs.serverVerified = true
|
||||
// make sure we really had all necessary values before, and only serverVerified was missing
|
||||
err = cs.maybeUpgradeCrypto()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cs.secureAEAD).ToNot(BeNil())
|
||||
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
|
||||
Expect(aeadChanged).ToNot(Receive())
|
||||
Expect(aeadChanged).ToNot(BeClosed())
|
||||
Expect(handshakeEvent).To(Receive())
|
||||
Expect(handshakeEvent).ToNot(Receive())
|
||||
Expect(handshakeEvent).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("tries to escalate before reading a handshake message", func() {
|
||||
Expect(cs.secureAEAD).To(BeNil())
|
||||
cs.serverVerified = true
|
||||
go cs.HandleCryptoStream()
|
||||
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure)))
|
||||
Expect(cs.secureAEAD).ToNot(BeNil())
|
||||
Expect(aeadChanged).ToNot(Receive())
|
||||
Expect(aeadChanged).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("tries to escalate the crypto after receiving a diversification nonce", func(done Done) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cs.HandleCryptoStream()
|
||||
Fail("HandleCryptoStream should not have returned")
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
|
||||
close(done)
|
||||
}()
|
||||
Eventually(handshakeEvent).Should(Receive())
|
||||
Expect(cs.secureAEAD).ToNot(BeNil())
|
||||
Expect(handshakeEvent).ToNot(Receive())
|
||||
Expect(handshakeEvent).ToNot(BeClosed())
|
||||
// make the go routine return
|
||||
stream.close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("tries to escalate the crypto after receiving a diversification nonce", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
|
||||
close(done)
|
||||
}()
|
||||
cs.diversificationNonce = nil
|
||||
cs.serverVerified = true
|
||||
Expect(cs.secureAEAD).To(BeNil())
|
||||
cs.SetDiversificationNonce([]byte("div"))
|
||||
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure)))
|
||||
Eventually(handshakeEvent).Should(Receive())
|
||||
Expect(cs.secureAEAD).ToNot(BeNil())
|
||||
Expect(aeadChanged).ToNot(Receive())
|
||||
Expect(aeadChanged).ToNot(BeClosed())
|
||||
close(done)
|
||||
Expect(handshakeEvent).ToNot(Receive())
|
||||
Expect(handshakeEvent).ToNot(BeClosed())
|
||||
// make the go routine return
|
||||
stream.close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
Context("null encryption", func() {
|
||||
|
@ -813,6 +851,22 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Context("reporting the connection state", func() {
|
||||
It("reports the connection state before the handshake completes", func() {
|
||||
chain := []*x509.Certificate{testdata.GetCertificate().Leaf}
|
||||
certManager.chain = chain
|
||||
state := cs.ConnectionState()
|
||||
Expect(state.HandshakeComplete).To(BeFalse())
|
||||
Expect(state.PeerCertificates).To(Equal(chain))
|
||||
})
|
||||
|
||||
It("reports the connection state after the handshake completes", func() {
|
||||
doSHLO()
|
||||
state := cs.ConnectionState()
|
||||
Expect(state.HandshakeComplete).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("forcing encryption levels", func() {
|
||||
It("forces null encryption", func() {
|
||||
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(4), []byte{}).Return([]byte("foobar unencrypted"))
|
||||
|
@ -862,32 +916,51 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
|
||||
Context("Diversification Nonces", func() {
|
||||
It("sets a diversification nonce", func() {
|
||||
go cs.HandleCryptoStream()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
|
||||
close(done)
|
||||
}()
|
||||
nonce := []byte("foobar")
|
||||
cs.SetDiversificationNonce(nonce)
|
||||
Eventually(func() []byte { return cs.diversificationNonce }).Should(Equal(nonce))
|
||||
// make the go routine return
|
||||
stream.close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't do anything when called multiple times with the same nonce", func(done Done) {
|
||||
go cs.HandleCryptoStream()
|
||||
It("doesn't do anything when called multiple times with the same nonce", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
|
||||
close(done)
|
||||
}()
|
||||
nonce := []byte("foobar")
|
||||
cs.SetDiversificationNonce(nonce)
|
||||
cs.SetDiversificationNonce(nonce)
|
||||
Eventually(func() []byte { return cs.diversificationNonce }).Should(Equal(nonce))
|
||||
close(done)
|
||||
// make the go routine return
|
||||
stream.close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("rejects a different diversification nonce", func() {
|
||||
var err error
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
err = cs.HandleCryptoStream()
|
||||
defer GinkgoRecover()
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(errConflictingDiversificationNonces))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
nonce1 := []byte("foobar")
|
||||
nonce2 := []byte("raboof")
|
||||
cs.SetDiversificationNonce(nonce1)
|
||||
cs.SetDiversificationNonce(nonce2)
|
||||
Eventually(func() error { return err }).Should(MatchError(errConflictingDiversificationNonces))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
|
|
29
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go
generated
vendored
29
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go
generated
vendored
|
@ -23,6 +23,8 @@ type KeyExchangeFunction func() crypto.KeyExchange
|
|||
|
||||
// The CryptoSetupServer handles all things crypto for the Session
|
||||
type cryptoSetupServer struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
connID protocol.ConnectionID
|
||||
remoteAddr net.Addr
|
||||
scfg *ServerConfig
|
||||
|
@ -42,7 +44,7 @@ type cryptoSetupServer struct {
|
|||
|
||||
receivedParams bool
|
||||
paramsChan chan<- TransportParameters
|
||||
aeadChanged chan<- protocol.EncryptionLevel
|
||||
handshakeEvent chan<- struct{}
|
||||
|
||||
keyDerivation QuicCryptoKeyDerivationFunction
|
||||
keyExchange KeyExchangeFunction
|
||||
|
@ -51,7 +53,7 @@ type cryptoSetupServer struct {
|
|||
|
||||
params *TransportParameters
|
||||
|
||||
mutex sync.RWMutex
|
||||
sni string // need to fill out the ConnectionState
|
||||
}
|
||||
|
||||
var _ CryptoSetup = &cryptoSetupServer{}
|
||||
|
@ -76,7 +78,7 @@ func NewCryptoSetup(
|
|||
supportedVersions []protocol.VersionNumber,
|
||||
acceptSTK func(net.Addr, *Cookie) bool,
|
||||
paramsChan chan<- TransportParameters,
|
||||
aeadChanged chan<- protocol.EncryptionLevel,
|
||||
handshakeEvent chan<- struct{},
|
||||
) (CryptoSetup, error) {
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
|
||||
if err != nil {
|
||||
|
@ -96,7 +98,7 @@ func NewCryptoSetup(
|
|||
acceptSTKCallback: acceptSTK,
|
||||
sentSHLO: make(chan struct{}),
|
||||
paramsChan: paramsChan,
|
||||
aeadChanged: aeadChanged,
|
||||
handshakeEvent: handshakeEvent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -139,6 +141,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
|
|||
if sni == "" {
|
||||
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
|
||||
}
|
||||
h.sni = sni
|
||||
|
||||
// prevent version downgrade attacks
|
||||
// see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/N-de9j63tCk for a discussion and examples
|
||||
|
@ -182,7 +185,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
|
|||
if _, err := h.cryptoStream.Write(reply); err != nil {
|
||||
return false, err
|
||||
}
|
||||
h.aeadChanged <- protocol.EncryptionForwardSecure
|
||||
h.handshakeEvent <- struct{}{}
|
||||
close(h.sentSHLO)
|
||||
return true, nil
|
||||
}
|
||||
|
@ -206,9 +209,9 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
|
|||
if err == nil {
|
||||
if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client
|
||||
h.receivedForwardSecurePacket = true
|
||||
// wait until protocol.EncryptionForwardSecure was sent on the aeadChan
|
||||
// wait for the send on the handshakeEvent chan
|
||||
<-h.sentSHLO
|
||||
close(h.aeadChanged)
|
||||
close(h.handshakeEvent)
|
||||
}
|
||||
return res, protocol.EncryptionForwardSecure, nil
|
||||
}
|
||||
|
@ -396,8 +399,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h.aeadChanged <- protocol.EncryptionSecure
|
||||
h.handshakeEvent <- struct{}{}
|
||||
|
||||
// Generate a new curve instance to derive the forward secure key
|
||||
var fsNonce bytes.Buffer
|
||||
|
@ -454,6 +456,15 @@ func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) {
|
|||
panic("not needed for cryptoSetupServer")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) ConnectionState() ConnectionState {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
return ConnectionState{
|
||||
ServerName: h.sni,
|
||||
HandshakeComplete: h.receivedForwardSecurePacket,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
|
||||
if len(nonce) != 32 {
|
||||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")
|
||||
|
|
66
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server_test.go
generated
vendored
66
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server_test.go
generated
vendored
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
|
@ -63,35 +64,36 @@ func mockQuicCryptoKeyDerivation(forwardSecure bool, sharedSecret, nonces []byte
|
|||
}
|
||||
|
||||
type mockStream struct {
|
||||
unblockRead chan struct{} // close this chan to unblock Read
|
||||
unblockRead chan struct{}
|
||||
dataToRead bytes.Buffer
|
||||
dataWritten bytes.Buffer
|
||||
}
|
||||
|
||||
var _ io.ReadWriter = &mockStream{}
|
||||
|
||||
var errMockStreamClosing = errors.New("mock stream closing")
|
||||
|
||||
func newMockStream() *mockStream {
|
||||
return &mockStream{unblockRead: make(chan struct{})}
|
||||
}
|
||||
|
||||
// call Close to make Read return
|
||||
func (s *mockStream) Read(p []byte) (int, error) {
|
||||
n, _ := s.dataToRead.Read(p)
|
||||
if n == 0 { // block if there's no data
|
||||
<-s.unblockRead
|
||||
return 0, errMockStreamClosing
|
||||
}
|
||||
return n, nil // never return an EOF
|
||||
}
|
||||
|
||||
func (s *mockStream) ReadByte() (byte, error) {
|
||||
return s.dataToRead.ReadByte()
|
||||
}
|
||||
|
||||
func (s *mockStream) Write(p []byte) (int, error) {
|
||||
return s.dataWritten.Write(p)
|
||||
}
|
||||
|
||||
func (s *mockStream) Close() error { panic("not implemented") }
|
||||
func (s *mockStream) Reset(error) { panic("not implemented") }
|
||||
func (mockStream) CloseRemote(offset protocol.ByteCount) { panic("not implemented") }
|
||||
func (s mockStream) StreamID() protocol.StreamID { panic("not implemented") }
|
||||
func (s *mockStream) close() {
|
||||
close(s.unblockRead)
|
||||
}
|
||||
|
||||
type mockCookieProtector struct {
|
||||
data []byte
|
||||
|
@ -122,7 +124,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
cs *cryptoSetupServer
|
||||
stream *mockStream
|
||||
paramsChan chan TransportParameters
|
||||
aeadChanged chan protocol.EncryptionLevel
|
||||
handshakeEvent chan struct{}
|
||||
nonce32 []byte
|
||||
versionTag []byte
|
||||
validSTK []byte
|
||||
|
@ -144,7 +146,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
|
||||
// use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking
|
||||
paramsChan = make(chan TransportParameters, 1)
|
||||
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
||||
handshakeEvent = make(chan struct{}, 2)
|
||||
stream = newMockStream()
|
||||
kex = &mockKEX{}
|
||||
signer = &mockSigner{}
|
||||
|
@ -168,7 +170,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
supportedVersions,
|
||||
nil,
|
||||
paramsChan,
|
||||
aeadChanged,
|
||||
handshakeEvent,
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
cs = csInt.(*cryptoSetupServer)
|
||||
|
@ -183,10 +185,6 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
cs.cryptoStream = stream
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
close(stream.unblockRead)
|
||||
})
|
||||
|
||||
Context("diversification nonce", func() {
|
||||
BeforeEach(func() {
|
||||
cs.secureAEAD = mockcrypto.NewMockAEAD(mockCtrl)
|
||||
|
@ -345,10 +343,10 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
err := cs.HandleCryptoStream()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ"))
|
||||
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
|
||||
Expect(handshakeEvent).To(Receive()) // for the switch to secure
|
||||
Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO"))
|
||||
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure)))
|
||||
Expect(aeadChanged).ToNot(BeClosed())
|
||||
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
|
||||
Expect(handshakeEvent).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("rejects client nonces that have the wrong length", func() {
|
||||
|
@ -379,9 +377,9 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO"))
|
||||
Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ"))
|
||||
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
|
||||
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure)))
|
||||
Expect(aeadChanged).ToNot(BeClosed())
|
||||
Expect(handshakeEvent).To(Receive()) // for the switch to secure
|
||||
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
|
||||
Expect(handshakeEvent).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("recognizes inchoate CHLOs missing SCID", func() {
|
||||
|
@ -537,7 +535,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
TagKEXS: kexs,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
|
||||
Expect(handshakeEvent).To(Receive()) // for the switch to secure
|
||||
close(cs.sentSHLO)
|
||||
}
|
||||
|
||||
|
@ -659,7 +657,26 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(200), []byte{})
|
||||
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 200, []byte{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(aeadChanged).To(BeClosed())
|
||||
Expect(handshakeEvent).To(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
Context("reporting the connection state", func() {
|
||||
It("reports before the handshake completes", func() {
|
||||
cs.sni = "server name"
|
||||
state := cs.ConnectionState()
|
||||
Expect(state.HandshakeComplete).To(BeFalse())
|
||||
Expect(state.ServerName).To(Equal("server name"))
|
||||
})
|
||||
|
||||
It("reports after the handshake completes", func() {
|
||||
doCHLO()
|
||||
// receive a forward secure packet
|
||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(11), []byte{})
|
||||
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 11, []byte{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
state := cs.ConnectionState()
|
||||
Expect(state.HandshakeComplete).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -723,6 +740,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(done).To(BeFalse())
|
||||
Expect(stream.dataWritten.Bytes()).To(ContainSubstring(string(validSTK)))
|
||||
Expect(cs.sni).To(Equal("foo"))
|
||||
})
|
||||
|
||||
It("works with proper STK", func() {
|
||||
|
|
26
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go
generated
vendored
26
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go
generated
vendored
|
@ -28,7 +28,7 @@ type cryptoSetupTLS struct {
|
|||
|
||||
tls MintTLS
|
||||
cryptoStream *CryptoStreamConn
|
||||
aeadChanged chan<- protocol.EncryptionLevel
|
||||
handshakeEvent chan<- struct{}
|
||||
}
|
||||
|
||||
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
|
||||
|
@ -36,7 +36,7 @@ func NewCryptoSetupTLSServer(
|
|||
tls MintTLS,
|
||||
cryptoStream *CryptoStreamConn,
|
||||
nullAEAD crypto.AEAD,
|
||||
aeadChanged chan<- protocol.EncryptionLevel,
|
||||
handshakeEvent chan<- struct{},
|
||||
version protocol.VersionNumber,
|
||||
) CryptoSetup {
|
||||
return &cryptoSetupTLS{
|
||||
|
@ -45,7 +45,7 @@ func NewCryptoSetupTLSServer(
|
|||
nullAEAD: nullAEAD,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
aeadChanged: aeadChanged,
|
||||
handshakeEvent: handshakeEvent,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -54,7 +54,7 @@ func NewCryptoSetupTLSClient(
|
|||
cryptoStream io.ReadWriter,
|
||||
connID protocol.ConnectionID,
|
||||
hostname string,
|
||||
aeadChanged chan<- protocol.EncryptionLevel,
|
||||
handshakeEvent chan<- struct{},
|
||||
tls MintTLS,
|
||||
version protocol.VersionNumber,
|
||||
) (CryptoSetup, error) {
|
||||
|
@ -68,7 +68,7 @@ func NewCryptoSetupTLSClient(
|
|||
tls: tls,
|
||||
nullAEAD: nullAEAD,
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
aeadChanged: aeadChanged,
|
||||
handshakeEvent: handshakeEvent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -102,9 +102,8 @@ handshakeLoop:
|
|||
h.aead = aead
|
||||
h.mutex.Unlock()
|
||||
|
||||
// signal to the outside world that the handshake completed
|
||||
h.aeadChanged <- protocol.EncryptionForwardSecure
|
||||
close(h.aeadChanged)
|
||||
h.handshakeEvent <- struct{}{}
|
||||
close(h.handshakeEvent)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -165,3 +164,14 @@ func (h *cryptoSetupTLS) DiversificationNonce() []byte {
|
|||
func (h *cryptoSetupTLS) SetDiversificationNonce([]byte) {
|
||||
panic("diversification nonce not needed for TLS")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) ConnectionState() ConnectionState {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
mintConnState := h.tls.ConnectionState()
|
||||
return ConnectionState{
|
||||
// TODO: set the ServerName, once mint exports it
|
||||
HandshakeComplete: h.aead != nil,
|
||||
PeerCertificates: mintConnState.PeerCertificates,
|
||||
}
|
||||
}
|
||||
|
|
41
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls_test.go
generated
vendored
41
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls_test.go
generated
vendored
|
@ -21,16 +21,16 @@ func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, e
|
|||
var _ = Describe("TLS Crypto Setup", func() {
|
||||
var (
|
||||
cs *cryptoSetupTLS
|
||||
aeadChanged chan protocol.EncryptionLevel
|
||||
handshakeEvent chan struct{}
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
||||
handshakeEvent = make(chan struct{}, 2)
|
||||
cs = NewCryptoSetupTLSServer(
|
||||
nil,
|
||||
NewCryptoStreamConn(nil),
|
||||
nil, // AEAD
|
||||
aeadChanged,
|
||||
handshakeEvent,
|
||||
protocol.VersionTLS,
|
||||
).(*cryptoSetupTLS)
|
||||
cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl)
|
||||
|
@ -51,8 +51,8 @@ var _ = Describe("TLS Crypto Setup", func() {
|
|||
cs.keyDerivation = mockKeyDerivation
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure)))
|
||||
Expect(aeadChanged).To(BeClosed())
|
||||
Expect(handshakeEvent).To(Receive())
|
||||
Expect(handshakeEvent).To(BeClosed())
|
||||
})
|
||||
|
||||
It("handshakes until it is connected", func() {
|
||||
|
@ -63,7 +63,30 @@ var _ = Describe("TLS Crypto Setup", func() {
|
|||
cs.keyDerivation = mockKeyDerivation
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(aeadChanged).To(Receive())
|
||||
Expect(handshakeEvent).To(Receive())
|
||||
})
|
||||
|
||||
Context("reporting the handshake state", func() {
|
||||
It("reports before the handshake compeletes", func() {
|
||||
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
|
||||
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{})
|
||||
state := cs.ConnectionState()
|
||||
Expect(state.HandshakeComplete).To(BeFalse())
|
||||
Expect(state.PeerCertificates).To(BeNil())
|
||||
})
|
||||
|
||||
It("reports after the handshake completes", func() {
|
||||
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
|
||||
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{})
|
||||
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
|
||||
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected)
|
||||
cs.keyDerivation = mockKeyDerivation
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
state := cs.ConnectionState()
|
||||
Expect(state.HandshakeComplete).To(BeTrue())
|
||||
Expect(state.PeerCertificates).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("escalating crypto", func() {
|
||||
|
@ -181,16 +204,16 @@ var _ = Describe("TLS Crypto Setup", func() {
|
|||
var _ = Describe("TLS Crypto Setup, for the client", func() {
|
||||
var (
|
||||
cs *cryptoSetupTLS
|
||||
aeadChanged chan protocol.EncryptionLevel
|
||||
handshakeEvent chan struct{}
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
||||
handshakeEvent = make(chan struct{})
|
||||
csInt, err := NewCryptoSetupTLSClient(
|
||||
nil,
|
||||
0,
|
||||
"quic.clemente.io",
|
||||
aeadChanged,
|
||||
handshakeEvent,
|
||||
nil, // mintTLS
|
||||
protocol.VersionTLS,
|
||||
)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"io"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
|
@ -29,6 +30,7 @@ type MintTLS interface {
|
|||
// additional methods
|
||||
Handshake() mint.Alert
|
||||
State() mint.State
|
||||
ConnectionState() mint.ConnectionState
|
||||
|
||||
SetCryptoStream(io.ReadWriter)
|
||||
SetExtensionHandler(mint.AppExtensionHandler) error
|
||||
|
@ -41,8 +43,17 @@ type CryptoSetup interface {
|
|||
// TODO: clean up this interface
|
||||
DiversificationNonce() []byte // only needed for cryptoSetupServer
|
||||
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
|
||||
ConnectionState() ConnectionState
|
||||
|
||||
GetSealer() (protocol.EncryptionLevel, Sealer)
|
||||
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
|
||||
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
|
||||
}
|
||||
|
||||
// ConnectionState records basic details about the QUIC connection.
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
type ConnectionState struct {
|
||||
HandshakeComplete bool // handshake is complete
|
||||
ServerName string // server name requested by client, if any (server side only)
|
||||
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ type extensionHandlerClient struct {
|
|||
var _ mint.AppExtensionHandler = &extensionHandlerClient{}
|
||||
var _ TLSExtensionHandler = &extensionHandlerClient{}
|
||||
|
||||
// NewExtensionHandlerClient creates a new extension handler for the client.
|
||||
func NewExtensionHandlerClient(
|
||||
params *TransportParameters,
|
||||
initialVersion protocol.VersionNumber,
|
||||
|
@ -57,7 +58,10 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi
|
|||
|
||||
func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
|
||||
ext := &tlsExtensionBody{}
|
||||
found, _ := el.Find(ext)
|
||||
found, err := el.Find(ext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if hType != mint.HandshakeTypeEncryptedExtensions && hType != mint.HandshakeTypeNewSessionTicket {
|
||||
if found {
|
||||
|
|
|
@ -39,7 +39,8 @@ var _ = Describe("TLS Extension Handler, for the client", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(el).To(HaveLen(1))
|
||||
ext := &tlsExtensionBody{}
|
||||
found := el.Find(ext)
|
||||
found, err := el.Find(ext)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(found).To(BeTrue())
|
||||
chtp := &clientHelloTransportParameters{}
|
||||
_, err = syntax.Unmarshal(ext.data, chtp)
|
||||
|
|
|
@ -24,6 +24,7 @@ type extensionHandlerServer struct {
|
|||
var _ mint.AppExtensionHandler = &extensionHandlerServer{}
|
||||
var _ TLSExtensionHandler = &extensionHandlerServer{}
|
||||
|
||||
// NewExtensionHandlerServer creates a new extension handler for the server
|
||||
func NewExtensionHandlerServer(
|
||||
params *TransportParameters,
|
||||
supportedVersions []protocol.VersionNumber,
|
||||
|
@ -66,7 +67,10 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi
|
|||
|
||||
func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
|
||||
ext := &tlsExtensionBody{}
|
||||
found, _ := el.Find(ext)
|
||||
found, err := el.Find(ext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if hType != mint.HandshakeTypeClientHello {
|
||||
if found {
|
||||
|
|
|
@ -48,7 +48,8 @@ var _ = Describe("TLS Extension Handler, for the server", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(el).To(HaveLen(1))
|
||||
ext := &tlsExtensionBody{}
|
||||
found := el.Find(ext)
|
||||
found, err := el.Find(ext)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(found).To(BeTrue())
|
||||
eetp := &encryptedExtensionsTransportParameters{}
|
||||
_, err = syntax.Unmarshal(ext.data, eetp)
|
||||
|
|
|
@ -64,6 +64,9 @@ type ByteCount uint64
|
|||
// MaxByteCount is the maximum value of a ByteCount
|
||||
const MaxByteCount = ByteCount(1<<62 - 1)
|
||||
|
||||
// An ApplicationErrorCode is an application-defined error code.
|
||||
type ApplicationErrorCode uint16
|
||||
|
||||
// MaxReceivePacketSize maximum packet size of any QUIC packet, based on
|
||||
// ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header,
|
||||
// UDP adds an additional 8 bytes. This is a total overhead of 48 bytes.
|
||||
|
|
9
vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go
generated
vendored
9
vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go
generated
vendored
|
@ -56,6 +56,9 @@ const DefaultMaxReceiveConnectionFlowControlWindowClient = 15 * (1 << 20) // 15
|
|||
// This is the value that Chromium is using
|
||||
const ConnectionFlowControlMultiplier = 1.5
|
||||
|
||||
// WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client
|
||||
const WindowUpdateThreshold = 0.25
|
||||
|
||||
// MaxIncomingStreams is the maximum number of streams that a peer may open
|
||||
const MaxIncomingStreams = 100
|
||||
|
||||
|
@ -122,3 +125,9 @@ const ClosedSessionDeleteTimeout = time.Minute
|
|||
|
||||
// NumCachedCertificates is the number of cached compressed certificate chains, each taking ~1K space
|
||||
const NumCachedCertificates = 128
|
||||
|
||||
// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame.
|
||||
// This avoids splitting up STREAM frames into small pieces, which has 2 advantages:
|
||||
// 1. it reduces the framing overhead
|
||||
// 2. it reduces the head-of-line blocking, when a packet is lost
|
||||
const MinStreamFrameSize ByteCount = 128
|
||||
|
|
|
@ -139,7 +139,7 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error
|
|||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *AckFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
func (f *AckFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
|
||||
if !version.UsesIETFFrameFormat() {
|
||||
return f.minLengthLegacy(version)
|
||||
}
|
||||
|
@ -157,7 +157,7 @@ func (f *AckFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount
|
|||
length += utils.VarIntLen(uint64(f.LargestAcked - lowestInFirstRange))
|
||||
|
||||
if !f.HasMissingRanges() {
|
||||
return length, nil
|
||||
return length
|
||||
}
|
||||
var lowest protocol.PacketNumber
|
||||
for i, ackRange := range f.AckRanges {
|
||||
|
@ -169,7 +169,7 @@ func (f *AckFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount
|
|||
length += utils.VarIntLen(uint64(ackRange.Last - ackRange.First))
|
||||
lowest = ackRange.First
|
||||
}
|
||||
return length, nil
|
||||
return length
|
||||
}
|
||||
|
||||
// HasMissingRanges returns if this frame reports any missing packets
|
||||
|
|
|
@ -308,7 +308,7 @@ func (f *AckFrame) writeLegacy(b *bytes.Buffer, _ protocol.VersionNumber) error
|
|||
return nil
|
||||
}
|
||||
|
||||
func (f *AckFrame) minLengthLegacy(_ protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
func (f *AckFrame) minLengthLegacy(_ protocol.VersionNumber) protocol.ByteCount {
|
||||
length := protocol.ByteCount(1 + 2 + 1) // 1 TypeByte, 2 ACK delay time, 1 Num Timestamp
|
||||
length += protocol.ByteCount(protocol.GetPacketNumberLength(f.LargestAcked))
|
||||
|
||||
|
@ -320,7 +320,7 @@ func (f *AckFrame) minLengthLegacy(_ protocol.VersionNumber) (protocol.ByteCount
|
|||
length += missingSequenceNumberDeltaLen
|
||||
}
|
||||
// we don't write
|
||||
return length, nil
|
||||
return length
|
||||
}
|
||||
|
||||
// numWritableNackRanges calculates the number of ACK blocks that are about to be written
|
||||
|
|
|
@ -4,17 +4,26 @@ import (
|
|||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// A BlockedFrame is a BLOCKED frame
|
||||
type BlockedFrame struct{}
|
||||
type BlockedFrame struct {
|
||||
Offset protocol.ByteCount
|
||||
}
|
||||
|
||||
// ParseBlockedFrame parses a BLOCKED frame
|
||||
func ParseBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*BlockedFrame, error) {
|
||||
func ParseBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*BlockedFrame, error) {
|
||||
if _, err := r.ReadByte(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &BlockedFrame{}, nil
|
||||
offset, err := utils.ReadVarInt(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &BlockedFrame{
|
||||
Offset: protocol.ByteCount(offset),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
||||
|
@ -23,13 +32,14 @@ func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) er
|
|||
}
|
||||
typeByte := uint8(0x08)
|
||||
b.WriteByte(typeByte)
|
||||
utils.WriteVarInt(b, uint64(f.Offset))
|
||||
return nil
|
||||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *BlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
if !version.UsesIETFFrameFormat() { // writing this frame would result in a legacy BLOCKED being written, which is longer
|
||||
return 1 + 4, nil
|
||||
func (f *BlockedFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
|
||||
if !version.UsesIETFFrameFormat() {
|
||||
return 1 + 4
|
||||
}
|
||||
return 1, nil
|
||||
return 1 + utils.VarIntLen(uint64(f.Offset))
|
||||
}
|
||||
|
|
|
@ -2,8 +2,10 @@ package wire
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
@ -12,30 +14,41 @@ import (
|
|||
var _ = Describe("BLOCKED frame", func() {
|
||||
Context("when parsing", func() {
|
||||
It("accepts sample frame", func() {
|
||||
b := bytes.NewReader([]byte{0x08})
|
||||
_, err := ParseBlockedFrame(b, protocol.VersionWhatever)
|
||||
data := []byte{0x08}
|
||||
data = append(data, encodeVarInt(0x12345678)...)
|
||||
b := bytes.NewReader(data)
|
||||
frame, err := ParseBlockedFrame(b, versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(frame.Offset).To(Equal(protocol.ByteCount(0x12345678)))
|
||||
Expect(b.Len()).To(BeZero())
|
||||
})
|
||||
|
||||
It("errors on EOFs", func() {
|
||||
_, err := ParseBlockedFrame(bytes.NewReader(nil), protocol.VersionWhatever)
|
||||
Expect(err).To(HaveOccurred())
|
||||
data := []byte{0x08}
|
||||
data = append(data, encodeVarInt(0x12345678)...)
|
||||
_, err := ParseBlockedFrame(bytes.NewReader(data), versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
for i := range data {
|
||||
_, err := ParseBlockedFrame(bytes.NewReader(data[:i]), versionIETFFrames)
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("when writing", func() {
|
||||
It("writes a sample frame", func() {
|
||||
b := &bytes.Buffer{}
|
||||
frame := BlockedFrame{}
|
||||
frame := BlockedFrame{Offset: 0xdeadbeef}
|
||||
err := frame.Write(b, protocol.VersionWhatever)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b.Bytes()).To(Equal([]byte{0x08}))
|
||||
expected := []byte{0x08}
|
||||
expected = append(expected, encodeVarInt(0xdeadbeef)...)
|
||||
Expect(b.Bytes()).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("has the correct min length", func() {
|
||||
frame := BlockedFrame{}
|
||||
Expect(frame.MinLength(versionIETFFrames)).To(Equal(protocol.ByteCount(1)))
|
||||
frame := BlockedFrame{Offset: 0x12345}
|
||||
Expect(frame.MinLength(versionIETFFrames)).To(Equal(1 + utils.VarIntLen(0x12345)))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
6
vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go
generated
vendored
6
vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go
generated
vendored
|
@ -68,11 +68,11 @@ func ParseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber)
|
|||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *ConnectionCloseFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
func (f *ConnectionCloseFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
|
||||
if version.UsesIETFFrameFormat() {
|
||||
return 1 + 2 + utils.VarIntLen(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase)), nil
|
||||
return 1 + 2 + utils.VarIntLen(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase))
|
||||
}
|
||||
return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase)), nil
|
||||
return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase))
|
||||
}
|
||||
|
||||
// Write writes an CONNECTION_CLOSE frame.
|
||||
|
|
|
@ -9,5 +9,5 @@ import (
|
|||
// A Frame in QUIC
|
||||
type Frame interface {
|
||||
Write(b *bytes.Buffer, version protocol.VersionNumber) error
|
||||
MinLength(version protocol.VersionNumber) (protocol.ByteCount, error)
|
||||
MinLength(version protocol.VersionNumber) protocol.ByteCount
|
||||
}
|
||||
|
|
|
@ -63,6 +63,6 @@ func (f *GoawayFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
|
|||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *GoawayFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
return protocol.ByteCount(1 + 4 + 4 + 2 + len(f.ReasonPhrase)), nil
|
||||
func (f *GoawayFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
|
||||
return protocol.ByteCount(1 + 4 + 4 + 2 + len(f.ReasonPhrase))
|
||||
}
|
||||
|
|
|
@ -43,9 +43,9 @@ func (f *MaxDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) er
|
|||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *MaxDataFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
func (f *MaxDataFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
|
||||
if !version.UsesIETFFrameFormat() { // writing this frame would result in a gQUIC WINDOW_UPDATE being written, which is longer
|
||||
return 1 + 4 + 8, nil
|
||||
return 1 + 4 + 8
|
||||
}
|
||||
return 1 + utils.VarIntLen(uint64(f.ByteOffset)), nil
|
||||
return 1 + utils.VarIntLen(uint64(f.ByteOffset))
|
||||
}
|
||||
|
|
6
vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go
generated
vendored
6
vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go
generated
vendored
|
@ -51,10 +51,10 @@ func (f *MaxStreamDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumb
|
|||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *MaxStreamDataFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
func (f *MaxStreamDataFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
|
||||
// writing this frame would result in a gQUIC WINDOW_UPDATE being written, which has a different length
|
||||
if !version.UsesIETFFrameFormat() {
|
||||
return 1 + 4 + 8, nil
|
||||
return 1 + 4 + 8
|
||||
}
|
||||
return 1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.ByteOffset)), nil
|
||||
return 1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.ByteOffset))
|
||||
}
|
||||
|
|
37
vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_id_frame.go
generated
vendored
Normal file
37
vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_id_frame.go
generated
vendored
Normal file
|
@ -0,0 +1,37 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// A MaxStreamIDFrame is a MAX_STREAM_ID frame
|
||||
type MaxStreamIDFrame struct {
|
||||
StreamID protocol.StreamID
|
||||
}
|
||||
|
||||
// ParseMaxStreamIDFrame parses a MAX_STREAM_ID frame
|
||||
func ParseMaxStreamIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamIDFrame, error) {
|
||||
// read the Type byte
|
||||
if _, err := r.ReadByte(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
streamID, err := utils.ReadVarInt(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &MaxStreamIDFrame{StreamID: protocol.StreamID(streamID)}, nil
|
||||
}
|
||||
|
||||
func (f *MaxStreamIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
|
||||
b.WriteByte(0x6)
|
||||
utils.WriteVarInt(b, uint64(f.StreamID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *MaxStreamIDFrame) MinLength(protocol.VersionNumber) protocol.ByteCount {
|
||||
return 1 + utils.VarIntLen(uint64(f.StreamID))
|
||||
}
|
51
vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_id_frame_test.go
generated
vendored
Normal file
51
vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_id_frame_test.go
generated
vendored
Normal file
|
@ -0,0 +1,51 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("MAX_STREAM_ID frame", func() {
|
||||
Context("parsing", func() {
|
||||
It("accepts sample frame", func() {
|
||||
data := []byte{0x6}
|
||||
data = append(data, encodeVarInt(0xdecafbad)...)
|
||||
b := bytes.NewReader(data)
|
||||
f, err := ParseMaxStreamIDFrame(b, protocol.VersionWhatever)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(f.StreamID).To(Equal(protocol.StreamID(0xdecafbad)))
|
||||
Expect(b.Len()).To(BeZero())
|
||||
})
|
||||
|
||||
It("errors on EOFs", func() {
|
||||
data := []byte{0x06}
|
||||
data = append(data, encodeVarInt(0xdeadbeefcafe13)...)
|
||||
_, err := ParseMaxStreamIDFrame(bytes.NewReader(data), protocol.VersionWhatever)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
for i := range data {
|
||||
_, err := ParseMaxStreamIDFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever)
|
||||
Expect(err).To(HaveOccurred())
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("writing", func() {
|
||||
It("writes a sample frame", func() {
|
||||
b := &bytes.Buffer{}
|
||||
frame := MaxStreamIDFrame{StreamID: 0x12345678}
|
||||
frame.Write(b, protocol.VersionWhatever)
|
||||
expected := []byte{0x6}
|
||||
expected = append(expected, encodeVarInt(0x12345678)...)
|
||||
Expect(b.Bytes()).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("has the correct min length", func() {
|
||||
frame := MaxStreamIDFrame{StreamID: 0x1337}
|
||||
Expect(frame.MinLength(protocol.VersionWhatever)).To(Equal(1 + utils.VarIntLen(0x1337)))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -28,6 +28,6 @@ func (f *PingFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error
|
|||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *PingFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
return 1, nil
|
||||
func (f *PingFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
|
||||
return 1
|
||||
}
|
||||
|
|
|
@ -7,10 +7,12 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// A RstStreamFrame in QUIC
|
||||
// A RstStreamFrame is a RST_STREAM frame in QUIC
|
||||
type RstStreamFrame struct {
|
||||
StreamID protocol.StreamID
|
||||
ErrorCode uint32
|
||||
// The error code is a uint32 in gQUIC, but a uint16 in IETF QUIC.
|
||||
// protocol.ApplicaitonErrorCode is a uint16, so larger values in gQUIC frames will be truncated.
|
||||
ErrorCode protocol.ApplicationErrorCode
|
||||
ByteOffset protocol.ByteCount
|
||||
}
|
||||
|
||||
|
@ -21,7 +23,7 @@ func ParseRstStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*RstS
|
|||
}
|
||||
|
||||
var streamID protocol.StreamID
|
||||
var errorCode uint32
|
||||
var errorCode uint16
|
||||
var byteOffset protocol.ByteCount
|
||||
if version.UsesIETFFrameFormat() {
|
||||
sid, err := utils.ReadVarInt(r)
|
||||
|
@ -29,11 +31,10 @@ func ParseRstStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*RstS
|
|||
return nil, err
|
||||
}
|
||||
streamID = protocol.StreamID(sid)
|
||||
ec, err := utils.BigEndian.ReadUint16(r)
|
||||
errorCode, err = utils.BigEndian.ReadUint16(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
errorCode = uint32(ec)
|
||||
bo, err := utils.ReadVarInt(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -54,12 +55,12 @@ func ParseRstStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*RstS
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
errorCode = uint32(ec)
|
||||
errorCode = uint16(ec)
|
||||
}
|
||||
|
||||
return &RstStreamFrame{
|
||||
StreamID: streamID,
|
||||
ErrorCode: errorCode,
|
||||
ErrorCode: protocol.ApplicationErrorCode(errorCode),
|
||||
ByteOffset: byteOffset,
|
||||
}, nil
|
||||
}
|
||||
|
@ -74,15 +75,15 @@ func (f *RstStreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber)
|
|||
} else {
|
||||
utils.BigEndian.WriteUint32(b, uint32(f.StreamID))
|
||||
utils.BigEndian.WriteUint64(b, uint64(f.ByteOffset))
|
||||
utils.BigEndian.WriteUint32(b, f.ErrorCode)
|
||||
utils.BigEndian.WriteUint32(b, uint32(f.ErrorCode))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *RstStreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
func (f *RstStreamFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
|
||||
if version.UsesIETFFrameFormat() {
|
||||
return 1 + utils.VarIntLen(uint64(f.StreamID)) + 2 + utils.VarIntLen(uint64(f.ByteOffset)), nil
|
||||
return 1 + utils.VarIntLen(uint64(f.StreamID)) + 2 + utils.VarIntLen(uint64(f.ByteOffset))
|
||||
}
|
||||
return 1 + 4 + 8 + 4, nil
|
||||
return 1 + 4 + 8 + 4
|
||||
}
|
||||
|
|
10
vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame_test.go
generated
vendored
10
vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame_test.go
generated
vendored
|
@ -22,7 +22,7 @@ var _ = Describe("RST_STREAM frame", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef)))
|
||||
Expect(frame.ByteOffset).To(Equal(protocol.ByteCount(0x987654321)))
|
||||
Expect(frame.ErrorCode).To(Equal(uint32(0x1337)))
|
||||
Expect(frame.ErrorCode).To(Equal(protocol.ApplicationErrorCode(0x1337)))
|
||||
})
|
||||
|
||||
It("errors on EOFs", func() {
|
||||
|
@ -44,13 +44,13 @@ var _ = Describe("RST_STREAM frame", func() {
|
|||
b := bytes.NewReader([]byte{0x1,
|
||||
0xde, 0xad, 0xbe, 0xef, // stream id
|
||||
0x88, 0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11, // byte offset
|
||||
0x34, 0x12, 0x37, 0x13, // error code
|
||||
0x0, 0x0, 0xca, 0xfe, // error code
|
||||
})
|
||||
frame, err := ParseRstStreamFrame(b, versionBigEndian)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef)))
|
||||
Expect(frame.ByteOffset).To(Equal(protocol.ByteCount(0x8877665544332211)))
|
||||
Expect(frame.ErrorCode).To(Equal(uint32(0x34123713)))
|
||||
Expect(frame.ErrorCode).To(Equal(protocol.ApplicationErrorCode(0xcafe)))
|
||||
})
|
||||
|
||||
It("errors on EOFs", func() {
|
||||
|
@ -103,7 +103,7 @@ var _ = Describe("RST_STREAM frame", func() {
|
|||
frame := RstStreamFrame{
|
||||
StreamID: 0x1337,
|
||||
ByteOffset: 0x11223344decafbad,
|
||||
ErrorCode: 0xdeadbeef,
|
||||
ErrorCode: 0xcafe,
|
||||
}
|
||||
b := &bytes.Buffer{}
|
||||
err := frame.Write(b, versionBigEndian)
|
||||
|
@ -111,7 +111,7 @@ var _ = Describe("RST_STREAM frame", func() {
|
|||
Expect(b.Bytes()).To(Equal([]byte{0x01,
|
||||
0x0, 0x0, 0x13, 0x37, // stream id
|
||||
0x11, 0x22, 0x33, 0x44, 0xde, 0xca, 0xfb, 0xad, // byte offset
|
||||
0xde, 0xad, 0xbe, 0xef, // error code
|
||||
0x0, 0x0, 0xca, 0xfe, // error code
|
||||
}))
|
||||
})
|
||||
|
||||
|
|
47
vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame.go
generated
vendored
Normal file
47
vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame.go
generated
vendored
Normal file
|
@ -0,0 +1,47 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// A StopSendingFrame is a STOP_SENDING frame
|
||||
type StopSendingFrame struct {
|
||||
StreamID protocol.StreamID
|
||||
ErrorCode protocol.ApplicationErrorCode
|
||||
}
|
||||
|
||||
// ParseStopSendingFrame parses a STOP_SENDING frame
|
||||
func ParseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSendingFrame, error) {
|
||||
if _, err := r.ReadByte(); err != nil { // read the TypeByte
|
||||
return nil, err
|
||||
}
|
||||
|
||||
streamID, err := utils.ReadVarInt(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
errorCode, err := utils.BigEndian.ReadUint16(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &StopSendingFrame{
|
||||
StreamID: protocol.StreamID(streamID),
|
||||
ErrorCode: protocol.ApplicationErrorCode(errorCode),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *StopSendingFrame) MinLength(_ protocol.VersionNumber) protocol.ByteCount {
|
||||
return 1 + utils.VarIntLen(uint64(f.StreamID)) + 2
|
||||
}
|
||||
|
||||
func (f *StopSendingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
|
||||
b.WriteByte(0x0c)
|
||||
utils.WriteVarInt(b, uint64(f.StreamID))
|
||||
utils.BigEndian.WriteUint16(b, uint16(f.ErrorCode))
|
||||
return nil
|
||||
}
|
63
vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame_test.go
generated
vendored
Normal file
63
vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame_test.go
generated
vendored
Normal file
|
@ -0,0 +1,63 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("STOP_SENDING frame", func() {
|
||||
Context("when parsing", func() {
|
||||
It("parses a sample frame", func() {
|
||||
data := []byte{0x0c}
|
||||
data = append(data, encodeVarInt(0xdecafbad)...) // stream ID
|
||||
data = append(data, []byte{0x13, 0x37}...) // error code
|
||||
b := bytes.NewReader(data)
|
||||
frame, err := ParseStopSendingFrame(b, versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdecafbad)))
|
||||
Expect(frame.ErrorCode).To(Equal(protocol.ApplicationErrorCode(0x1337)))
|
||||
Expect(b.Len()).To(BeZero())
|
||||
})
|
||||
|
||||
It("errors on EOFs", func() {
|
||||
data := []byte{0x0c}
|
||||
data = append(data, encodeVarInt(0xdecafbad)...) // stream ID
|
||||
data = append(data, []byte{0x13, 0x37}...) // error code
|
||||
_, err := ParseStopSendingFrame(bytes.NewReader(data), versionIETFFrames)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
for i := range data {
|
||||
_, err := ParseStopSendingFrame(bytes.NewReader(data[:i]), versionIETFFrames)
|
||||
Expect(err).To(HaveOccurred())
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("when writing", func() {
|
||||
It("writes", func() {
|
||||
frame := &StopSendingFrame{
|
||||
StreamID: 0xdeadbeefcafe,
|
||||
ErrorCode: 0x10,
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
err := frame.Write(buf, versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expected := []byte{0x0c}
|
||||
expected = append(expected, encodeVarInt(0xdeadbeefcafe)...)
|
||||
expected = append(expected, []byte{0x0, 0x10}...)
|
||||
Expect(buf.Bytes()).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("has the correct min length", func() {
|
||||
frame := &StopSendingFrame{
|
||||
StreamID: 0xdeadbeef,
|
||||
ErrorCode: 0x10,
|
||||
}
|
||||
Expect(frame.MinLength(versionIETFFrames)).To(Equal(1 + 2 + utils.VarIntLen(0xdeadbeef)))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -22,7 +22,10 @@ var (
|
|||
errPacketNumberLenNotSet = errors.New("StopWaitingFrame: PacketNumberLen not set")
|
||||
)
|
||||
|
||||
func (f *StopWaitingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
|
||||
func (f *StopWaitingFrame) Write(b *bytes.Buffer, v protocol.VersionNumber) error {
|
||||
if v.UsesIETFFrameFormat() {
|
||||
return errors.New("STOP_WAITING not defined in IETF QUIC")
|
||||
}
|
||||
// make sure the PacketNumber was set
|
||||
if f.PacketNumber == protocol.PacketNumber(0) {
|
||||
return errPacketNumberNotSet
|
||||
|
@ -49,14 +52,8 @@ func (f *StopWaitingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) erro
|
|||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *StopWaitingFrame) MinLength(_ protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
minLength := protocol.ByteCount(1) // typeByte
|
||||
|
||||
if f.PacketNumberLen == protocol.PacketNumberLenInvalid {
|
||||
return 0, errPacketNumberLenNotSet
|
||||
}
|
||||
minLength += protocol.ByteCount(f.PacketNumberLen)
|
||||
return minLength, nil
|
||||
func (f *StopWaitingFrame) MinLength(_ protocol.VersionNumber) protocol.ByteCount {
|
||||
return 1 + protocol.ByteCount(f.PacketNumberLen)
|
||||
}
|
||||
|
||||
// ParseStopWaitingFrame parses a StopWaiting frame
|
||||
|
|
31
vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_waiting_frame_test.go
generated
vendored
31
vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_waiting_frame_test.go
generated
vendored
|
@ -84,7 +84,7 @@ var _ = Describe("StopWaitingFrame", func() {
|
|||
LeastUnacked: 10,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}
|
||||
err := frame.Write(b, protocol.VersionWhatever)
|
||||
err := frame.Write(b, versionBigEndian)
|
||||
Expect(err).To(MatchError(errPacketNumberNotSet))
|
||||
})
|
||||
|
||||
|
@ -94,7 +94,7 @@ var _ = Describe("StopWaitingFrame", func() {
|
|||
LeastUnacked: 10,
|
||||
PacketNumber: 13,
|
||||
}
|
||||
err := frame.Write(b, protocol.VersionWhatever)
|
||||
err := frame.Write(b, versionBigEndian)
|
||||
Expect(err).To(MatchError(errPacketNumberLenNotSet))
|
||||
})
|
||||
|
||||
|
@ -105,10 +105,21 @@ var _ = Describe("StopWaitingFrame", func() {
|
|||
PacketNumber: 5,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}
|
||||
err := frame.Write(b, protocol.VersionWhatever)
|
||||
err := frame.Write(b, versionBigEndian)
|
||||
Expect(err).To(MatchError(errLeastUnackedHigherThanPacketNumber))
|
||||
})
|
||||
|
||||
It("refuses to write for IETF QUIC", func() {
|
||||
b := &bytes.Buffer{}
|
||||
frame := &StopWaitingFrame{
|
||||
LeastUnacked: 10,
|
||||
PacketNumber: 13,
|
||||
PacketNumberLen: protocol.PacketNumberLen6,
|
||||
}
|
||||
err := frame.Write(b, versionIETFFrames)
|
||||
Expect(err).To(MatchError("STOP_WAITING not defined in IETF QUIC"))
|
||||
})
|
||||
|
||||
Context("LeastUnackedDelta length", func() {
|
||||
Context("in big endian", func() {
|
||||
It("writes a 1-byte LeastUnackedDelta", func() {
|
||||
|
@ -176,18 +187,10 @@ var _ = Describe("StopWaitingFrame", func() {
|
|||
Expect(frame.MinLength(protocol.VersionWhatever)).To(Equal(protocol.ByteCount(length + 1)))
|
||||
}
|
||||
})
|
||||
|
||||
It("errors when packetNumberLen is not set", func() {
|
||||
frame := &StopWaitingFrame{
|
||||
LeastUnacked: 10,
|
||||
}
|
||||
_, err := frame.MinLength(0)
|
||||
Expect(err).To(MatchError(errPacketNumberLenNotSet))
|
||||
})
|
||||
})
|
||||
|
||||
Context("self consistency", func() {
|
||||
It("reads a stop waiting frame that it wrote", func() {
|
||||
It("reads a STOP_WAITING frame that it wrote", func() {
|
||||
packetNumber := protocol.PacketNumber(13)
|
||||
frame := &StopWaitingFrame{
|
||||
LeastUnacked: 10,
|
||||
|
@ -195,9 +198,9 @@ var _ = Describe("StopWaitingFrame", func() {
|
|||
PacketNumberLen: protocol.PacketNumberLen4,
|
||||
}
|
||||
b := &bytes.Buffer{}
|
||||
err := frame.Write(b, protocol.VersionWhatever)
|
||||
err := frame.Write(b, versionBigEndian)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
readframe, err := ParseStopWaitingFrame(bytes.NewReader(b.Bytes()), packetNumber, protocol.PacketNumberLen4, protocol.VersionWhatever)
|
||||
readframe, err := ParseStopWaitingFrame(bytes.NewReader(b.Bytes()), packetNumber, protocol.PacketNumberLen4, versionBigEndian)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(readframe.LeastUnacked).To(Equal(frame.LeastUnacked))
|
||||
})
|
||||
|
|
19
vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame.go
generated
vendored
19
vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame.go
generated
vendored
|
@ -10,10 +10,11 @@ import (
|
|||
// A StreamBlockedFrame in QUIC
|
||||
type StreamBlockedFrame struct {
|
||||
StreamID protocol.StreamID
|
||||
Offset protocol.ByteCount
|
||||
}
|
||||
|
||||
// ParseStreamBlockedFrame parses a STREAM_BLOCKED frame
|
||||
func ParseStreamBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamBlockedFrame, error) {
|
||||
func ParseStreamBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamBlockedFrame, error) {
|
||||
if _, err := r.ReadByte(); err != nil { // read the TypeByte
|
||||
return nil, err
|
||||
}
|
||||
|
@ -21,7 +22,14 @@ func ParseStreamBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &StreamBlockedFrame{StreamID: protocol.StreamID(sid)}, nil
|
||||
offset, err := utils.ReadVarInt(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &StreamBlockedFrame{
|
||||
StreamID: protocol.StreamID(sid),
|
||||
Offset: protocol.ByteCount(offset),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Write writes a STREAM_BLOCKED frame
|
||||
|
@ -31,13 +39,14 @@ func (f *StreamBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumb
|
|||
}
|
||||
b.WriteByte(0x09)
|
||||
utils.WriteVarInt(b, uint64(f.StreamID))
|
||||
utils.WriteVarInt(b, uint64(f.Offset))
|
||||
return nil
|
||||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *StreamBlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
func (f *StreamBlockedFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
|
||||
if !version.UsesIETFFrameFormat() {
|
||||
return 1 + 4, nil
|
||||
return 1 + 4
|
||||
}
|
||||
return 1 + utils.VarIntLen(uint64(f.StreamID)), nil
|
||||
return 1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.Offset))
|
||||
}
|
||||
|
|
10
vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame_test.go
generated
vendored
10
vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame_test.go
generated
vendored
|
@ -14,17 +14,20 @@ var _ = Describe("STREAM_BLOCKED frame", func() {
|
|||
Context("parsing", func() {
|
||||
It("accepts sample frame", func() {
|
||||
data := []byte{0x9}
|
||||
data = append(data, encodeVarInt(0xdeadbeef)...)
|
||||
data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID
|
||||
data = append(data, encodeVarInt(0xdecafbad)...) // offset
|
||||
b := bytes.NewReader(data)
|
||||
frame, err := ParseStreamBlockedFrame(b, versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef)))
|
||||
Expect(frame.Offset).To(Equal(protocol.ByteCount(0xdecafbad)))
|
||||
Expect(b.Len()).To(BeZero())
|
||||
})
|
||||
|
||||
It("errors on EOFs", func() {
|
||||
data := []byte{0x9}
|
||||
data = append(data, encodeVarInt(0xdeadbeef)...)
|
||||
data = append(data, encodeVarInt(0xc0010ff)...)
|
||||
_, err := ParseStreamBlockedFrame(bytes.NewReader(data), versionIETFFrames)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
for i := range data {
|
||||
|
@ -38,19 +41,22 @@ var _ = Describe("STREAM_BLOCKED frame", func() {
|
|||
It("has proper min length", func() {
|
||||
f := &StreamBlockedFrame{
|
||||
StreamID: 0x1337,
|
||||
Offset: 0xdeadbeef,
|
||||
}
|
||||
Expect(f.MinLength(0)).To(Equal(1 + utils.VarIntLen(0x1337)))
|
||||
Expect(f.MinLength(0)).To(Equal(1 + utils.VarIntLen(0x1337) + utils.VarIntLen(0xdeadbeef)))
|
||||
})
|
||||
|
||||
It("writes a sample frame", func() {
|
||||
b := &bytes.Buffer{}
|
||||
f := &StreamBlockedFrame{
|
||||
StreamID: 0xdecafbad,
|
||||
Offset: 0x1337,
|
||||
}
|
||||
err := f.Write(b, versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expected := []byte{0x9}
|
||||
expected = append(expected, encodeVarInt(uint64(f.StreamID))...)
|
||||
expected = append(expected, encodeVarInt(uint64(f.Offset))...)
|
||||
Expect(b.Bytes()).To(Equal(expected))
|
||||
})
|
||||
})
|
||||
|
|
|
@ -117,7 +117,7 @@ func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) err
|
|||
|
||||
// MinLength returns the length of the header of a StreamFrame
|
||||
// the total length of the frame is frame.MinLength() + frame.DataLen()
|
||||
func (f *StreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
func (f *StreamFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
|
||||
if !version.UsesIETFFrameFormat() {
|
||||
return f.minLengthLegacy(version)
|
||||
}
|
||||
|
@ -128,5 +128,5 @@ func (f *StreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCo
|
|||
if f.DataLenPresent {
|
||||
length += utils.VarIntLen(uint64(f.DataLen()))
|
||||
}
|
||||
return length, nil
|
||||
return length
|
||||
}
|
||||
|
|
|
@ -183,12 +183,12 @@ func (f *StreamFrame) getOffsetLength() protocol.ByteCount {
|
|||
return 8
|
||||
}
|
||||
|
||||
func (f *StreamFrame) minLengthLegacy(_ protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
func (f *StreamFrame) minLengthLegacy(_ protocol.VersionNumber) protocol.ByteCount {
|
||||
length := protocol.ByteCount(1) + protocol.ByteCount(f.calculateStreamIDLength()) + f.getOffsetLength()
|
||||
if f.DataLenPresent {
|
||||
length += 2
|
||||
}
|
||||
return length, nil
|
||||
return length
|
||||
}
|
||||
|
||||
// DataLen gives the length of data in bytes
|
||||
|
|
8
vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame_legacy_test.go
generated
vendored
8
vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame_legacy_test.go
generated
vendored
|
@ -210,7 +210,7 @@ var _ = Describe("STREAM frame (for gQUIC)", func() {
|
|||
}
|
||||
err := f.Write(b, versionBigEndian)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
minLength, _ := f.MinLength(0)
|
||||
minLength := f.MinLength(0)
|
||||
Expect(b.Bytes()[0] & 0x20).To(Equal(uint8(0x20)))
|
||||
Expect(b.Bytes()[minLength-2 : minLength]).To(Equal([]byte{0x13, 0x37}))
|
||||
})
|
||||
|
@ -229,9 +229,9 @@ var _ = Describe("STREAM frame (for gQUIC)", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b.Bytes()[0] & 0x20).To(Equal(uint8(0)))
|
||||
Expect(b.Bytes()[1 : b.Len()-dataLen]).ToNot(ContainSubstring(string([]byte{0x37, 0x13})))
|
||||
minLength, _ := f.MinLength(versionBigEndian)
|
||||
minLength := f.MinLength(versionBigEndian)
|
||||
f.DataLenPresent = true
|
||||
minLengthWithoutDataLen, _ := f.MinLength(versionBigEndian)
|
||||
minLengthWithoutDataLen := f.MinLength(versionBigEndian)
|
||||
Expect(minLength).To(Equal(minLengthWithoutDataLen - 2))
|
||||
})
|
||||
|
||||
|
@ -242,7 +242,7 @@ var _ = Describe("STREAM frame (for gQUIC)", func() {
|
|||
DataLenPresent: false,
|
||||
Offset: 0xdeadbeef,
|
||||
}
|
||||
minLengthWithoutDataLen, _ := f.MinLength(versionBigEndian)
|
||||
minLengthWithoutDataLen := f.MinLength(versionBigEndian)
|
||||
f.DataLenPresent = true
|
||||
Expect(f.MinLength(versionBigEndian)).To(Equal(minLengthWithoutDataLen + 2))
|
||||
})
|
||||
|
|
37
vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_id_blocked_frame.go
generated
vendored
Normal file
37
vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_id_blocked_frame.go
generated
vendored
Normal file
|
@ -0,0 +1,37 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// A StreamIDBlockedFrame is a STREAM_ID_BLOCKED frame
|
||||
type StreamIDBlockedFrame struct {
|
||||
StreamID protocol.StreamID
|
||||
}
|
||||
|
||||
// ParseStreamIDBlockedFrame parses a STREAM_ID_BLOCKED frame
|
||||
func ParseStreamIDBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamIDBlockedFrame, error) {
|
||||
if _, err := r.ReadByte(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
streamID, err := utils.ReadVarInt(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &StreamIDBlockedFrame{StreamID: protocol.StreamID(streamID)}, nil
|
||||
}
|
||||
|
||||
func (f *StreamIDBlockedFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
|
||||
typeByte := uint8(0x0a)
|
||||
b.WriteByte(typeByte)
|
||||
utils.WriteVarInt(b, uint64(f.StreamID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *StreamIDBlockedFrame) MinLength(_ protocol.VersionNumber) protocol.ByteCount {
|
||||
return 1 + utils.VarIntLen(uint64(f.StreamID))
|
||||
}
|
53
vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_id_blocked_frame_test.go
generated
vendored
Normal file
53
vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_id_blocked_frame_test.go
generated
vendored
Normal file
|
@ -0,0 +1,53 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("STREAM_ID_BLOCKED frame", func() {
|
||||
Context("parsing", func() {
|
||||
It("accepts sample frame", func() {
|
||||
expected := []byte{0xa}
|
||||
expected = append(expected, encodeVarInt(0xdecafbad)...)
|
||||
b := bytes.NewReader(expected)
|
||||
frame, err := ParseStreamIDBlockedFrame(b, protocol.VersionWhatever)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdecafbad)))
|
||||
Expect(b.Len()).To(BeZero())
|
||||
})
|
||||
|
||||
It("errors on EOFs", func() {
|
||||
data := []byte{0xa}
|
||||
data = append(data, encodeVarInt(0x12345678)...)
|
||||
_, err := ParseStreamIDBlockedFrame(bytes.NewReader(data), versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
for i := range data {
|
||||
_, err := ParseStreamIDBlockedFrame(bytes.NewReader(data[:i]), versionIETFFrames)
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("writing", func() {
|
||||
It("writes a sample frame", func() {
|
||||
b := &bytes.Buffer{}
|
||||
frame := StreamIDBlockedFrame{StreamID: 0xdeadbeefcafe}
|
||||
err := frame.Write(b, protocol.VersionWhatever)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expected := []byte{0xa}
|
||||
expected = append(expected, encodeVarInt(0xdeadbeefcafe)...)
|
||||
Expect(b.Bytes()).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("has the correct min length", func() {
|
||||
frame := StreamIDBlockedFrame{StreamID: 0x123456}
|
||||
Expect(frame.MinLength(0)).To(Equal(protocol.ByteCount(1) + utils.VarIntLen(0x123456)))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -56,6 +56,10 @@ func (mc *mintController) State() mint.State {
|
|||
return mc.conn.State().HandshakeState
|
||||
}
|
||||
|
||||
func (mc *mintController) ConnectionState() mint.ConnectionState {
|
||||
return mc.conn.State()
|
||||
}
|
||||
|
||||
func (mc *mintController) SetCryptoStream(stream io.ReadWriter) {
|
||||
mc.csc.SetStream(stream)
|
||||
}
|
||||
|
@ -73,6 +77,7 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf
|
|||
},
|
||||
}
|
||||
if tlsConf != nil {
|
||||
mconf.ServerName = tlsConf.ServerName
|
||||
mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates))
|
||||
for i, certChain := range tlsConf.Certificates {
|
||||
mconf.Certificates[i] = &mint.Certificate{
|
||||
|
@ -87,6 +92,13 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf
|
|||
mconf.Certificates[i].Chain[j] = c
|
||||
}
|
||||
}
|
||||
switch tlsConf.ClientAuth {
|
||||
case tls.NoClientCert:
|
||||
case tls.RequireAnyClientCert:
|
||||
mconf.RequireClientAuth = true
|
||||
default:
|
||||
return nil, errors.New("mint currently only support ClientAuthType RequireAnyClientCert")
|
||||
}
|
||||
}
|
||||
if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil {
|
||||
return nil, err
|
||||
|
@ -128,14 +140,14 @@ func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, versio
|
|||
|
||||
// packUnencryptedPacket provides a low-overhead way to pack a packet.
|
||||
// It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available.
|
||||
func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, sf *wire.StreamFrame, pers protocol.Perspective) ([]byte, error) {
|
||||
func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, f wire.Frame, pers protocol.Perspective) ([]byte, error) {
|
||||
raw := getPacketBuffer()
|
||||
buffer := bytes.NewBuffer(raw)
|
||||
if err := hdr.Write(buffer, pers, hdr.Version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payloadStartIndex := buffer.Len()
|
||||
if err := sf.Write(buffer, hdr.Version); err != nil {
|
||||
if err := f.Write(buffer, hdr.Version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw = raw[0:buffer.Len()]
|
||||
|
@ -144,7 +156,7 @@ func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, sf *wire.StreamFr
|
|||
if utils.Debug() {
|
||||
utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(raw), hdr.ConnectionID, protocol.EncryptionUnencrypted)
|
||||
hdr.Log()
|
||||
wire.LogFrame(sf, true)
|
||||
wire.LogFrame(f, true)
|
||||
}
|
||||
return raw, nil
|
||||
}
|
||||
|
|
|
@ -2,9 +2,11 @@ package quic
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
@ -33,6 +35,45 @@ var _ = Describe("Packing and unpacking Initial packets", func() {
|
|||
hdr.Raw = buf.Bytes()
|
||||
})
|
||||
|
||||
Context("generating a mint.Config", func() {
|
||||
It("sets non-blocking mode", func() {
|
||||
mintConf, err := tlsToMintConfig(nil, protocol.PerspectiveClient)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(mintConf.NonBlocking).To(BeTrue())
|
||||
})
|
||||
|
||||
It("sets the server name", func() {
|
||||
conf := &tls.Config{ServerName: "www.example.com"}
|
||||
mintConf, err := tlsToMintConfig(conf, protocol.PerspectiveClient)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(mintConf.ServerName).To(Equal("www.example.com"))
|
||||
})
|
||||
|
||||
It("sets the certificate chain", func() {
|
||||
tlsConf := testdata.GetTLSConfig()
|
||||
mintConf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveClient)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(mintConf.Certificates).ToNot(BeEmpty())
|
||||
Expect(mintConf.Certificates).To(HaveLen(len(tlsConf.Certificates)))
|
||||
})
|
||||
|
||||
It("requires client authentication", func() {
|
||||
mintConf, err := tlsToMintConfig(nil, protocol.PerspectiveClient)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(mintConf.RequireClientAuth).To(BeFalse())
|
||||
conf := &tls.Config{ClientAuth: tls.RequireAnyClientCert}
|
||||
mintConf, err = tlsToMintConfig(conf, protocol.PerspectiveClient)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(mintConf.RequireClientAuth).To(BeTrue())
|
||||
})
|
||||
|
||||
It("rejects unsupported client auth types", func() {
|
||||
conf := &tls.Config{ClientAuth: tls.RequireAndVerifyClientCert}
|
||||
_, err := tlsToMintConfig(conf, protocol.PerspectiveClient)
|
||||
Expect(err).To(MatchError("mint currently only support ClientAuthType RequireAnyClientCert"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("unpacking", func() {
|
||||
packPacket := func(frames []wire.Frame) []byte {
|
||||
buf := &bytes.Buffer{}
|
||||
|
|
141
vendor/github.com/lucas-clemente/quic-go/mock_crypto_stream_test.go
generated
vendored
Normal file
141
vendor/github.com/lucas-clemente/quic-go/mock_crypto_stream_test.go
generated
vendored
Normal file
|
@ -0,0 +1,141 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go (interfaces: CryptoStream)
|
||||
|
||||
// Package quic is a generated GoMock package.
|
||||
package quic
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
wire "github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// MockCryptoStream is a mock of CryptoStream interface
|
||||
type MockCryptoStream struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockCryptoStreamMockRecorder
|
||||
}
|
||||
|
||||
// MockCryptoStreamMockRecorder is the mock recorder for MockCryptoStream
|
||||
type MockCryptoStreamMockRecorder struct {
|
||||
mock *MockCryptoStream
|
||||
}
|
||||
|
||||
// NewMockCryptoStream creates a new mock instance
|
||||
func NewMockCryptoStream(ctrl *gomock.Controller) *MockCryptoStream {
|
||||
mock := &MockCryptoStream{ctrl: ctrl}
|
||||
mock.recorder = &MockCryptoStreamMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockCryptoStream) EXPECT() *MockCryptoStreamMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Read mocks base method
|
||||
func (m *MockCryptoStream) Read(arg0 []byte) (int, error) {
|
||||
ret := m.ctrl.Call(m, "Read", arg0)
|
||||
ret0, _ := ret[0].(int)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Read indicates an expected call of Read
|
||||
func (mr *MockCryptoStreamMockRecorder) Read(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockCryptoStream)(nil).Read), arg0)
|
||||
}
|
||||
|
||||
// StreamID mocks base method
|
||||
func (m *MockCryptoStream) StreamID() protocol.StreamID {
|
||||
ret := m.ctrl.Call(m, "StreamID")
|
||||
ret0, _ := ret[0].(protocol.StreamID)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// StreamID indicates an expected call of StreamID
|
||||
func (mr *MockCryptoStreamMockRecorder) StreamID() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockCryptoStream)(nil).StreamID))
|
||||
}
|
||||
|
||||
// Write mocks base method
|
||||
func (m *MockCryptoStream) Write(arg0 []byte) (int, error) {
|
||||
ret := m.ctrl.Call(m, "Write", arg0)
|
||||
ret0, _ := ret[0].(int)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Write indicates an expected call of Write
|
||||
func (mr *MockCryptoStreamMockRecorder) Write(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockCryptoStream)(nil).Write), arg0)
|
||||
}
|
||||
|
||||
// closeForShutdown mocks base method
|
||||
func (m *MockCryptoStream) closeForShutdown(arg0 error) {
|
||||
m.ctrl.Call(m, "closeForShutdown", arg0)
|
||||
}
|
||||
|
||||
// closeForShutdown indicates an expected call of closeForShutdown
|
||||
func (mr *MockCryptoStreamMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockCryptoStream)(nil).closeForShutdown), arg0)
|
||||
}
|
||||
|
||||
// getWindowUpdate mocks base method
|
||||
func (m *MockCryptoStream) getWindowUpdate() protocol.ByteCount {
|
||||
ret := m.ctrl.Call(m, "getWindowUpdate")
|
||||
ret0, _ := ret[0].(protocol.ByteCount)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// getWindowUpdate indicates an expected call of getWindowUpdate
|
||||
func (mr *MockCryptoStreamMockRecorder) getWindowUpdate() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockCryptoStream)(nil).getWindowUpdate))
|
||||
}
|
||||
|
||||
// handleMaxStreamDataFrame mocks base method
|
||||
func (m *MockCryptoStream) handleMaxStreamDataFrame(arg0 *wire.MaxStreamDataFrame) {
|
||||
m.ctrl.Call(m, "handleMaxStreamDataFrame", arg0)
|
||||
}
|
||||
|
||||
// handleMaxStreamDataFrame indicates an expected call of handleMaxStreamDataFrame
|
||||
func (mr *MockCryptoStreamMockRecorder) handleMaxStreamDataFrame(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleMaxStreamDataFrame", reflect.TypeOf((*MockCryptoStream)(nil).handleMaxStreamDataFrame), arg0)
|
||||
}
|
||||
|
||||
// handleStreamFrame mocks base method
|
||||
func (m *MockCryptoStream) handleStreamFrame(arg0 *wire.StreamFrame) error {
|
||||
ret := m.ctrl.Call(m, "handleStreamFrame", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// handleStreamFrame indicates an expected call of handleStreamFrame
|
||||
func (mr *MockCryptoStreamMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockCryptoStream)(nil).handleStreamFrame), arg0)
|
||||
}
|
||||
|
||||
// popStreamFrame mocks base method
|
||||
func (m *MockCryptoStream) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) {
|
||||
ret := m.ctrl.Call(m, "popStreamFrame", arg0)
|
||||
ret0, _ := ret[0].(*wire.StreamFrame)
|
||||
ret1, _ := ret[1].(bool)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// popStreamFrame indicates an expected call of popStreamFrame
|
||||
func (mr *MockCryptoStreamMockRecorder) popStreamFrame(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockCryptoStream)(nil).popStreamFrame), arg0)
|
||||
}
|
||||
|
||||
// setReadOffset mocks base method
|
||||
func (m *MockCryptoStream) setReadOffset(arg0 protocol.ByteCount) {
|
||||
m.ctrl.Call(m, "setReadOffset", arg0)
|
||||
}
|
||||
|
||||
// setReadOffset indicates an expected call of setReadOffset
|
||||
func (mr *MockCryptoStreamMockRecorder) setReadOffset(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setReadOffset", reflect.TypeOf((*MockCryptoStream)(nil).setReadOffset), arg0)
|
||||
}
|
132
vendor/github.com/lucas-clemente/quic-go/mock_receive_stream_internal_test.go
generated
vendored
Normal file
132
vendor/github.com/lucas-clemente/quic-go/mock_receive_stream_internal_test.go
generated
vendored
Normal file
|
@ -0,0 +1,132 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go (interfaces: ReceiveStreamI)
|
||||
|
||||
// Package quic is a generated GoMock package.
|
||||
package quic
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
wire "github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// MockReceiveStreamI is a mock of ReceiveStreamI interface
|
||||
type MockReceiveStreamI struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockReceiveStreamIMockRecorder
|
||||
}
|
||||
|
||||
// MockReceiveStreamIMockRecorder is the mock recorder for MockReceiveStreamI
|
||||
type MockReceiveStreamIMockRecorder struct {
|
||||
mock *MockReceiveStreamI
|
||||
}
|
||||
|
||||
// NewMockReceiveStreamI creates a new mock instance
|
||||
func NewMockReceiveStreamI(ctrl *gomock.Controller) *MockReceiveStreamI {
|
||||
mock := &MockReceiveStreamI{ctrl: ctrl}
|
||||
mock.recorder = &MockReceiveStreamIMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockReceiveStreamI) EXPECT() *MockReceiveStreamIMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// CancelRead mocks base method
|
||||
func (m *MockReceiveStreamI) CancelRead(arg0 protocol.ApplicationErrorCode) error {
|
||||
ret := m.ctrl.Call(m, "CancelRead", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CancelRead indicates an expected call of CancelRead
|
||||
func (mr *MockReceiveStreamIMockRecorder) CancelRead(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockReceiveStreamI)(nil).CancelRead), arg0)
|
||||
}
|
||||
|
||||
// Read mocks base method
|
||||
func (m *MockReceiveStreamI) Read(arg0 []byte) (int, error) {
|
||||
ret := m.ctrl.Call(m, "Read", arg0)
|
||||
ret0, _ := ret[0].(int)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Read indicates an expected call of Read
|
||||
func (mr *MockReceiveStreamIMockRecorder) Read(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockReceiveStreamI)(nil).Read), arg0)
|
||||
}
|
||||
|
||||
// SetReadDeadline mocks base method
|
||||
func (m *MockReceiveStreamI) SetReadDeadline(arg0 time.Time) error {
|
||||
ret := m.ctrl.Call(m, "SetReadDeadline", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SetReadDeadline indicates an expected call of SetReadDeadline
|
||||
func (mr *MockReceiveStreamIMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockReceiveStreamI)(nil).SetReadDeadline), arg0)
|
||||
}
|
||||
|
||||
// StreamID mocks base method
|
||||
func (m *MockReceiveStreamI) StreamID() protocol.StreamID {
|
||||
ret := m.ctrl.Call(m, "StreamID")
|
||||
ret0, _ := ret[0].(protocol.StreamID)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// StreamID indicates an expected call of StreamID
|
||||
func (mr *MockReceiveStreamIMockRecorder) StreamID() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockReceiveStreamI)(nil).StreamID))
|
||||
}
|
||||
|
||||
// closeForShutdown mocks base method
|
||||
func (m *MockReceiveStreamI) closeForShutdown(arg0 error) {
|
||||
m.ctrl.Call(m, "closeForShutdown", arg0)
|
||||
}
|
||||
|
||||
// closeForShutdown indicates an expected call of closeForShutdown
|
||||
func (mr *MockReceiveStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockReceiveStreamI)(nil).closeForShutdown), arg0)
|
||||
}
|
||||
|
||||
// getWindowUpdate mocks base method
|
||||
func (m *MockReceiveStreamI) getWindowUpdate() protocol.ByteCount {
|
||||
ret := m.ctrl.Call(m, "getWindowUpdate")
|
||||
ret0, _ := ret[0].(protocol.ByteCount)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// getWindowUpdate indicates an expected call of getWindowUpdate
|
||||
func (mr *MockReceiveStreamIMockRecorder) getWindowUpdate() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockReceiveStreamI)(nil).getWindowUpdate))
|
||||
}
|
||||
|
||||
// handleRstStreamFrame mocks base method
|
||||
func (m *MockReceiveStreamI) handleRstStreamFrame(arg0 *wire.RstStreamFrame) error {
|
||||
ret := m.ctrl.Call(m, "handleRstStreamFrame", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// handleRstStreamFrame indicates an expected call of handleRstStreamFrame
|
||||
func (mr *MockReceiveStreamIMockRecorder) handleRstStreamFrame(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleRstStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleRstStreamFrame), arg0)
|
||||
}
|
||||
|
||||
// handleStreamFrame mocks base method
|
||||
func (m *MockReceiveStreamI) handleStreamFrame(arg0 *wire.StreamFrame) error {
|
||||
ret := m.ctrl.Call(m, "handleStreamFrame", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// handleStreamFrame indicates an expected call of handleStreamFrame
|
||||
func (mr *MockReceiveStreamIMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleStreamFrame), arg0)
|
||||
}
|
154
vendor/github.com/lucas-clemente/quic-go/mock_send_stream_internal_test.go
generated
vendored
Normal file
154
vendor/github.com/lucas-clemente/quic-go/mock_send_stream_internal_test.go
generated
vendored
Normal file
|
@ -0,0 +1,154 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go (interfaces: SendStreamI)
|
||||
|
||||
// Package quic is a generated GoMock package.
|
||||
package quic
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
wire "github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// MockSendStreamI is a mock of SendStreamI interface
|
||||
type MockSendStreamI struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockSendStreamIMockRecorder
|
||||
}
|
||||
|
||||
// MockSendStreamIMockRecorder is the mock recorder for MockSendStreamI
|
||||
type MockSendStreamIMockRecorder struct {
|
||||
mock *MockSendStreamI
|
||||
}
|
||||
|
||||
// NewMockSendStreamI creates a new mock instance
|
||||
func NewMockSendStreamI(ctrl *gomock.Controller) *MockSendStreamI {
|
||||
mock := &MockSendStreamI{ctrl: ctrl}
|
||||
mock.recorder = &MockSendStreamIMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockSendStreamI) EXPECT() *MockSendStreamIMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// CancelWrite mocks base method
|
||||
func (m *MockSendStreamI) CancelWrite(arg0 protocol.ApplicationErrorCode) error {
|
||||
ret := m.ctrl.Call(m, "CancelWrite", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CancelWrite indicates an expected call of CancelWrite
|
||||
func (mr *MockSendStreamIMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockSendStreamI)(nil).CancelWrite), arg0)
|
||||
}
|
||||
|
||||
// Close mocks base method
|
||||
func (m *MockSendStreamI) Close() error {
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close
|
||||
func (mr *MockSendStreamIMockRecorder) Close() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSendStreamI)(nil).Close))
|
||||
}
|
||||
|
||||
// Context mocks base method
|
||||
func (m *MockSendStreamI) Context() context.Context {
|
||||
ret := m.ctrl.Call(m, "Context")
|
||||
ret0, _ := ret[0].(context.Context)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Context indicates an expected call of Context
|
||||
func (mr *MockSendStreamIMockRecorder) Context() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockSendStreamI)(nil).Context))
|
||||
}
|
||||
|
||||
// SetWriteDeadline mocks base method
|
||||
func (m *MockSendStreamI) SetWriteDeadline(arg0 time.Time) error {
|
||||
ret := m.ctrl.Call(m, "SetWriteDeadline", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SetWriteDeadline indicates an expected call of SetWriteDeadline
|
||||
func (mr *MockSendStreamIMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockSendStreamI)(nil).SetWriteDeadline), arg0)
|
||||
}
|
||||
|
||||
// StreamID mocks base method
|
||||
func (m *MockSendStreamI) StreamID() protocol.StreamID {
|
||||
ret := m.ctrl.Call(m, "StreamID")
|
||||
ret0, _ := ret[0].(protocol.StreamID)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// StreamID indicates an expected call of StreamID
|
||||
func (mr *MockSendStreamIMockRecorder) StreamID() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockSendStreamI)(nil).StreamID))
|
||||
}
|
||||
|
||||
// Write mocks base method
|
||||
func (m *MockSendStreamI) Write(arg0 []byte) (int, error) {
|
||||
ret := m.ctrl.Call(m, "Write", arg0)
|
||||
ret0, _ := ret[0].(int)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Write indicates an expected call of Write
|
||||
func (mr *MockSendStreamIMockRecorder) Write(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendStreamI)(nil).Write), arg0)
|
||||
}
|
||||
|
||||
// closeForShutdown mocks base method
|
||||
func (m *MockSendStreamI) closeForShutdown(arg0 error) {
|
||||
m.ctrl.Call(m, "closeForShutdown", arg0)
|
||||
}
|
||||
|
||||
// closeForShutdown indicates an expected call of closeForShutdown
|
||||
func (mr *MockSendStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockSendStreamI)(nil).closeForShutdown), arg0)
|
||||
}
|
||||
|
||||
// handleMaxStreamDataFrame mocks base method
|
||||
func (m *MockSendStreamI) handleMaxStreamDataFrame(arg0 *wire.MaxStreamDataFrame) {
|
||||
m.ctrl.Call(m, "handleMaxStreamDataFrame", arg0)
|
||||
}
|
||||
|
||||
// handleMaxStreamDataFrame indicates an expected call of handleMaxStreamDataFrame
|
||||
func (mr *MockSendStreamIMockRecorder) handleMaxStreamDataFrame(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleMaxStreamDataFrame", reflect.TypeOf((*MockSendStreamI)(nil).handleMaxStreamDataFrame), arg0)
|
||||
}
|
||||
|
||||
// handleStopSendingFrame mocks base method
|
||||
func (m *MockSendStreamI) handleStopSendingFrame(arg0 *wire.StopSendingFrame) {
|
||||
m.ctrl.Call(m, "handleStopSendingFrame", arg0)
|
||||
}
|
||||
|
||||
// handleStopSendingFrame indicates an expected call of handleStopSendingFrame
|
||||
func (mr *MockSendStreamIMockRecorder) handleStopSendingFrame(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockSendStreamI)(nil).handleStopSendingFrame), arg0)
|
||||
}
|
||||
|
||||
// popStreamFrame mocks base method
|
||||
func (m *MockSendStreamI) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) {
|
||||
ret := m.ctrl.Call(m, "popStreamFrame", arg0)
|
||||
ret0, _ := ret[0].(*wire.StreamFrame)
|
||||
ret1, _ := ret[1].(bool)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// popStreamFrame indicates an expected call of popStreamFrame
|
||||
func (mr *MockSendStreamIMockRecorder) popStreamFrame(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockSendStreamI)(nil).popStreamFrame), arg0)
|
||||
}
|
72
vendor/github.com/lucas-clemente/quic-go/mock_stream_frame_source_test.go
generated
vendored
Normal file
72
vendor/github.com/lucas-clemente/quic-go/mock_stream_frame_source_test.go
generated
vendored
Normal file
|
@ -0,0 +1,72 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go (interfaces: StreamFrameSource)
|
||||
|
||||
// Package quic is a generated GoMock package.
|
||||
package quic
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
wire "github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// MockStreamFrameSource is a mock of StreamFrameSource interface
|
||||
type MockStreamFrameSource struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockStreamFrameSourceMockRecorder
|
||||
}
|
||||
|
||||
// MockStreamFrameSourceMockRecorder is the mock recorder for MockStreamFrameSource
|
||||
type MockStreamFrameSourceMockRecorder struct {
|
||||
mock *MockStreamFrameSource
|
||||
}
|
||||
|
||||
// NewMockStreamFrameSource creates a new mock instance
|
||||
func NewMockStreamFrameSource(ctrl *gomock.Controller) *MockStreamFrameSource {
|
||||
mock := &MockStreamFrameSource{ctrl: ctrl}
|
||||
mock.recorder = &MockStreamFrameSourceMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockStreamFrameSource) EXPECT() *MockStreamFrameSourceMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// HasCryptoStreamData mocks base method
|
||||
func (m *MockStreamFrameSource) HasCryptoStreamData() bool {
|
||||
ret := m.ctrl.Call(m, "HasCryptoStreamData")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// HasCryptoStreamData indicates an expected call of HasCryptoStreamData
|
||||
func (mr *MockStreamFrameSourceMockRecorder) HasCryptoStreamData() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasCryptoStreamData", reflect.TypeOf((*MockStreamFrameSource)(nil).HasCryptoStreamData))
|
||||
}
|
||||
|
||||
// PopCryptoStreamFrame mocks base method
|
||||
func (m *MockStreamFrameSource) PopCryptoStreamFrame(arg0 protocol.ByteCount) *wire.StreamFrame {
|
||||
ret := m.ctrl.Call(m, "PopCryptoStreamFrame", arg0)
|
||||
ret0, _ := ret[0].(*wire.StreamFrame)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// PopCryptoStreamFrame indicates an expected call of PopCryptoStreamFrame
|
||||
func (mr *MockStreamFrameSourceMockRecorder) PopCryptoStreamFrame(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopCryptoStreamFrame", reflect.TypeOf((*MockStreamFrameSource)(nil).PopCryptoStreamFrame), arg0)
|
||||
}
|
||||
|
||||
// PopStreamFrames mocks base method
|
||||
func (m *MockStreamFrameSource) PopStreamFrames(arg0 protocol.ByteCount) []*wire.StreamFrame {
|
||||
ret := m.ctrl.Call(m, "PopStreamFrames", arg0)
|
||||
ret0, _ := ret[0].([]*wire.StreamFrame)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// PopStreamFrames indicates an expected call of PopStreamFrames
|
||||
func (mr *MockStreamFrameSourceMockRecorder) PopStreamFrames(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopStreamFrames", reflect.TypeOf((*MockStreamFrameSource)(nil).PopStreamFrames), arg0)
|
||||
}
|
61
vendor/github.com/lucas-clemente/quic-go/mock_stream_getter_test.go
generated
vendored
Normal file
61
vendor/github.com/lucas-clemente/quic-go/mock_stream_getter_test.go
generated
vendored
Normal file
|
@ -0,0 +1,61 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go (interfaces: StreamGetter)
|
||||
|
||||
// Package quic is a generated GoMock package.
|
||||
package quic
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// MockStreamGetter is a mock of StreamGetter interface
|
||||
type MockStreamGetter struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockStreamGetterMockRecorder
|
||||
}
|
||||
|
||||
// MockStreamGetterMockRecorder is the mock recorder for MockStreamGetter
|
||||
type MockStreamGetterMockRecorder struct {
|
||||
mock *MockStreamGetter
|
||||
}
|
||||
|
||||
// NewMockStreamGetter creates a new mock instance
|
||||
func NewMockStreamGetter(ctrl *gomock.Controller) *MockStreamGetter {
|
||||
mock := &MockStreamGetter{ctrl: ctrl}
|
||||
mock.recorder = &MockStreamGetterMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockStreamGetter) EXPECT() *MockStreamGetterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetOrOpenReceiveStream mocks base method
|
||||
func (m *MockStreamGetter) GetOrOpenReceiveStream(arg0 protocol.StreamID) (receiveStreamI, error) {
|
||||
ret := m.ctrl.Call(m, "GetOrOpenReceiveStream", arg0)
|
||||
ret0, _ := ret[0].(receiveStreamI)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream
|
||||
func (mr *MockStreamGetterMockRecorder) GetOrOpenReceiveStream(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenReceiveStream), arg0)
|
||||
}
|
||||
|
||||
// GetOrOpenSendStream mocks base method
|
||||
func (m *MockStreamGetter) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStreamI, error) {
|
||||
ret := m.ctrl.Call(m, "GetOrOpenSendStream", arg0)
|
||||
ret0, _ := ret[0].(sendStreamI)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream
|
||||
func (mr *MockStreamGetterMockRecorder) GetOrOpenSendStream(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenSendStream), arg0)
|
||||
}
|
239
vendor/github.com/lucas-clemente/quic-go/mock_stream_internal_test.go
generated
vendored
Normal file
239
vendor/github.com/lucas-clemente/quic-go/mock_stream_internal_test.go
generated
vendored
Normal file
|
@ -0,0 +1,239 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go (interfaces: StreamI)
|
||||
|
||||
// Package quic is a generated GoMock package.
|
||||
package quic
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
wire "github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// MockStreamI is a mock of StreamI interface
|
||||
type MockStreamI struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockStreamIMockRecorder
|
||||
}
|
||||
|
||||
// MockStreamIMockRecorder is the mock recorder for MockStreamI
|
||||
type MockStreamIMockRecorder struct {
|
||||
mock *MockStreamI
|
||||
}
|
||||
|
||||
// NewMockStreamI creates a new mock instance
|
||||
func NewMockStreamI(ctrl *gomock.Controller) *MockStreamI {
|
||||
mock := &MockStreamI{ctrl: ctrl}
|
||||
mock.recorder = &MockStreamIMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockStreamI) EXPECT() *MockStreamIMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// CancelRead mocks base method
|
||||
func (m *MockStreamI) CancelRead(arg0 protocol.ApplicationErrorCode) error {
|
||||
ret := m.ctrl.Call(m, "CancelRead", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CancelRead indicates an expected call of CancelRead
|
||||
func (mr *MockStreamIMockRecorder) CancelRead(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockStreamI)(nil).CancelRead), arg0)
|
||||
}
|
||||
|
||||
// CancelWrite mocks base method
|
||||
func (m *MockStreamI) CancelWrite(arg0 protocol.ApplicationErrorCode) error {
|
||||
ret := m.ctrl.Call(m, "CancelWrite", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CancelWrite indicates an expected call of CancelWrite
|
||||
func (mr *MockStreamIMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockStreamI)(nil).CancelWrite), arg0)
|
||||
}
|
||||
|
||||
// Close mocks base method
|
||||
func (m *MockStreamI) Close() error {
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close
|
||||
func (mr *MockStreamIMockRecorder) Close() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStreamI)(nil).Close))
|
||||
}
|
||||
|
||||
// Context mocks base method
|
||||
func (m *MockStreamI) Context() context.Context {
|
||||
ret := m.ctrl.Call(m, "Context")
|
||||
ret0, _ := ret[0].(context.Context)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Context indicates an expected call of Context
|
||||
func (mr *MockStreamIMockRecorder) Context() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockStreamI)(nil).Context))
|
||||
}
|
||||
|
||||
// Read mocks base method
|
||||
func (m *MockStreamI) Read(arg0 []byte) (int, error) {
|
||||
ret := m.ctrl.Call(m, "Read", arg0)
|
||||
ret0, _ := ret[0].(int)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Read indicates an expected call of Read
|
||||
func (mr *MockStreamIMockRecorder) Read(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStreamI)(nil).Read), arg0)
|
||||
}
|
||||
|
||||
// SetDeadline mocks base method
|
||||
func (m *MockStreamI) SetDeadline(arg0 time.Time) error {
|
||||
ret := m.ctrl.Call(m, "SetDeadline", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SetDeadline indicates an expected call of SetDeadline
|
||||
func (mr *MockStreamIMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockStreamI)(nil).SetDeadline), arg0)
|
||||
}
|
||||
|
||||
// SetReadDeadline mocks base method
|
||||
func (m *MockStreamI) SetReadDeadline(arg0 time.Time) error {
|
||||
ret := m.ctrl.Call(m, "SetReadDeadline", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SetReadDeadline indicates an expected call of SetReadDeadline
|
||||
func (mr *MockStreamIMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockStreamI)(nil).SetReadDeadline), arg0)
|
||||
}
|
||||
|
||||
// SetWriteDeadline mocks base method
|
||||
func (m *MockStreamI) SetWriteDeadline(arg0 time.Time) error {
|
||||
ret := m.ctrl.Call(m, "SetWriteDeadline", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SetWriteDeadline indicates an expected call of SetWriteDeadline
|
||||
func (mr *MockStreamIMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStreamI)(nil).SetWriteDeadline), arg0)
|
||||
}
|
||||
|
||||
// StreamID mocks base method
|
||||
func (m *MockStreamI) StreamID() protocol.StreamID {
|
||||
ret := m.ctrl.Call(m, "StreamID")
|
||||
ret0, _ := ret[0].(protocol.StreamID)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// StreamID indicates an expected call of StreamID
|
||||
func (mr *MockStreamIMockRecorder) StreamID() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockStreamI)(nil).StreamID))
|
||||
}
|
||||
|
||||
// Write mocks base method
|
||||
func (m *MockStreamI) Write(arg0 []byte) (int, error) {
|
||||
ret := m.ctrl.Call(m, "Write", arg0)
|
||||
ret0, _ := ret[0].(int)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Write indicates an expected call of Write
|
||||
func (mr *MockStreamIMockRecorder) Write(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStreamI)(nil).Write), arg0)
|
||||
}
|
||||
|
||||
// closeForShutdown mocks base method
|
||||
func (m *MockStreamI) closeForShutdown(arg0 error) {
|
||||
m.ctrl.Call(m, "closeForShutdown", arg0)
|
||||
}
|
||||
|
||||
// closeForShutdown indicates an expected call of closeForShutdown
|
||||
func (mr *MockStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockStreamI)(nil).closeForShutdown), arg0)
|
||||
}
|
||||
|
||||
// getWindowUpdate mocks base method
|
||||
func (m *MockStreamI) getWindowUpdate() protocol.ByteCount {
|
||||
ret := m.ctrl.Call(m, "getWindowUpdate")
|
||||
ret0, _ := ret[0].(protocol.ByteCount)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// getWindowUpdate indicates an expected call of getWindowUpdate
|
||||
func (mr *MockStreamIMockRecorder) getWindowUpdate() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockStreamI)(nil).getWindowUpdate))
|
||||
}
|
||||
|
||||
// handleMaxStreamDataFrame mocks base method
|
||||
func (m *MockStreamI) handleMaxStreamDataFrame(arg0 *wire.MaxStreamDataFrame) {
|
||||
m.ctrl.Call(m, "handleMaxStreamDataFrame", arg0)
|
||||
}
|
||||
|
||||
// handleMaxStreamDataFrame indicates an expected call of handleMaxStreamDataFrame
|
||||
func (mr *MockStreamIMockRecorder) handleMaxStreamDataFrame(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleMaxStreamDataFrame", reflect.TypeOf((*MockStreamI)(nil).handleMaxStreamDataFrame), arg0)
|
||||
}
|
||||
|
||||
// handleRstStreamFrame mocks base method
|
||||
func (m *MockStreamI) handleRstStreamFrame(arg0 *wire.RstStreamFrame) error {
|
||||
ret := m.ctrl.Call(m, "handleRstStreamFrame", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// handleRstStreamFrame indicates an expected call of handleRstStreamFrame
|
||||
func (mr *MockStreamIMockRecorder) handleRstStreamFrame(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleRstStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleRstStreamFrame), arg0)
|
||||
}
|
||||
|
||||
// handleStopSendingFrame mocks base method
|
||||
func (m *MockStreamI) handleStopSendingFrame(arg0 *wire.StopSendingFrame) {
|
||||
m.ctrl.Call(m, "handleStopSendingFrame", arg0)
|
||||
}
|
||||
|
||||
// handleStopSendingFrame indicates an expected call of handleStopSendingFrame
|
||||
func (mr *MockStreamIMockRecorder) handleStopSendingFrame(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockStreamI)(nil).handleStopSendingFrame), arg0)
|
||||
}
|
||||
|
||||
// handleStreamFrame mocks base method
|
||||
func (m *MockStreamI) handleStreamFrame(arg0 *wire.StreamFrame) error {
|
||||
ret := m.ctrl.Call(m, "handleStreamFrame", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// handleStreamFrame indicates an expected call of handleStreamFrame
|
||||
func (mr *MockStreamIMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleStreamFrame), arg0)
|
||||
}
|
||||
|
||||
// popStreamFrame mocks base method
|
||||
func (m *MockStreamI) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) {
|
||||
ret := m.ctrl.Call(m, "popStreamFrame", arg0)
|
||||
ret0, _ := ret[0].(*wire.StreamFrame)
|
||||
ret1, _ := ret[1].(bool)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// popStreamFrame indicates an expected call of popStreamFrame
|
||||
func (mr *MockStreamIMockRecorder) popStreamFrame(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockStreamI)(nil).popStreamFrame), arg0)
|
||||
}
|
146
vendor/github.com/lucas-clemente/quic-go/mock_stream_manager_test.go
generated
vendored
Normal file
146
vendor/github.com/lucas-clemente/quic-go/mock_stream_manager_test.go
generated
vendored
Normal file
|
@ -0,0 +1,146 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go (interfaces: StreamManager)
|
||||
|
||||
// Package quic is a generated GoMock package.
|
||||
package quic
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// MockStreamManager is a mock of StreamManager interface
|
||||
type MockStreamManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockStreamManagerMockRecorder
|
||||
}
|
||||
|
||||
// MockStreamManagerMockRecorder is the mock recorder for MockStreamManager
|
||||
type MockStreamManagerMockRecorder struct {
|
||||
mock *MockStreamManager
|
||||
}
|
||||
|
||||
// NewMockStreamManager creates a new mock instance
|
||||
func NewMockStreamManager(ctrl *gomock.Controller) *MockStreamManager {
|
||||
mock := &MockStreamManager{ctrl: ctrl}
|
||||
mock.recorder = &MockStreamManagerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockStreamManager) EXPECT() *MockStreamManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AcceptStream mocks base method
|
||||
func (m *MockStreamManager) AcceptStream() (Stream, error) {
|
||||
ret := m.ctrl.Call(m, "AcceptStream")
|
||||
ret0, _ := ret[0].(Stream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AcceptStream indicates an expected call of AcceptStream
|
||||
func (mr *MockStreamManagerMockRecorder) AcceptStream() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream))
|
||||
}
|
||||
|
||||
// CloseWithError mocks base method
|
||||
func (m *MockStreamManager) CloseWithError(arg0 error) {
|
||||
m.ctrl.Call(m, "CloseWithError", arg0)
|
||||
}
|
||||
|
||||
// CloseWithError indicates an expected call of CloseWithError
|
||||
func (mr *MockStreamManagerMockRecorder) CloseWithError(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockStreamManager)(nil).CloseWithError), arg0)
|
||||
}
|
||||
|
||||
// DeleteStream mocks base method
|
||||
func (m *MockStreamManager) DeleteStream(arg0 protocol.StreamID) error {
|
||||
ret := m.ctrl.Call(m, "DeleteStream", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteStream indicates an expected call of DeleteStream
|
||||
func (mr *MockStreamManagerMockRecorder) DeleteStream(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStream", reflect.TypeOf((*MockStreamManager)(nil).DeleteStream), arg0)
|
||||
}
|
||||
|
||||
// GetOrOpenReceiveStream mocks base method
|
||||
func (m *MockStreamManager) GetOrOpenReceiveStream(arg0 protocol.StreamID) (receiveStreamI, error) {
|
||||
ret := m.ctrl.Call(m, "GetOrOpenReceiveStream", arg0)
|
||||
ret0, _ := ret[0].(receiveStreamI)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream
|
||||
func (mr *MockStreamManagerMockRecorder) GetOrOpenReceiveStream(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenReceiveStream), arg0)
|
||||
}
|
||||
|
||||
// GetOrOpenSendStream mocks base method
|
||||
func (m *MockStreamManager) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStreamI, error) {
|
||||
ret := m.ctrl.Call(m, "GetOrOpenSendStream", arg0)
|
||||
ret0, _ := ret[0].(sendStreamI)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream
|
||||
func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0)
|
||||
}
|
||||
|
||||
// GetOrOpenStream mocks base method
|
||||
func (m *MockStreamManager) GetOrOpenStream(arg0 protocol.StreamID) (streamI, error) {
|
||||
ret := m.ctrl.Call(m, "GetOrOpenStream", arg0)
|
||||
ret0, _ := ret[0].(streamI)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetOrOpenStream indicates an expected call of GetOrOpenStream
|
||||
func (mr *MockStreamManagerMockRecorder) GetOrOpenStream(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenStream), arg0)
|
||||
}
|
||||
|
||||
// OpenStream mocks base method
|
||||
func (m *MockStreamManager) OpenStream() (Stream, error) {
|
||||
ret := m.ctrl.Call(m, "OpenStream")
|
||||
ret0, _ := ret[0].(Stream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// OpenStream indicates an expected call of OpenStream
|
||||
func (mr *MockStreamManagerMockRecorder) OpenStream() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockStreamManager)(nil).OpenStream))
|
||||
}
|
||||
|
||||
// OpenStreamSync mocks base method
|
||||
func (m *MockStreamManager) OpenStreamSync() (Stream, error) {
|
||||
ret := m.ctrl.Call(m, "OpenStreamSync")
|
||||
ret0, _ := ret[0].(Stream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// OpenStreamSync indicates an expected call of OpenStreamSync
|
||||
func (mr *MockStreamManagerMockRecorder) OpenStreamSync() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync))
|
||||
}
|
||||
|
||||
// UpdateLimits mocks base method
|
||||
func (m *MockStreamManager) UpdateLimits(arg0 *handshake.TransportParameters) {
|
||||
m.ctrl.Call(m, "UpdateLimits", arg0)
|
||||
}
|
||||
|
||||
// UpdateLimits indicates an expected call of UpdateLimits
|
||||
func (mr *MockStreamManagerMockRecorder) UpdateLimits(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLimits", reflect.TypeOf((*MockStreamManager)(nil).UpdateLimits), arg0)
|
||||
}
|
76
vendor/github.com/lucas-clemente/quic-go/mock_stream_sender_test.go
generated
vendored
Normal file
76
vendor/github.com/lucas-clemente/quic-go/mock_stream_sender_test.go
generated
vendored
Normal file
|
@ -0,0 +1,76 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go (interfaces: StreamSender)
|
||||
|
||||
// Package quic is a generated GoMock package.
|
||||
package quic
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
wire "github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// MockStreamSender is a mock of StreamSender interface
|
||||
type MockStreamSender struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockStreamSenderMockRecorder
|
||||
}
|
||||
|
||||
// MockStreamSenderMockRecorder is the mock recorder for MockStreamSender
|
||||
type MockStreamSenderMockRecorder struct {
|
||||
mock *MockStreamSender
|
||||
}
|
||||
|
||||
// NewMockStreamSender creates a new mock instance
|
||||
func NewMockStreamSender(ctrl *gomock.Controller) *MockStreamSender {
|
||||
mock := &MockStreamSender{ctrl: ctrl}
|
||||
mock.recorder = &MockStreamSenderMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockStreamSender) EXPECT() *MockStreamSenderMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// onHasStreamData mocks base method
|
||||
func (m *MockStreamSender) onHasStreamData(arg0 protocol.StreamID) {
|
||||
m.ctrl.Call(m, "onHasStreamData", arg0)
|
||||
}
|
||||
|
||||
// onHasStreamData indicates an expected call of onHasStreamData
|
||||
func (mr *MockStreamSenderMockRecorder) onHasStreamData(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamData", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamData), arg0)
|
||||
}
|
||||
|
||||
// onHasWindowUpdate mocks base method
|
||||
func (m *MockStreamSender) onHasWindowUpdate(arg0 protocol.StreamID) {
|
||||
m.ctrl.Call(m, "onHasWindowUpdate", arg0)
|
||||
}
|
||||
|
||||
// onHasWindowUpdate indicates an expected call of onHasWindowUpdate
|
||||
func (mr *MockStreamSenderMockRecorder) onHasWindowUpdate(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasWindowUpdate", reflect.TypeOf((*MockStreamSender)(nil).onHasWindowUpdate), arg0)
|
||||
}
|
||||
|
||||
// onStreamCompleted mocks base method
|
||||
func (m *MockStreamSender) onStreamCompleted(arg0 protocol.StreamID) {
|
||||
m.ctrl.Call(m, "onStreamCompleted", arg0)
|
||||
}
|
||||
|
||||
// onStreamCompleted indicates an expected call of onStreamCompleted
|
||||
func (mr *MockStreamSenderMockRecorder) onStreamCompleted(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onStreamCompleted", reflect.TypeOf((*MockStreamSender)(nil).onStreamCompleted), arg0)
|
||||
}
|
||||
|
||||
// queueControlFrame mocks base method
|
||||
func (m *MockStreamSender) queueControlFrame(arg0 wire.Frame) {
|
||||
m.ctrl.Call(m, "queueControlFrame", arg0)
|
||||
}
|
||||
|
||||
// queueControlFrame indicates an expected call of queueControlFrame
|
||||
func (mr *MockStreamSenderMockRecorder) queueControlFrame(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "queueControlFrame", reflect.TypeOf((*MockStreamSender)(nil).queueControlFrame), arg0)
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
package quic
|
||||
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/lucas-clemente/quic-go streamI StreamI"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_receive_stream_internal_test.go github.com/lucas-clemente/quic-go receiveStreamI ReceiveStreamI"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_send_stream_internal_test.go github.com/lucas-clemente/quic-go sendStreamI SendStreamI"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender StreamSender"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter StreamGetter"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource StreamFrameSource"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStreamI CryptoStream"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager StreamManager"
|
||||
//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_stream_getter_test.go mock_stream_manager_test.go"
|
||||
//go:generate sh -c "goimports -w mock*_test.go"
|
|
@ -0,0 +1,19 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Mockgen refuses to generate mocks private types.
|
||||
# This script copies the quic package to a temporary directory, and adds an public alias for the private type.
|
||||
# It then creates a mock for this public (alias) type.
|
||||
|
||||
TEMP_DIR=$(mktemp -d)
|
||||
mkdir -p $TEMP_DIR/src/github.com/lucas-clemente/quic-go/
|
||||
|
||||
# copy all .go files to a temporary directory
|
||||
# golang.org/x/crypto/curve25519/ uses Go compiler directives, which is confusing to mockgen
|
||||
rsync -r --exclude 'vendor/golang.org/x/crypto/curve25519/' --include='*.go' --include '*/' --exclude '*' $GOPATH/src/github.com/lucas-clemente/quic-go/ $TEMP_DIR/src/github.com/lucas-clemente/quic-go/
|
||||
echo "type $5 = $4" >> $TEMP_DIR/src/github.com/lucas-clemente/quic-go/interface.go
|
||||
|
||||
export GOPATH="$TEMP_DIR:$GOPATH"
|
||||
|
||||
mockgen -package $1 -self_package $1 -destination $2 $3 $5
|
||||
|
||||
rm -r "$TEMP_DIR"
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/ackhandler"
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
|
@ -18,6 +19,12 @@ type packedPacket struct {
|
|||
encryptionLevel protocol.EncryptionLevel
|
||||
}
|
||||
|
||||
type streamFrameSource interface {
|
||||
HasCryptoStreamData() bool
|
||||
PopCryptoStreamFrame(protocol.ByteCount) *wire.StreamFrame
|
||||
PopStreamFrames(protocol.ByteCount) []*wire.StreamFrame
|
||||
}
|
||||
|
||||
type packetPacker struct {
|
||||
connectionID protocol.ConnectionID
|
||||
perspective protocol.Perspective
|
||||
|
@ -25,20 +32,23 @@ type packetPacker struct {
|
|||
cryptoSetup handshake.CryptoSetup
|
||||
|
||||
packetNumberGenerator *packetNumberGenerator
|
||||
streamFramer *streamFramer
|
||||
streams streamFrameSource
|
||||
|
||||
controlFrameMutex sync.Mutex
|
||||
controlFrames []wire.Frame
|
||||
|
||||
stopWaiting *wire.StopWaitingFrame
|
||||
ackFrame *wire.AckFrame
|
||||
leastUnacked protocol.PacketNumber
|
||||
omitConnectionID bool
|
||||
hasSentPacket bool // has the packetPacker already sent a packet
|
||||
makeNextPacketRetransmittable bool
|
||||
}
|
||||
|
||||
func newPacketPacker(connectionID protocol.ConnectionID,
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
cryptoSetup handshake.CryptoSetup,
|
||||
streamFramer *streamFramer,
|
||||
streamFramer streamFrameSource,
|
||||
perspective protocol.Perspective,
|
||||
version protocol.VersionNumber,
|
||||
) *packetPacker {
|
||||
|
@ -47,7 +57,7 @@ func newPacketPacker(connectionID protocol.ConnectionID,
|
|||
connectionID: connectionID,
|
||||
perspective: perspective,
|
||||
version: version,
|
||||
streamFramer: streamFramer,
|
||||
streams: streamFramer,
|
||||
packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength),
|
||||
}
|
||||
}
|
||||
|
@ -73,7 +83,7 @@ func (p *packetPacker) PackAckPacket() (*packedPacket, error) {
|
|||
encLevel, sealer := p.cryptoSetup.GetSealer()
|
||||
header := p.getHeader(encLevel)
|
||||
frames := []wire.Frame{p.ackFrame}
|
||||
if p.stopWaiting != nil {
|
||||
if p.stopWaiting != nil { // a STOP_WAITING will only be queued when using gQUIC
|
||||
p.stopWaiting.PacketNumber = header.PacketNumber
|
||||
p.stopWaiting.PacketNumberLen = header.PacketNumberLen
|
||||
frames = append(frames, p.stopWaiting)
|
||||
|
@ -98,14 +108,20 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (*
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if p.stopWaiting == nil {
|
||||
return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame")
|
||||
}
|
||||
header := p.getHeader(packet.EncryptionLevel)
|
||||
p.stopWaiting.PacketNumber = header.PacketNumber
|
||||
p.stopWaiting.PacketNumberLen = header.PacketNumberLen
|
||||
frames := append([]wire.Frame{p.stopWaiting}, packet.Frames...)
|
||||
var frames []wire.Frame
|
||||
if !p.version.UsesIETFFrameFormat() { // for gQUIC: pack a STOP_WAITING first
|
||||
if p.stopWaiting == nil {
|
||||
return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame")
|
||||
}
|
||||
swf := p.stopWaiting
|
||||
swf.PacketNumber = header.PacketNumber
|
||||
swf.PacketNumberLen = header.PacketNumberLen
|
||||
p.stopWaiting = nil
|
||||
frames = append([]wire.Frame{swf}, packet.Frames...)
|
||||
} else {
|
||||
frames = packet.Frames
|
||||
}
|
||||
raw, err := p.writeAndSealPacket(header, frames, sealer)
|
||||
return &packedPacket{
|
||||
header: header,
|
||||
|
@ -118,7 +134,7 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (*
|
|||
// PackPacket packs a new packet
|
||||
// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
|
||||
func (p *packetPacker) PackPacket() (*packedPacket, error) {
|
||||
hasCryptoStreamFrame := p.streamFramer.HasCryptoStreamFrame()
|
||||
hasCryptoStreamFrame := p.streams.HasCryptoStreamData()
|
||||
// if this is the first packet to be send, make sure it contains stream data
|
||||
if !p.hasSentPacket && !hasCryptoStreamFrame {
|
||||
return nil, nil
|
||||
|
@ -153,6 +169,15 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) {
|
|||
if len(payloadFrames) == 1 && p.stopWaiting != nil {
|
||||
return nil, nil
|
||||
}
|
||||
// check if this packet only contains an ACK and / or STOP_WAITING
|
||||
if !ackhandler.HasRetransmittableFrames(payloadFrames) {
|
||||
if p.makeNextPacketRetransmittable {
|
||||
payloadFrames = append(payloadFrames, &wire.PingFrame{})
|
||||
p.makeNextPacketRetransmittable = false
|
||||
}
|
||||
} else { // this packet already contains a retransmittable frame. No need to send a PING
|
||||
p.makeNextPacketRetransmittable = false
|
||||
}
|
||||
p.stopWaiting = nil
|
||||
p.ackFrame = nil
|
||||
|
||||
|
@ -176,7 +201,9 @@ func (p *packetPacker) packCryptoPacket() (*packedPacket, error) {
|
|||
return nil, err
|
||||
}
|
||||
maxLen := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength
|
||||
frames := []wire.Frame{p.streamFramer.PopCryptoStreamFrame(maxLen)}
|
||||
sf := p.streams.PopCryptoStreamFrame(maxLen)
|
||||
sf.DataLenPresent = false
|
||||
frames := []wire.Frame{sf}
|
||||
raw, err := p.writeAndSealPacket(header, frames, sealer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -197,29 +224,20 @@ func (p *packetPacker) composeNextPacket(
|
|||
var payloadFrames []wire.Frame
|
||||
|
||||
// STOP_WAITING and ACK will always fit
|
||||
if p.stopWaiting != nil {
|
||||
payloadFrames = append(payloadFrames, p.stopWaiting)
|
||||
l, err := p.stopWaiting.MinLength(p.version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payloadLength += l
|
||||
}
|
||||
if p.ackFrame != nil {
|
||||
if p.ackFrame != nil { // ACKs need to go first, so that the sentPacketHandler will recognize them
|
||||
payloadFrames = append(payloadFrames, p.ackFrame)
|
||||
l, err := p.ackFrame.MinLength(p.version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l := p.ackFrame.MinLength(p.version)
|
||||
payloadLength += l
|
||||
}
|
||||
if p.stopWaiting != nil { // a STOP_WAITING will only be queued when using gQUIC
|
||||
payloadFrames = append(payloadFrames, p.stopWaiting)
|
||||
payloadLength += p.stopWaiting.MinLength(p.version)
|
||||
}
|
||||
|
||||
p.controlFrameMutex.Lock()
|
||||
for len(p.controlFrames) > 0 {
|
||||
frame := p.controlFrames[len(p.controlFrames)-1]
|
||||
minLength, err := frame.MinLength(p.version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
minLength := frame.MinLength(p.version)
|
||||
if payloadLength+minLength > maxFrameSize {
|
||||
break
|
||||
}
|
||||
|
@ -227,6 +245,7 @@ func (p *packetPacker) composeNextPacket(
|
|||
payloadLength += minLength
|
||||
p.controlFrames = p.controlFrames[:len(p.controlFrames)-1]
|
||||
}
|
||||
p.controlFrameMutex.Unlock()
|
||||
|
||||
if payloadLength > maxFrameSize {
|
||||
return nil, fmt.Errorf("Packet Packer BUG: packet payload (%d) too large (%d)", payloadLength, maxFrameSize)
|
||||
|
@ -247,20 +266,14 @@ func (p *packetPacker) composeNextPacket(
|
|||
maxFrameSize += 2
|
||||
}
|
||||
|
||||
fs := p.streamFramer.PopStreamFrames(maxFrameSize - payloadLength)
|
||||
fs := p.streams.PopStreamFrames(maxFrameSize - payloadLength)
|
||||
if len(fs) != 0 {
|
||||
fs[len(fs)-1].DataLenPresent = false
|
||||
}
|
||||
|
||||
// TODO: Simplify
|
||||
for _, f := range fs {
|
||||
payloadFrames = append(payloadFrames, f)
|
||||
}
|
||||
|
||||
for b := p.streamFramer.PopBlockedFrame(); b != nil; b = p.streamFramer.PopBlockedFrame() {
|
||||
p.controlFrames = append(p.controlFrames, b)
|
||||
}
|
||||
|
||||
return payloadFrames, nil
|
||||
}
|
||||
|
||||
|
@ -271,7 +284,9 @@ func (p *packetPacker) QueueControlFrame(frame wire.Frame) {
|
|||
case *wire.AckFrame:
|
||||
p.ackFrame = f
|
||||
default:
|
||||
p.controlFrameMutex.Lock()
|
||||
p.controlFrames = append(p.controlFrames, f)
|
||||
p.controlFrameMutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -377,3 +392,7 @@ func (p *packetPacker) SetLeastUnacked(leastUnacked protocol.PacketNumber) {
|
|||
func (p *packetPacker) SetOmitConnectionID() {
|
||||
p.omitConnectionID = true
|
||||
}
|
||||
|
||||
func (p *packetPacker) MakeNextPacketRetransmittable() {
|
||||
p.makeNextPacketRetransmittable = true
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"math"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/lucas-clemente/quic-go/ackhandler"
|
||||
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
|
@ -49,29 +50,32 @@ func (m *mockCryptoSetup) GetSealerWithEncryptionLevel(protocol.EncryptionLevel)
|
|||
}
|
||||
func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.divNonce }
|
||||
func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce = divNonce }
|
||||
func (m *mockCryptoSetup) ConnectionState() ConnectionState { panic("not implemented") }
|
||||
|
||||
var _ = Describe("Packet packer", func() {
|
||||
var (
|
||||
packer *packetPacker
|
||||
publicHeaderLen protocol.ByteCount
|
||||
maxFrameSize protocol.ByteCount
|
||||
streamFramer *streamFramer
|
||||
cryptoStream *stream
|
||||
cryptoStream cryptoStreamI
|
||||
mockStreamFramer *MockStreamFrameSource
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
version := versionGQUICFrames
|
||||
cryptoStream = &stream{streamID: version.CryptoStreamID(), flowController: flowcontrol.NewStreamFlowController(version.CryptoStreamID(), false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil)}
|
||||
streamsMap := newStreamsMap(nil, protocol.PerspectiveServer, versionGQUICFrames)
|
||||
streamFramer = newStreamFramer(cryptoStream, streamsMap, nil, versionGQUICFrames)
|
||||
mockSender := NewMockStreamSender(mockCtrl)
|
||||
mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes()
|
||||
cryptoStream = newCryptoStream(mockSender, flowcontrol.NewStreamFlowController(version.CryptoStreamID(), false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil), version)
|
||||
mockStreamFramer = NewMockStreamFrameSource(mockCtrl)
|
||||
|
||||
packer = &packetPacker{
|
||||
cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure},
|
||||
connectionID: 0x1337,
|
||||
packetNumberGenerator: newPacketNumberGenerator(1, protocol.SkipPacketAveragePeriodLength),
|
||||
streamFramer: streamFramer,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
}
|
||||
packer = newPacketPacker(
|
||||
0x1337,
|
||||
1,
|
||||
&mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure},
|
||||
mockStreamFramer,
|
||||
protocol.PerspectiveServer,
|
||||
version,
|
||||
)
|
||||
publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number
|
||||
maxFrameSize = protocol.MaxPacketSize - protocol.ByteCount((&mockSealer{}).Overhead()) - publicHeaderLen
|
||||
packer.hasSentPacket = true
|
||||
|
@ -79,33 +83,36 @@ var _ = Describe("Packet packer", func() {
|
|||
})
|
||||
|
||||
It("returns nil when no packet is queued", func() {
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData()
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
|
||||
p, err := packer.PackPacket()
|
||||
Expect(p).To(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("packs single packets", func() {
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData()
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: 5,
|
||||
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
|
||||
}
|
||||
streamFramer.AddFrameForRetransmission(f)
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f})
|
||||
p, err := packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p).ToNot(BeNil())
|
||||
b := &bytes.Buffer{}
|
||||
f.Write(b, packer.version)
|
||||
Expect(p.frames).To(HaveLen(1))
|
||||
Expect(p.frames).To(Equal([]wire.Frame{f}))
|
||||
Expect(p.raw).To(ContainSubstring(string(b.Bytes())))
|
||||
})
|
||||
|
||||
It("stores the encryption level a packet was sealed with", func() {
|
||||
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionForwardSecure
|
||||
f := &wire.StreamFrame{
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData()
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{{
|
||||
StreamID: 5,
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
streamFramer.AddFrameForRetransmission(f)
|
||||
}})
|
||||
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionForwardSecure
|
||||
p, err := packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure))
|
||||
|
@ -213,7 +220,7 @@ var _ = Describe("Packet packer", func() {
|
|||
})
|
||||
})
|
||||
|
||||
It("packs a ConnectionClose", func() {
|
||||
It("packs a CONNECTION_CLOSE", func() {
|
||||
ccf := wire.ConnectionCloseFrame{
|
||||
ErrorCode: 0x1337,
|
||||
ReasonPhrase: "foobar",
|
||||
|
@ -224,23 +231,21 @@ var _ = Describe("Packet packer", func() {
|
|||
Expect(p.frames[0]).To(Equal(&ccf))
|
||||
})
|
||||
|
||||
It("doesn't send any other frames when sending a ConnectionClose", func() {
|
||||
ccf := wire.ConnectionCloseFrame{
|
||||
It("doesn't send any other frames when sending a CONNECTION_CLOSE", func() {
|
||||
// expect no mockStreamFramer.PopStreamFrames
|
||||
ccf := &wire.ConnectionCloseFrame{
|
||||
ErrorCode: 0x1337,
|
||||
ReasonPhrase: "foobar",
|
||||
}
|
||||
packer.controlFrames = []wire.Frame{&wire.MaxStreamDataFrame{StreamID: 37}}
|
||||
streamFramer.AddFrameForRetransmission(&wire.StreamFrame{
|
||||
StreamID: 5,
|
||||
Data: []byte("foobar"),
|
||||
})
|
||||
p, err := packer.PackConnectionClose(&ccf)
|
||||
p, err := packer.PackConnectionClose(ccf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.frames).To(HaveLen(1))
|
||||
Expect(p.frames[0]).To(Equal(&ccf))
|
||||
Expect(p.frames).To(Equal([]wire.Frame{ccf}))
|
||||
})
|
||||
|
||||
It("packs only control frames", func() {
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData()
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
|
||||
packer.QueueControlFrame(&wire.RstStreamFrame{})
|
||||
packer.QueueControlFrame(&wire.MaxDataFrame{})
|
||||
p, err := packer.PackPacket()
|
||||
|
@ -251,6 +256,8 @@ var _ = Describe("Packet packer", func() {
|
|||
})
|
||||
|
||||
It("increases the packet number", func() {
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2)
|
||||
packer.QueueControlFrame(&wire.RstStreamFrame{})
|
||||
p1, err := packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -263,6 +270,8 @@ var _ = Describe("Packet packer", func() {
|
|||
})
|
||||
|
||||
It("packs a STOP_WAITING frame first", func() {
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData()
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
|
||||
packer.packetNumberGenerator.next = 15
|
||||
swf := &wire.StopWaitingFrame{LeastUnacked: 10}
|
||||
packer.QueueControlFrame(&wire.RstStreamFrame{})
|
||||
|
@ -275,6 +284,8 @@ var _ = Describe("Packet packer", func() {
|
|||
})
|
||||
|
||||
It("sets the LeastUnackedDelta length of a STOP_WAITING frame", func() {
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData()
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
|
||||
packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number
|
||||
packer.packetNumberGenerator.next = packetNumber
|
||||
swf := &wire.StopWaitingFrame{LeastUnacked: packetNumber - 0x100}
|
||||
|
@ -286,6 +297,8 @@ var _ = Describe("Packet packer", func() {
|
|||
})
|
||||
|
||||
It("does not pack a packet containing only a STOP_WAITING frame", func() {
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData()
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
|
||||
swf := &wire.StopWaitingFrame{LeastUnacked: 10}
|
||||
packer.QueueControlFrame(swf)
|
||||
p, err := packer.PackPacket()
|
||||
|
@ -294,6 +307,8 @@ var _ = Describe("Packet packer", func() {
|
|||
})
|
||||
|
||||
It("packs a packet if it has queued control frames, but no new control frames", func() {
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData()
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
|
||||
packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}}
|
||||
p, err := packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -301,6 +316,7 @@ var _ = Describe("Packet packer", func() {
|
|||
})
|
||||
|
||||
It("refuses to send a packet that doesn't contain crypto stream data, if it has never sent a packet before", func() {
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData()
|
||||
packer.hasSentPacket = false
|
||||
packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}}
|
||||
p, err := packer.PackPacket()
|
||||
|
@ -329,8 +345,7 @@ var _ = Describe("Packet packer", func() {
|
|||
|
||||
It("packs a lot of control frames into 2 packets if they don't fit into one", func() {
|
||||
blockedFrame := &wire.BlockedFrame{}
|
||||
minLength, _ := blockedFrame.MinLength(packer.version)
|
||||
maxFramesPerPacket := int(maxFrameSize) / int(minLength)
|
||||
maxFramesPerPacket := int(maxFrameSize) / int(blockedFrame.MinLength(packer.version))
|
||||
var controlFrames []wire.Frame
|
||||
for i := 0; i < maxFramesPerPacket+10; i++ {
|
||||
controlFrames = append(controlFrames, blockedFrame)
|
||||
|
@ -345,16 +360,17 @@ var _ = Describe("Packet packer", func() {
|
|||
})
|
||||
|
||||
It("only increases the packet number when there is an actual packet to send", func() {
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
|
||||
packer.packetNumberGenerator.nextToSkip = 1000
|
||||
p, err := packer.PackPacket()
|
||||
Expect(p).To(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(1)))
|
||||
f := &wire.StreamFrame{
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{{
|
||||
StreamID: 5,
|
||||
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
|
||||
}
|
||||
streamFramer.AddFrameForRetransmission(f)
|
||||
}})
|
||||
p, err = packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p).ToNot(BeNil())
|
||||
|
@ -362,320 +378,207 @@ var _ = Describe("Packet packer", func() {
|
|||
Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(2)))
|
||||
})
|
||||
|
||||
Context("STREAM Frame handling", func() {
|
||||
It("does not splits a STREAM frame with maximum size, for gQUIC frames", func() {
|
||||
f := &wire.StreamFrame{
|
||||
Offset: 1,
|
||||
StreamID: 5,
|
||||
DataLenPresent: false,
|
||||
}
|
||||
minLength, _ := f.MinLength(packer.version)
|
||||
maxStreamFrameDataLen := maxFrameSize - minLength
|
||||
f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen))
|
||||
streamFramer.AddFrameForRetransmission(f)
|
||||
payloadFrames, err := packer.composeNextPacket(maxFrameSize, true)
|
||||
It("adds a PING frame when it's supposed to send a retransmittable packet", func() {
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2)
|
||||
packer.QueueControlFrame(&wire.AckFrame{})
|
||||
packer.QueueControlFrame(&wire.StopWaitingFrame{})
|
||||
packer.MakeNextPacketRetransmittable()
|
||||
p, err := packer.PackPacket()
|
||||
Expect(p).ToNot(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(payloadFrames).To(HaveLen(1))
|
||||
Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
|
||||
payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
|
||||
Expect(p.frames).To(HaveLen(3))
|
||||
Expect(p.frames).To(ContainElement(&wire.PingFrame{}))
|
||||
// make sure the next packet doesn't contain another PING
|
||||
packer.QueueControlFrame(&wire.AckFrame{})
|
||||
p, err = packer.PackPacket()
|
||||
Expect(p).ToNot(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(payloadFrames).To(BeEmpty())
|
||||
Expect(p.frames).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("does not splits a STREAM frame with maximum size, for IETF draft style frame", func() {
|
||||
packer.version = versionIETFFrames
|
||||
streamFramer.version = versionIETFFrames
|
||||
It("waits until there's something to send before adding a PING frame", func() {
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2)
|
||||
packer.MakeNextPacketRetransmittable()
|
||||
p, err := packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p).To(BeNil())
|
||||
packer.QueueControlFrame(&wire.AckFrame{})
|
||||
p, err = packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.frames).To(HaveLen(2))
|
||||
Expect(p.frames).To(ContainElement(&wire.PingFrame{}))
|
||||
})
|
||||
|
||||
It("doesn't send a PING if it already sent another retransmittable frame", func() {
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2)
|
||||
packer.MakeNextPacketRetransmittable()
|
||||
packer.QueueControlFrame(&wire.MaxDataFrame{})
|
||||
p, err := packer.PackPacket()
|
||||
Expect(p).ToNot(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.frames).To(HaveLen(1))
|
||||
packer.QueueControlFrame(&wire.AckFrame{})
|
||||
p, err = packer.PackPacket()
|
||||
Expect(p).ToNot(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.frames).To(HaveLen(1))
|
||||
})
|
||||
|
||||
Context("STREAM frame handling", func() {
|
||||
It("does not splits a STREAM frame with maximum size, for gQUIC frames", func() {
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).DoAndReturn(func(maxSize protocol.ByteCount) []*wire.StreamFrame {
|
||||
f := &wire.StreamFrame{
|
||||
Offset: 1,
|
||||
StreamID: 5,
|
||||
DataLenPresent: true,
|
||||
}
|
||||
minLength, _ := f.MinLength(packer.version)
|
||||
// for IETF draft style STREAM frames, we don't know the size of the DataLen, because it is a variable length integer
|
||||
// in the general case, we therefore use a STREAM frame that is 1 byte smaller than the maximum size
|
||||
maxStreamFrameDataLen := maxFrameSize - minLength - 1
|
||||
f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen))
|
||||
streamFramer.AddFrameForRetransmission(f)
|
||||
payloadFrames, err := packer.composeNextPacket(maxFrameSize, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(payloadFrames).To(HaveLen(1))
|
||||
Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
|
||||
payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(payloadFrames).To(BeEmpty())
|
||||
f.Data = bytes.Repeat([]byte{'f'}, int(maxSize-f.MinLength(packer.version)))
|
||||
return []*wire.StreamFrame{f}
|
||||
})
|
||||
|
||||
It("correctly handles a STREAM frame with one byte less than maximum size", func() {
|
||||
maxStreamFrameDataLen := maxFrameSize - (1 + 1 + 2) - 1
|
||||
f1 := &wire.StreamFrame{
|
||||
StreamID: 5,
|
||||
Offset: 1,
|
||||
Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)),
|
||||
}
|
||||
f2 := &wire.StreamFrame{
|
||||
StreamID: 5,
|
||||
Offset: 1,
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
streamFramer.AddFrameForRetransmission(f1)
|
||||
streamFramer.AddFrameForRetransmission(f2)
|
||||
p, err := packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize - 1)))
|
||||
Expect(p.frames).To(HaveLen(1))
|
||||
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
|
||||
p, err = packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.frames).To(HaveLen(1))
|
||||
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
|
||||
})
|
||||
|
||||
It("packs multiple small STREAM frames into single packet", func() {
|
||||
f1 := &wire.StreamFrame{
|
||||
StreamID: 5,
|
||||
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
|
||||
}
|
||||
f2 := &wire.StreamFrame{
|
||||
StreamID: 5,
|
||||
Data: []byte{0xBE, 0xEF, 0x13, 0x37},
|
||||
}
|
||||
f3 := &wire.StreamFrame{
|
||||
StreamID: 3,
|
||||
Data: []byte{0xCA, 0xFE},
|
||||
}
|
||||
streamFramer.AddFrameForRetransmission(f1)
|
||||
streamFramer.AddFrameForRetransmission(f2)
|
||||
streamFramer.AddFrameForRetransmission(f3)
|
||||
p, err := packer.PackPacket()
|
||||
Expect(p).ToNot(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := &bytes.Buffer{}
|
||||
f1.Write(b, 0)
|
||||
f2.Write(b, 0)
|
||||
f3.Write(b, 0)
|
||||
Expect(p.frames).To(HaveLen(3))
|
||||
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue())
|
||||
Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeTrue())
|
||||
Expect(p.frames[2].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
|
||||
Expect(p.raw).To(ContainSubstring(string(f1.Data)))
|
||||
Expect(p.raw).To(ContainSubstring(string(f2.Data)))
|
||||
Expect(p.raw).To(ContainSubstring(string(f3.Data)))
|
||||
})
|
||||
|
||||
It("splits one STREAM frame larger than maximum size", func() {
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: 7,
|
||||
Offset: 1,
|
||||
}
|
||||
minLength, _ := f.MinLength(packer.version)
|
||||
maxStreamFrameDataLen := maxFrameSize - minLength
|
||||
f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+200)
|
||||
streamFramer.AddFrameForRetransmission(f)
|
||||
payloadFrames, err := packer.composeNextPacket(maxFrameSize, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(payloadFrames).To(HaveLen(1))
|
||||
Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
|
||||
Expect(payloadFrames[0].(*wire.StreamFrame).Data).To(HaveLen(int(maxStreamFrameDataLen)))
|
||||
payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(payloadFrames).To(HaveLen(1))
|
||||
Expect(payloadFrames[0].(*wire.StreamFrame).Data).To(HaveLen(200))
|
||||
Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
|
||||
payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(payloadFrames).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("packs 2 STREAM frames that are too big for one packet correctly", func() {
|
||||
maxStreamFrameDataLen := maxFrameSize - (1 + 1 + 2)
|
||||
f1 := &wire.StreamFrame{
|
||||
StreamID: 5,
|
||||
Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+100),
|
||||
Offset: 1,
|
||||
}
|
||||
f2 := &wire.StreamFrame{
|
||||
StreamID: 5,
|
||||
Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+100),
|
||||
Offset: 1,
|
||||
}
|
||||
streamFramer.AddFrameForRetransmission(f1)
|
||||
streamFramer.AddFrameForRetransmission(f2)
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
|
||||
p, err := packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.frames).To(HaveLen(1))
|
||||
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
|
||||
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
|
||||
p, err = packer.PackPacket()
|
||||
Expect(p.frames).To(HaveLen(2))
|
||||
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue())
|
||||
Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
|
||||
p, err = packer.PackPacket()
|
||||
Expect(p.frames).To(HaveLen(1))
|
||||
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p).ToNot(BeNil())
|
||||
p, err = packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p).To(BeNil())
|
||||
})
|
||||
|
||||
It("packs a packet that has the maximum packet size when given a large enough STREAM frame", func() {
|
||||
It("does not splits a STREAM frame with maximum size, for IETF draft style frame", func() {
|
||||
packer.version = versionIETFFrames
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).DoAndReturn(func(maxSize protocol.ByteCount) []*wire.StreamFrame {
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: 5,
|
||||
Offset: 1,
|
||||
StreamID: 5,
|
||||
DataLenPresent: true,
|
||||
}
|
||||
minLength, _ := f.MinLength(packer.version)
|
||||
f.Data = bytes.Repeat([]byte{'f'}, int(maxFrameSize-minLength+1)) // + 1 since MinceLength is 1 bigger than the actual StreamFrame header
|
||||
streamFramer.AddFrameForRetransmission(f)
|
||||
f.Data = bytes.Repeat([]byte{'f'}, int(maxSize-f.MinLength(packer.version)))
|
||||
return []*wire.StreamFrame{f}
|
||||
})
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
|
||||
p, err := packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p).ToNot(BeNil())
|
||||
Expect(p.frames).To(HaveLen(1))
|
||||
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
|
||||
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
|
||||
p, err = packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p).To(BeNil())
|
||||
})
|
||||
|
||||
It("splits a STREAM frame larger than the maximum size", func() {
|
||||
f := &wire.StreamFrame{
|
||||
It("packs multiple small STREAM frames into single packet", func() {
|
||||
f1 := &wire.StreamFrame{
|
||||
StreamID: 5,
|
||||
Offset: 1,
|
||||
Data: []byte("frame 1"),
|
||||
DataLenPresent: true,
|
||||
}
|
||||
minLength, _ := f.MinLength(packer.version)
|
||||
f.Data = bytes.Repeat([]byte{'f'}, int(maxFrameSize-minLength+2)) // + 2 since MinceLength is 1 bigger than the actual StreamFrame header
|
||||
|
||||
streamFramer.AddFrameForRetransmission(f)
|
||||
payloadFrames, err := packer.composeNextPacket(maxFrameSize, true)
|
||||
f2 := &wire.StreamFrame{
|
||||
StreamID: 5,
|
||||
Data: []byte("frame 2"),
|
||||
DataLenPresent: true,
|
||||
}
|
||||
f3 := &wire.StreamFrame{
|
||||
StreamID: 3,
|
||||
Data: []byte("frame 3"),
|
||||
DataLenPresent: true,
|
||||
}
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData()
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f1, f2, f3})
|
||||
p, err := packer.PackPacket()
|
||||
Expect(p).ToNot(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(payloadFrames).To(HaveLen(1))
|
||||
payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(payloadFrames).To(HaveLen(1))
|
||||
Expect(p.frames).To(HaveLen(3))
|
||||
Expect(p.frames[0].(*wire.StreamFrame).Data).To(Equal([]byte("frame 1")))
|
||||
Expect(p.frames[1].(*wire.StreamFrame).Data).To(Equal([]byte("frame 2")))
|
||||
Expect(p.frames[2].(*wire.StreamFrame).Data).To(Equal([]byte("frame 3")))
|
||||
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue())
|
||||
Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeTrue())
|
||||
Expect(p.frames[2].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
|
||||
})
|
||||
|
||||
It("refuses to send unencrypted stream data on a data stream", func() {
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData()
|
||||
// don't expect a call to mockStreamFramer.PopStreamFrames
|
||||
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: 3,
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
streamFramer.AddFrameForRetransmission(f)
|
||||
p, err := packer.PackPacket()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(p).To(BeNil())
|
||||
})
|
||||
|
||||
It("sends non forward-secure data as the client", func() {
|
||||
packer.perspective = protocol.PerspectiveClient
|
||||
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: 5,
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
streamFramer.AddFrameForRetransmission(f)
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData()
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f})
|
||||
packer.perspective = protocol.PerspectiveClient
|
||||
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure
|
||||
p, err := packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure))
|
||||
Expect(p.frames[0]).To(Equal(f))
|
||||
Expect(p.frames).To(Equal([]wire.Frame{f}))
|
||||
})
|
||||
|
||||
It("does not send non forward-secure data as the server", func() {
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData()
|
||||
// don't expect a call to mockStreamFramer.PopStreamFrames
|
||||
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: 5,
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
streamFramer.AddFrameForRetransmission(f)
|
||||
p, err := packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p).To(BeNil())
|
||||
})
|
||||
|
||||
It("sends unencrypted stream data on the crypto stream", func() {
|
||||
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted
|
||||
cryptoStream.dataForWriting = []byte("foobar")
|
||||
p, err := packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted))
|
||||
Expect(p.frames).To(HaveLen(1))
|
||||
Expect(p.frames[0]).To(Equal(&wire.StreamFrame{
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: packer.version.CryptoStreamID(),
|
||||
Data: []byte("foobar"),
|
||||
}))
|
||||
}
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true)
|
||||
mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f)
|
||||
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted
|
||||
p, err := packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.frames).To(Equal([]wire.Frame{f}))
|
||||
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted))
|
||||
})
|
||||
|
||||
It("sends encrypted stream data on the crypto stream", func() {
|
||||
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionSecure
|
||||
cryptoStream.dataForWriting = []byte("foobar")
|
||||
p, err := packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure))
|
||||
Expect(p.frames).To(HaveLen(1))
|
||||
Expect(p.frames[0]).To(Equal(&wire.StreamFrame{
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: packer.version.CryptoStreamID(),
|
||||
Data: []byte("foobar"),
|
||||
}))
|
||||
}
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true)
|
||||
mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f)
|
||||
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionSecure
|
||||
p, err := packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.frames).To(Equal([]wire.Frame{f}))
|
||||
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure))
|
||||
})
|
||||
|
||||
It("does not pack stream frames if not allowed", func() {
|
||||
It("does not pack STREAM frames if not allowed", func() {
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData()
|
||||
// don't expect a call to mockStreamFramer.PopStreamFrames
|
||||
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted
|
||||
packer.QueueControlFrame(&wire.AckFrame{})
|
||||
streamFramer.AddFrameForRetransmission(&wire.StreamFrame{StreamID: 3, Data: []byte("foobar")})
|
||||
ack := &wire.AckFrame{LargestAcked: 10}
|
||||
packer.QueueControlFrame(ack)
|
||||
p, err := packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.frames).To(HaveLen(1))
|
||||
Expect(func() { _ = p.frames[0].(*wire.AckFrame) }).NotTo(Panic())
|
||||
Expect(p.frames).To(Equal([]wire.Frame{ack}))
|
||||
})
|
||||
})
|
||||
|
||||
Context("BLOCKED frames", func() {
|
||||
It("queues a BLOCKED frame", func() {
|
||||
length := 100
|
||||
streamFramer.blockedFrameQueue = []wire.Frame{&wire.StreamBlockedFrame{StreamID: 5}}
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: 5,
|
||||
Data: bytes.Repeat([]byte{'f'}, length),
|
||||
}
|
||||
streamFramer.AddFrameForRetransmission(f)
|
||||
_, err := packer.composeNextPacket(maxFrameSize, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(packer.controlFrames[0]).To(Equal(&wire.StreamBlockedFrame{StreamID: 5}))
|
||||
})
|
||||
|
||||
It("removes the dataLen attribute from the last StreamFrame, even if it queued a BLOCKED frame", func() {
|
||||
length := 100
|
||||
streamFramer.blockedFrameQueue = []wire.Frame{&wire.StreamBlockedFrame{StreamID: 5}}
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: 5,
|
||||
Data: bytes.Repeat([]byte{'f'}, length),
|
||||
}
|
||||
streamFramer.AddFrameForRetransmission(f)
|
||||
p, err := packer.composeNextPacket(maxFrameSize, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p).To(HaveLen(1))
|
||||
Expect(p[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
|
||||
})
|
||||
|
||||
It("packs a connection-level BlockedFrame", func() {
|
||||
streamFramer.blockedFrameQueue = []wire.Frame{&wire.BlockedFrame{}}
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: 5,
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
streamFramer.AddFrameForRetransmission(f)
|
||||
_, err := packer.composeNextPacket(maxFrameSize, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(packer.controlFrames[0]).To(Equal(&wire.BlockedFrame{}))
|
||||
})
|
||||
})
|
||||
|
||||
It("returns nil if we only have a single STOP_WAITING", func() {
|
||||
packer.QueueControlFrame(&wire.StopWaitingFrame{})
|
||||
p, err := packer.PackPacket()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(p).To(BeNil())
|
||||
})
|
||||
|
||||
It("packs a single ACK", func() {
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData()
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
|
||||
ack := &wire.AckFrame{LargestAcked: 42}
|
||||
packer.QueueControlFrame(ack)
|
||||
p, err := packer.PackPacket()
|
||||
|
@ -685,6 +588,8 @@ var _ = Describe("Packet packer", func() {
|
|||
})
|
||||
|
||||
It("does not return nil if we only have a single ACK but request it to be sent", func() {
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData()
|
||||
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
|
||||
ack := &wire.AckFrame{}
|
||||
packer.QueueControlFrame(ack)
|
||||
p, err := packer.PackPacket()
|
||||
|
@ -692,15 +597,6 @@ var _ = Describe("Packet packer", func() {
|
|||
Expect(p).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("queues a control frame to be sent in the next packet", func() {
|
||||
msd := &wire.MaxStreamDataFrame{StreamID: 5}
|
||||
packer.QueueControlFrame(msd)
|
||||
p, err := packer.PackPacket()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(p.frames).To(HaveLen(1))
|
||||
Expect(p.frames[0]).To(Equal(msd))
|
||||
})
|
||||
|
||||
Context("retransmitting of handshake packets", func() {
|
||||
swf := &wire.StopWaitingFrame{LeastUnacked: 1}
|
||||
sf := &wire.StreamFrame{
|
||||
|
@ -719,8 +615,19 @@ var _ = Describe("Packet packer", func() {
|
|||
}
|
||||
p, err := packer.PackHandshakeRetransmission(packet)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.frames).To(ContainElement(sf))
|
||||
Expect(p.frames).To(ContainElement(swf))
|
||||
Expect(p.frames).To(Equal([]wire.Frame{swf, sf}))
|
||||
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted))
|
||||
})
|
||||
|
||||
It("doesn't add a STOP_WAITING frame for IETF QUIC", func() {
|
||||
packer.version = versionIETFFrames
|
||||
packet := &ackhandler.Packet{
|
||||
EncryptionLevel: protocol.EncryptionUnencrypted,
|
||||
Frames: []wire.Frame{sf},
|
||||
}
|
||||
p, err := packer.PackHandshakeRetransmission(packet)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.frames).To(Equal([]wire.Frame{sf}))
|
||||
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted))
|
||||
})
|
||||
|
||||
|
@ -733,8 +640,7 @@ var _ = Describe("Packet packer", func() {
|
|||
}
|
||||
p, err := packer.PackHandshakeRetransmission(packet)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.frames).To(ContainElement(sf))
|
||||
Expect(p.frames).To(ContainElement(swf))
|
||||
Expect(p.frames).To(Equal([]wire.Frame{swf, sf}))
|
||||
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure))
|
||||
// a packet sent by the server with initial encryption contains the SHLO
|
||||
// it needs to have a diversification nonce
|
||||
|
@ -768,11 +674,16 @@ var _ = Describe("Packet packer", func() {
|
|||
})
|
||||
|
||||
It("pads Initial packets to the required minimum packet size", func() {
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: packer.version.CryptoStreamID(),
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true)
|
||||
mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f)
|
||||
packer.version = protocol.VersionTLS
|
||||
packer.hasSentPacket = false
|
||||
packer.perspective = protocol.PerspectiveClient
|
||||
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted
|
||||
cryptoStream.dataForWriting = []byte("foobar")
|
||||
packet, err := packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(packet.raw).To(HaveLen(protocol.MinInitialPacketSize))
|
||||
|
@ -795,7 +706,7 @@ var _ = Describe("Packet packer", func() {
|
|||
_, err := packer.PackHandshakeRetransmission(&ackhandler.Packet{
|
||||
EncryptionLevel: protocol.EncryptionSecure,
|
||||
})
|
||||
Expect(err).To(MatchError("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame"))
|
||||
Expect(err).To(MatchError("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame"))
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -807,7 +718,7 @@ var _ = Describe("Packet packer", func() {
|
|||
Expect(p.frames).To(Equal([]wire.Frame{&wire.AckFrame{DelayTime: math.MaxInt64}}))
|
||||
})
|
||||
|
||||
It("packs ACK packets with SWFs", func() {
|
||||
It("packs ACK packets with STOP_WAITING frames", func() {
|
||||
packer.QueueControlFrame(&wire.AckFrame{})
|
||||
packer.QueueControlFrame(&wire.StopWaitingFrame{})
|
||||
p, err := packer.PackAckPacket()
|
||||
|
|
|
@ -107,21 +107,32 @@ func (u *packetUnpacker) parseIETFFrame(r *bytes.Reader, typeByte byte, hdr *wir
|
|||
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
|
||||
}
|
||||
case 0x6:
|
||||
// TODO(#964): remove STOP_WAITING frames
|
||||
// TODO(#878): implement the MAX_STREAM_ID frame
|
||||
frame, err = wire.ParseStopWaitingFrame(r, hdr.PacketNumber, hdr.PacketNumberLen, u.version)
|
||||
frame, err = wire.ParseMaxStreamIDFrame(r, u.version)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidStopWaitingData, err.Error())
|
||||
err = qerr.Error(qerr.InvalidFrameData, err.Error())
|
||||
}
|
||||
case 0x7:
|
||||
frame, err = wire.ParsePingFrame(r, u.version)
|
||||
case 0x8:
|
||||
frame, err = wire.ParseBlockedFrame(r, u.version)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
|
||||
}
|
||||
case 0x9:
|
||||
frame, err = wire.ParseStreamBlockedFrame(r, u.version)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
|
||||
}
|
||||
case 0xa:
|
||||
frame, err = wire.ParseStreamIDBlockedFrame(r, u.version)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidFrameData, err.Error())
|
||||
}
|
||||
case 0xc:
|
||||
frame, err = wire.ParseStopSendingFrame(r, u.version)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidFrameData, err.Error())
|
||||
}
|
||||
case 0xe:
|
||||
frame, err = wire.ParseAckFrame(r, u.version)
|
||||
if err != nil {
|
||||
|
|
|
@ -102,7 +102,7 @@ var _ = Describe("Packet unpacker", func() {
|
|||
f := &wire.RstStreamFrame{
|
||||
StreamID: 0xdeadbeef,
|
||||
ByteOffset: 0xdecafbad11223344,
|
||||
ErrorCode: 0x13371234,
|
||||
ErrorCode: 0x1337,
|
||||
}
|
||||
err := f.Write(buf, versionGQUICFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -342,8 +342,19 @@ var _ = Describe("Packet unpacker", func() {
|
|||
Expect(packet.frames).To(Equal([]wire.Frame{f}))
|
||||
})
|
||||
|
||||
It("unpacks MAX_STREAM_ID frames", func() {
|
||||
f := &wire.MaxStreamIDFrame{StreamID: 0x1337}
|
||||
buf := &bytes.Buffer{}
|
||||
err := f.Write(buf, versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
setData(buf.Bytes())
|
||||
packet, err := unpacker.Unpack(hdrBin, hdr, data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(packet.frames).To(Equal([]wire.Frame{f}))
|
||||
})
|
||||
|
||||
It("unpacks connection-level BLOCKED frames", func() {
|
||||
f := &wire.BlockedFrame{}
|
||||
f := &wire.BlockedFrame{Offset: 0x1234}
|
||||
buf := &bytes.Buffer{}
|
||||
err := f.Write(buf, versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -354,7 +365,32 @@ var _ = Describe("Packet unpacker", func() {
|
|||
})
|
||||
|
||||
It("unpacks stream-level BLOCKED frames", func() {
|
||||
f := &wire.StreamBlockedFrame{StreamID: 0xdeadbeef}
|
||||
f := &wire.StreamBlockedFrame{
|
||||
StreamID: 0xdeadbeef,
|
||||
Offset: 0xdead,
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
err := f.Write(buf, versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
setData(buf.Bytes())
|
||||
packet, err := unpacker.Unpack(hdrBin, hdr, data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(packet.frames).To(Equal([]wire.Frame{f}))
|
||||
})
|
||||
|
||||
It("unpacks STREAM_ID_BLOCKED frames", func() {
|
||||
f := &wire.StreamIDBlockedFrame{StreamID: 0x1234567}
|
||||
buf := &bytes.Buffer{}
|
||||
err := f.Write(buf, versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
setData(buf.Bytes())
|
||||
packet, err := unpacker.Unpack(hdrBin, hdr, data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(packet.frames).To(Equal([]wire.Frame{f}))
|
||||
})
|
||||
|
||||
It("unpacks STOP_SENDING frames", func() {
|
||||
f := &wire.StopSendingFrame{StreamID: 0x42}
|
||||
buf := &bytes.Buffer{}
|
||||
err := f.Write(buf, versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -392,9 +428,13 @@ var _ = Describe("Packet unpacker", func() {
|
|||
0x02: qerr.InvalidConnectionCloseData,
|
||||
0x04: qerr.InvalidWindowUpdateData,
|
||||
0x05: qerr.InvalidWindowUpdateData,
|
||||
0x06: qerr.InvalidFrameData,
|
||||
0x08: qerr.InvalidBlockedData,
|
||||
0x09: qerr.InvalidBlockedData,
|
||||
0x0a: qerr.InvalidFrameData,
|
||||
0x0c: qerr.InvalidFrameData,
|
||||
0x0e: qerr.InvalidAckData,
|
||||
0x10: qerr.InvalidStreamData,
|
||||
0xe: qerr.InvalidAckData,
|
||||
} {
|
||||
setData([]byte{b})
|
||||
_, err := unpacker.Unpack(hdrBin, hdr, data)
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
// Code generated by "stringer -type=ErrorCode"; DO NOT EDIT
|
||||
// Code generated by "stringer -type=ErrorCode"; DO NOT EDIT.
|
||||
|
||||
package qerr
|
||||
|
||||
import "fmt"
|
||||
import "strconv"
|
||||
|
||||
const (
|
||||
_ErrorCode_name_0 = "InternalErrorStreamDataAfterTerminationInvalidPacketHeaderInvalidFrameDataInvalidFecDataInvalidRstStreamDataInvalidConnectionCloseDataInvalidGoawayDataInvalidAckDataInvalidVersionNegotiationPacketInvalidPublicRstPacketDecryptionFailureEncryptionFailurePacketTooLarge"
|
||||
|
@ -19,7 +19,6 @@ var (
|
|||
_ErrorCode_index_2 = [...]uint16{0, 15, 37, 57, 75, 96, 112, 127, 147, 167, 191, 226, 250, 279, 309, 340, 366, 385, 410, 425, 445, 457, 475, 505, 530, 547}
|
||||
_ErrorCode_index_3 = [...]uint16{0, 14, 29, 50, 65, 90, 119, 158, 184, 208, 231, 249, 279, 301, 322, 340, 366, 390, 425}
|
||||
_ErrorCode_index_4 = [...]uint16{0, 16, 45, 78, 97, 114, 144, 169, 192, 215, 238, 256, 276, 292, 308, 346, 379, 410, 448, 459, 477, 498, 532}
|
||||
_ErrorCode_index_5 = [...]uint8{0, 34}
|
||||
)
|
||||
|
||||
func (i ErrorCode) String() string {
|
||||
|
@ -42,6 +41,6 @@ func (i ErrorCode) String() string {
|
|||
case i == 97:
|
||||
return _ErrorCode_name_5
|
||||
default:
|
||||
return fmt.Sprintf("ErrorCode(%d)", i)
|
||||
return "ErrorCode(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,284 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type receiveStreamI interface {
|
||||
ReceiveStream
|
||||
|
||||
handleStreamFrame(*wire.StreamFrame) error
|
||||
handleRstStreamFrame(*wire.RstStreamFrame) error
|
||||
closeForShutdown(error)
|
||||
getWindowUpdate() protocol.ByteCount
|
||||
}
|
||||
|
||||
type receiveStream struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
streamID protocol.StreamID
|
||||
|
||||
sender streamSender
|
||||
|
||||
frameQueue *streamFrameSorter
|
||||
readPosInFrame int
|
||||
readOffset protocol.ByteCount
|
||||
|
||||
closeForShutdownErr error
|
||||
cancelReadErr error
|
||||
resetRemotelyErr StreamError
|
||||
|
||||
closedForShutdown bool // set when CloseForShutdown() is called
|
||||
finRead bool // set once we read a frame with a FinBit
|
||||
canceledRead bool // set when CancelRead() is called
|
||||
resetRemotely bool // set when HandleRstStreamFrame() is called
|
||||
|
||||
readChan chan struct{}
|
||||
readDeadline time.Time
|
||||
|
||||
flowController flowcontrol.StreamFlowController
|
||||
version protocol.VersionNumber
|
||||
}
|
||||
|
||||
var _ ReceiveStream = &receiveStream{}
|
||||
var _ receiveStreamI = &receiveStream{}
|
||||
|
||||
func newReceiveStream(
|
||||
streamID protocol.StreamID,
|
||||
sender streamSender,
|
||||
flowController flowcontrol.StreamFlowController,
|
||||
) *receiveStream {
|
||||
return &receiveStream{
|
||||
streamID: streamID,
|
||||
sender: sender,
|
||||
flowController: flowController,
|
||||
frameQueue: newStreamFrameSorter(),
|
||||
readChan: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *receiveStream) StreamID() protocol.StreamID {
|
||||
return s.streamID
|
||||
}
|
||||
|
||||
// Read implements io.Reader. It is not thread safe!
|
||||
func (s *receiveStream) Read(p []byte) (int, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.finRead {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if s.canceledRead {
|
||||
return 0, s.cancelReadErr
|
||||
}
|
||||
if s.resetRemotely {
|
||||
return 0, s.resetRemotelyErr
|
||||
}
|
||||
if s.closedForShutdown {
|
||||
return 0, s.closeForShutdownErr
|
||||
}
|
||||
|
||||
bytesRead := 0
|
||||
for bytesRead < len(p) {
|
||||
frame := s.frameQueue.Head()
|
||||
if frame == nil && bytesRead > 0 {
|
||||
return bytesRead, s.closeForShutdownErr
|
||||
}
|
||||
|
||||
for {
|
||||
// Stop waiting on errors
|
||||
if s.closedForShutdown {
|
||||
return bytesRead, s.closeForShutdownErr
|
||||
}
|
||||
if s.canceledRead {
|
||||
return bytesRead, s.cancelReadErr
|
||||
}
|
||||
if s.resetRemotely {
|
||||
return bytesRead, s.resetRemotelyErr
|
||||
}
|
||||
|
||||
deadline := s.readDeadline
|
||||
if !deadline.IsZero() && !time.Now().Before(deadline) {
|
||||
return bytesRead, errDeadline
|
||||
}
|
||||
|
||||
if frame != nil {
|
||||
s.readPosInFrame = int(s.readOffset - frame.Offset)
|
||||
break
|
||||
}
|
||||
|
||||
s.mutex.Unlock()
|
||||
if deadline.IsZero() {
|
||||
<-s.readChan
|
||||
} else {
|
||||
select {
|
||||
case <-s.readChan:
|
||||
case <-time.After(deadline.Sub(time.Now())):
|
||||
}
|
||||
}
|
||||
s.mutex.Lock()
|
||||
frame = s.frameQueue.Head()
|
||||
}
|
||||
|
||||
if bytesRead > len(p) {
|
||||
return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p))
|
||||
}
|
||||
if s.readPosInFrame > int(frame.DataLen()) {
|
||||
return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, frame.DataLen())
|
||||
}
|
||||
|
||||
s.mutex.Unlock()
|
||||
|
||||
copy(p[bytesRead:], frame.Data[s.readPosInFrame:])
|
||||
m := utils.Min(len(p)-bytesRead, int(frame.DataLen())-s.readPosInFrame)
|
||||
s.readPosInFrame += m
|
||||
bytesRead += m
|
||||
s.readOffset += protocol.ByteCount(m)
|
||||
|
||||
s.mutex.Lock()
|
||||
// when a RST_STREAM was received, the was already informed about the final byteOffset for this stream
|
||||
if !s.resetRemotely {
|
||||
s.flowController.AddBytesRead(protocol.ByteCount(m))
|
||||
}
|
||||
// this call triggers the flow controller to increase the flow control window, if necessary
|
||||
if s.flowController.HasWindowUpdate() {
|
||||
s.sender.onHasWindowUpdate(s.streamID)
|
||||
}
|
||||
|
||||
if s.readPosInFrame >= int(frame.DataLen()) {
|
||||
s.frameQueue.Pop()
|
||||
s.finRead = frame.FinBit
|
||||
if frame.FinBit {
|
||||
s.sender.onStreamCompleted(s.streamID)
|
||||
return bytesRead, io.EOF
|
||||
}
|
||||
}
|
||||
}
|
||||
return bytesRead, nil
|
||||
}
|
||||
|
||||
func (s *receiveStream) CancelRead(errorCode protocol.ApplicationErrorCode) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.finRead {
|
||||
return nil
|
||||
}
|
||||
if s.canceledRead {
|
||||
return nil
|
||||
}
|
||||
s.canceledRead = true
|
||||
s.cancelReadErr = fmt.Errorf("Read on stream %d canceled with error code %d", s.streamID, errorCode)
|
||||
s.signalRead()
|
||||
if s.version.UsesIETFFrameFormat() {
|
||||
s.sender.queueControlFrame(&wire.StopSendingFrame{
|
||||
StreamID: s.streamID,
|
||||
ErrorCode: errorCode,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error {
|
||||
maxOffset := frame.Offset + frame.DataLen()
|
||||
if err := s.flowController.UpdateHighestReceived(maxOffset, frame.FinBit); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
if err := s.frameQueue.Push(frame); err != nil && err != errDuplicateStreamData {
|
||||
return err
|
||||
}
|
||||
s.signalRead()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *receiveStream) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.closedForShutdown {
|
||||
return nil
|
||||
}
|
||||
if err := s.flowController.UpdateHighestReceived(frame.ByteOffset, true); err != nil {
|
||||
return err
|
||||
}
|
||||
// In gQUIC, error code 0 has a special meaning.
|
||||
// The peer will reliably continue transmitting, but is not interested in reading from the stream.
|
||||
// We should therefore just continue reading from the stream, until we encounter the FIN bit.
|
||||
if !s.version.UsesIETFFrameFormat() && frame.ErrorCode == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ignore duplicate RST_STREAM frames for this stream (after checking their final offset)
|
||||
if s.resetRemotely {
|
||||
return nil
|
||||
}
|
||||
s.resetRemotely = true
|
||||
s.resetRemotelyErr = streamCanceledError{
|
||||
errorCode: frame.ErrorCode,
|
||||
error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode),
|
||||
}
|
||||
s.signalRead()
|
||||
s.sender.onStreamCompleted(s.streamID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *receiveStream) CloseRemote(offset protocol.ByteCount) {
|
||||
s.handleStreamFrame(&wire.StreamFrame{FinBit: true, Offset: offset})
|
||||
}
|
||||
|
||||
func (s *receiveStream) onClose(offset protocol.ByteCount) {
|
||||
if s.canceledRead && !s.version.UsesIETFFrameFormat() {
|
||||
s.sender.queueControlFrame(&wire.RstStreamFrame{
|
||||
StreamID: s.streamID,
|
||||
ByteOffset: offset,
|
||||
ErrorCode: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *receiveStream) SetReadDeadline(t time.Time) error {
|
||||
s.mutex.Lock()
|
||||
oldDeadline := s.readDeadline
|
||||
s.readDeadline = t
|
||||
s.mutex.Unlock()
|
||||
// if the new deadline is before the currently set deadline, wake up Read()
|
||||
if t.Before(oldDeadline) {
|
||||
s.signalRead()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseForShutdown closes a stream abruptly.
|
||||
// It makes Read unblock (and return the error) immediately.
|
||||
// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
|
||||
func (s *receiveStream) closeForShutdown(err error) {
|
||||
s.mutex.Lock()
|
||||
s.closedForShutdown = true
|
||||
s.closeForShutdownErr = err
|
||||
s.mutex.Unlock()
|
||||
s.signalRead()
|
||||
}
|
||||
|
||||
func (s *receiveStream) getWindowUpdate() protocol.ByteCount {
|
||||
return s.flowController.GetWindowUpdate()
|
||||
}
|
||||
|
||||
// signalRead performs a non-blocking send on the readChan
|
||||
func (s *receiveStream) signalRead() {
|
||||
select {
|
||||
case s.readChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
|
@ -0,0 +1,648 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/lucas-clemente/quic-go/internal/mocks"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/onsi/gomega/gbytes"
|
||||
)
|
||||
|
||||
var _ = Describe("Receive Stream", func() {
|
||||
const streamID protocol.StreamID = 1337
|
||||
|
||||
var (
|
||||
str *receiveStream
|
||||
strWithTimeout io.Reader // str wrapped with gbytes.TimeoutReader
|
||||
mockFC *mocks.MockStreamFlowController
|
||||
mockSender *MockStreamSender
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
mockSender = NewMockStreamSender(mockCtrl)
|
||||
mockFC = mocks.NewMockStreamFlowController(mockCtrl)
|
||||
str = newReceiveStream(streamID, mockSender, mockFC)
|
||||
|
||||
timeout := scaleDuration(250 * time.Millisecond)
|
||||
strWithTimeout = gbytes.TimeoutReader(str, timeout)
|
||||
})
|
||||
|
||||
It("gets stream id", func() {
|
||||
Expect(str.StreamID()).To(Equal(protocol.StreamID(1337)))
|
||||
})
|
||||
|
||||
Context("reading", func() {
|
||||
It("reads a single STREAM frame", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4))
|
||||
mockFC.EXPECT().HasWindowUpdate()
|
||||
frame := wire.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
|
||||
}
|
||||
err := str.handleStreamFrame(&frame)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := make([]byte, 4)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(4))
|
||||
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
|
||||
})
|
||||
|
||||
It("reads a single STREAM frame in multiple goes", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
|
||||
mockFC.EXPECT().HasWindowUpdate().Times(2)
|
||||
frame := wire.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
|
||||
}
|
||||
err := str.handleStreamFrame(&frame)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := make([]byte, 2)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(2))
|
||||
Expect(b).To(Equal([]byte{0xDE, 0xAD}))
|
||||
n, err = strWithTimeout.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(2))
|
||||
Expect(b).To(Equal([]byte{0xBE, 0xEF}))
|
||||
})
|
||||
|
||||
It("reads all data available", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2)
|
||||
mockFC.EXPECT().HasWindowUpdate().Times(2)
|
||||
frame1 := wire.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte{0xDE, 0xAD},
|
||||
}
|
||||
frame2 := wire.StreamFrame{
|
||||
Offset: 2,
|
||||
Data: []byte{0xBE, 0xEF},
|
||||
}
|
||||
err := str.handleStreamFrame(&frame1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = str.handleStreamFrame(&frame2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := make([]byte, 6)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(4))
|
||||
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF, 0x00, 0x00}))
|
||||
})
|
||||
|
||||
It("assembles multiple STREAM frames", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2)
|
||||
mockFC.EXPECT().HasWindowUpdate().Times(2)
|
||||
frame1 := wire.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte{0xDE, 0xAD},
|
||||
}
|
||||
frame2 := wire.StreamFrame{
|
||||
Offset: 2,
|
||||
Data: []byte{0xBE, 0xEF},
|
||||
}
|
||||
err := str.handleStreamFrame(&frame1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = str.handleStreamFrame(&frame2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := make([]byte, 4)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(4))
|
||||
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
|
||||
})
|
||||
|
||||
It("waits until data is available", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
|
||||
mockFC.EXPECT().HasWindowUpdate()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
frame := wire.StreamFrame{Data: []byte{0xDE, 0xAD}}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
err := str.handleStreamFrame(&frame)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
b := make([]byte, 2)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(2))
|
||||
})
|
||||
|
||||
It("handles STREAM frames in wrong order", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2)
|
||||
mockFC.EXPECT().HasWindowUpdate().Times(2)
|
||||
frame1 := wire.StreamFrame{
|
||||
Offset: 2,
|
||||
Data: []byte{0xBE, 0xEF},
|
||||
}
|
||||
frame2 := wire.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte{0xDE, 0xAD},
|
||||
}
|
||||
err := str.handleStreamFrame(&frame1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = str.handleStreamFrame(&frame2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := make([]byte, 4)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(4))
|
||||
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
|
||||
})
|
||||
|
||||
It("ignores duplicate STREAM frames", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2)
|
||||
mockFC.EXPECT().HasWindowUpdate().Times(2)
|
||||
frame1 := wire.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte{0xDE, 0xAD},
|
||||
}
|
||||
frame2 := wire.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte{0x13, 0x37},
|
||||
}
|
||||
frame3 := wire.StreamFrame{
|
||||
Offset: 2,
|
||||
Data: []byte{0xBE, 0xEF},
|
||||
}
|
||||
err := str.handleStreamFrame(&frame1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = str.handleStreamFrame(&frame2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = str.handleStreamFrame(&frame3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := make([]byte, 4)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(4))
|
||||
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
|
||||
})
|
||||
|
||||
It("doesn't rejects a STREAM frames with an overlapping data range", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4))
|
||||
mockFC.EXPECT().HasWindowUpdate().Times(2)
|
||||
frame1 := wire.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte("foob"),
|
||||
}
|
||||
frame2 := wire.StreamFrame{
|
||||
Offset: 2,
|
||||
Data: []byte("obar"),
|
||||
}
|
||||
err := str.handleStreamFrame(&frame1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = str.handleStreamFrame(&frame2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := make([]byte, 6)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(6))
|
||||
Expect(b).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("passes on errors from the streamFrameSorter", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), false)
|
||||
err := str.handleStreamFrame(&wire.StreamFrame{StreamID: streamID}) // STREAM frame without data
|
||||
Expect(err).To(MatchError(errEmptyStreamData))
|
||||
})
|
||||
|
||||
It("calls the onHasWindowUpdate callback, when the a MAX_STREAM_DATA should be sent", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6))
|
||||
mockFC.EXPECT().HasWindowUpdate().Return(true)
|
||||
mockSender.EXPECT().onHasWindowUpdate(streamID)
|
||||
frame1 := wire.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
err := str.handleStreamFrame(&frame1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := make([]byte, 6)
|
||||
_, err = strWithTimeout.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
Context("deadlines", func() {
|
||||
It("the deadline error has the right net.Error properties", func() {
|
||||
Expect(errDeadline.Temporary()).To(BeTrue())
|
||||
Expect(errDeadline.Timeout()).To(BeTrue())
|
||||
Expect(errDeadline).To(MatchError("deadline exceeded"))
|
||||
})
|
||||
|
||||
It("returns an error when Read is called after the deadline", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false).AnyTimes()
|
||||
f := &wire.StreamFrame{Data: []byte("foobar")}
|
||||
err := str.handleStreamFrame(f)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str.SetReadDeadline(time.Now().Add(-time.Second))
|
||||
b := make([]byte, 6)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(err).To(MatchError(errDeadline))
|
||||
Expect(n).To(BeZero())
|
||||
})
|
||||
|
||||
It("unblocks after the deadline", func() {
|
||||
deadline := time.Now().Add(scaleDuration(50 * time.Millisecond))
|
||||
str.SetReadDeadline(deadline)
|
||||
b := make([]byte, 6)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(err).To(MatchError(errDeadline))
|
||||
Expect(n).To(BeZero())
|
||||
Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(10*time.Millisecond)))
|
||||
})
|
||||
|
||||
It("doesn't unblock if the deadline is changed before the first one expires", func() {
|
||||
deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond))
|
||||
deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond))
|
||||
str.SetReadDeadline(deadline1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond))
|
||||
str.SetReadDeadline(deadline2)
|
||||
// make sure that this was actually execute before the deadline expires
|
||||
Expect(time.Now()).To(BeTemporally("<", deadline1))
|
||||
}()
|
||||
runtime.Gosched()
|
||||
b := make([]byte, 10)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(err).To(MatchError(errDeadline))
|
||||
Expect(n).To(BeZero())
|
||||
Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond)))
|
||||
})
|
||||
|
||||
It("unblocks earlier, when a new deadline is set", func() {
|
||||
deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond))
|
||||
deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond))
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
time.Sleep(scaleDuration(10 * time.Millisecond))
|
||||
str.SetReadDeadline(deadline2)
|
||||
// make sure that this was actually execute before the deadline expires
|
||||
Expect(time.Now()).To(BeTemporally("<", deadline2))
|
||||
}()
|
||||
str.SetReadDeadline(deadline1)
|
||||
runtime.Gosched()
|
||||
b := make([]byte, 10)
|
||||
_, err := strWithTimeout.Read(b)
|
||||
Expect(err).To(MatchError(errDeadline))
|
||||
Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(25*time.Millisecond)))
|
||||
})
|
||||
})
|
||||
|
||||
Context("closing", func() {
|
||||
Context("with FIN bit", func() {
|
||||
It("returns EOFs", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4))
|
||||
mockFC.EXPECT().HasWindowUpdate()
|
||||
str.handleStreamFrame(&wire.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
|
||||
FinBit: true,
|
||||
})
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
b := make([]byte, 4)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
Expect(n).To(Equal(4))
|
||||
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
|
||||
n, err = strWithTimeout.Read(b)
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
})
|
||||
|
||||
It("handles out-of-order frames", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2)
|
||||
mockFC.EXPECT().HasWindowUpdate().Times(2)
|
||||
frame1 := wire.StreamFrame{
|
||||
Offset: 2,
|
||||
Data: []byte{0xBE, 0xEF},
|
||||
FinBit: true,
|
||||
}
|
||||
frame2 := wire.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte{0xDE, 0xAD},
|
||||
}
|
||||
err := str.handleStreamFrame(&frame1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = str.handleStreamFrame(&frame2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
b := make([]byte, 4)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
Expect(n).To(Equal(4))
|
||||
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
|
||||
n, err = strWithTimeout.Read(b)
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
})
|
||||
|
||||
It("returns EOFs with partial read", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), true)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
|
||||
mockFC.EXPECT().HasWindowUpdate()
|
||||
err := str.handleStreamFrame(&wire.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte{0xde, 0xad},
|
||||
FinBit: true,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
b := make([]byte, 4)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
Expect(n).To(Equal(2))
|
||||
Expect(b[:n]).To(Equal([]byte{0xde, 0xad}))
|
||||
})
|
||||
|
||||
It("handles immediate FINs", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0))
|
||||
mockFC.EXPECT().HasWindowUpdate()
|
||||
err := str.handleStreamFrame(&wire.StreamFrame{
|
||||
Offset: 0,
|
||||
FinBit: true,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
b := make([]byte, 4)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
})
|
||||
})
|
||||
|
||||
It("closes when CloseRemote is called", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0))
|
||||
mockFC.EXPECT().HasWindowUpdate()
|
||||
str.CloseRemote(0)
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
b := make([]byte, 8)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
})
|
||||
})
|
||||
|
||||
Context("closing for shutdown", func() {
|
||||
testErr := errors.New("test error")
|
||||
|
||||
It("immediately returns all reads", func() {
|
||||
done := make(chan struct{})
|
||||
b := make([]byte, 4)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
str.closeForShutdown(testErr)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("errors for all following reads", func() {
|
||||
str.closeForShutdown(testErr)
|
||||
b := make([]byte, 1)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("stream cancelations", func() {
|
||||
Context("canceling read", func() {
|
||||
It("unblocks Read", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := strWithTimeout.Read([]byte{0})
|
||||
Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234"))
|
||||
close(done)
|
||||
}()
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
err := str.CancelRead(1234)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't allow further calls to Read", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
err := str.CancelRead(1234)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = strWithTimeout.Read([]byte{0})
|
||||
Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234"))
|
||||
})
|
||||
|
||||
It("does nothing when CancelRead is called twice", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
err := str.CancelRead(1234)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = str.CancelRead(2345)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = strWithTimeout.Read([]byte{0})
|
||||
Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234"))
|
||||
})
|
||||
|
||||
It("doesn't send a RST_STREAM frame, if the FIN was already read", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6))
|
||||
mockFC.EXPECT().HasWindowUpdate()
|
||||
// no calls to mockSender.queueControlFrame
|
||||
err := str.handleStreamFrame(&wire.StreamFrame{
|
||||
StreamID: streamID,
|
||||
Data: []byte("foobar"),
|
||||
FinBit: true,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
_, err = strWithTimeout.Read(make([]byte, 100))
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
err = str.CancelRead(1234)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("queues a STOP_SENDING frame, for IETF QUIC", func() {
|
||||
mockSender.EXPECT().queueControlFrame(&wire.StopSendingFrame{
|
||||
StreamID: streamID,
|
||||
ErrorCode: 1234,
|
||||
})
|
||||
err := str.CancelRead(1234)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("doesn't queue a STOP_SENDING frame, for gQUIC", func() {
|
||||
|
||||
})
|
||||
})
|
||||
|
||||
Context("receiving RST_STREAM frames", func() {
|
||||
rst := &wire.RstStreamFrame{
|
||||
StreamID: streamID,
|
||||
ByteOffset: 42,
|
||||
ErrorCode: 1234,
|
||||
}
|
||||
|
||||
It("unblocks Read", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := strWithTimeout.Read([]byte{0})
|
||||
Expect(err).To(MatchError("Stream 1337 was reset with error code 1234"))
|
||||
Expect(err).To(BeAssignableToTypeOf(streamCanceledError{}))
|
||||
Expect(err.(streamCanceledError).Canceled()).To(BeTrue())
|
||||
Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(1234)))
|
||||
close(done)
|
||||
}()
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.handleRstStreamFrame(rst)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't allow further calls to Read", func() {
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
|
||||
err := str.handleRstStreamFrame(rst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = strWithTimeout.Read([]byte{0})
|
||||
Expect(err).To(MatchError("Stream 1337 was reset with error code 1234"))
|
||||
Expect(err).To(BeAssignableToTypeOf(streamCanceledError{}))
|
||||
Expect(err.(streamCanceledError).Canceled()).To(BeTrue())
|
||||
Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(1234)))
|
||||
})
|
||||
|
||||
It("errors when receiving a RST_STREAM with an inconsistent offset", func() {
|
||||
testErr := errors.New("already received a different final offset before")
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Return(testErr)
|
||||
err := str.handleRstStreamFrame(rst)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
It("ignores duplicate RST_STREAM frames", func() {
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2)
|
||||
err := str.handleRstStreamFrame(rst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = str.handleRstStreamFrame(rst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("doesn't do anyting when it was closed for shutdown", func() {
|
||||
str.closeForShutdown(nil)
|
||||
err := str.handleRstStreamFrame(rst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
Context("for gQUIC", func() {
|
||||
BeforeEach(func() {
|
||||
str.version = versionGQUICFrames
|
||||
})
|
||||
|
||||
It("unblocks Read when receiving a RST_STREAM frame with non-zero error code", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
|
||||
readReturned := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := strWithTimeout.Read([]byte{0})
|
||||
Expect(err).To(MatchError("Stream 1337 was reset with error code 1234"))
|
||||
Expect(err).To(BeAssignableToTypeOf(streamCanceledError{}))
|
||||
Expect(err.(streamCanceledError).Canceled()).To(BeTrue())
|
||||
Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(1234)))
|
||||
close(readReturned)
|
||||
}()
|
||||
Consistently(readReturned).ShouldNot(BeClosed())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
err := str.handleRstStreamFrame(rst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(readReturned).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("continues reading until the end when receiving a RST_STREAM frame with error code 0", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true).Times(2)
|
||||
gomock.InOrder(
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)),
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)),
|
||||
mockSender.EXPECT().onStreamCompleted(streamID),
|
||||
)
|
||||
mockFC.EXPECT().HasWindowUpdate().Times(2)
|
||||
readReturned := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
n, err := strWithTimeout.Read(make([]byte, 4))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(4))
|
||||
n, err = strWithTimeout.Read(make([]byte, 4))
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
Expect(n).To(Equal(2))
|
||||
close(readReturned)
|
||||
}()
|
||||
Consistently(readReturned).ShouldNot(BeClosed())
|
||||
err := str.handleStreamFrame(&wire.StreamFrame{
|
||||
StreamID: streamID,
|
||||
Data: []byte("foobar"),
|
||||
FinBit: true,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = str.handleRstStreamFrame(&wire.RstStreamFrame{
|
||||
StreamID: streamID,
|
||||
ByteOffset: 6,
|
||||
ErrorCode: 0,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(readReturned).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("flow control", func() {
|
||||
It("errors when a STREAM frame causes a flow control violation", func() {
|
||||
testErr := errors.New("flow control violation")
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), false).Return(testErr)
|
||||
frame := wire.StreamFrame{
|
||||
Offset: 2,
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
err := str.handleStreamFrame(&frame)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
It("gets a window update", func() {
|
||||
mockFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x100))
|
||||
Expect(str.getWindowUpdate()).To(Equal(protocol.ByteCount(0x100)))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,313 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type sendStreamI interface {
|
||||
SendStream
|
||||
handleStopSendingFrame(*wire.StopSendingFrame)
|
||||
popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool)
|
||||
closeForShutdown(error)
|
||||
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
|
||||
}
|
||||
|
||||
type sendStream struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
|
||||
streamID protocol.StreamID
|
||||
sender streamSender
|
||||
|
||||
writeOffset protocol.ByteCount
|
||||
|
||||
cancelWriteErr error
|
||||
closeForShutdownErr error
|
||||
|
||||
closedForShutdown bool // set when CloseForShutdown() is called
|
||||
finishedWriting bool // set once Close() is called
|
||||
canceledWrite bool // set when CancelWrite() is called, or a STOP_SENDING frame is received
|
||||
finSent bool // set when a STREAM_FRAME with FIN bit has b
|
||||
|
||||
dataForWriting []byte
|
||||
writeChan chan struct{}
|
||||
writeDeadline time.Time
|
||||
|
||||
flowController flowcontrol.StreamFlowController
|
||||
|
||||
version protocol.VersionNumber
|
||||
}
|
||||
|
||||
var _ SendStream = &sendStream{}
|
||||
var _ sendStreamI = &sendStream{}
|
||||
|
||||
func newSendStream(
|
||||
streamID protocol.StreamID,
|
||||
sender streamSender,
|
||||
flowController flowcontrol.StreamFlowController,
|
||||
version protocol.VersionNumber,
|
||||
) *sendStream {
|
||||
s := &sendStream{
|
||||
streamID: streamID,
|
||||
sender: sender,
|
||||
flowController: flowController,
|
||||
writeChan: make(chan struct{}, 1),
|
||||
version: version,
|
||||
}
|
||||
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *sendStream) StreamID() protocol.StreamID {
|
||||
return s.streamID // same for receiveStream and sendStream
|
||||
}
|
||||
|
||||
func (s *sendStream) Write(p []byte) (int, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.finishedWriting {
|
||||
return 0, fmt.Errorf("write on closed stream %d", s.streamID)
|
||||
}
|
||||
if s.canceledWrite {
|
||||
return 0, s.cancelWriteErr
|
||||
}
|
||||
if s.closeForShutdownErr != nil {
|
||||
return 0, s.closeForShutdownErr
|
||||
}
|
||||
if !s.writeDeadline.IsZero() && !time.Now().Before(s.writeDeadline) {
|
||||
return 0, errDeadline
|
||||
}
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
s.dataForWriting = make([]byte, len(p))
|
||||
copy(s.dataForWriting, p)
|
||||
s.sender.onHasStreamData(s.streamID)
|
||||
|
||||
var bytesWritten int
|
||||
var err error
|
||||
for {
|
||||
bytesWritten = len(p) - len(s.dataForWriting)
|
||||
deadline := s.writeDeadline
|
||||
if !deadline.IsZero() && !time.Now().Before(deadline) {
|
||||
s.dataForWriting = nil
|
||||
err = errDeadline
|
||||
break
|
||||
}
|
||||
if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown {
|
||||
break
|
||||
}
|
||||
|
||||
s.mutex.Unlock()
|
||||
if deadline.IsZero() {
|
||||
<-s.writeChan
|
||||
} else {
|
||||
select {
|
||||
case <-s.writeChan:
|
||||
case <-time.After(deadline.Sub(time.Now())):
|
||||
}
|
||||
}
|
||||
s.mutex.Lock()
|
||||
}
|
||||
|
||||
if s.closeForShutdownErr != nil {
|
||||
err = s.closeForShutdownErr
|
||||
} else if s.cancelWriteErr != nil {
|
||||
err = s.cancelWriteErr
|
||||
}
|
||||
return bytesWritten, err
|
||||
}
|
||||
|
||||
// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
|
||||
// maxBytes is the maximum length this frame (including frame header) will have.
|
||||
func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.closeForShutdownErr != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
frame := &wire.StreamFrame{
|
||||
StreamID: s.streamID,
|
||||
Offset: s.writeOffset,
|
||||
DataLenPresent: true,
|
||||
}
|
||||
frameLen := frame.MinLength(s.version)
|
||||
if frameLen >= maxBytes { // a STREAM frame must have at least one byte of data
|
||||
return nil, s.dataForWriting != nil
|
||||
}
|
||||
frame.Data, frame.FinBit = s.getDataForWriting(maxBytes - frameLen)
|
||||
if len(frame.Data) == 0 && !frame.FinBit {
|
||||
// this can happen if:
|
||||
// - popStreamFrame is called but there's no data for writing
|
||||
// - there's data for writing, but the stream is stream-level flow control blocked
|
||||
// - there's data for writing, but the stream is connection-level flow control blocked
|
||||
if s.dataForWriting == nil {
|
||||
return nil, false
|
||||
}
|
||||
isBlocked, _ := s.flowController.IsBlocked()
|
||||
return nil, !isBlocked
|
||||
}
|
||||
if frame.FinBit {
|
||||
s.finSent = true
|
||||
s.sender.onStreamCompleted(s.streamID)
|
||||
} else if s.streamID != s.version.CryptoStreamID() { // TODO(#657): Flow control for the crypto stream
|
||||
if isBlocked, offset := s.flowController.IsBlocked(); isBlocked {
|
||||
s.sender.queueControlFrame(&wire.StreamBlockedFrame{
|
||||
StreamID: s.streamID,
|
||||
Offset: offset,
|
||||
})
|
||||
return frame, false
|
||||
}
|
||||
}
|
||||
return frame, s.dataForWriting != nil
|
||||
}
|
||||
|
||||
func (s *sendStream) getDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) {
|
||||
if s.dataForWriting == nil {
|
||||
return nil, s.finishedWriting && !s.finSent
|
||||
}
|
||||
|
||||
// TODO(#657): Flow control for the crypto stream
|
||||
if s.streamID != s.version.CryptoStreamID() {
|
||||
maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize())
|
||||
}
|
||||
if maxBytes == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var ret []byte
|
||||
if protocol.ByteCount(len(s.dataForWriting)) > maxBytes {
|
||||
ret = s.dataForWriting[:maxBytes]
|
||||
s.dataForWriting = s.dataForWriting[maxBytes:]
|
||||
} else {
|
||||
ret = s.dataForWriting
|
||||
s.dataForWriting = nil
|
||||
s.signalWrite()
|
||||
}
|
||||
s.writeOffset += protocol.ByteCount(len(ret))
|
||||
s.flowController.AddBytesSent(protocol.ByteCount(len(ret)))
|
||||
return ret, s.finishedWriting && s.dataForWriting == nil && !s.finSent
|
||||
}
|
||||
|
||||
func (s *sendStream) Close() error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.canceledWrite {
|
||||
return fmt.Errorf("Close called for canceled stream %d", s.streamID)
|
||||
}
|
||||
s.finishedWriting = true
|
||||
s.sender.onHasStreamData(s.streamID) // need to send the FIN
|
||||
s.ctxCancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sendStream) CancelWrite(errorCode protocol.ApplicationErrorCode) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
return s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode))
|
||||
}
|
||||
|
||||
// must be called after locking the mutex
|
||||
func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, writeErr error) error {
|
||||
if s.canceledWrite {
|
||||
return nil
|
||||
}
|
||||
if s.finishedWriting {
|
||||
return fmt.Errorf("CancelWrite for closed stream %d", s.streamID)
|
||||
}
|
||||
s.canceledWrite = true
|
||||
s.cancelWriteErr = writeErr
|
||||
s.signalWrite()
|
||||
s.sender.queueControlFrame(&wire.RstStreamFrame{
|
||||
StreamID: s.streamID,
|
||||
ByteOffset: s.writeOffset,
|
||||
ErrorCode: errorCode,
|
||||
})
|
||||
// TODO(#991): cancel retransmissions for this stream
|
||||
s.ctxCancel()
|
||||
s.sender.onStreamCompleted(s.streamID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
s.handleStopSendingFrameImpl(frame)
|
||||
}
|
||||
|
||||
func (s *sendStream) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) {
|
||||
s.flowController.UpdateSendWindow(frame.ByteOffset)
|
||||
s.mutex.Lock()
|
||||
if s.dataForWriting != nil {
|
||||
s.sender.onHasStreamData(s.streamID)
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
}
|
||||
|
||||
// must be called after locking the mutex
|
||||
func (s *sendStream) handleStopSendingFrameImpl(frame *wire.StopSendingFrame) {
|
||||
writeErr := streamCanceledError{
|
||||
errorCode: frame.ErrorCode,
|
||||
error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode),
|
||||
}
|
||||
errorCode := errorCodeStopping
|
||||
if !s.version.UsesIETFFrameFormat() {
|
||||
errorCode = errorCodeStoppingGQUIC
|
||||
}
|
||||
s.cancelWriteImpl(errorCode, writeErr)
|
||||
}
|
||||
|
||||
func (s *sendStream) Context() context.Context {
|
||||
return s.ctx
|
||||
}
|
||||
|
||||
func (s *sendStream) SetWriteDeadline(t time.Time) error {
|
||||
s.mutex.Lock()
|
||||
oldDeadline := s.writeDeadline
|
||||
s.writeDeadline = t
|
||||
s.mutex.Unlock()
|
||||
if t.Before(oldDeadline) {
|
||||
s.signalWrite()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseForShutdown closes a stream abruptly.
|
||||
// It makes Write unblock (and return the error) immediately.
|
||||
// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
|
||||
func (s *sendStream) closeForShutdown(err error) {
|
||||
s.mutex.Lock()
|
||||
s.closedForShutdown = true
|
||||
s.closeForShutdownErr = err
|
||||
s.mutex.Unlock()
|
||||
s.signalWrite()
|
||||
s.ctxCancel()
|
||||
}
|
||||
|
||||
func (s *sendStream) getWriteOffset() protocol.ByteCount {
|
||||
return s.writeOffset
|
||||
}
|
||||
|
||||
// signalWrite performs a non-blocking send on the writeChan
|
||||
func (s *sendStream) signalWrite() {
|
||||
select {
|
||||
case s.writeChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
|
@ -0,0 +1,636 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/lucas-clemente/quic-go/internal/mocks"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/onsi/gomega/gbytes"
|
||||
)
|
||||
|
||||
var _ = Describe("Send Stream", func() {
|
||||
const streamID protocol.StreamID = 1337
|
||||
|
||||
var (
|
||||
str *sendStream
|
||||
strWithTimeout io.Writer // str wrapped with gbytes.TimeoutWriter
|
||||
mockFC *mocks.MockStreamFlowController
|
||||
mockSender *MockStreamSender
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
mockSender = NewMockStreamSender(mockCtrl)
|
||||
mockFC = mocks.NewMockStreamFlowController(mockCtrl)
|
||||
str = newSendStream(streamID, mockSender, mockFC, protocol.VersionWhatever)
|
||||
|
||||
timeout := scaleDuration(250 * time.Millisecond)
|
||||
strWithTimeout = gbytes.TimeoutWriter(str, timeout)
|
||||
})
|
||||
|
||||
waitForWrite := func() {
|
||||
EventuallyWithOffset(0, func() []byte {
|
||||
str.mutex.Lock()
|
||||
data := str.dataForWriting
|
||||
str.mutex.Unlock()
|
||||
return data
|
||||
}).ShouldNot(BeEmpty())
|
||||
}
|
||||
|
||||
It("gets stream id", func() {
|
||||
Expect(str.StreamID()).To(Equal(protocol.StreamID(1337)))
|
||||
})
|
||||
|
||||
Context("writing", func() {
|
||||
It("writes and gets all data at once", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
|
||||
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6))
|
||||
mockFC.EXPECT().IsBlocked()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
n, err := strWithTimeout.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(6))
|
||||
close(done)
|
||||
}()
|
||||
waitForWrite()
|
||||
f, _ := str.popStreamFrame(1000)
|
||||
Expect(f.Data).To(Equal([]byte("foobar")))
|
||||
Expect(f.FinBit).To(BeFalse())
|
||||
Expect(f.Offset).To(BeZero())
|
||||
Expect(f.DataLenPresent).To(BeTrue())
|
||||
Expect(str.writeOffset).To(Equal(protocol.ByteCount(6)))
|
||||
Expect(str.dataForWriting).To(BeNil())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("writes and gets data in two turns", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
frameHeaderLen := protocol.ByteCount(4)
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2)
|
||||
mockFC.EXPECT().AddBytesSent(gomock.Any() /* protocol.ByteCount(3)*/).Times(2)
|
||||
mockFC.EXPECT().IsBlocked().Times(2)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
n, err := strWithTimeout.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(6))
|
||||
close(done)
|
||||
}()
|
||||
waitForWrite()
|
||||
f, _ := str.popStreamFrame(3 + frameHeaderLen)
|
||||
Expect(f.Data).To(Equal([]byte("foo")))
|
||||
Expect(f.FinBit).To(BeFalse())
|
||||
Expect(f.Offset).To(BeZero())
|
||||
Expect(f.DataLenPresent).To(BeTrue())
|
||||
f, _ = str.popStreamFrame(100)
|
||||
Expect(f.Data).To(Equal([]byte("bar")))
|
||||
Expect(f.FinBit).To(BeFalse())
|
||||
Expect(f.Offset).To(Equal(protocol.ByteCount(3)))
|
||||
Expect(f.DataLenPresent).To(BeTrue())
|
||||
Expect(str.popStreamFrame(1000)).To(BeNil())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("popStreamFrame returns nil if no data is available", func() {
|
||||
frame, hasMoreData := str.popStreamFrame(1000)
|
||||
Expect(frame).To(BeNil())
|
||||
Expect(hasMoreData).To(BeFalse())
|
||||
})
|
||||
|
||||
It("says if it has more data for writing", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2)
|
||||
mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2)
|
||||
mockFC.EXPECT().IsBlocked().Times(2)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
n, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 100))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(100))
|
||||
close(done)
|
||||
}()
|
||||
waitForWrite()
|
||||
frame, hasMoreData := str.popStreamFrame(50)
|
||||
Expect(hasMoreData).To(BeTrue())
|
||||
frame, hasMoreData = str.popStreamFrame(1000)
|
||||
Expect(frame).ToNot(BeNil())
|
||||
Expect(hasMoreData).To(BeFalse())
|
||||
frame, _ = str.popStreamFrame(1000)
|
||||
Expect(frame).To(BeNil())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("copies the slice while writing", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
frameHeaderSize := protocol.ByteCount(4)
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2)
|
||||
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(1))
|
||||
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(2))
|
||||
mockFC.EXPECT().IsBlocked().Times(2)
|
||||
s := []byte("foo")
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
n, err := strWithTimeout.Write(s)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(3))
|
||||
close(done)
|
||||
}()
|
||||
waitForWrite()
|
||||
frame, _ := str.popStreamFrame(frameHeaderSize + 1)
|
||||
Expect(frame.Data).To(Equal([]byte("f")))
|
||||
s[1] = 'e'
|
||||
f, _ := str.popStreamFrame(100)
|
||||
Expect(f).ToNot(BeNil())
|
||||
Expect(f.Data).To(Equal([]byte("oo")))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("returns when given a nil input", func() {
|
||||
n, err := strWithTimeout.Write(nil)
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("returns when given an empty slice", func() {
|
||||
n, err := strWithTimeout.Write([]byte(""))
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("cancels the context when Close is called", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
Expect(str.Context().Done()).ToNot(BeClosed())
|
||||
str.Close()
|
||||
Expect(str.Context().Done()).To(BeClosed())
|
||||
})
|
||||
|
||||
Context("flow control blocking", func() {
|
||||
It("returns nil when it is blocked", func() {
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(0))
|
||||
mockFC.EXPECT().IsBlocked().Return(true, protocol.ByteCount(10))
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
waitForWrite()
|
||||
f, hasMoreData := str.popStreamFrame(1000)
|
||||
Expect(f).To(BeNil())
|
||||
Expect(hasMoreData).To(BeFalse())
|
||||
// make the Write go routine return
|
||||
str.closeForShutdown(nil)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("queues a BLOCKED frame if the stream is flow control blocked", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().queueControlFrame(&wire.StreamBlockedFrame{
|
||||
StreamID: streamID,
|
||||
Offset: 10,
|
||||
})
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
|
||||
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6))
|
||||
// don't use offset 6 here, to make sure the BLOCKED frame contains the number returned by the flow controller
|
||||
mockFC.EXPECT().IsBlocked().Return(true, protocol.ByteCount(10))
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
waitForWrite()
|
||||
f, hasMoreData := str.popStreamFrame(1000)
|
||||
Expect(f).ToNot(BeNil())
|
||||
Expect(hasMoreData).To(BeFalse())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("says that it doesn't have any more data, when it is flow control blocked", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
|
||||
mockFC.EXPECT().AddBytesSent(gomock.Any())
|
||||
mockFC.EXPECT().IsBlocked().Return(true, protocol.ByteCount(10))
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := str.Write(bytes.Repeat([]byte{0}, 100))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
waitForWrite()
|
||||
f, hasMoreData := str.popStreamFrame(50)
|
||||
Expect(f).ToNot(BeNil())
|
||||
Expect(hasMoreData).To(BeFalse())
|
||||
// make the Write go routine return
|
||||
str.closeForShutdown(nil)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't queue a BLOCKED frame if the stream is flow control blocked, but the frame popped has the FIN bit set", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID).Times(2) // once for the Write, once for the Close
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
|
||||
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6))
|
||||
// don't EXPECT a call to mockFC.IsBlocked
|
||||
// don't EXPECT a call to mockSender.queueControlFrame
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
waitForWrite()
|
||||
Expect(str.Close()).To(Succeed())
|
||||
f, hasMoreData := str.popStreamFrame(1000)
|
||||
Expect(hasMoreData).To(BeFalse())
|
||||
Expect(f).ToNot(BeNil())
|
||||
Expect(f.FinBit).To(BeTrue())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
Context("deadlines", func() {
|
||||
It("returns an error when Write is called after the deadline", func() {
|
||||
str.SetWriteDeadline(time.Now().Add(-time.Second))
|
||||
n, err := strWithTimeout.Write([]byte("foobar"))
|
||||
Expect(err).To(MatchError(errDeadline))
|
||||
Expect(n).To(BeZero())
|
||||
})
|
||||
|
||||
It("unblocks after the deadline", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
deadline := time.Now().Add(scaleDuration(50 * time.Millisecond))
|
||||
str.SetWriteDeadline(deadline)
|
||||
n, err := strWithTimeout.Write([]byte("foobar"))
|
||||
Expect(err).To(MatchError(errDeadline))
|
||||
Expect(n).To(BeZero())
|
||||
Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond)))
|
||||
})
|
||||
|
||||
It("returns the number of bytes written, when the deadline expires", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(10000)).AnyTimes()
|
||||
mockFC.EXPECT().AddBytesSent(gomock.Any())
|
||||
mockFC.EXPECT().IsBlocked()
|
||||
deadline := time.Now().Add(scaleDuration(50 * time.Millisecond))
|
||||
str.SetWriteDeadline(deadline)
|
||||
var n int
|
||||
writeReturned := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
n, err = strWithTimeout.Write(bytes.Repeat([]byte{0}, 100))
|
||||
Expect(err).To(MatchError(errDeadline))
|
||||
Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond)))
|
||||
close(writeReturned)
|
||||
}()
|
||||
waitForWrite()
|
||||
frame, hasMoreData := str.popStreamFrame(50)
|
||||
Expect(frame).ToNot(BeNil())
|
||||
Expect(hasMoreData).To(BeTrue())
|
||||
Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed())
|
||||
Expect(n).To(BeEquivalentTo(frame.DataLen()))
|
||||
})
|
||||
|
||||
It("doesn't pop any data after the deadline expired", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(10000)).AnyTimes()
|
||||
mockFC.EXPECT().AddBytesSent(gomock.Any())
|
||||
mockFC.EXPECT().IsBlocked()
|
||||
deadline := time.Now().Add(scaleDuration(50 * time.Millisecond))
|
||||
str.SetWriteDeadline(deadline)
|
||||
writeReturned := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 100))
|
||||
Expect(err).To(MatchError(errDeadline))
|
||||
close(writeReturned)
|
||||
}()
|
||||
waitForWrite()
|
||||
frame, hasMoreData := str.popStreamFrame(50)
|
||||
Expect(frame).ToNot(BeNil())
|
||||
Expect(hasMoreData).To(BeTrue())
|
||||
Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed())
|
||||
frame, hasMoreData = str.popStreamFrame(50)
|
||||
Expect(frame).To(BeNil())
|
||||
Expect(hasMoreData).To(BeFalse())
|
||||
})
|
||||
|
||||
It("doesn't unblock if the deadline is changed before the first one expires", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond))
|
||||
deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond))
|
||||
str.SetWriteDeadline(deadline1)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond))
|
||||
str.SetWriteDeadline(deadline2)
|
||||
// make sure that this was actually execute before the deadline expires
|
||||
Expect(time.Now()).To(BeTemporally("<", deadline1))
|
||||
close(done)
|
||||
}()
|
||||
runtime.Gosched()
|
||||
n, err := strWithTimeout.Write([]byte("foobar"))
|
||||
Expect(err).To(MatchError(errDeadline))
|
||||
Expect(n).To(BeZero())
|
||||
Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond)))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("unblocks earlier, when a new deadline is set", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond))
|
||||
deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond))
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
time.Sleep(scaleDuration(10 * time.Millisecond))
|
||||
str.SetWriteDeadline(deadline2)
|
||||
// make sure that this was actually execute before the deadline expires
|
||||
Expect(time.Now()).To(BeTemporally("<", deadline2))
|
||||
close(done)
|
||||
}()
|
||||
str.SetWriteDeadline(deadline1)
|
||||
runtime.Gosched()
|
||||
_, err := strWithTimeout.Write([]byte("foobar"))
|
||||
Expect(err).To(MatchError(errDeadline))
|
||||
Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond)))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
Context("closing", func() {
|
||||
It("doesn't allow writes after it has been closed", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
str.Close()
|
||||
_, err := strWithTimeout.Write([]byte("foobar"))
|
||||
Expect(err).To(MatchError("write on closed stream 1337"))
|
||||
})
|
||||
|
||||
It("allows FIN", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.Close()
|
||||
f, hasMoreData := str.popStreamFrame(1000)
|
||||
Expect(f).ToNot(BeNil())
|
||||
Expect(f.Data).To(BeEmpty())
|
||||
Expect(f.FinBit).To(BeTrue())
|
||||
Expect(hasMoreData).To(BeFalse())
|
||||
})
|
||||
|
||||
It("doesn't send a FIN when there's still data", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
frameHeaderLen := protocol.ByteCount(4)
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2)
|
||||
mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2)
|
||||
mockFC.EXPECT().IsBlocked()
|
||||
str.dataForWriting = []byte("foobar")
|
||||
Expect(str.Close()).To(Succeed())
|
||||
f, _ := str.popStreamFrame(3 + frameHeaderLen)
|
||||
Expect(f).ToNot(BeNil())
|
||||
Expect(f.Data).To(Equal([]byte("foo")))
|
||||
Expect(f.FinBit).To(BeFalse())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
f, _ = str.popStreamFrame(100)
|
||||
Expect(f.Data).To(Equal([]byte("bar")))
|
||||
Expect(f.FinBit).To(BeTrue())
|
||||
})
|
||||
|
||||
It("doesn't allow FIN after it is closed for shutdown", func() {
|
||||
str.closeForShutdown(errors.New("test"))
|
||||
f, hasMoreData := str.popStreamFrame(1000)
|
||||
Expect(f).To(BeNil())
|
||||
Expect(hasMoreData).To(BeFalse())
|
||||
})
|
||||
|
||||
It("doesn't allow FIN twice", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.Close()
|
||||
f, _ := str.popStreamFrame(1000)
|
||||
Expect(f).ToNot(BeNil())
|
||||
Expect(f.Data).To(BeEmpty())
|
||||
Expect(f.FinBit).To(BeTrue())
|
||||
f, hasMoreData := str.popStreamFrame(1000)
|
||||
Expect(f).To(BeNil())
|
||||
Expect(hasMoreData).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("closing for shutdown", func() {
|
||||
testErr := errors.New("test")
|
||||
|
||||
It("returns errors when the stream is cancelled", func() {
|
||||
str.closeForShutdown(testErr)
|
||||
n, err := strWithTimeout.Write([]byte("foo"))
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
It("doesn't get data for writing if an error occurred", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
|
||||
mockFC.EXPECT().AddBytesSent(gomock.Any())
|
||||
mockFC.EXPECT().IsBlocked()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 500))
|
||||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
waitForWrite()
|
||||
frame, hasMoreData := str.popStreamFrame(50) // get a STREAM frame containing some data, but not all
|
||||
Expect(frame).ToNot(BeNil())
|
||||
Expect(hasMoreData).To(BeTrue())
|
||||
str.closeForShutdown(testErr)
|
||||
frame, hasMoreData = str.popStreamFrame(1000)
|
||||
Expect(frame).To(BeNil())
|
||||
Expect(hasMoreData).To(BeFalse())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("cancels the context", func() {
|
||||
Expect(str.Context().Done()).ToNot(BeClosed())
|
||||
str.closeForShutdown(testErr)
|
||||
Expect(str.Context().Done()).To(BeClosed())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("handling MAX_STREAM_DATA frames", func() {
|
||||
It("informs the flow controller", func() {
|
||||
mockFC.EXPECT().UpdateSendWindow(protocol.ByteCount(0x1337))
|
||||
str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{
|
||||
StreamID: streamID,
|
||||
ByteOffset: 0x1337,
|
||||
})
|
||||
})
|
||||
|
||||
It("says when it has data for sending", func() {
|
||||
mockFC.EXPECT().UpdateSendWindow(gomock.Any())
|
||||
mockSender.EXPECT().onHasStreamData(streamID).Times(2) // once for Write, once for the MAX_STREAM_DATA frame
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
waitForWrite()
|
||||
str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{
|
||||
StreamID: streamID,
|
||||
ByteOffset: 42,
|
||||
})
|
||||
// make sure the Write go routine returns
|
||||
str.closeForShutdown(nil)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
Context("stream cancelations", func() {
|
||||
Context("canceling writing", func() {
|
||||
It("queues a RST_STREAM frame", func() {
|
||||
mockSender.EXPECT().queueControlFrame(&wire.RstStreamFrame{
|
||||
StreamID: streamID,
|
||||
ByteOffset: 1234,
|
||||
ErrorCode: 9876,
|
||||
})
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.writeOffset = 1234
|
||||
err := str.CancelWrite(9876)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("unblocks Write", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount)
|
||||
mockFC.EXPECT().AddBytesSent(gomock.Any())
|
||||
mockFC.EXPECT().IsBlocked()
|
||||
writeReturned := make(chan struct{})
|
||||
var n int
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
n, err = strWithTimeout.Write(bytes.Repeat([]byte{0}, 100))
|
||||
Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234"))
|
||||
close(writeReturned)
|
||||
}()
|
||||
waitForWrite()
|
||||
frame, _ := str.popStreamFrame(50)
|
||||
Expect(frame).ToNot(BeNil())
|
||||
err := str.CancelWrite(1234)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(writeReturned).Should(BeClosed())
|
||||
Expect(n).To(BeEquivalentTo(frame.DataLen()))
|
||||
})
|
||||
|
||||
It("cancels the context", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
Expect(str.Context().Done()).ToNot(BeClosed())
|
||||
str.CancelWrite(1234)
|
||||
Expect(str.Context().Done()).To(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't allow further calls to Write", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
err := str.CancelWrite(1234)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = strWithTimeout.Write([]byte("foobar"))
|
||||
Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234"))
|
||||
})
|
||||
|
||||
It("only cancels once", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
err := str.CancelWrite(1234)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = str.CancelWrite(4321)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("doesn't cancel when the stream was already closed", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
err := str.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = str.CancelWrite(123)
|
||||
Expect(err).To(MatchError("CancelWrite for closed stream 1337"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("receiving STOP_SENDING frames", func() {
|
||||
It("queues a RST_STREAM frames with error code Stopping", func() {
|
||||
mockSender.EXPECT().queueControlFrame(&wire.RstStreamFrame{
|
||||
StreamID: streamID,
|
||||
ErrorCode: errorCodeStopping,
|
||||
})
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.handleStopSendingFrame(&wire.StopSendingFrame{
|
||||
StreamID: streamID,
|
||||
ErrorCode: 101,
|
||||
})
|
||||
})
|
||||
|
||||
It("unblocks Write", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := str.Write([]byte("foobar"))
|
||||
Expect(err).To(MatchError("Stream 1337 was reset with error code 123"))
|
||||
Expect(err).To(BeAssignableToTypeOf(streamCanceledError{}))
|
||||
Expect(err.(streamCanceledError).Canceled()).To(BeTrue())
|
||||
Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(123)))
|
||||
close(done)
|
||||
}()
|
||||
waitForWrite()
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.handleStopSendingFrame(&wire.StopSendingFrame{
|
||||
StreamID: streamID,
|
||||
ErrorCode: 123,
|
||||
})
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't allow further calls to Write", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.handleStopSendingFrame(&wire.StopSendingFrame{
|
||||
StreamID: streamID,
|
||||
ErrorCode: 123,
|
||||
})
|
||||
_, err := str.Write([]byte("foobar"))
|
||||
Expect(err).To(MatchError("Stream 1337 was reset with error code 123"))
|
||||
Expect(err).To(BeAssignableToTypeOf(streamCanceledError{}))
|
||||
Expect(err.(streamCanceledError).Canceled()).To(BeTrue())
|
||||
Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(123)))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
|
@ -19,8 +19,8 @@ import (
|
|||
// packetHandler handles packets
|
||||
type packetHandler interface {
|
||||
Session
|
||||
getCryptoStream() cryptoStream
|
||||
handshakeStatus() <-chan handshakeEvent
|
||||
getCryptoStream() cryptoStreamI
|
||||
handshakeStatus() <-chan error
|
||||
handlePacket(*receivedPacket)
|
||||
GetVersion() protocol.VersionNumber
|
||||
run() error
|
||||
|
@ -40,15 +40,17 @@ type server struct {
|
|||
certChain crypto.CertChain
|
||||
scfg *handshake.ServerConfig
|
||||
|
||||
sessions map[protocol.ConnectionID]packetHandler
|
||||
sessionsMutex sync.RWMutex
|
||||
deleteClosedSessionsAfter time.Duration
|
||||
sessions map[protocol.ConnectionID]packetHandler
|
||||
closed bool
|
||||
|
||||
serverError error
|
||||
sessionQueue chan Session
|
||||
errorChan chan struct{}
|
||||
|
||||
// set as members, so they can be set in the tests
|
||||
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, error)
|
||||
deleteClosedSessionsAfter time.Duration
|
||||
}
|
||||
|
||||
var _ Listener = &server{}
|
||||
|
@ -240,6 +242,12 @@ func (s *server) Accept() (Session, error) {
|
|||
// Close the server
|
||||
func (s *server) Close() error {
|
||||
s.sessionsMutex.Lock()
|
||||
if s.closed {
|
||||
s.sessionsMutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
s.closed = true
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, session := range s.sessions {
|
||||
if session != nil {
|
||||
|
@ -254,10 +262,9 @@ func (s *server) Close() error {
|
|||
s.sessionsMutex.Unlock()
|
||||
wg.Wait()
|
||||
|
||||
if s.conn == nil {
|
||||
return nil
|
||||
}
|
||||
return s.conn.Close()
|
||||
err := s.conn.Close()
|
||||
<-s.errorChan // wait for serve() to return
|
||||
return err
|
||||
}
|
||||
|
||||
// Addr returns the server's network address
|
||||
|
@ -384,15 +391,9 @@ func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.C
|
|||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
ev := <-session.handshakeStatus()
|
||||
if ev.err != nil {
|
||||
if err := <-session.handshakeStatus(); err != nil {
|
||||
return
|
||||
}
|
||||
if ev.encLevel == protocol.EncryptionForwardSecure {
|
||||
break
|
||||
}
|
||||
}
|
||||
s.sessionQueue <- session
|
||||
}()
|
||||
}
|
||||
|
|
|
@ -28,8 +28,7 @@ type mockSession struct {
|
|||
closeReason error
|
||||
closedRemote bool
|
||||
stopRunLoop chan struct{} // run returns as soon as this channel receives a value
|
||||
handshakeChan chan handshakeEvent
|
||||
handshakeComplete chan error // for WaitUntilHandshakeComplete
|
||||
handshakeChan chan error
|
||||
}
|
||||
|
||||
func (s *mockSession) handlePacket(*receivedPacket) {
|
||||
|
@ -40,9 +39,6 @@ func (s *mockSession) run() error {
|
|||
<-s.stopRunLoop
|
||||
return s.closeReason
|
||||
}
|
||||
func (s *mockSession) WaitUntilHandshakeComplete() error {
|
||||
return <-s.handshakeComplete
|
||||
}
|
||||
func (s *mockSession) Close(e error) error {
|
||||
if s.closed {
|
||||
return nil
|
||||
|
@ -59,19 +55,19 @@ func (s *mockSession) closeRemote(e error) {
|
|||
close(s.stopRunLoop)
|
||||
}
|
||||
func (s *mockSession) OpenStream() (Stream, error) {
|
||||
return &stream{streamID: 1337}, nil
|
||||
return &stream{}, nil
|
||||
}
|
||||
func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") }
|
||||
func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") }
|
||||
func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") }
|
||||
func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") }
|
||||
func (*mockSession) Context() context.Context { panic("not implemented") }
|
||||
func (*mockSession) ConnectionState() ConnectionState { panic("not implemented") }
|
||||
func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever }
|
||||
func (s *mockSession) handshakeStatus() <-chan handshakeEvent { return s.handshakeChan }
|
||||
func (*mockSession) getCryptoStream() cryptoStream { panic("not implemented") }
|
||||
func (s *mockSession) handshakeStatus() <-chan error { return s.handshakeChan }
|
||||
func (*mockSession) getCryptoStream() cryptoStreamI { panic("not implemented") }
|
||||
|
||||
var _ Session = &mockSession{}
|
||||
var _ NonFWSession = &mockSession{}
|
||||
|
||||
func newMockSession(
|
||||
_ connection,
|
||||
|
@ -83,8 +79,7 @@ func newMockSession(
|
|||
) (packetHandler, error) {
|
||||
s := mockSession{
|
||||
connectionID: connectionID,
|
||||
handshakeChan: make(chan handshakeEvent),
|
||||
handshakeComplete: make(chan error),
|
||||
handshakeChan: make(chan error),
|
||||
stopRunLoop: make(chan struct{}),
|
||||
}
|
||||
return &s, nil
|
||||
|
@ -155,9 +150,8 @@ var _ = Describe("Server", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
sess := serv.sessions[connID].(*mockSession)
|
||||
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
|
||||
Consistently(func() Session { return acceptedSess }).Should(BeNil())
|
||||
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionForwardSecure}
|
||||
close(sess.handshakeChan)
|
||||
Eventually(func() Session { return acceptedSess }).Should(Equal(sess))
|
||||
close(done)
|
||||
}, 0.5)
|
||||
|
@ -173,7 +167,7 @@ var _ = Describe("Server", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
sess := serv.sessions[connID].(*mockSession)
|
||||
sess.handshakeChan <- handshakeEvent{err: errors.New("handshake failed")}
|
||||
sess.handshakeChan <- errors.New("handshake failed")
|
||||
Consistently(func() bool { return accepted }).Should(BeFalse())
|
||||
close(done)
|
||||
})
|
||||
|
@ -222,6 +216,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("closes sessions and the connection when Close is called", func() {
|
||||
go serv.serve()
|
||||
session, _ := newMockSession(nil, 0, 0, nil, nil, nil)
|
||||
serv.sessions[1] = session
|
||||
err := serv.Close()
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
type nullAEAD struct {
|
||||
|
@ -98,6 +99,26 @@ func (s *serverTLS) newMintConnImpl(bc *handshake.CryptoStreamConn, v protocol.V
|
|||
return tls, extHandler.GetPeerParams(), nil
|
||||
}
|
||||
|
||||
func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Header, aead crypto.AEAD, closeErr error) error {
|
||||
ccf := &wire.ConnectionCloseFrame{
|
||||
ErrorCode: qerr.HandshakeFailed,
|
||||
ReasonPhrase: closeErr.Error(),
|
||||
}
|
||||
replyHdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
ConnectionID: clientHdr.ConnectionID, // echo the client's connection ID
|
||||
PacketNumber: 1, // random packet number
|
||||
Version: clientHdr.Version,
|
||||
}
|
||||
data, err := packUnencryptedPacket(aead, replyHdr, ccf, protocol.PerspectiveServer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.conn.WriteTo(data, remoteAddr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, error) {
|
||||
if len(hdr.Raw)+len(data) < protocol.MinInitialPacketSize {
|
||||
return nil, errors.New("dropping too small Initial packet")
|
||||
|
@ -110,19 +131,30 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat
|
|||
}
|
||||
|
||||
// unpack packet and check stream frame contents
|
||||
version := hdr.Version
|
||||
aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.ConnectionID, version)
|
||||
aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.ConnectionID, hdr.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame, err := unpackInitialPacket(aead, hdr, data, version)
|
||||
frame, err := unpackInitialPacket(aead, hdr, data, hdr.Version)
|
||||
if err != nil {
|
||||
utils.Debugf("Error unpacking initial packet: %s", err)
|
||||
return nil, nil
|
||||
}
|
||||
sess, err := s.handleUnpackedInitial(remoteAddr, hdr, frame, aead)
|
||||
if err != nil {
|
||||
if ccerr := s.sendConnectionClose(remoteAddr, hdr, aead, err); ccerr != nil {
|
||||
utils.Debugf("Error sending CONNECTION_CLOSE: ", ccerr)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (packetHandler, error) {
|
||||
version := hdr.Version
|
||||
bc := handshake.NewCryptoStreamConn(remoteAddr)
|
||||
bc.AddDataForReading(frame.Data)
|
||||
tls, paramsChan, err := s.newMintConn(bc, hdr.Version)
|
||||
tls, paramsChan, err := s.newMintConn(bc, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -176,7 +208,7 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat
|
|||
return nil, err
|
||||
}
|
||||
cs := sess.getCryptoStream()
|
||||
cs.SetReadOffset(frame.DataLen())
|
||||
cs.setReadOffset(frame.DataLen())
|
||||
bc.SetStream(cs)
|
||||
return sess, nil
|
||||
}
|
||||
|
|
|
@ -4,15 +4,16 @@ import (
|
|||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/mocks"
|
||||
"github.com/lucas-clemente/quic-go/internal/mocks/handshake"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/mocks"
|
||||
"github.com/lucas-clemente/quic-go/internal/mocks/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
@ -65,6 +66,18 @@ var _ = Describe("Stateless TLS handling", func() {
|
|||
return hdr, data
|
||||
}
|
||||
|
||||
unpackPacket := func(data []byte) (*wire.Header, []byte) {
|
||||
r := bytes.NewReader(conn.dataWritten.Bytes())
|
||||
hdr, err := wire.ParseHeaderSentByServer(r, protocol.VersionTLS)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hdr.Raw = data[:len(data)-r.Len()]
|
||||
aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, hdr.ConnectionID, protocol.VersionTLS)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
payload, err := aead.Open(nil, data[len(data)-r.Len():], hdr.PacketNumber, hdr.Raw)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return hdr, payload
|
||||
}
|
||||
|
||||
It("sends a version negotiation packet if it doesn't support the version", func() {
|
||||
server.HandleInitial(nil, &wire.Header{Version: 0x1337}, bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize))
|
||||
Expect(conn.dataWritten.Len()).ToNot(BeZero())
|
||||
|
@ -124,4 +137,20 @@ var _ = Describe("Stateless TLS handling", func() {
|
|||
Eventually(sessionChan).Should(Receive())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("sends a CONNECTION_CLOSE, if mint returns an error", func() {
|
||||
mintTLS.EXPECT().Handshake().Return(mint.AlertAccessDenied)
|
||||
extHandler.EXPECT().GetPeerParams()
|
||||
hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")})
|
||||
server.HandleInitial(nil, hdr, data)
|
||||
// the Handshake packet is written by the session
|
||||
Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty())
|
||||
// unpack the packet to check that it actually contains a CONNECTION_CLOSE
|
||||
hdr, data = unpackPacket(conn.dataWritten.Bytes())
|
||||
Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake))
|
||||
ccf, err := wire.ParseConnectionCloseFrame(bytes.NewReader(data), protocol.VersionTLS)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(ccf.ErrorCode).To(Equal(qerr.HandshakeFailed))
|
||||
Expect(ccf.ReasonPhrase).To(Equal(mint.AlertAccessDenied.String()))
|
||||
})
|
||||
})
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -24,6 +23,23 @@ type unpacker interface {
|
|||
Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error)
|
||||
}
|
||||
|
||||
type streamGetter interface {
|
||||
GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error)
|
||||
GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error)
|
||||
}
|
||||
|
||||
type streamManager interface {
|
||||
GetOrOpenStream(protocol.StreamID) (streamI, error)
|
||||
GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error)
|
||||
GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error)
|
||||
OpenStream() (Stream, error)
|
||||
OpenStreamSync() (Stream, error)
|
||||
AcceptStream() (Stream, error)
|
||||
DeleteStream(protocol.StreamID) error
|
||||
UpdateLimits(*handshake.TransportParameters)
|
||||
CloseWithError(error)
|
||||
}
|
||||
|
||||
type receivedPacket struct {
|
||||
remoteAddr net.Addr
|
||||
header *wire.Header
|
||||
|
@ -36,11 +52,6 @@ var (
|
|||
newCryptoSetupClient = handshake.NewCryptoSetupClient
|
||||
)
|
||||
|
||||
type handshakeEvent struct {
|
||||
encLevel protocol.EncryptionLevel
|
||||
err error
|
||||
}
|
||||
|
||||
type closeError struct {
|
||||
err error
|
||||
remote bool
|
||||
|
@ -55,15 +66,15 @@ type session struct {
|
|||
|
||||
conn connection
|
||||
|
||||
streamsMap *streamsMap
|
||||
cryptoStream cryptoStream
|
||||
streamsMap streamManager
|
||||
cryptoStream cryptoStreamI
|
||||
|
||||
rttStats *congestion.RTTStats
|
||||
|
||||
sentPacketHandler ackhandler.SentPacketHandler
|
||||
receivedPacketHandler ackhandler.ReceivedPacketHandler
|
||||
streamFramer *streamFramer
|
||||
|
||||
windowUpdateQueue *windowUpdateQueue
|
||||
connFlowController flowcontrol.ConnectionFlowController
|
||||
|
||||
unpacker unpacker
|
||||
|
@ -87,17 +98,14 @@ type session struct {
|
|||
|
||||
// this channel is passed to the CryptoSetup and receives the transport parameters, as soon as the peer sends them
|
||||
paramsChan <-chan handshake.TransportParameters
|
||||
// this channel is passed to the CryptoSetup and receives the current encryption level
|
||||
// it is closed as soon as the handshake is complete
|
||||
aeadChanged <-chan protocol.EncryptionLevel
|
||||
// the handshakeEvent channel is passed to the CryptoSetup.
|
||||
// It receives when it makes sense to try decrypting undecryptable packets.
|
||||
handshakeEvent <-chan struct{}
|
||||
// handshakeChan is returned by handshakeStatus.
|
||||
// It receives any error that might occur during the handshake.
|
||||
// It is closed when the handshake is complete.
|
||||
handshakeChan chan error
|
||||
handshakeComplete bool
|
||||
// will be closed as soon as the handshake completes, and receive any error that might occur until then
|
||||
// it is used to block WaitUntilHandshakeComplete()
|
||||
handshakeCompleteChan chan error
|
||||
// handshakeChan receives handshake events and is closed as soon the handshake completes
|
||||
// the receiving end of this channel is passed to the creator of the session
|
||||
// it receives at most 3 handshake events: 2 when the encryption level changes, and one error
|
||||
handshakeChan chan handshakeEvent
|
||||
|
||||
lastRcvdPacketNumber protocol.PacketNumber
|
||||
// Used to calculate the next packet number from the truncated wire
|
||||
|
@ -116,6 +124,7 @@ type session struct {
|
|||
}
|
||||
|
||||
var _ Session = &session{}
|
||||
var _ streamSender = &session{}
|
||||
|
||||
// newSession makes a new session
|
||||
func newSession(
|
||||
|
@ -127,14 +136,14 @@ func newSession(
|
|||
config *Config,
|
||||
) (packetHandler, error) {
|
||||
paramsChan := make(chan handshake.TransportParameters)
|
||||
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
||||
handshakeEvent := make(chan struct{}, 1)
|
||||
s := &session{
|
||||
conn: conn,
|
||||
connectionID: connectionID,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
version: v,
|
||||
config: config,
|
||||
aeadChanged: aeadChanged,
|
||||
handshakeEvent: handshakeEvent,
|
||||
paramsChan: paramsChan,
|
||||
}
|
||||
s.preSetup()
|
||||
|
@ -154,7 +163,7 @@ func newSession(
|
|||
s.config.Versions,
|
||||
s.config.AcceptCookie,
|
||||
paramsChan,
|
||||
aeadChanged,
|
||||
handshakeEvent,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -175,14 +184,14 @@ var newClientSession = func(
|
|||
negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton
|
||||
) (packetHandler, error) {
|
||||
paramsChan := make(chan handshake.TransportParameters)
|
||||
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
||||
handshakeEvent := make(chan struct{}, 1)
|
||||
s := &session{
|
||||
conn: conn,
|
||||
connectionID: connectionID,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
version: v,
|
||||
config: config,
|
||||
aeadChanged: aeadChanged,
|
||||
handshakeEvent: handshakeEvent,
|
||||
paramsChan: paramsChan,
|
||||
}
|
||||
s.preSetup()
|
||||
|
@ -201,7 +210,7 @@ var newClientSession = func(
|
|||
tlsConf,
|
||||
transportParams,
|
||||
paramsChan,
|
||||
aeadChanged,
|
||||
handshakeEvent,
|
||||
initialVersion,
|
||||
negotiatedVersions,
|
||||
)
|
||||
|
@ -223,21 +232,21 @@ func newTLSServerSession(
|
|||
peerParams *handshake.TransportParameters,
|
||||
v protocol.VersionNumber,
|
||||
) (packetHandler, error) {
|
||||
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
||||
handshakeEvent := make(chan struct{}, 1)
|
||||
s := &session{
|
||||
conn: conn,
|
||||
config: config,
|
||||
connectionID: connectionID,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
version: v,
|
||||
aeadChanged: aeadChanged,
|
||||
handshakeEvent: handshakeEvent,
|
||||
}
|
||||
s.preSetup()
|
||||
s.cryptoSetup = handshake.NewCryptoSetupTLSServer(
|
||||
tls,
|
||||
cryptoStreamConn,
|
||||
nullAEAD,
|
||||
aeadChanged,
|
||||
handshakeEvent,
|
||||
v,
|
||||
)
|
||||
if err := s.postSetup(initialPacketNumber); err != nil {
|
||||
|
@ -260,14 +269,14 @@ var newTLSClientSession = func(
|
|||
paramsChan <-chan handshake.TransportParameters,
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
) (packetHandler, error) {
|
||||
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
||||
handshakeEvent := make(chan struct{}, 1)
|
||||
s := &session{
|
||||
conn: conn,
|
||||
config: config,
|
||||
connectionID: connectionID,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
version: v,
|
||||
aeadChanged: aeadChanged,
|
||||
handshakeEvent: handshakeEvent,
|
||||
paramsChan: paramsChan,
|
||||
}
|
||||
s.preSetup()
|
||||
|
@ -276,7 +285,7 @@ var newTLSClientSession = func(
|
|||
s.cryptoStream,
|
||||
s.connectionID,
|
||||
hostname,
|
||||
aeadChanged,
|
||||
handshakeEvent,
|
||||
tls,
|
||||
v,
|
||||
)
|
||||
|
@ -294,12 +303,11 @@ func (s *session) preSetup() {
|
|||
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
|
||||
s.rttStats,
|
||||
)
|
||||
s.cryptoStream = s.newStream(s.version.CryptoStreamID()).(cryptoStream)
|
||||
s.cryptoStream = s.newCryptoStream()
|
||||
}
|
||||
|
||||
func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error {
|
||||
s.handshakeChan = make(chan handshakeEvent, 3)
|
||||
s.handshakeCompleteChan = make(chan error, 1)
|
||||
s.handshakeChan = make(chan error, 1)
|
||||
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
|
||||
s.closeChan = make(chan closeError, 1)
|
||||
s.sendingScheduled = make(chan struct{}, 1)
|
||||
|
@ -314,9 +322,12 @@ func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error {
|
|||
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
|
||||
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version)
|
||||
|
||||
s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.version)
|
||||
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.connFlowController, s.version)
|
||||
|
||||
if s.version.UsesTLS() {
|
||||
s.streamsMap = newStreamsMap(s.newStream, s.perspective)
|
||||
} else {
|
||||
s.streamsMap = newStreamsMapLegacy(s.newStream, s.perspective)
|
||||
}
|
||||
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version)
|
||||
s.packer = newPacketPacker(s.connectionID,
|
||||
initialPacketNumber,
|
||||
s.cryptoSetup,
|
||||
|
@ -324,6 +335,7 @@ func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error {
|
|||
s.perspective,
|
||||
s.version,
|
||||
)
|
||||
s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.cryptoStream, s.packer.QueueControlFrame)
|
||||
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
|
||||
return nil
|
||||
}
|
||||
|
@ -339,7 +351,7 @@ func (s *session) run() error {
|
|||
}()
|
||||
|
||||
var closeErr closeError
|
||||
aeadChanged := s.aeadChanged
|
||||
handshakeEvent := s.handshakeEvent
|
||||
|
||||
runLoop:
|
||||
for {
|
||||
|
@ -377,16 +389,20 @@ runLoop:
|
|||
putPacketBuffer(p.header.Raw)
|
||||
case p := <-s.paramsChan:
|
||||
s.processTransportParameters(&p)
|
||||
case l, ok := <-aeadChanged:
|
||||
case _, ok := <-handshakeEvent:
|
||||
if !ok { // the aeadChanged chan was closed. This means that the handshake is completed.
|
||||
s.handshakeComplete = true
|
||||
aeadChanged = nil // prevent this case from ever being selected again
|
||||
handshakeEvent = nil // prevent this case from ever being selected again
|
||||
s.sentPacketHandler.SetHandshakeComplete()
|
||||
if !s.version.UsesTLS() && s.perspective == protocol.PerspectiveClient {
|
||||
// In gQUIC, there's no equivalent to the Finished message in TLS
|
||||
// The server knows that the handshake is complete when it receives the first forward-secure packet sent by the client.
|
||||
// We need to make sure that the client actually sends such a packet.
|
||||
s.packer.QueueControlFrame(&wire.PingFrame{})
|
||||
}
|
||||
close(s.handshakeChan)
|
||||
close(s.handshakeCompleteChan)
|
||||
} else {
|
||||
s.tryDecryptingQueuedPackets()
|
||||
s.handshakeChan <- handshakeEvent{encLevel: l}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -403,9 +419,26 @@ runLoop:
|
|||
s.keepAlivePingSent = true
|
||||
}
|
||||
|
||||
if err := s.sendPacket(); err != nil {
|
||||
sendingAllowed := s.sentPacketHandler.SendingAllowed()
|
||||
if !sendingAllowed { // if congestion limited, at least try sending an ACK frame
|
||||
if err := s.maybeSendAckOnlyPacket(); err != nil {
|
||||
s.closeLocal(err)
|
||||
}
|
||||
} else {
|
||||
// repeatedly try sending until we don't have any more data, or run out of the congestion window
|
||||
for sendingAllowed {
|
||||
sentPacket, err := s.sendPacket()
|
||||
if err != nil {
|
||||
s.closeLocal(err)
|
||||
break
|
||||
}
|
||||
if !sentPacket {
|
||||
break
|
||||
}
|
||||
sendingAllowed = s.sentPacketHandler.SendingAllowed()
|
||||
}
|
||||
}
|
||||
|
||||
if !s.receivedTooManyUndecrytablePacketsTime.IsZero() && s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout).Before(now) && len(s.undecryptablePackets) != 0 {
|
||||
s.closeLocal(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received"))
|
||||
}
|
||||
|
@ -415,17 +448,12 @@ runLoop:
|
|||
if s.handshakeComplete && now.Sub(s.lastNetworkActivityTime) >= s.config.IdleTimeout {
|
||||
s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
|
||||
}
|
||||
|
||||
if err := s.streamsMap.DeleteClosedStreams(); err != nil {
|
||||
s.closeLocal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// only send the error the handshakeChan when the handshake is not completed yet
|
||||
// otherwise this chan will already be closed
|
||||
if !s.handshakeComplete {
|
||||
s.handshakeCompleteChan <- closeErr.err
|
||||
s.handshakeChan <- handshakeEvent{err: closeErr.err}
|
||||
s.handshakeChan <- closeErr.err
|
||||
}
|
||||
s.handleCloseError(closeErr)
|
||||
return closeErr.err
|
||||
|
@ -435,6 +463,10 @@ func (s *session) Context() context.Context {
|
|||
return s.ctx
|
||||
}
|
||||
|
||||
func (s *session) ConnectionState() ConnectionState {
|
||||
return s.cryptoSetup.ConnectionState()
|
||||
}
|
||||
|
||||
func (s *session) maybeResetTimer() {
|
||||
var deadline time.Time
|
||||
if s.config.KeepAlive && s.handshakeComplete && !s.keepAlivePingSent {
|
||||
|
@ -504,7 +536,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
|
|||
s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber)
|
||||
|
||||
isRetransmittable := ackhandler.HasRetransmittableFrames(packet.frames)
|
||||
if err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, isRetransmittable); err != nil {
|
||||
if err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, p.rcvTime, isRetransmittable); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -524,8 +556,7 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve
|
|||
s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase))
|
||||
case *wire.GoawayFrame:
|
||||
err = errors.New("unimplemented: handling GOAWAY frames")
|
||||
case *wire.StopWaitingFrame:
|
||||
s.receivedPacketHandler.IgnoreBelow(frame.LeastUnacked)
|
||||
case *wire.StopWaitingFrame: // ignore STOP_WAITINGs
|
||||
case *wire.RstStreamFrame:
|
||||
err = s.handleRstStreamFrame(frame)
|
||||
case *wire.MaxDataFrame:
|
||||
|
@ -534,6 +565,8 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve
|
|||
err = s.handleMaxStreamDataFrame(frame)
|
||||
case *wire.BlockedFrame:
|
||||
case *wire.StreamBlockedFrame:
|
||||
case *wire.StopSendingFrame:
|
||||
err = s.handleStopSendingFrame(frame)
|
||||
case *wire.PingFrame:
|
||||
default:
|
||||
return errors.New("Session BUG: unexpected frame type")
|
||||
|
@ -563,9 +596,12 @@ func (s *session) handlePacket(p *receivedPacket) {
|
|||
|
||||
func (s *session) handleStreamFrame(frame *wire.StreamFrame) error {
|
||||
if frame.StreamID == s.version.CryptoStreamID() {
|
||||
return s.cryptoStream.AddStreamFrame(frame)
|
||||
if frame.FinBit {
|
||||
return errors.New("Received STREAM frame with FIN bit for the crypto stream")
|
||||
}
|
||||
str, err := s.streamsMap.GetOrOpenStream(frame.StreamID)
|
||||
return s.cryptoStream.handleStreamFrame(frame)
|
||||
}
|
||||
str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -574,7 +610,7 @@ func (s *session) handleStreamFrame(frame *wire.StreamFrame) error {
|
|||
// ignore this StreamFrame
|
||||
return nil
|
||||
}
|
||||
return str.AddStreamFrame(frame)
|
||||
return str.handleStreamFrame(frame)
|
||||
}
|
||||
|
||||
func (s *session) handleMaxDataFrame(frame *wire.MaxDataFrame) {
|
||||
|
@ -582,7 +618,11 @@ func (s *session) handleMaxDataFrame(frame *wire.MaxDataFrame) {
|
|||
}
|
||||
|
||||
func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error {
|
||||
str, err := s.streamsMap.GetOrOpenStream(frame.StreamID)
|
||||
if frame.StreamID == s.version.CryptoStreamID() {
|
||||
s.cryptoStream.handleMaxStreamDataFrame(frame)
|
||||
return nil
|
||||
}
|
||||
str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -590,12 +630,15 @@ func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error
|
|||
// stream is closed and already garbage collected
|
||||
return nil
|
||||
}
|
||||
str.UpdateSendWindow(frame.ByteOffset)
|
||||
str.handleMaxStreamDataFrame(frame)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
|
||||
str, err := s.streamsMap.GetOrOpenStream(frame.StreamID)
|
||||
if frame.StreamID == s.version.CryptoStreamID() {
|
||||
return errors.New("Received RST_STREAM frame for the crypto stream")
|
||||
}
|
||||
str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -603,11 +646,31 @@ func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
|
|||
// stream is closed and already garbage collected
|
||||
return nil
|
||||
}
|
||||
return str.RegisterRemoteError(fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode), frame.ByteOffset)
|
||||
return str.handleRstStreamFrame(frame)
|
||||
}
|
||||
|
||||
func (s *session) handleStopSendingFrame(frame *wire.StopSendingFrame) error {
|
||||
if frame.StreamID == s.version.CryptoStreamID() {
|
||||
return errors.New("Received a STOP_SENDING frame for the crypto stream")
|
||||
}
|
||||
str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if str == nil {
|
||||
// stream is closed and already garbage collected
|
||||
return nil
|
||||
}
|
||||
str.handleStopSendingFrame(frame)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error {
|
||||
return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime)
|
||||
if err := s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime); err != nil {
|
||||
return err
|
||||
}
|
||||
s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) closeLocal(e error) {
|
||||
|
@ -647,7 +710,7 @@ func (s *session) handleCloseError(closeErr closeError) error {
|
|||
utils.Errorf("Closing session with error: %s", closeErr.err.Error())
|
||||
}
|
||||
|
||||
s.cryptoStream.Cancel(quicErr)
|
||||
s.cryptoStream.closeForShutdown(quicErr)
|
||||
s.streamsMap.CloseWithError(quicErr)
|
||||
|
||||
if closeErr.err == errCloseSessionForNewVersion || closeErr.err == handshake.ErrCloseSessionForRetry {
|
||||
|
@ -669,42 +732,27 @@ func (s *session) handleCloseError(closeErr closeError) error {
|
|||
|
||||
func (s *session) processTransportParameters(params *handshake.TransportParameters) {
|
||||
s.peerParams = params
|
||||
s.streamsMap.UpdateMaxStreamLimit(params.MaxStreams)
|
||||
s.streamsMap.UpdateLimits(params)
|
||||
if params.OmitConnectionID {
|
||||
s.packer.SetOmitConnectionID()
|
||||
}
|
||||
s.connFlowController.UpdateSendWindow(params.ConnectionFlowControlWindow)
|
||||
s.streamsMap.Range(func(str streamI) {
|
||||
str.UpdateSendWindow(params.StreamFlowControlWindow)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *session) sendPacket() error {
|
||||
s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked())
|
||||
|
||||
// Get MAX_DATA and MAX_STREAM_DATA frames
|
||||
// this call triggers the flow controller to increase the flow control windows, if necessary
|
||||
windowUpdates := s.getWindowUpdates()
|
||||
for _, f := range windowUpdates {
|
||||
s.packer.QueueControlFrame(f)
|
||||
// the crypto stream is the only open stream at this moment
|
||||
// so we don't need to update stream flow control windows
|
||||
}
|
||||
|
||||
func (s *session) maybeSendAckOnlyPacket() error {
|
||||
ack := s.receivedPacketHandler.GetAckFrame()
|
||||
if ack != nil {
|
||||
s.packer.QueueControlFrame(ack)
|
||||
}
|
||||
|
||||
// Repeatedly try sending until we don't have any more data, or run out of the congestion window
|
||||
for {
|
||||
if !s.sentPacketHandler.SendingAllowed() {
|
||||
if ack == nil {
|
||||
return nil
|
||||
}
|
||||
// If we aren't allowed to send, at least try sending an ACK frame
|
||||
swf := s.sentPacketHandler.GetStopWaitingFrame(false)
|
||||
if swf != nil {
|
||||
s.packer.QueueControlFrame(ack)
|
||||
|
||||
if !s.version.UsesIETFFrameFormat() { // for gQUIC, maybe add a STOP_WAITING
|
||||
if swf := s.sentPacketHandler.GetStopWaitingFrame(false); swf != nil {
|
||||
s.packer.QueueControlFrame(swf)
|
||||
}
|
||||
}
|
||||
packet, err := s.packer.PackAckPacket()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -712,6 +760,22 @@ func (s *session) sendPacket() error {
|
|||
return s.sendPackedPacket(packet)
|
||||
}
|
||||
|
||||
func (s *session) sendPacket() (bool, error) {
|
||||
s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked())
|
||||
|
||||
if offset := s.connFlowController.GetWindowUpdate(); offset != 0 {
|
||||
s.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: offset})
|
||||
}
|
||||
if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked {
|
||||
s.packer.QueueControlFrame(&wire.BlockedFrame{Offset: offset})
|
||||
}
|
||||
s.windowUpdateQueue.QueueAll()
|
||||
|
||||
ack := s.receivedPacketHandler.GetAckFrame()
|
||||
if ack != nil {
|
||||
s.packer.QueueControlFrame(ack)
|
||||
}
|
||||
|
||||
// check for retransmissions first
|
||||
for {
|
||||
retransmitPacket := s.sentPacketHandler.DequeuePacketForRetransmission()
|
||||
|
@ -719,21 +783,23 @@ func (s *session) sendPacket() error {
|
|||
break
|
||||
}
|
||||
|
||||
// retransmit handshake packets
|
||||
if retransmitPacket.EncryptionLevel != protocol.EncryptionForwardSecure {
|
||||
if s.handshakeComplete {
|
||||
// Don't retransmit handshake packets when the handshake is complete
|
||||
continue
|
||||
}
|
||||
utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber)
|
||||
if !s.version.UsesIETFFrameFormat() {
|
||||
s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true))
|
||||
}
|
||||
packet, err := s.packer.PackHandshakeRetransmission(retransmitPacket)
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
if err = s.sendPackedPacket(packet); err != nil {
|
||||
return err
|
||||
if err := s.sendPackedPacket(packet); err != nil {
|
||||
return false, err
|
||||
}
|
||||
} else {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// queue all retransmittable frames sent in forward-secure packets
|
||||
utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber)
|
||||
// resend the frames that were in the packet
|
||||
for _, frame := range retransmitPacket.GetFramesForRetransmission() {
|
||||
|
@ -746,34 +812,25 @@ func (s *session) sendPacket() error {
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hasRetransmission := s.streamFramer.HasFramesForRetransmission()
|
||||
if ack != nil || hasRetransmission {
|
||||
swf := s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission)
|
||||
if swf != nil {
|
||||
if !s.version.UsesIETFFrameFormat() && (ack != nil || hasRetransmission) {
|
||||
if swf := s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission); swf != nil {
|
||||
s.packer.QueueControlFrame(swf)
|
||||
}
|
||||
}
|
||||
// add a retransmittable frame
|
||||
if s.sentPacketHandler.ShouldSendRetransmittablePacket() {
|
||||
s.packer.QueueControlFrame(&wire.PingFrame{})
|
||||
s.packer.MakeNextPacketRetransmittable()
|
||||
}
|
||||
packet, err := s.packer.PackPacket()
|
||||
if err != nil || packet == nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
if err = s.sendPackedPacket(packet); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// send every window update twice
|
||||
for _, f := range windowUpdates {
|
||||
s.packer.QueueControlFrame(f)
|
||||
}
|
||||
windowUpdates = nil
|
||||
ack = nil
|
||||
if err := s.sendPackedPacket(packet); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *session) sendPackedPacket(packet *packedPacket) error {
|
||||
|
@ -824,7 +881,7 @@ func (s *session) GetOrOpenStream(id protocol.StreamID) (Stream, error) {
|
|||
return str, err
|
||||
}
|
||||
// make sure to return an actual nil value here, not an Stream with value nil
|
||||
return nil, err
|
||||
return str, err
|
||||
}
|
||||
|
||||
// AcceptStream returns the next stream openend by the peer
|
||||
|
@ -841,18 +898,6 @@ func (s *session) OpenStreamSync() (Stream, error) {
|
|||
return s.streamsMap.OpenStreamSync()
|
||||
}
|
||||
|
||||
func (s *session) WaitUntilHandshakeComplete() error {
|
||||
return <-s.handshakeCompleteChan
|
||||
}
|
||||
|
||||
func (s *session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) {
|
||||
s.packer.QueueControlFrame(&wire.RstStreamFrame{
|
||||
StreamID: id,
|
||||
ByteOffset: offset,
|
||||
})
|
||||
s.scheduleSending()
|
||||
}
|
||||
|
||||
func (s *session) newStream(id protocol.StreamID) streamI {
|
||||
var initialSendWindow protocol.ByteCount
|
||||
if s.peerParams != nil {
|
||||
|
@ -867,7 +912,21 @@ func (s *session) newStream(id protocol.StreamID) streamI {
|
|||
initialSendWindow,
|
||||
s.rttStats,
|
||||
)
|
||||
return newStream(id, s.scheduleSending, s.queueResetStreamFrame, flowController, s.version)
|
||||
return newStream(id, s, flowController, s.version)
|
||||
}
|
||||
|
||||
func (s *session) newCryptoStream() cryptoStreamI {
|
||||
id := s.version.CryptoStreamID()
|
||||
flowController := flowcontrol.NewStreamFlowController(
|
||||
id,
|
||||
s.version.StreamContributesToConnectionFlowControl(id),
|
||||
s.connFlowController,
|
||||
protocol.ReceiveStreamFlowControlWindow,
|
||||
protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow),
|
||||
0,
|
||||
s.rttStats,
|
||||
)
|
||||
return newCryptoStream(s, flowController, s.version)
|
||||
}
|
||||
|
||||
func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error {
|
||||
|
@ -908,22 +967,25 @@ func (s *session) tryDecryptingQueuedPackets() {
|
|||
s.undecryptablePackets = s.undecryptablePackets[:0]
|
||||
}
|
||||
|
||||
func (s *session) getWindowUpdates() []wire.Frame {
|
||||
var res []wire.Frame
|
||||
s.streamsMap.Range(func(str streamI) {
|
||||
if offset := str.GetWindowUpdate(); offset != 0 {
|
||||
res = append(res, &wire.MaxStreamDataFrame{
|
||||
StreamID: str.StreamID(),
|
||||
ByteOffset: offset,
|
||||
})
|
||||
func (s *session) queueControlFrame(f wire.Frame) {
|
||||
s.packer.QueueControlFrame(f)
|
||||
s.scheduleSending()
|
||||
}
|
||||
})
|
||||
if offset := s.connFlowController.GetWindowUpdate(); offset != 0 {
|
||||
res = append(res, &wire.MaxDataFrame{
|
||||
ByteOffset: offset,
|
||||
})
|
||||
|
||||
func (s *session) onHasWindowUpdate(id protocol.StreamID) {
|
||||
s.windowUpdateQueue.Add(id)
|
||||
s.scheduleSending()
|
||||
}
|
||||
|
||||
func (s *session) onHasStreamData(id protocol.StreamID) {
|
||||
s.streamFramer.AddActiveStream(id)
|
||||
s.scheduleSending()
|
||||
}
|
||||
|
||||
func (s *session) onStreamCompleted(id protocol.StreamID) {
|
||||
if err := s.streamsMap.DeleteStream(id); err != nil {
|
||||
s.Close(err)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (s *session) LocalAddr() net.Addr {
|
||||
|
@ -935,11 +997,11 @@ func (s *session) RemoteAddr() net.Addr {
|
|||
return s.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (s *session) handshakeStatus() <-chan handshakeEvent {
|
||||
func (s *session) handshakeStatus() <-chan error {
|
||||
return s.handshakeChan
|
||||
}
|
||||
|
||||
func (s *session) getCryptoStream() cryptoStream {
|
||||
func (s *session) getCryptoStream() cryptoStreamI {
|
||||
return s.cryptoStream
|
||||
}
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,88 +1,85 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
const (
|
||||
errorCodeStopping protocol.ApplicationErrorCode = 0
|
||||
errorCodeStoppingGQUIC protocol.ApplicationErrorCode = 7
|
||||
)
|
||||
|
||||
// The streamSender is notified by the stream about various events.
|
||||
type streamSender interface {
|
||||
queueControlFrame(wire.Frame)
|
||||
onHasWindowUpdate(protocol.StreamID)
|
||||
onHasStreamData(protocol.StreamID)
|
||||
onStreamCompleted(protocol.StreamID)
|
||||
}
|
||||
|
||||
// Each of the both stream halves gets its own uniStreamSender.
|
||||
// This is necessary in order to keep track when both halves have been completed.
|
||||
type uniStreamSender struct {
|
||||
streamSender
|
||||
onStreamCompletedImpl func()
|
||||
}
|
||||
|
||||
func (s *uniStreamSender) queueControlFrame(f wire.Frame) {
|
||||
s.streamSender.queueControlFrame(f)
|
||||
}
|
||||
|
||||
func (s *uniStreamSender) onHasWindowUpdate(id protocol.StreamID) {
|
||||
s.streamSender.onHasWindowUpdate(id)
|
||||
}
|
||||
|
||||
func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) {
|
||||
s.streamSender.onHasStreamData(id)
|
||||
}
|
||||
|
||||
func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) {
|
||||
s.onStreamCompletedImpl()
|
||||
}
|
||||
|
||||
var _ streamSender = &uniStreamSender{}
|
||||
|
||||
type streamI interface {
|
||||
Stream
|
||||
|
||||
AddStreamFrame(*wire.StreamFrame) error
|
||||
RegisterRemoteError(error, protocol.ByteCount) error
|
||||
HasDataForWriting() bool
|
||||
GetDataForWriting(maxBytes protocol.ByteCount) (data []byte, shouldSendFin bool)
|
||||
GetWriteOffset() protocol.ByteCount
|
||||
Finished() bool
|
||||
Cancel(error)
|
||||
// methods needed for flow control
|
||||
GetWindowUpdate() protocol.ByteCount
|
||||
UpdateSendWindow(protocol.ByteCount)
|
||||
IsFlowControlBlocked() bool
|
||||
closeForShutdown(error)
|
||||
// for receiving
|
||||
handleStreamFrame(*wire.StreamFrame) error
|
||||
handleRstStreamFrame(*wire.RstStreamFrame) error
|
||||
getWindowUpdate() protocol.ByteCount
|
||||
// for sending
|
||||
handleStopSendingFrame(*wire.StopSendingFrame)
|
||||
popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool)
|
||||
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
|
||||
}
|
||||
|
||||
type cryptoStream interface {
|
||||
streamI
|
||||
SetReadOffset(protocol.ByteCount)
|
||||
}
|
||||
var _ receiveStreamI = (streamI)(nil)
|
||||
var _ sendStreamI = (streamI)(nil)
|
||||
|
||||
// A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface
|
||||
//
|
||||
// Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually.
|
||||
type stream struct {
|
||||
mutex sync.Mutex
|
||||
receiveStream
|
||||
sendStream
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
completedMutex sync.Mutex
|
||||
sender streamSender
|
||||
receiveStreamCompleted bool
|
||||
sendStreamCompleted bool
|
||||
|
||||
streamID protocol.StreamID
|
||||
onData func()
|
||||
// onReset is a callback that should send a RST_STREAM
|
||||
onReset func(protocol.StreamID, protocol.ByteCount)
|
||||
|
||||
readPosInFrame int
|
||||
writeOffset protocol.ByteCount
|
||||
readOffset protocol.ByteCount
|
||||
|
||||
// Once set, the errors must not be changed!
|
||||
err error
|
||||
|
||||
// cancelled is set when Cancel() is called
|
||||
cancelled utils.AtomicBool
|
||||
// finishedReading is set once we read a frame with a FinBit
|
||||
finishedReading utils.AtomicBool
|
||||
// finisedWriting is set once Close() is called
|
||||
finishedWriting utils.AtomicBool
|
||||
// resetLocally is set if Reset() is called
|
||||
resetLocally utils.AtomicBool
|
||||
// resetRemotely is set if RegisterRemoteError() is called
|
||||
resetRemotely utils.AtomicBool
|
||||
|
||||
frameQueue *streamFrameSorter
|
||||
readChan chan struct{}
|
||||
readDeadline time.Time
|
||||
|
||||
dataForWriting []byte
|
||||
finSent utils.AtomicBool
|
||||
rstSent utils.AtomicBool
|
||||
writeChan chan struct{}
|
||||
writeDeadline time.Time
|
||||
|
||||
flowController flowcontrol.StreamFlowController
|
||||
version protocol.VersionNumber
|
||||
}
|
||||
|
||||
var _ Stream = &stream{}
|
||||
var _ streamI = &stream{}
|
||||
|
||||
type deadlineError struct{}
|
||||
|
||||
|
@ -92,290 +89,58 @@ func (deadlineError) Timeout() bool { return true }
|
|||
|
||||
var errDeadline net.Error = &deadlineError{}
|
||||
|
||||
type streamCanceledError struct {
|
||||
error
|
||||
errorCode protocol.ApplicationErrorCode
|
||||
}
|
||||
|
||||
func (streamCanceledError) Canceled() bool { return true }
|
||||
func (e streamCanceledError) ErrorCode() protocol.ApplicationErrorCode { return e.errorCode }
|
||||
|
||||
var _ StreamError = &streamCanceledError{}
|
||||
|
||||
// newStream creates a new Stream
|
||||
func newStream(StreamID protocol.StreamID,
|
||||
onData func(),
|
||||
onReset func(protocol.StreamID, protocol.ByteCount),
|
||||
func newStream(streamID protocol.StreamID,
|
||||
sender streamSender,
|
||||
flowController flowcontrol.StreamFlowController,
|
||||
version protocol.VersionNumber,
|
||||
) *stream {
|
||||
s := &stream{
|
||||
onData: onData,
|
||||
onReset: onReset,
|
||||
streamID: StreamID,
|
||||
flowController: flowController,
|
||||
frameQueue: newStreamFrameSorter(),
|
||||
readChan: make(chan struct{}, 1),
|
||||
writeChan: make(chan struct{}, 1),
|
||||
version: version,
|
||||
s := &stream{sender: sender}
|
||||
senderForSendStream := &uniStreamSender{
|
||||
streamSender: sender,
|
||||
onStreamCompletedImpl: func() {
|
||||
s.completedMutex.Lock()
|
||||
s.sendStreamCompleted = true
|
||||
s.checkIfCompleted()
|
||||
s.completedMutex.Unlock()
|
||||
},
|
||||
}
|
||||
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
|
||||
s.sendStream = *newSendStream(streamID, senderForSendStream, flowController, version)
|
||||
senderForReceiveStream := &uniStreamSender{
|
||||
streamSender: sender,
|
||||
onStreamCompletedImpl: func() {
|
||||
s.completedMutex.Lock()
|
||||
s.receiveStreamCompleted = true
|
||||
s.checkIfCompleted()
|
||||
s.completedMutex.Unlock()
|
||||
},
|
||||
}
|
||||
s.receiveStream = *newReceiveStream(streamID, senderForReceiveStream, flowController)
|
||||
return s
|
||||
}
|
||||
|
||||
// Read implements io.Reader. It is not thread safe!
|
||||
func (s *stream) Read(p []byte) (int, error) {
|
||||
s.mutex.Lock()
|
||||
err := s.err
|
||||
s.mutex.Unlock()
|
||||
if s.cancelled.Get() || s.resetLocally.Get() {
|
||||
return 0, err
|
||||
}
|
||||
if s.finishedReading.Get() {
|
||||
return 0, io.EOF
|
||||
// need to define StreamID() here, since both receiveStream and readStream have a StreamID()
|
||||
func (s *stream) StreamID() protocol.StreamID {
|
||||
// the result is same for receiveStream and sendStream
|
||||
return s.sendStream.StreamID()
|
||||
}
|
||||
|
||||
bytesRead := 0
|
||||
for bytesRead < len(p) {
|
||||
s.mutex.Lock()
|
||||
frame := s.frameQueue.Head()
|
||||
if frame == nil && bytesRead > 0 {
|
||||
err = s.err
|
||||
s.mutex.Unlock()
|
||||
return bytesRead, err
|
||||
}
|
||||
|
||||
var err error
|
||||
for {
|
||||
// Stop waiting on errors
|
||||
if s.resetLocally.Get() || s.cancelled.Get() {
|
||||
err = s.err
|
||||
break
|
||||
}
|
||||
|
||||
deadline := s.readDeadline
|
||||
if !deadline.IsZero() && !time.Now().Before(deadline) {
|
||||
err = errDeadline
|
||||
break
|
||||
}
|
||||
|
||||
if frame != nil {
|
||||
s.readPosInFrame = int(s.readOffset - frame.Offset)
|
||||
break
|
||||
}
|
||||
|
||||
s.mutex.Unlock()
|
||||
if deadline.IsZero() {
|
||||
<-s.readChan
|
||||
} else {
|
||||
select {
|
||||
case <-s.readChan:
|
||||
case <-time.After(deadline.Sub(time.Now())):
|
||||
}
|
||||
}
|
||||
s.mutex.Lock()
|
||||
frame = s.frameQueue.Head()
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return bytesRead, err
|
||||
}
|
||||
|
||||
m := utils.Min(len(p)-bytesRead, int(frame.DataLen())-s.readPosInFrame)
|
||||
|
||||
if bytesRead > len(p) {
|
||||
return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p))
|
||||
}
|
||||
if s.readPosInFrame > int(frame.DataLen()) {
|
||||
return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, frame.DataLen())
|
||||
}
|
||||
copy(p[bytesRead:], frame.Data[s.readPosInFrame:])
|
||||
|
||||
s.readPosInFrame += m
|
||||
bytesRead += m
|
||||
s.readOffset += protocol.ByteCount(m)
|
||||
|
||||
// when a RST_STREAM was received, the was already informed about the final byteOffset for this stream
|
||||
if !s.resetRemotely.Get() {
|
||||
s.flowController.AddBytesRead(protocol.ByteCount(m))
|
||||
}
|
||||
s.onData() // so that a possible WINDOW_UPDATE is sent
|
||||
|
||||
if s.readPosInFrame >= int(frame.DataLen()) {
|
||||
fin := frame.FinBit
|
||||
s.mutex.Lock()
|
||||
s.frameQueue.Pop()
|
||||
s.mutex.Unlock()
|
||||
if fin {
|
||||
s.finishedReading.Set(true)
|
||||
return bytesRead, io.EOF
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bytesRead, nil
|
||||
}
|
||||
|
||||
func (s *stream) Write(p []byte) (int, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.resetLocally.Get() || s.err != nil {
|
||||
return 0, s.err
|
||||
}
|
||||
if s.finishedWriting.Get() {
|
||||
return 0, fmt.Errorf("write on closed stream %d", s.streamID)
|
||||
}
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
s.dataForWriting = make([]byte, len(p))
|
||||
copy(s.dataForWriting, p)
|
||||
s.onData()
|
||||
|
||||
var err error
|
||||
for {
|
||||
deadline := s.writeDeadline
|
||||
if !deadline.IsZero() && !time.Now().Before(deadline) {
|
||||
err = errDeadline
|
||||
break
|
||||
}
|
||||
if s.dataForWriting == nil || s.err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
s.mutex.Unlock()
|
||||
if deadline.IsZero() {
|
||||
<-s.writeChan
|
||||
} else {
|
||||
select {
|
||||
case <-s.writeChan:
|
||||
case <-time.After(deadline.Sub(time.Now())):
|
||||
}
|
||||
}
|
||||
s.mutex.Lock()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if s.err != nil {
|
||||
return len(p) - len(s.dataForWriting), s.err
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (s *stream) GetWriteOffset() protocol.ByteCount {
|
||||
return s.writeOffset
|
||||
}
|
||||
|
||||
// HasDataForWriting says if there's stream available to be dequeued for writing
|
||||
func (s *stream) HasDataForWriting() bool {
|
||||
s.mutex.Lock()
|
||||
hasData := s.err == nil && // nothing should be sent if an error occurred
|
||||
(len(s.dataForWriting) > 0 || // there is data queued for sending
|
||||
s.finishedWriting.Get() && !s.finSent.Get()) // if there is no data, but writing finished and the FIN hasn't been sent yet
|
||||
s.mutex.Unlock()
|
||||
return hasData
|
||||
}
|
||||
|
||||
func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) {
|
||||
data, shouldSendFin := s.getDataForWritingImpl(maxBytes)
|
||||
if shouldSendFin {
|
||||
s.finSent.Set(true)
|
||||
}
|
||||
return data, shouldSendFin
|
||||
}
|
||||
|
||||
func (s *stream) getDataForWritingImpl(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.err != nil || s.dataForWriting == nil {
|
||||
return nil, s.finishedWriting.Get() && !s.finSent.Get()
|
||||
}
|
||||
|
||||
// TODO(#657): Flow control for the crypto stream
|
||||
if s.streamID != s.version.CryptoStreamID() {
|
||||
maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize())
|
||||
}
|
||||
if maxBytes == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var ret []byte
|
||||
if protocol.ByteCount(len(s.dataForWriting)) > maxBytes {
|
||||
ret = s.dataForWriting[:maxBytes]
|
||||
s.dataForWriting = s.dataForWriting[maxBytes:]
|
||||
} else {
|
||||
ret = s.dataForWriting
|
||||
s.dataForWriting = nil
|
||||
s.signalWrite()
|
||||
}
|
||||
s.writeOffset += protocol.ByteCount(len(ret))
|
||||
s.flowController.AddBytesSent(protocol.ByteCount(len(ret)))
|
||||
return ret, s.finishedWriting.Get() && s.dataForWriting == nil && !s.finSent.Get()
|
||||
}
|
||||
|
||||
// Close implements io.Closer
|
||||
func (s *stream) Close() error {
|
||||
s.finishedWriting.Set(true)
|
||||
s.ctxCancel()
|
||||
s.onData()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stream) shouldSendReset() bool {
|
||||
if s.rstSent.Get() {
|
||||
return false
|
||||
}
|
||||
return (s.resetLocally.Get() || s.resetRemotely.Get()) && !s.finishedWriteAndSentFin()
|
||||
}
|
||||
|
||||
// AddStreamFrame adds a new stream frame
|
||||
func (s *stream) AddStreamFrame(frame *wire.StreamFrame) error {
|
||||
maxOffset := frame.Offset + frame.DataLen()
|
||||
if err := s.flowController.UpdateHighestReceived(maxOffset, frame.FinBit); err != nil {
|
||||
if err := s.sendStream.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
if err := s.frameQueue.Push(frame); err != nil && err != errDuplicateStreamData {
|
||||
return err
|
||||
}
|
||||
s.signalRead()
|
||||
return nil
|
||||
}
|
||||
|
||||
// signalRead performs a non-blocking send on the readChan
|
||||
func (s *stream) signalRead() {
|
||||
select {
|
||||
case s.readChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// signalRead performs a non-blocking send on the writeChan
|
||||
func (s *stream) signalWrite() {
|
||||
select {
|
||||
case s.writeChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stream) SetReadDeadline(t time.Time) error {
|
||||
s.mutex.Lock()
|
||||
oldDeadline := s.readDeadline
|
||||
s.readDeadline = t
|
||||
s.mutex.Unlock()
|
||||
// if the new deadline is before the currently set deadline, wake up Read()
|
||||
if t.Before(oldDeadline) {
|
||||
s.signalRead()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stream) SetWriteDeadline(t time.Time) error {
|
||||
s.mutex.Lock()
|
||||
oldDeadline := s.writeDeadline
|
||||
s.writeDeadline = t
|
||||
s.mutex.Unlock()
|
||||
if t.Before(oldDeadline) {
|
||||
s.signalWrite()
|
||||
}
|
||||
// in gQUIC, we need to send a RST_STREAM with the final offset if CancelRead() was called
|
||||
s.receiveStream.onClose(s.sendStream.getWriteOffset())
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -385,107 +150,31 @@ func (s *stream) SetDeadline(t time.Time) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// CloseRemote makes the stream receive a "virtual" FIN stream frame at a given offset
|
||||
func (s *stream) CloseRemote(offset protocol.ByteCount) {
|
||||
s.AddStreamFrame(&wire.StreamFrame{FinBit: true, Offset: offset})
|
||||
// CloseForShutdown closes a stream abruptly.
|
||||
// It makes Read and Write unblock (and return the error) immediately.
|
||||
// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
|
||||
func (s *stream) closeForShutdown(err error) {
|
||||
s.sendStream.closeForShutdown(err)
|
||||
s.receiveStream.closeForShutdown(err)
|
||||
}
|
||||
|
||||
// Cancel is called by session to indicate that an error occurred
|
||||
// The stream should will be closed immediately
|
||||
func (s *stream) Cancel(err error) {
|
||||
s.mutex.Lock()
|
||||
s.cancelled.Set(true)
|
||||
s.ctxCancel()
|
||||
// errors must not be changed!
|
||||
if s.err == nil {
|
||||
s.err = err
|
||||
s.signalRead()
|
||||
s.signalWrite()
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
}
|
||||
|
||||
// resets the stream locally
|
||||
func (s *stream) Reset(err error) {
|
||||
if s.resetLocally.Get() {
|
||||
return
|
||||
}
|
||||
s.mutex.Lock()
|
||||
s.resetLocally.Set(true)
|
||||
s.ctxCancel()
|
||||
// errors must not be changed!
|
||||
if s.err == nil {
|
||||
s.err = err
|
||||
s.signalRead()
|
||||
s.signalWrite()
|
||||
}
|
||||
if s.shouldSendReset() {
|
||||
s.onReset(s.streamID, s.writeOffset)
|
||||
s.rstSent.Set(true)
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
}
|
||||
|
||||
// resets the stream remotely
|
||||
func (s *stream) RegisterRemoteError(err error, offset protocol.ByteCount) error {
|
||||
if s.resetRemotely.Get() {
|
||||
return nil
|
||||
}
|
||||
s.mutex.Lock()
|
||||
s.resetRemotely.Set(true)
|
||||
s.ctxCancel()
|
||||
// errors must not be changed!
|
||||
if s.err == nil {
|
||||
s.err = err
|
||||
s.signalWrite()
|
||||
}
|
||||
if err := s.flowController.UpdateHighestReceived(offset, true); err != nil {
|
||||
func (s *stream) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
|
||||
if err := s.receiveStream.handleRstStreamFrame(frame); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.shouldSendReset() {
|
||||
s.onReset(s.streamID, s.writeOffset)
|
||||
s.rstSent.Set(true)
|
||||
if !s.version.UsesIETFFrameFormat() {
|
||||
s.handleStopSendingFrame(&wire.StopSendingFrame{
|
||||
StreamID: s.StreamID(),
|
||||
ErrorCode: frame.ErrorCode,
|
||||
})
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stream) finishedWriteAndSentFin() bool {
|
||||
return s.finishedWriting.Get() && s.finSent.Get()
|
||||
// checkIfCompleted is called from the uniStreamSender, when one of the stream halves is completed.
|
||||
// It makes sure that the onStreamCompleted callback is only called if both receive and send side have completed.
|
||||
func (s *stream) checkIfCompleted() {
|
||||
if s.sendStreamCompleted && s.receiveStreamCompleted {
|
||||
s.sender.onStreamCompleted(s.StreamID())
|
||||
}
|
||||
|
||||
func (s *stream) Finished() bool {
|
||||
return s.cancelled.Get() ||
|
||||
(s.finishedReading.Get() && s.finishedWriteAndSentFin()) ||
|
||||
(s.resetRemotely.Get() && s.rstSent.Get()) ||
|
||||
(s.finishedReading.Get() && s.rstSent.Get()) ||
|
||||
(s.finishedWriteAndSentFin() && s.resetRemotely.Get())
|
||||
}
|
||||
|
||||
func (s *stream) Context() context.Context {
|
||||
return s.ctx
|
||||
}
|
||||
|
||||
func (s *stream) StreamID() protocol.StreamID {
|
||||
return s.streamID
|
||||
}
|
||||
|
||||
func (s *stream) UpdateSendWindow(n protocol.ByteCount) {
|
||||
s.flowController.UpdateSendWindow(n)
|
||||
}
|
||||
|
||||
func (s *stream) IsFlowControlBlocked() bool {
|
||||
return s.flowController.IsBlocked()
|
||||
}
|
||||
|
||||
func (s *stream) GetWindowUpdate() protocol.ByteCount {
|
||||
return s.flowController.GetWindowUpdate()
|
||||
}
|
||||
|
||||
// SetReadOffset sets the read offset.
|
||||
// It is only needed for the crypto stream.
|
||||
// It must not be called concurrently with any other stream methods, especially Read and Write.
|
||||
func (s *stream) SetReadOffset(offset protocol.ByteCount) {
|
||||
s.readOffset = offset
|
||||
s.frameQueue.readPosition = offset
|
||||
}
|
||||
|
|
|
@ -1,32 +1,34 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type streamFramer struct {
|
||||
streamsMap *streamsMap
|
||||
cryptoStream streamI
|
||||
streamGetter streamGetter
|
||||
cryptoStream cryptoStreamI
|
||||
version protocol.VersionNumber
|
||||
|
||||
connFlowController flowcontrol.ConnectionFlowController
|
||||
|
||||
retransmissionQueue []*wire.StreamFrame
|
||||
blockedFrameQueue []wire.Frame
|
||||
|
||||
streamQueueMutex sync.Mutex
|
||||
activeStreams map[protocol.StreamID]struct{}
|
||||
streamQueue []protocol.StreamID
|
||||
hasCryptoStreamData bool
|
||||
}
|
||||
|
||||
func newStreamFramer(
|
||||
cryptoStream streamI,
|
||||
streamsMap *streamsMap,
|
||||
cfc flowcontrol.ConnectionFlowController,
|
||||
cryptoStream cryptoStreamI,
|
||||
streamGetter streamGetter,
|
||||
v protocol.VersionNumber,
|
||||
) *streamFramer {
|
||||
return &streamFramer{
|
||||
streamsMap: streamsMap,
|
||||
streamGetter: streamGetter,
|
||||
cryptoStream: cryptoStream,
|
||||
connFlowController: cfc,
|
||||
activeStreams: make(map[protocol.StreamID]struct{}),
|
||||
version: v,
|
||||
}
|
||||
}
|
||||
|
@ -35,114 +37,101 @@ func (f *streamFramer) AddFrameForRetransmission(frame *wire.StreamFrame) {
|
|||
f.retransmissionQueue = append(f.retransmissionQueue, frame)
|
||||
}
|
||||
|
||||
func (f *streamFramer) AddActiveStream(id protocol.StreamID) {
|
||||
if id == f.version.CryptoStreamID() { // the crypto stream is handled separately
|
||||
f.streamQueueMutex.Lock()
|
||||
f.hasCryptoStreamData = true
|
||||
f.streamQueueMutex.Unlock()
|
||||
return
|
||||
}
|
||||
f.streamQueueMutex.Lock()
|
||||
if _, ok := f.activeStreams[id]; !ok {
|
||||
f.streamQueue = append(f.streamQueue, id)
|
||||
f.activeStreams[id] = struct{}{}
|
||||
}
|
||||
f.streamQueueMutex.Unlock()
|
||||
}
|
||||
|
||||
func (f *streamFramer) PopStreamFrames(maxLen protocol.ByteCount) []*wire.StreamFrame {
|
||||
fs, currentLen := f.maybePopFramesForRetransmission(maxLen)
|
||||
return append(fs, f.maybePopNormalFrames(maxLen-currentLen)...)
|
||||
}
|
||||
|
||||
func (f *streamFramer) PopBlockedFrame() wire.Frame {
|
||||
if len(f.blockedFrameQueue) == 0 {
|
||||
return nil
|
||||
}
|
||||
frame := f.blockedFrameQueue[0]
|
||||
f.blockedFrameQueue = f.blockedFrameQueue[1:]
|
||||
return frame
|
||||
}
|
||||
|
||||
func (f *streamFramer) HasFramesForRetransmission() bool {
|
||||
return len(f.retransmissionQueue) > 0
|
||||
}
|
||||
|
||||
func (f *streamFramer) HasCryptoStreamFrame() bool {
|
||||
return f.cryptoStream.HasDataForWriting()
|
||||
func (f *streamFramer) HasCryptoStreamData() bool {
|
||||
f.streamQueueMutex.Lock()
|
||||
hasCryptoStreamData := f.hasCryptoStreamData
|
||||
f.streamQueueMutex.Unlock()
|
||||
return hasCryptoStreamData
|
||||
}
|
||||
|
||||
// TODO(lclemente): This is somewhat duplicate with the normal path for generating frames.
|
||||
func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.StreamFrame {
|
||||
if !f.HasCryptoStreamFrame() {
|
||||
return nil
|
||||
}
|
||||
frame := &wire.StreamFrame{
|
||||
StreamID: f.cryptoStream.StreamID(),
|
||||
Offset: f.cryptoStream.GetWriteOffset(),
|
||||
}
|
||||
frameHeaderBytes, _ := frame.MinLength(f.version) // can never error
|
||||
frame.Data, frame.FinBit = f.cryptoStream.GetDataForWriting(maxLen - frameHeaderBytes)
|
||||
f.streamQueueMutex.Lock()
|
||||
frame, hasMoreData := f.cryptoStream.popStreamFrame(maxLen)
|
||||
f.hasCryptoStreamData = hasMoreData
|
||||
f.streamQueueMutex.Unlock()
|
||||
return frame
|
||||
}
|
||||
|
||||
func (f *streamFramer) maybePopFramesForRetransmission(maxLen protocol.ByteCount) (res []*wire.StreamFrame, currentLen protocol.ByteCount) {
|
||||
func (f *streamFramer) maybePopFramesForRetransmission(maxTotalLen protocol.ByteCount) (res []*wire.StreamFrame, currentLen protocol.ByteCount) {
|
||||
for len(f.retransmissionQueue) > 0 {
|
||||
frame := f.retransmissionQueue[0]
|
||||
frame.DataLenPresent = true
|
||||
|
||||
frameHeaderLen, _ := frame.MinLength(f.version) // can never error
|
||||
if currentLen+frameHeaderLen >= maxLen {
|
||||
frameHeaderLen := frame.MinLength(f.version) // can never error
|
||||
maxLen := maxTotalLen - currentLen
|
||||
if frameHeaderLen+frame.DataLen() > maxLen && maxLen < protocol.MinStreamFrameSize {
|
||||
break
|
||||
}
|
||||
|
||||
currentLen += frameHeaderLen
|
||||
|
||||
splitFrame := maybeSplitOffFrame(frame, maxLen-currentLen)
|
||||
splitFrame := maybeSplitOffFrame(frame, maxLen-frameHeaderLen)
|
||||
if splitFrame != nil { // StreamFrame was split
|
||||
res = append(res, splitFrame)
|
||||
currentLen += splitFrame.DataLen()
|
||||
currentLen += frameHeaderLen + splitFrame.DataLen()
|
||||
break
|
||||
}
|
||||
|
||||
f.retransmissionQueue = f.retransmissionQueue[1:]
|
||||
res = append(res, frame)
|
||||
currentLen += frame.DataLen()
|
||||
currentLen += frameHeaderLen + frame.DataLen()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []*wire.StreamFrame) {
|
||||
frame := &wire.StreamFrame{DataLenPresent: true}
|
||||
func (f *streamFramer) maybePopNormalFrames(maxTotalLen protocol.ByteCount) []*wire.StreamFrame {
|
||||
var currentLen protocol.ByteCount
|
||||
|
||||
fn := func(s streamI) (bool, error) {
|
||||
if s == nil {
|
||||
return true, nil
|
||||
var frames []*wire.StreamFrame
|
||||
f.streamQueueMutex.Lock()
|
||||
// pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet
|
||||
numActiveStreams := len(f.streamQueue)
|
||||
for i := 0; i < numActiveStreams; i++ {
|
||||
if maxTotalLen-currentLen < protocol.MinStreamFrameSize {
|
||||
break
|
||||
}
|
||||
|
||||
frame.StreamID = s.StreamID()
|
||||
frame.Offset = s.GetWriteOffset()
|
||||
// not perfect, but thread-safe since writeOffset is only written when getting data
|
||||
frameHeaderBytes, _ := frame.MinLength(f.version) // can never error
|
||||
if currentLen+frameHeaderBytes > maxBytes {
|
||||
return false, nil // theoretically, we could find another stream that fits, but this is quite unlikely, so we stop here
|
||||
id := f.streamQueue[0]
|
||||
f.streamQueue = f.streamQueue[1:]
|
||||
str, err := f.streamGetter.GetOrOpenSendStream(id)
|
||||
if err != nil { // can happen if the stream completed after it said it had data
|
||||
delete(f.activeStreams, id)
|
||||
continue
|
||||
}
|
||||
maxLen := maxBytes - currentLen - frameHeaderBytes
|
||||
|
||||
if s.HasDataForWriting() {
|
||||
frame.Data, frame.FinBit = s.GetDataForWriting(maxLen)
|
||||
frame, hasMoreData := str.popStreamFrame(maxTotalLen - currentLen)
|
||||
if hasMoreData { // put the stream back in the queue (at the end)
|
||||
f.streamQueue = append(f.streamQueue, id)
|
||||
} else { // no more data to send. Stream is not active any more
|
||||
delete(f.activeStreams, id)
|
||||
}
|
||||
if len(frame.Data) == 0 && !frame.FinBit {
|
||||
return true, nil
|
||||
if frame == nil { // can happen if the receiveStream was canceled after it said it had data
|
||||
continue
|
||||
}
|
||||
|
||||
// Finally, check if we are now FC blocked and should queue a BLOCKED frame
|
||||
if !frame.FinBit && s.IsFlowControlBlocked() {
|
||||
f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.StreamBlockedFrame{StreamID: s.StreamID()})
|
||||
frames = append(frames, frame)
|
||||
currentLen += frame.MinLength(f.version) + frame.DataLen()
|
||||
}
|
||||
if f.connFlowController.IsBlocked() {
|
||||
f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.BlockedFrame{})
|
||||
}
|
||||
|
||||
res = append(res, frame)
|
||||
currentLen += frameHeaderBytes + frame.DataLen()
|
||||
|
||||
if currentLen == maxBytes {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
frame = &wire.StreamFrame{DataLenPresent: true}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
f.streamsMap.RoundRobinIterate(fn)
|
||||
return
|
||||
f.streamQueueMutex.Unlock()
|
||||
return frames
|
||||
}
|
||||
|
||||
// maybeSplitOffFrame removes the first n bytes and returns them as a separate frame. If n >= len(frame), nil is returned and nothing is modified.
|
||||
|
|
|
@ -2,9 +2,9 @@ package quic
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/lucas-clemente/quic-go/internal/mocks"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
|
@ -21,12 +21,13 @@ var _ = Describe("Stream Framer", func() {
|
|||
var (
|
||||
retransmittedFrame1, retransmittedFrame2 *wire.StreamFrame
|
||||
framer *streamFramer
|
||||
streamsMap *streamsMap
|
||||
stream1, stream2 *mocks.MockStreamI
|
||||
connFC *mocks.MockConnectionFlowController
|
||||
cryptoStream *MockCryptoStream
|
||||
stream1, stream2 *MockSendStreamI
|
||||
streamGetter *MockStreamGetter
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
streamGetter = NewMockStreamGetter(mockCtrl)
|
||||
retransmittedFrame1 = &wire.StreamFrame{
|
||||
StreamID: 5,
|
||||
Data: []byte{0x13, 0x37},
|
||||
|
@ -36,25 +37,14 @@ var _ = Describe("Stream Framer", func() {
|
|||
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
|
||||
}
|
||||
|
||||
stream1 = mocks.NewMockStreamI(mockCtrl)
|
||||
stream1 = NewMockSendStreamI(mockCtrl)
|
||||
stream1.EXPECT().StreamID().Return(protocol.StreamID(5)).AnyTimes()
|
||||
stream2 = mocks.NewMockStreamI(mockCtrl)
|
||||
stream2 = NewMockSendStreamI(mockCtrl)
|
||||
stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes()
|
||||
|
||||
streamsMap = newStreamsMap(nil, protocol.PerspectiveServer, versionGQUICFrames)
|
||||
streamsMap.putStream(stream1)
|
||||
streamsMap.putStream(stream2)
|
||||
|
||||
connFC = mocks.NewMockConnectionFlowController(mockCtrl)
|
||||
framer = newStreamFramer(nil, streamsMap, connFC, versionGQUICFrames)
|
||||
cryptoStream = NewMockCryptoStream(mockCtrl)
|
||||
framer = newStreamFramer(cryptoStream, streamGetter, versionGQUICFrames)
|
||||
})
|
||||
|
||||
setNoData := func(str *mocks.MockStreamI) {
|
||||
str.EXPECT().HasDataForWriting().Return(false).AnyTimes()
|
||||
str.EXPECT().GetDataForWriting(gomock.Any()).Return(nil, false).AnyTimes()
|
||||
str.EXPECT().GetWriteOffset().AnyTimes()
|
||||
}
|
||||
|
||||
It("says if it has retransmissions", func() {
|
||||
Expect(framer.HasFramesForRetransmission()).To(BeFalse())
|
||||
framer.AddFrameForRetransmission(retransmittedFrame1)
|
||||
|
@ -62,119 +52,220 @@ var _ = Describe("Stream Framer", func() {
|
|||
})
|
||||
|
||||
It("sets the DataLenPresent for dequeued retransmitted frames", func() {
|
||||
setNoData(stream1)
|
||||
setNoData(stream2)
|
||||
framer.AddFrameForRetransmission(retransmittedFrame1)
|
||||
fs := framer.PopStreamFrames(protocol.MaxByteCount)
|
||||
Expect(fs).To(HaveLen(1))
|
||||
Expect(fs[0].DataLenPresent).To(BeTrue())
|
||||
})
|
||||
|
||||
It("sets the DataLenPresent for dequeued normal frames", func() {
|
||||
connFC.EXPECT().IsBlocked()
|
||||
setNoData(stream2)
|
||||
stream1.EXPECT().GetWriteOffset()
|
||||
stream1.EXPECT().HasDataForWriting().Return(true)
|
||||
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false)
|
||||
stream1.EXPECT().IsFlowControlBlocked()
|
||||
fs := framer.PopStreamFrames(protocol.MaxByteCount)
|
||||
Expect(fs).To(HaveLen(1))
|
||||
Expect(fs[0].DataLenPresent).To(BeTrue())
|
||||
Context("handling the crypto stream", func() {
|
||||
It("says if it has crypto stream data", func() {
|
||||
Expect(framer.HasCryptoStreamData()).To(BeFalse())
|
||||
framer.AddActiveStream(framer.version.CryptoStreamID())
|
||||
Expect(framer.HasCryptoStreamData()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("says that it doesn't have crypto stream data after popping all data", func() {
|
||||
streamID := framer.version.CryptoStreamID()
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: streamID,
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
cryptoStream.EXPECT().popStreamFrame(protocol.ByteCount(1000)).Return(f, false)
|
||||
framer.AddActiveStream(streamID)
|
||||
Expect(framer.PopCryptoStreamFrame(1000)).To(Equal(f))
|
||||
Expect(framer.HasCryptoStreamData()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("says that it has more crypto stream data if not all data was popped", func() {
|
||||
streamID := framer.version.CryptoStreamID()
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: streamID,
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
cryptoStream.EXPECT().popStreamFrame(protocol.ByteCount(1000)).Return(f, true)
|
||||
framer.AddActiveStream(streamID)
|
||||
Expect(framer.PopCryptoStreamFrame(1000)).To(Equal(f))
|
||||
Expect(framer.HasCryptoStreamData()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Popping", func() {
|
||||
BeforeEach(func() {
|
||||
// nothing is blocked here
|
||||
connFC.EXPECT().IsBlocked().AnyTimes()
|
||||
stream1.EXPECT().IsFlowControlBlocked().Return(false).AnyTimes()
|
||||
stream2.EXPECT().IsFlowControlBlocked().Return(false).AnyTimes()
|
||||
})
|
||||
|
||||
It("returns nil when popping an empty framer", func() {
|
||||
setNoData(stream1)
|
||||
setNoData(stream2)
|
||||
Expect(framer.PopStreamFrames(1000)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("pops frames for retransmission", func() {
|
||||
setNoData(stream1)
|
||||
setNoData(stream2)
|
||||
framer.AddFrameForRetransmission(retransmittedFrame1)
|
||||
framer.AddFrameForRetransmission(retransmittedFrame2)
|
||||
fs := framer.PopStreamFrames(1000)
|
||||
Expect(fs).To(HaveLen(2))
|
||||
Expect(fs[0]).To(Equal(retransmittedFrame1))
|
||||
Expect(fs[1]).To(Equal(retransmittedFrame2))
|
||||
Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1, retransmittedFrame2}))
|
||||
// make sure the frames are actually removed, and not returned a second time
|
||||
Expect(framer.PopStreamFrames(1000)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("returns normal frames", func() {
|
||||
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false)
|
||||
stream1.EXPECT().HasDataForWriting().Return(true)
|
||||
stream1.EXPECT().GetWriteOffset()
|
||||
setNoData(stream2)
|
||||
fs := framer.PopStreamFrames(1000)
|
||||
Expect(fs).To(HaveLen(1))
|
||||
Expect(fs[0].StreamID).To(Equal(stream1.StreamID()))
|
||||
Expect(fs[0].Data).To(Equal([]byte("foobar")))
|
||||
Expect(fs[0].FinBit).To(BeFalse())
|
||||
It("doesn't pop frames for retransmission, if the size would be smaller than the minimum STREAM frame size", func() {
|
||||
framer.AddFrameForRetransmission(&wire.StreamFrame{
|
||||
StreamID: id1,
|
||||
Data: bytes.Repeat([]byte{'a'}, int(protocol.MinStreamFrameSize)),
|
||||
})
|
||||
|
||||
It("returns multiple normal frames", func() {
|
||||
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false)
|
||||
stream1.EXPECT().HasDataForWriting().Return(true)
|
||||
stream1.EXPECT().GetWriteOffset()
|
||||
stream2.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobaz"), false)
|
||||
stream2.EXPECT().HasDataForWriting().Return(true)
|
||||
stream2.EXPECT().GetWriteOffset()
|
||||
fs := framer.PopStreamFrames(1000)
|
||||
Expect(fs).To(HaveLen(2))
|
||||
// Swap if we dequeued in other order
|
||||
if fs[0].StreamID != stream1.StreamID() {
|
||||
fs[0], fs[1] = fs[1], fs[0]
|
||||
}
|
||||
Expect(fs[0].StreamID).To(Equal(stream1.StreamID()))
|
||||
Expect(fs[0].Data).To(Equal([]byte("foobar")))
|
||||
Expect(fs[1].StreamID).To(Equal(stream2.StreamID()))
|
||||
Expect(fs[1].Data).To(Equal([]byte("foobaz")))
|
||||
})
|
||||
|
||||
It("returns retransmission frames before normal frames", func() {
|
||||
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false)
|
||||
stream1.EXPECT().HasDataForWriting().Return(true)
|
||||
stream1.EXPECT().GetWriteOffset()
|
||||
setNoData(stream2)
|
||||
framer.AddFrameForRetransmission(retransmittedFrame1)
|
||||
fs := framer.PopStreamFrames(1000)
|
||||
Expect(fs).To(HaveLen(2))
|
||||
Expect(fs[0]).To(Equal(retransmittedFrame1))
|
||||
Expect(fs[1].StreamID).To(Equal(stream1.StreamID()))
|
||||
})
|
||||
|
||||
It("does not pop empty frames", func() {
|
||||
stream1.EXPECT().HasDataForWriting().Return(false)
|
||||
stream1.EXPECT().GetWriteOffset()
|
||||
setNoData(stream2)
|
||||
fs := framer.PopStreamFrames(5)
|
||||
fs := framer.PopStreamFrames(protocol.MinStreamFrameSize - 1)
|
||||
Expect(fs).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("uses the round-robin scheduling", func() {
|
||||
streamFrameHeaderLen := protocol.ByteCount(4)
|
||||
stream1.EXPECT().GetDataForWriting(10-streamFrameHeaderLen).Return(bytes.Repeat([]byte("f"), int(10-streamFrameHeaderLen)), false)
|
||||
stream1.EXPECT().HasDataForWriting().Return(true)
|
||||
stream1.EXPECT().GetWriteOffset()
|
||||
stream2.EXPECT().GetDataForWriting(protocol.ByteCount(10-streamFrameHeaderLen)).Return(bytes.Repeat([]byte("e"), int(10-streamFrameHeaderLen)), false)
|
||||
stream2.EXPECT().HasDataForWriting().Return(true)
|
||||
stream2.EXPECT().GetWriteOffset()
|
||||
fs := framer.PopStreamFrames(10)
|
||||
Expect(fs).To(HaveLen(1))
|
||||
// it doesn't matter here if this data is from stream1 or from stream2...
|
||||
firstStreamID := fs[0].StreamID
|
||||
fs = framer.PopStreamFrames(10)
|
||||
Expect(fs).To(HaveLen(1))
|
||||
// ... but the data popped this time has to be from the other stream
|
||||
Expect(fs[0].StreamID).ToNot(Equal(firstStreamID))
|
||||
It("pops frames for retransmission, even if the remaining space in the packet is too small, if the frame doesn't need to be split", func() {
|
||||
framer.AddFrameForRetransmission(retransmittedFrame1)
|
||||
fs := framer.PopStreamFrames(protocol.MinStreamFrameSize - 1)
|
||||
Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1}))
|
||||
})
|
||||
|
||||
It("pops frames for retransmission, if the remaining size is the miniumum STREAM frame size", func() {
|
||||
framer.AddFrameForRetransmission(retransmittedFrame1)
|
||||
fs := framer.PopStreamFrames(protocol.MinStreamFrameSize)
|
||||
Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1}))
|
||||
})
|
||||
|
||||
It("returns normal frames", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: id1,
|
||||
Data: []byte("foobar"),
|
||||
Offset: 42,
|
||||
}
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false)
|
||||
framer.AddActiveStream(id1)
|
||||
fs := framer.PopStreamFrames(1000)
|
||||
Expect(fs).To(Equal([]*wire.StreamFrame{f}))
|
||||
})
|
||||
|
||||
It("skips a stream that was reported active, but was completed shortly after", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(nil, errors.New("stream was already deleted"))
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: id2,
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f, false)
|
||||
framer.AddActiveStream(id1)
|
||||
framer.AddActiveStream(id2)
|
||||
Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f}))
|
||||
})
|
||||
|
||||
It("skips a stream that was reported active, but doesn't have any data", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: id2,
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(nil, false)
|
||||
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f, false)
|
||||
framer.AddActiveStream(id1)
|
||||
framer.AddActiveStream(id2)
|
||||
Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f}))
|
||||
})
|
||||
|
||||
It("pops from a stream multiple times, if it has enough data", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2)
|
||||
f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
|
||||
f2 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")}
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, true)
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false)
|
||||
framer.AddActiveStream(id1) // only add it once
|
||||
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f1}))
|
||||
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f2}))
|
||||
// no further calls to popStreamFrame, after popStreamFrame said there's no more data
|
||||
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(BeNil())
|
||||
})
|
||||
|
||||
It("re-queues a stream at the end, if it has enough data", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2)
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
|
||||
f11 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
|
||||
f12 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")}
|
||||
f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")}
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f11, true)
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f12, false)
|
||||
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false)
|
||||
framer.AddActiveStream(id1) // only add it once
|
||||
framer.AddActiveStream(id2)
|
||||
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f11})) // first a frame from stream 1
|
||||
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f2})) // then a frame from stream 2
|
||||
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f12})) // then another frame from stream 1
|
||||
})
|
||||
|
||||
It("only dequeues data from each stream once per packet", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
|
||||
f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
|
||||
f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")}
|
||||
// both streams have more data, and will be re-queued
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, true)
|
||||
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, true)
|
||||
framer.AddActiveStream(id1)
|
||||
framer.AddActiveStream(id2)
|
||||
Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f1, f2}))
|
||||
})
|
||||
|
||||
It("returns multiple normal frames in the order they were reported active", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
|
||||
f1 := &wire.StreamFrame{Data: []byte("foobar")}
|
||||
f2 := &wire.StreamFrame{Data: []byte("foobaz")}
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, false)
|
||||
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false)
|
||||
framer.AddActiveStream(id2)
|
||||
framer.AddActiveStream(id1)
|
||||
Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f2, f1}))
|
||||
})
|
||||
|
||||
It("only asks a stream for data once, even if it was reported active multiple times", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
|
||||
f := &wire.StreamFrame{Data: []byte("foobar")}
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) // only one call to this function
|
||||
framer.AddActiveStream(id1)
|
||||
framer.AddActiveStream(id1)
|
||||
Expect(framer.PopStreamFrames(1000)).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("returns retransmission frames before normal frames", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
|
||||
framer.AddActiveStream(id1)
|
||||
f1 := &wire.StreamFrame{Data: []byte("foobar")}
|
||||
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, false)
|
||||
framer.AddFrameForRetransmission(retransmittedFrame1)
|
||||
fs := framer.PopStreamFrames(1000)
|
||||
Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1, f1}))
|
||||
})
|
||||
|
||||
It("does not pop empty frames", func() {
|
||||
fs := framer.PopStreamFrames(500)
|
||||
Expect(fs).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("pops frames that have the minimum size", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
|
||||
stream1.EXPECT().popStreamFrame(protocol.MinStreamFrameSize).Return(&wire.StreamFrame{Data: []byte("foobar")}, false)
|
||||
framer.AddActiveStream(id1)
|
||||
framer.PopStreamFrames(protocol.MinStreamFrameSize)
|
||||
})
|
||||
|
||||
It("does not pop frames smaller than the mimimum size", func() {
|
||||
// don't expect a call to PopStreamFrame()
|
||||
framer.PopStreamFrames(protocol.MinStreamFrameSize - 1)
|
||||
})
|
||||
|
||||
It("stops iterating when the remaining size is smaller than the minimum STREAM frame size", func() {
|
||||
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
|
||||
// pop a frame such that the remaining size is one byte less than the minimum STREAM frame size
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: id1,
|
||||
Data: bytes.Repeat([]byte("f"), int(500-protocol.MinStreamFrameSize)),
|
||||
}
|
||||
stream1.EXPECT().popStreamFrame(protocol.ByteCount(500)).Return(f, false)
|
||||
framer.AddActiveStream(id1)
|
||||
fs := framer.PopStreamFrames(500)
|
||||
Expect(fs).To(Equal([]*wire.StreamFrame{f}))
|
||||
})
|
||||
|
||||
Context("splitting of frames", func() {
|
||||
|
@ -212,139 +303,28 @@ var _ = Describe("Stream Framer", func() {
|
|||
})
|
||||
|
||||
It("splits a frame", func() {
|
||||
setNoData(stream1)
|
||||
setNoData(stream2)
|
||||
framer.AddFrameForRetransmission(retransmittedFrame2)
|
||||
origlen := retransmittedFrame2.DataLen()
|
||||
fs := framer.PopStreamFrames(6)
|
||||
frame := &wire.StreamFrame{Data: bytes.Repeat([]byte{0}, 600)}
|
||||
framer.AddFrameForRetransmission(frame)
|
||||
fs := framer.PopStreamFrames(500)
|
||||
Expect(fs).To(HaveLen(1))
|
||||
minLength, _ := fs[0].MinLength(framer.version)
|
||||
Expect(minLength + fs[0].DataLen()).To(Equal(protocol.ByteCount(6)))
|
||||
Expect(framer.retransmissionQueue[0].Data).To(HaveLen(int(origlen - fs[0].DataLen())))
|
||||
minLength := fs[0].MinLength(framer.version)
|
||||
Expect(minLength + fs[0].DataLen()).To(Equal(protocol.ByteCount(500)))
|
||||
Expect(framer.retransmissionQueue[0].Data).To(HaveLen(int(600 - fs[0].DataLen())))
|
||||
Expect(framer.retransmissionQueue[0].Offset).To(Equal(fs[0].DataLen()))
|
||||
})
|
||||
|
||||
It("never returns an empty stream frame", func() {
|
||||
// this one frame will be split off from again and again in this test. Therefore, it has to be large enough (checked again at the end)
|
||||
origFrame := &wire.StreamFrame{
|
||||
StreamID: 5,
|
||||
Offset: 1,
|
||||
FinBit: false,
|
||||
Data: bytes.Repeat([]byte{'f'}, 30*30),
|
||||
}
|
||||
framer.AddFrameForRetransmission(origFrame)
|
||||
|
||||
minFrameDataLen := protocol.MaxPacketSize
|
||||
|
||||
for i := 0; i < 30; i++ {
|
||||
frames, currentLen := framer.maybePopFramesForRetransmission(protocol.ByteCount(i))
|
||||
if len(frames) == 0 {
|
||||
Expect(currentLen).To(BeZero())
|
||||
} else {
|
||||
Expect(frames).To(HaveLen(1))
|
||||
Expect(currentLen).ToNot(BeZero())
|
||||
dataLen := frames[0].DataLen()
|
||||
Expect(dataLen).ToNot(BeZero())
|
||||
if dataLen < minFrameDataLen {
|
||||
minFrameDataLen = dataLen
|
||||
}
|
||||
}
|
||||
}
|
||||
Expect(framer.retransmissionQueue).To(HaveLen(1)) // check that origFrame was large enough for this test and didn't get used up completely
|
||||
Expect(minFrameDataLen).To(Equal(protocol.ByteCount(1)))
|
||||
})
|
||||
|
||||
It("only removes a frame from the framer after returning all split parts", func() {
|
||||
setNoData(stream1)
|
||||
setNoData(stream2)
|
||||
framer.AddFrameForRetransmission(retransmittedFrame2)
|
||||
fs := framer.PopStreamFrames(6)
|
||||
frameHeaderLen := protocol.ByteCount(4)
|
||||
frame := &wire.StreamFrame{Data: bytes.Repeat([]byte{0}, int(501-frameHeaderLen))}
|
||||
framer.AddFrameForRetransmission(frame)
|
||||
fs := framer.PopStreamFrames(500)
|
||||
Expect(fs).To(HaveLen(1))
|
||||
Expect(framer.retransmissionQueue).ToNot(BeEmpty())
|
||||
fs = framer.PopStreamFrames(1000)
|
||||
fs = framer.PopStreamFrames(500)
|
||||
Expect(fs).To(HaveLen(1))
|
||||
Expect(fs[0].DataLen()).To(BeEquivalentTo(1))
|
||||
Expect(framer.retransmissionQueue).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("sending FINs", func() {
|
||||
It("sends FINs when streams are closed", func() {
|
||||
offset := protocol.ByteCount(42)
|
||||
stream1.EXPECT().HasDataForWriting().Return(true)
|
||||
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return(nil, true)
|
||||
stream1.EXPECT().GetWriteOffset().Return(offset)
|
||||
setNoData(stream2)
|
||||
|
||||
fs := framer.PopStreamFrames(1000)
|
||||
Expect(fs).To(HaveLen(1))
|
||||
Expect(fs[0].StreamID).To(Equal(stream1.StreamID()))
|
||||
Expect(fs[0].Offset).To(Equal(offset))
|
||||
Expect(fs[0].FinBit).To(BeTrue())
|
||||
Expect(fs[0].Data).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("bundles FINs with data", func() {
|
||||
offset := protocol.ByteCount(42)
|
||||
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), true)
|
||||
stream1.EXPECT().HasDataForWriting().Return(true)
|
||||
stream1.EXPECT().GetWriteOffset().Return(offset)
|
||||
setNoData(stream2)
|
||||
|
||||
fs := framer.PopStreamFrames(1000)
|
||||
Expect(fs).To(HaveLen(1))
|
||||
Expect(fs[0].StreamID).To(Equal(stream1.StreamID()))
|
||||
Expect(fs[0].Data).To(Equal([]byte("foobar")))
|
||||
Expect(fs[0].FinBit).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("BLOCKED frames", func() {
|
||||
It("Pop returns nil if no frame is queued", func() {
|
||||
Expect(framer.PopBlockedFrame()).To(BeNil())
|
||||
})
|
||||
|
||||
It("queues and pops BLOCKED frames for individually blocked streams", func() {
|
||||
connFC.EXPECT().IsBlocked()
|
||||
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false)
|
||||
stream1.EXPECT().HasDataForWriting().Return(true)
|
||||
stream1.EXPECT().GetWriteOffset()
|
||||
stream1.EXPECT().IsFlowControlBlocked().Return(true)
|
||||
setNoData(stream2)
|
||||
frames := framer.PopStreamFrames(1000)
|
||||
Expect(frames).To(HaveLen(1))
|
||||
f := framer.PopBlockedFrame()
|
||||
Expect(f).To(BeAssignableToTypeOf(&wire.StreamBlockedFrame{}))
|
||||
bf := f.(*wire.StreamBlockedFrame)
|
||||
Expect(bf.StreamID).To(Equal(stream1.StreamID()))
|
||||
Expect(framer.PopBlockedFrame()).To(BeNil())
|
||||
})
|
||||
|
||||
It("does not queue a stream-level BLOCKED frame after sending the FinBit frame", func() {
|
||||
connFC.EXPECT().IsBlocked()
|
||||
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo"), true)
|
||||
stream1.EXPECT().HasDataForWriting().Return(true)
|
||||
stream1.EXPECT().GetWriteOffset()
|
||||
setNoData(stream2)
|
||||
frames := framer.PopStreamFrames(1000)
|
||||
Expect(frames).To(HaveLen(1))
|
||||
Expect(frames[0].FinBit).To(BeTrue())
|
||||
Expect(frames[0].DataLen()).To(Equal(protocol.ByteCount(3)))
|
||||
blockedFrame := framer.PopBlockedFrame()
|
||||
Expect(blockedFrame).To(BeNil())
|
||||
})
|
||||
|
||||
It("queues and pops BLOCKED frames for connection blocked streams", func() {
|
||||
connFC.EXPECT().IsBlocked().Return(true)
|
||||
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo"), false)
|
||||
stream1.EXPECT().HasDataForWriting().Return(true)
|
||||
stream1.EXPECT().GetWriteOffset()
|
||||
stream1.EXPECT().IsFlowControlBlocked().Return(false)
|
||||
setNoData(stream2)
|
||||
framer.PopStreamFrames(1000)
|
||||
f := framer.PopBlockedFrame()
|
||||
Expect(f).To(BeAssignableToTypeOf(&wire.BlockedFrame{}))
|
||||
Expect(framer.PopBlockedFrame()).To(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -5,8 +5,9 @@ import (
|
|||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
|
@ -16,11 +17,8 @@ type streamsMap struct {
|
|||
perspective protocol.Perspective
|
||||
|
||||
streams map[protocol.StreamID]streamI
|
||||
// needed for round-robin scheduling
|
||||
openStreams []protocol.StreamID
|
||||
roundRobinIndex int
|
||||
|
||||
nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
|
||||
nextStreamToOpen protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
|
||||
highestStreamOpenedByPeer protocol.StreamID
|
||||
nextStreamOrErrCond sync.Cond
|
||||
openStreamOrErrCond sync.Cond
|
||||
|
@ -29,47 +27,32 @@ type streamsMap struct {
|
|||
nextStreamToAccept protocol.StreamID
|
||||
|
||||
newStream newStreamLambda
|
||||
|
||||
numOutgoingStreams uint32
|
||||
numIncomingStreams uint32
|
||||
maxIncomingStreams uint32
|
||||
maxOutgoingStreams uint32
|
||||
}
|
||||
|
||||
type streamLambda func(streamI) (bool, error)
|
||||
var _ streamManager = &streamsMap{}
|
||||
|
||||
type newStreamLambda func(protocol.StreamID) streamI
|
||||
|
||||
var errMapAccess = errors.New("streamsMap: Error accessing the streams map")
|
||||
|
||||
func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver protocol.VersionNumber) *streamsMap {
|
||||
// add some tolerance to the maximum incoming streams value
|
||||
maxStreams := uint32(protocol.MaxIncomingStreams)
|
||||
maxIncomingStreams := utils.MaxUint32(
|
||||
maxStreams+protocol.MaxStreamsMinimumIncrement,
|
||||
uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)),
|
||||
)
|
||||
func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective) streamManager {
|
||||
sm := streamsMap{
|
||||
perspective: pers,
|
||||
streams: make(map[protocol.StreamID]streamI),
|
||||
openStreams: make([]protocol.StreamID, 0),
|
||||
newStream: newStream,
|
||||
maxIncomingStreams: maxIncomingStreams,
|
||||
}
|
||||
sm.nextStreamOrErrCond.L = &sm.mutex
|
||||
sm.openStreamOrErrCond.L = &sm.mutex
|
||||
|
||||
nextOddStream := protocol.StreamID(1)
|
||||
if ver.CryptoStreamID() == protocol.StreamID(1) {
|
||||
nextOddStream = 3
|
||||
}
|
||||
if pers == protocol.PerspectiveClient {
|
||||
sm.nextStream = nextOddStream
|
||||
sm.nextStreamToAccept = 2
|
||||
nextClientInitiatedStream := protocol.StreamID(1)
|
||||
nextServerInitiatedStream := protocol.StreamID(2)
|
||||
if pers == protocol.PerspectiveServer {
|
||||
sm.nextStreamToOpen = nextServerInitiatedStream
|
||||
sm.nextStreamToAccept = nextClientInitiatedStream
|
||||
} else {
|
||||
sm.nextStream = 2
|
||||
sm.nextStreamToAccept = nextOddStream
|
||||
sm.nextStreamToOpen = nextClientInitiatedStream
|
||||
sm.nextStreamToAccept = nextServerInitiatedStream
|
||||
}
|
||||
|
||||
return &sm
|
||||
}
|
||||
|
||||
|
@ -81,6 +64,23 @@ func (m *streamsMap) streamInitiatedBy(id protocol.StreamID) protocol.Perspectiv
|
|||
return protocol.PerspectiveClient
|
||||
}
|
||||
|
||||
func (m *streamsMap) nextStreamID(id protocol.StreamID) protocol.StreamID {
|
||||
if m.perspective == protocol.PerspectiveServer && id == 0 {
|
||||
return 1
|
||||
}
|
||||
return id + 2
|
||||
}
|
||||
|
||||
func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
|
||||
// every bidirectional stream is also a receive stream
|
||||
return m.GetOrOpenStream(id)
|
||||
}
|
||||
|
||||
func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
|
||||
// every bidirectional stream is also a send stream
|
||||
return m.GetOrOpenStream(id)
|
||||
}
|
||||
|
||||
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
|
||||
// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used.
|
||||
func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
|
||||
|
@ -88,7 +88,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
|
|||
s, ok := m.streams[id]
|
||||
m.mutex.RUnlock()
|
||||
if ok {
|
||||
return s, nil // s may be nil
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// ... we don't have an existing stream
|
||||
|
@ -101,7 +101,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
|
|||
}
|
||||
|
||||
if m.perspective == m.streamInitiatedBy(id) {
|
||||
if id <= m.nextStream { // this is a stream opened by us. Must have been closed already
|
||||
if id <= m.nextStreamToOpen { // this is a stream opened by us. Must have been closed already
|
||||
return nil, nil
|
||||
}
|
||||
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id))
|
||||
|
@ -110,14 +110,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
// sid is the next stream that will be opened
|
||||
sid := m.highestStreamOpenedByPeer + 2
|
||||
// if there is no stream opened yet, and this is the server, stream 1 should be openend
|
||||
if sid == 2 && m.perspective == protocol.PerspectiveServer {
|
||||
sid = 1
|
||||
}
|
||||
|
||||
for ; sid <= id; sid += 2 {
|
||||
for sid := m.nextStreamID(m.highestStreamOpenedByPeer); sid <= id; sid = m.nextStreamID(sid) {
|
||||
if _, err := m.openRemoteStream(sid); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -128,38 +121,26 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
|
|||
}
|
||||
|
||||
func (m *streamsMap) openRemoteStream(id protocol.StreamID) (streamI, error) {
|
||||
if m.numIncomingStreams >= m.maxIncomingStreams {
|
||||
return nil, qerr.TooManyOpenStreams
|
||||
}
|
||||
if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer {
|
||||
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer))
|
||||
}
|
||||
|
||||
m.numIncomingStreams++
|
||||
if id > m.highestStreamOpenedByPeer {
|
||||
m.highestStreamOpenedByPeer = id
|
||||
}
|
||||
|
||||
s := m.newStream(id)
|
||||
m.putStream(s)
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (m *streamsMap) openStreamImpl() (streamI, error) {
|
||||
id := m.nextStream
|
||||
if m.numOutgoingStreams >= m.maxOutgoingStreams {
|
||||
return nil, qerr.TooManyOpenStreams
|
||||
}
|
||||
|
||||
m.numOutgoingStreams++
|
||||
m.nextStream += 2
|
||||
s := m.newStream(id)
|
||||
s := m.newStream(m.nextStreamToOpen)
|
||||
m.putStream(s)
|
||||
m.nextStreamToOpen = m.nextStreamID(m.nextStreamToOpen)
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// OpenStream opens the next available stream
|
||||
func (m *streamsMap) OpenStream() (streamI, error) {
|
||||
func (m *streamsMap) OpenStream() (Stream, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
|
@ -169,7 +150,7 @@ func (m *streamsMap) OpenStream() (streamI, error) {
|
|||
return m.openStreamImpl()
|
||||
}
|
||||
|
||||
func (m *streamsMap) OpenStreamSync() (streamI, error) {
|
||||
func (m *streamsMap) OpenStreamSync() (Stream, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
|
@ -190,7 +171,7 @@ func (m *streamsMap) OpenStreamSync() (streamI, error) {
|
|||
|
||||
// AcceptStream returns the next stream opened by the peer
|
||||
// it blocks until a new stream is opened
|
||||
func (m *streamsMap) AcceptStream() (streamI, error) {
|
||||
func (m *streamsMap) AcceptStream() (Stream, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
var str streamI
|
||||
|
@ -209,104 +190,24 @@ func (m *streamsMap) AcceptStream() (streamI, error) {
|
|||
return str, nil
|
||||
}
|
||||
|
||||
func (m *streamsMap) DeleteClosedStreams() error {
|
||||
func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var numDeletedStreams int
|
||||
// for every closed stream, the streamID is replaced by 0 in the openStreams slice
|
||||
for i, streamID := range m.openStreams {
|
||||
str, ok := m.streams[streamID]
|
||||
_, ok := m.streams[id]
|
||||
if !ok {
|
||||
return errMapAccess
|
||||
}
|
||||
if !str.Finished() {
|
||||
continue
|
||||
}
|
||||
numDeletedStreams++
|
||||
m.openStreams[i] = 0
|
||||
if m.streamInitiatedBy(streamID) == m.perspective {
|
||||
m.numOutgoingStreams--
|
||||
} else {
|
||||
m.numIncomingStreams--
|
||||
}
|
||||
delete(m.streams, streamID)
|
||||
}
|
||||
|
||||
if numDeletedStreams == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// remove all 0s (representing closed streams) from the openStreams slice
|
||||
// and adjust the roundRobinIndex
|
||||
var j int
|
||||
for i, id := range m.openStreams {
|
||||
if i != j {
|
||||
m.openStreams[j] = m.openStreams[i]
|
||||
}
|
||||
if id != 0 {
|
||||
j++
|
||||
} else if j < m.roundRobinIndex {
|
||||
m.roundRobinIndex--
|
||||
}
|
||||
}
|
||||
m.openStreams = m.openStreams[:len(m.openStreams)-numDeletedStreams]
|
||||
delete(m.streams, id)
|
||||
m.openStreamOrErrCond.Signal()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RoundRobinIterate executes the streamLambda for every open stream, until the streamLambda returns false
|
||||
// It uses a round-robin-like scheduling to ensure that every stream is considered fairly
|
||||
// It prioritizes the the header-stream (StreamID 3)
|
||||
func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
numStreams := len(m.streams)
|
||||
startIndex := m.roundRobinIndex
|
||||
|
||||
for i := 0; i < numStreams; i++ {
|
||||
streamID := m.openStreams[(i+startIndex)%numStreams]
|
||||
cont, err := m.iterateFunc(streamID, fn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.roundRobinIndex = (m.roundRobinIndex + 1) % numStreams
|
||||
if !cont {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Range executes a callback for all streams, in pseudo-random order
|
||||
func (m *streamsMap) Range(cb func(s streamI)) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
for _, s := range m.streams {
|
||||
if s != nil {
|
||||
cb(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *streamsMap) iterateFunc(streamID protocol.StreamID, fn streamLambda) (bool, error) {
|
||||
str, ok := m.streams[streamID]
|
||||
if !ok {
|
||||
return true, errMapAccess
|
||||
}
|
||||
return fn(str)
|
||||
}
|
||||
|
||||
func (m *streamsMap) putStream(s streamI) error {
|
||||
id := s.StreamID()
|
||||
if _, ok := m.streams[id]; ok {
|
||||
return fmt.Errorf("a stream with ID %d already exists", id)
|
||||
}
|
||||
|
||||
m.streams[id] = s
|
||||
m.openStreams = append(m.openStreams, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -316,14 +217,20 @@ func (m *streamsMap) CloseWithError(err error) {
|
|||
m.closeErr = err
|
||||
m.nextStreamOrErrCond.Broadcast()
|
||||
m.openStreamOrErrCond.Broadcast()
|
||||
for _, s := range m.openStreams {
|
||||
m.streams[s].Cancel(err)
|
||||
for _, s := range m.streams {
|
||||
s.closeForShutdown(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *streamsMap) UpdateMaxStreamLimit(limit uint32) {
|
||||
// TODO(#952): this won't be needed when gQUIC supports stateless handshakes
|
||||
func (m *streamsMap) UpdateLimits(params *handshake.TransportParameters) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
m.maxOutgoingStreams = limit
|
||||
for id, str := range m.streams {
|
||||
str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{
|
||||
StreamID: id,
|
||||
ByteOffset: params.StreamFlowControlWindow,
|
||||
})
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
m.openStreamOrErrCond.Broadcast()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,257 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
type streamsMapLegacy struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
perspective protocol.Perspective
|
||||
|
||||
streams map[protocol.StreamID]streamI
|
||||
|
||||
nextStreamToOpen protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
|
||||
highestStreamOpenedByPeer protocol.StreamID
|
||||
nextStreamOrErrCond sync.Cond
|
||||
openStreamOrErrCond sync.Cond
|
||||
|
||||
closeErr error
|
||||
nextStreamToAccept protocol.StreamID
|
||||
|
||||
newStream newStreamLambda
|
||||
|
||||
numOutgoingStreams uint32
|
||||
numIncomingStreams uint32
|
||||
maxIncomingStreams uint32
|
||||
maxOutgoingStreams uint32
|
||||
}
|
||||
|
||||
var _ streamManager = &streamsMapLegacy{}
|
||||
|
||||
func newStreamsMapLegacy(newStream newStreamLambda, pers protocol.Perspective) streamManager {
|
||||
// add some tolerance to the maximum incoming streams value
|
||||
maxStreams := uint32(protocol.MaxIncomingStreams)
|
||||
maxIncomingStreams := utils.MaxUint32(
|
||||
maxStreams+protocol.MaxStreamsMinimumIncrement,
|
||||
uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)),
|
||||
)
|
||||
sm := streamsMapLegacy{
|
||||
perspective: pers,
|
||||
streams: make(map[protocol.StreamID]streamI),
|
||||
newStream: newStream,
|
||||
maxIncomingStreams: maxIncomingStreams,
|
||||
}
|
||||
sm.nextStreamOrErrCond.L = &sm.mutex
|
||||
sm.openStreamOrErrCond.L = &sm.mutex
|
||||
|
||||
nextServerInitiatedStream := protocol.StreamID(2)
|
||||
nextClientInitiatedStream := protocol.StreamID(3)
|
||||
if pers == protocol.PerspectiveServer {
|
||||
sm.highestStreamOpenedByPeer = 1
|
||||
}
|
||||
if pers == protocol.PerspectiveServer {
|
||||
sm.nextStreamToOpen = nextServerInitiatedStream
|
||||
sm.nextStreamToAccept = nextClientInitiatedStream
|
||||
} else {
|
||||
sm.nextStreamToOpen = nextClientInitiatedStream
|
||||
sm.nextStreamToAccept = nextServerInitiatedStream
|
||||
}
|
||||
return &sm
|
||||
}
|
||||
|
||||
// getStreamPerspective says which side should initiate a stream
|
||||
func (m *streamsMapLegacy) streamInitiatedBy(id protocol.StreamID) protocol.Perspective {
|
||||
if id%2 == 0 {
|
||||
return protocol.PerspectiveServer
|
||||
}
|
||||
return protocol.PerspectiveClient
|
||||
}
|
||||
|
||||
func (m *streamsMapLegacy) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
|
||||
// every bidirectional stream is also a receive stream
|
||||
return m.GetOrOpenStream(id)
|
||||
}
|
||||
|
||||
func (m *streamsMapLegacy) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
|
||||
// every bidirectional stream is also a send stream
|
||||
return m.GetOrOpenStream(id)
|
||||
}
|
||||
|
||||
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
|
||||
// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used.
|
||||
func (m *streamsMapLegacy) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
|
||||
m.mutex.RLock()
|
||||
s, ok := m.streams[id]
|
||||
m.mutex.RUnlock()
|
||||
if ok {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// ... we don't have an existing stream
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
// We need to check whether another invocation has already created a stream (between RUnlock() and Lock()).
|
||||
s, ok = m.streams[id]
|
||||
if ok {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
if m.perspective == m.streamInitiatedBy(id) {
|
||||
if id <= m.nextStreamToOpen { // this is a stream opened by us. Must have been closed already
|
||||
return nil, nil
|
||||
}
|
||||
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id))
|
||||
}
|
||||
if id <= m.highestStreamOpenedByPeer { // this is a peer-initiated stream that doesn't exist anymore. Must have been closed already
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for sid := m.highestStreamOpenedByPeer + 2; sid <= id; sid += 2 {
|
||||
if _, err := m.openRemoteStream(sid); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
m.nextStreamOrErrCond.Broadcast()
|
||||
return m.streams[id], nil
|
||||
}
|
||||
|
||||
func (m *streamsMapLegacy) openRemoteStream(id protocol.StreamID) (streamI, error) {
|
||||
if m.numIncomingStreams >= m.maxIncomingStreams {
|
||||
return nil, qerr.TooManyOpenStreams
|
||||
}
|
||||
if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer {
|
||||
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer))
|
||||
}
|
||||
|
||||
m.numIncomingStreams++
|
||||
if id > m.highestStreamOpenedByPeer {
|
||||
m.highestStreamOpenedByPeer = id
|
||||
}
|
||||
|
||||
s := m.newStream(id)
|
||||
m.putStream(s)
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (m *streamsMapLegacy) openStreamImpl() (streamI, error) {
|
||||
if m.numOutgoingStreams >= m.maxOutgoingStreams {
|
||||
return nil, qerr.TooManyOpenStreams
|
||||
}
|
||||
|
||||
m.numOutgoingStreams++
|
||||
s := m.newStream(m.nextStreamToOpen)
|
||||
m.putStream(s)
|
||||
m.nextStreamToOpen += 2
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// OpenStream opens the next available stream
|
||||
func (m *streamsMapLegacy) OpenStream() (Stream, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.closeErr != nil {
|
||||
return nil, m.closeErr
|
||||
}
|
||||
return m.openStreamImpl()
|
||||
}
|
||||
|
||||
func (m *streamsMapLegacy) OpenStreamSync() (Stream, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
for {
|
||||
if m.closeErr != nil {
|
||||
return nil, m.closeErr
|
||||
}
|
||||
str, err := m.openStreamImpl()
|
||||
if err == nil {
|
||||
return str, err
|
||||
}
|
||||
if err != nil && err != qerr.TooManyOpenStreams {
|
||||
return nil, err
|
||||
}
|
||||
m.openStreamOrErrCond.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
// AcceptStream returns the next stream opened by the peer
|
||||
// it blocks until a new stream is opened
|
||||
func (m *streamsMapLegacy) AcceptStream() (Stream, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
var str streamI
|
||||
for {
|
||||
var ok bool
|
||||
if m.closeErr != nil {
|
||||
return nil, m.closeErr
|
||||
}
|
||||
str, ok = m.streams[m.nextStreamToAccept]
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
m.nextStreamOrErrCond.Wait()
|
||||
}
|
||||
m.nextStreamToAccept += 2
|
||||
return str, nil
|
||||
}
|
||||
|
||||
func (m *streamsMapLegacy) DeleteStream(id protocol.StreamID) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
_, ok := m.streams[id]
|
||||
if !ok {
|
||||
return errMapAccess
|
||||
}
|
||||
delete(m.streams, id)
|
||||
if m.streamInitiatedBy(id) == m.perspective {
|
||||
m.numOutgoingStreams--
|
||||
} else {
|
||||
m.numIncomingStreams--
|
||||
}
|
||||
m.openStreamOrErrCond.Signal()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *streamsMapLegacy) putStream(s streamI) error {
|
||||
id := s.StreamID()
|
||||
if _, ok := m.streams[id]; ok {
|
||||
return fmt.Errorf("a stream with ID %d already exists", id)
|
||||
}
|
||||
m.streams[id] = s
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *streamsMapLegacy) CloseWithError(err error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
m.closeErr = err
|
||||
m.nextStreamOrErrCond.Broadcast()
|
||||
m.openStreamOrErrCond.Broadcast()
|
||||
for _, s := range m.streams {
|
||||
s.closeForShutdown(err)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(#952): this won't be needed when gQUIC supports stateless handshakes
|
||||
func (m *streamsMapLegacy) UpdateLimits(params *handshake.TransportParameters) {
|
||||
m.mutex.Lock()
|
||||
m.maxOutgoingStreams = params.MaxStreams
|
||||
for id, str := range m.streams {
|
||||
str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{
|
||||
StreamID: id,
|
||||
ByteOffset: params.StreamFlowControlWindow,
|
||||
})
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
m.openStreamOrErrCond.Broadcast()
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue