update quic

This commit is contained in:
Cadey Ratio 2018-01-20 10:07:01 -08:00
parent 3e711e63dc
commit d0695adfb6
104 changed files with 7622 additions and 4515 deletions

6
Gopkg.lock generated
View File

@ -421,6 +421,7 @@
revision = "393af48d391698c6ae4219566bfbdfef67269997"
[[projects]]
branch = "master"
name = "github.com/lucas-clemente/quic-go"
packages = [
".",
@ -435,8 +436,7 @@
"internal/wire",
"qerr"
]
revision = "ded0eb4f6f30a8049bd334a26ff7ff728648fe13"
version = "v0.6.0"
revision = "15bcc2579f7dab14c84f438741f2b535cf474df9"
[[projects]]
branch = "master"
@ -753,6 +753,6 @@
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "a11e1692755a705514dbd401ba4795821d1ac221d6f9100124c38a29db98c568"
inputs-digest = "97c8282ef9b3abed71907d17ccf38379134714596610880b02d5ca03be634678"
solver-name = "gps-cdcl"
solver-version = 1

View File

@ -0,0 +1,137 @@
# Gopkg.toml example
#
# Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md
# for detailed Gopkg.toml documentation.
#
# required = ["github.com/user/thing/cmd/thing"]
# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"]
#
# [[constraint]]
# name = "github.com/user/project"
# version = "1.0.0"
#
# [[constraint]]
# name = "github.com/user/project2"
# branch = "dev"
# source = "github.com/myfork/project2"
#
# [[override]]
# name = "github.com/x/y"
# version = "2.4.0"
[[constraint]]
branch = "master"
name = "github.com/Xe/gopreload"
[[constraint]]
name = "github.com/Xe/ln"
version = "0.1.0"
[[constraint]]
branch = "master"
name = "github.com/Xe/uuid"
[[constraint]]
branch = "master"
name = "github.com/Xe/x"
[[constraint]]
name = "github.com/asdine/storm"
version = "2.0.2"
[[constraint]]
branch = "master"
name = "github.com/brandur/simplebox"
[[constraint]]
name = "github.com/caarlos0/env"
version = "3.2.0"
[[constraint]]
branch = "master"
name = "github.com/dgryski/go-failure"
[[constraint]]
branch = "master"
name = "github.com/dickeyxxx/netrc"
[[constraint]]
branch = "master"
name = "github.com/facebookgo/flagenv"
[[constraint]]
branch = "master"
name = "github.com/golang/protobuf"
[[constraint]]
name = "github.com/google/gops"
version = "0.3.2"
[[constraint]]
name = "github.com/hashicorp/terraform"
version = "0.11.2"
[[constraint]]
name = "github.com/joho/godotenv"
version = "1.2.0"
[[constraint]]
branch = "master"
name = "github.com/jtolds/qod"
[[constraint]]
branch = "master"
name = "github.com/kr/pretty"
[[constraint]]
name = "github.com/lucas-clemente/quic-go"
branch = "master"
[[constraint]]
name = "github.com/magefile/mage"
version = "2.0.1"
[[constraint]]
branch = "master"
name = "github.com/mtneug/pkg"
[[constraint]]
branch = "master"
name = "github.com/olekukonko/tablewriter"
[[constraint]]
name = "github.com/pkg/errors"
version = "0.8.0"
[[constraint]]
branch = "master"
name = "github.com/streamrail/concurrent-map"
[[constraint]]
name = "github.com/xtaci/kcp-go"
version = "3.23.0"
[[constraint]]
name = "github.com/xtaci/smux"
version = "1.0.6"
[[constraint]]
name = "go.uber.org/atomic"
version = "1.3.1"
[[constraint]]
branch = "master"
name = "golang.org/x/crypto"
[[constraint]]
branch = "master"
name = "golang.org/x/net"
[[constraint]]
name = "google.golang.org/grpc"
version = "1.9.2"
[[constraint]]
name = "gopkg.in/alecthomas/kingpin.v2"
version = "2.2.6"

View File

@ -1,4 +1,5 @@
dist: trusty
group: travis_latest
addons:
hosts:
@ -8,6 +9,7 @@ language: go
go:
- 1.9.2
- 1.10beta1
# first part of the GOARCH workaround
# setting the GOARCH directly doesn't work, since the value will be overwritten later
@ -30,6 +32,7 @@ before_install:
- export GOARCH=$TRAVIS_GOARCH
- go env # for debugging
- google-chrome --version
- "printf \"quic.clemente.io certificate valid until: \" && openssl x509 -in example/fullchain.pem -enddate -noout | cut -d = -f 2"
- "export DISPLAY=:99.0"
- "Xvfb $DISPLAY &> /dev/null &"

View File

@ -1,6 +1,10 @@
# Changelog
## v0.6.1 (unreleased)
## v0.7 (unreleased)
- The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored.
- Remove `DialNonFWSecure` and `DialAddrNonFWSecure`.
- Expose the `ConnectionState` in the `Session` (experimental API).
## v0.6.0 (2017-12-12)

View File

@ -1,6 +1,7 @@
package ackhandler
import (
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
@ -11,3 +12,13 @@ func TestCrypto(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "AckHandler Suite")
}
var mockCtrl *gomock.Controller
var _ = BeforeEach(func() {
mockCtrl = gomock.NewController(GinkgoT())
})
var _ = AfterEach(func() {
mockCtrl.Finish()
})

View File

@ -16,6 +16,7 @@ type SentPacketHandler interface {
SendingAllowed() bool
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
ShouldSendRetransmittablePacket() bool
DequeuePacketForRetransmission() (packet *Packet)
GetLeastUnacked() protocol.PacketNumber
@ -26,7 +27,7 @@ type SentPacketHandler interface {
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error
IgnoreBelow(protocol.PacketNumber)
GetAlarmTimeout() time.Time

View File

@ -15,7 +15,8 @@ type Packet struct {
Length protocol.ByteCount
EncryptionLevel protocol.EncryptionLevel
SendTime time.Time
largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK
sendTime time.Time
}
// GetFramesForRetransmission gets all the frames for retransmission

View File

@ -34,10 +34,10 @@ func NewReceivedPacketHandler(version protocol.VersionNumber) ReceivedPacketHand
}
}
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error {
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
if packetNumber > h.largestObserved {
h.largestObserved = packetNumber
h.largestObservedReceivedTime = time.Now()
h.largestObservedReceivedTime = rcvTime
}
if packetNumber < h.ignoreBelow {
@ -47,7 +47,7 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
return err
}
h.maybeQueueAck(packetNumber, shouldInstigateAck)
h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck)
return nil
}
@ -58,7 +58,7 @@ func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
h.packetHistory.DeleteBelow(p)
}
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) {
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) {
h.packetsReceivedSinceLastAck++
if shouldInstigateAck {
@ -86,7 +86,7 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
h.ackQueued = true
} else {
if h.ackAlarm.IsZero() {
h.ackAlarm = time.Now().Add(h.ackSendDelay)
h.ackAlarm = rcvTime.Add(h.ackSendDelay)
}
}
}

View File

@ -21,34 +21,36 @@ var _ = Describe("receivedPacketHandler", func() {
Context("accepting packets", func() {
It("handles a packet that arrives late", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(1), true)
err := handler.ReceivedPacket(protocol.PacketNumber(1), time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(protocol.PacketNumber(3), true)
err = handler.ReceivedPacket(protocol.PacketNumber(3), time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(protocol.PacketNumber(2), true)
err = handler.ReceivedPacket(protocol.PacketNumber(2), time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
})
It("saves the time when each packet arrived", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(3), true)
err := handler.ReceivedPacket(protocol.PacketNumber(3), time.Now(), true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond))
})
It("updates the largestObserved and the largestObservedReceivedTime", func() {
now := time.Now()
handler.largestObserved = 3
handler.largestObservedReceivedTime = time.Now().Add(-1 * time.Second)
err := handler.ReceivedPacket(5, true)
handler.largestObservedReceivedTime = now.Add(-1 * time.Second)
err := handler.ReceivedPacket(5, now, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5)))
Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond))
Expect(handler.largestObservedReceivedTime).To(Equal(now))
})
It("doesn't update the largestObserved and the largestObservedReceivedTime for a belated packet", func() {
timestamp := time.Now().Add(-1 * time.Second)
now := time.Now()
timestamp := now.Add(-1 * time.Second)
handler.largestObserved = 5
handler.largestObservedReceivedTime = timestamp
err := handler.ReceivedPacket(4, true)
err := handler.ReceivedPacket(4, now, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5)))
Expect(handler.largestObservedReceivedTime).To(Equal(timestamp))
@ -57,7 +59,7 @@ var _ = Describe("receivedPacketHandler", func() {
It("passes on errors from receivedPacketHistory", func() {
var err error
for i := protocol.PacketNumber(0); i < 5*protocol.MaxTrackedReceivedAckRanges; i++ {
err = handler.ReceivedPacket(2*i+1, true)
err = handler.ReceivedPacket(2*i+1, time.Time{}, true)
// this will eventually return an error
// details about when exactly the receivedPacketHistory errors are tested there
if err != nil {
@ -72,7 +74,7 @@ var _ = Describe("receivedPacketHandler", func() {
Context("queueing ACKs", func() {
receiveAndAck10Packets := func() {
for i := 1; i <= 10; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i), true)
err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
}
Expect(handler.GetAckFrame()).ToNot(BeNil())
@ -80,14 +82,14 @@ var _ = Describe("receivedPacketHandler", func() {
}
It("always queues an ACK for the first packet", func() {
err := handler.ReceivedPacket(1, false)
err := handler.ReceivedPacket(1, time.Time{}, false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
Expect(handler.GetAlarmTimeout()).To(BeZero())
})
It("works with packet number 0", func() {
err := handler.ReceivedPacket(0, false)
err := handler.ReceivedPacket(0, time.Time{}, false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
Expect(handler.GetAlarmTimeout()).To(BeZero())
@ -95,11 +97,11 @@ var _ = Describe("receivedPacketHandler", func() {
It("queues an ACK for every second retransmittable packet, if they are arriving fast", func() {
receiveAndAck10Packets()
err := handler.ReceivedPacket(11, true)
err := handler.ReceivedPacket(11, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
Expect(handler.GetAlarmTimeout()).NotTo(BeZero())
err = handler.ReceivedPacket(12, true)
err = handler.ReceivedPacket(12, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
Expect(handler.GetAlarmTimeout()).To(BeZero())
@ -107,11 +109,11 @@ var _ = Describe("receivedPacketHandler", func() {
It("only sets the timer when receiving a retransmittable packets", func() {
receiveAndAck10Packets()
err := handler.ReceivedPacket(11, false)
err := handler.ReceivedPacket(11, time.Time{}, false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
Expect(handler.ackAlarm).To(BeZero())
err = handler.ReceivedPacket(12, true)
err = handler.ReceivedPacket(12, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
Expect(handler.ackAlarm).ToNot(BeZero())
@ -120,15 +122,15 @@ var _ = Describe("receivedPacketHandler", func() {
It("queues an ACK if it was reported missing before", func() {
receiveAndAck10Packets()
err := handler.ReceivedPacket(11, true)
err := handler.ReceivedPacket(11, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(13, true)
err = handler.ReceivedPacket(13, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame() // ACK: 1 and 3, missing: 2
Expect(ack).ToNot(BeNil())
Expect(ack.HasMissingRanges()).To(BeTrue())
Expect(handler.ackQueued).To(BeFalse())
err = handler.ReceivedPacket(12, false)
err = handler.ReceivedPacket(12, time.Time{}, false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
})
@ -136,10 +138,10 @@ var _ = Describe("receivedPacketHandler", func() {
It("queues an ACK if it creates a new missing range", func() {
receiveAndAck10Packets()
for i := 11; i < 16; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i), true)
err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
}
err := handler.ReceivedPacket(20, true) // we now know that packets 16 to 19 are missing
err := handler.ReceivedPacket(20, time.Time{}, true) // we now know that packets 16 to 19 are missing
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
ack := handler.GetAckFrame()
@ -154,9 +156,9 @@ var _ = Describe("receivedPacketHandler", func() {
})
It("generates a simple ACK frame", func() {
err := handler.ReceivedPacket(1, true)
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(2, true)
err = handler.ReceivedPacket(2, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
@ -166,7 +168,7 @@ var _ = Describe("receivedPacketHandler", func() {
})
It("generates an ACK for packet number 0", func() {
err := handler.ReceivedPacket(0, true)
err := handler.ReceivedPacket(0, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
@ -176,12 +178,12 @@ var _ = Describe("receivedPacketHandler", func() {
})
It("saves the last sent ACK", func() {
err := handler.ReceivedPacket(1, true)
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(handler.lastAck).To(Equal(ack))
err = handler.ReceivedPacket(2, true)
err = handler.ReceivedPacket(2, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
handler.ackQueued = true
ack = handler.GetAckFrame()
@ -190,9 +192,9 @@ var _ = Describe("receivedPacketHandler", func() {
})
It("generates an ACK frame with missing packets", func() {
err := handler.ReceivedPacket(1, true)
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(4, true)
err = handler.ReceivedPacket(4, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
@ -205,11 +207,11 @@ var _ = Describe("receivedPacketHandler", func() {
})
It("generates an ACK for packet number 0 and other packets", func() {
err := handler.ReceivedPacket(0, true)
err := handler.ReceivedPacket(0, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(1, true)
err = handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(3, true)
err = handler.ReceivedPacket(3, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
@ -223,15 +225,15 @@ var _ = Describe("receivedPacketHandler", func() {
It("accepts packets below the lower limit", func() {
handler.IgnoreBelow(6)
err := handler.ReceivedPacket(2, true)
err := handler.ReceivedPacket(2, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
})
It("doesn't add delayed packets to the packetHistory", func() {
handler.IgnoreBelow(7)
err := handler.ReceivedPacket(4, true)
err := handler.ReceivedPacket(4, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(10, true)
err = handler.ReceivedPacket(10, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
@ -241,7 +243,7 @@ var _ = Describe("receivedPacketHandler", func() {
It("deletes packets from the packetHistory when a lower limit is set", func() {
for i := 1; i <= 12; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i), true)
err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
}
handler.IgnoreBelow(7)
@ -256,7 +258,7 @@ var _ = Describe("receivedPacketHandler", func() {
// TODO: remove this test when dropping support for STOP_WAITINGs
It("handles a lower limit of 0", func() {
handler.IgnoreBelow(0)
err := handler.ReceivedPacket(1337, true)
err := handler.ReceivedPacket(1337, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
@ -264,7 +266,7 @@ var _ = Describe("receivedPacketHandler", func() {
})
It("resets all counters needed for the ACK queueing decision when sending an ACK", func() {
err := handler.ReceivedPacket(1, true)
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
handler.ackAlarm = time.Now().Add(-time.Minute)
Expect(handler.GetAckFrame()).ToNot(BeNil())
@ -275,7 +277,7 @@ var _ = Describe("receivedPacketHandler", func() {
})
It("doesn't generate an ACK when none is queued and the timer is not set", func() {
err := handler.ReceivedPacket(1, true)
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
handler.ackQueued = false
handler.ackAlarm = time.Time{}
@ -283,7 +285,7 @@ var _ = Describe("receivedPacketHandler", func() {
})
It("doesn't generate an ACK when none is queued and the timer has not yet expired", func() {
err := handler.ReceivedPacket(1, true)
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
handler.ackQueued = false
handler.ackAlarm = time.Now().Add(time.Minute)
@ -291,7 +293,7 @@ var _ = Describe("receivedPacketHandler", func() {
})
It("generates an ACK when the timer has expired", func() {
err := handler.ReceivedPacket(1, true)
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
handler.ackQueued = false
handler.ackAlarm = time.Now().Add(-time.Minute)

View File

@ -40,6 +40,10 @@ type sentPacketHandler struct {
largestAcked protocol.PacketNumber
largestReceivedPacketWithAck protocol.PacketNumber
// lowestPacketNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
// example: we send an ACK for packets 90-100 with packet number 20
// once we receive an ACK from the peer for packet 20, the lowestPacketNotConfirmedAcked is 101
lowestPacketNotConfirmedAcked protocol.PacketNumber
packetHistory *PacketList
stopWaitingManager stopWaitingManager
@ -95,6 +99,13 @@ func (h *sentPacketHandler) ShouldSendRetransmittablePacket() bool {
}
func (h *sentPacketHandler) SetHandshakeComplete() {
var queue []*Packet
for _, packet := range h.retransmissionQueue {
if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
queue = append(queue, packet)
}
}
h.retransmissionQueue = queue
h.handshakeComplete = true
}
@ -114,11 +125,19 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
h.lastSentPacketNumber = packet.PacketNumber
now := time.Now()
var largestAcked protocol.PacketNumber
if len(packet.Frames) > 0 {
if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok {
largestAcked = ackFrame.LargestAcked
}
}
packet.Frames = stripNonRetransmittableFrames(packet.Frames)
isRetransmittable := len(packet.Frames) != 0
if isRetransmittable {
packet.SendTime = now
packet.sendTime = now
packet.largestAcked = largestAcked
h.bytesInFlight += packet.Length
h.packetHistory.PushBack(*packet)
h.numNonRetransmittablePackets = 0
@ -134,7 +153,7 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
isRetransmittable,
)
h.updateLossDetectionAlarm()
h.updateLossDetectionAlarm(now)
return nil
}
@ -146,14 +165,12 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
// duplicate or out-of-order ACK
// if withPacketNumber <= h.largestReceivedPacketWithAck && withPacketNumber != 0 {
if withPacketNumber <= h.largestReceivedPacketWithAck {
utils.Debugf("ignoring ack because duplicate")
return ErrDuplicateOrOutOfOrderAck
}
h.largestReceivedPacketWithAck = withPacketNumber
// ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK)
if ackFrame.LargestAcked < h.lowestUnacked() {
utils.Debugf("ignoring ack because repeated")
return nil
}
h.largestAcked = ackFrame.LargestAcked
@ -178,13 +195,19 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
if encLevel < p.Value.EncryptionLevel {
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.Value.PacketNumber, p.Value.EncryptionLevel)
}
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
// It is safe to ignore the corner case of packets that just acked packet 0, because
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
if p.Value.largestAcked != 0 {
h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.Value.largestAcked+1)
}
h.onPacketAcked(p)
h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
}
}
h.detectLostPackets()
h.updateLossDetectionAlarm()
h.detectLostPackets(rcvTime)
h.updateLossDetectionAlarm(rcvTime)
h.garbageCollectSkippedPackets()
h.stopWaitingManager.ReceivedAck(ackFrame)
@ -192,6 +215,10 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
return nil
}
func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber {
return h.lowestPacketNotConfirmedAcked
}
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*PacketElement, error) {
var ackedPackets []*PacketElement
ackRangeIndex := 0
@ -233,7 +260,7 @@ func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, a
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
packet := el.Value
if packet.PacketNumber == largestAcked {
h.rttStats.UpdateRTT(rcvTime.Sub(packet.SendTime), ackDelay, time.Now())
h.rttStats.UpdateRTT(rcvTime.Sub(packet.sendTime), ackDelay, rcvTime)
return true
}
// Packets are sorted by number, so we can stop searching
@ -244,7 +271,7 @@ func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, a
return false
}
func (h *sentPacketHandler) updateLossDetectionAlarm() {
func (h *sentPacketHandler) updateLossDetectionAlarm(now time.Time) {
// Cancel the alarm if no packets are outstanding
if h.packetHistory.Len() == 0 {
h.alarm = time.Time{}
@ -253,19 +280,18 @@ func (h *sentPacketHandler) updateLossDetectionAlarm() {
// TODO(#497): TLP
if !h.handshakeComplete {
h.alarm = time.Now().Add(h.computeHandshakeTimeout())
h.alarm = now.Add(h.computeHandshakeTimeout())
} else if !h.lossTime.IsZero() {
// Early retransmit timer or time loss detection.
h.alarm = h.lossTime
} else {
// RTO
h.alarm = time.Now().Add(h.computeRTOTimeout())
h.alarm = now.Add(h.computeRTOTimeout())
}
}
func (h *sentPacketHandler) detectLostPackets() {
func (h *sentPacketHandler) detectLostPackets(now time.Time) {
h.lossTime = time.Time{}
now := time.Now()
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
@ -278,7 +304,7 @@ func (h *sentPacketHandler) detectLostPackets() {
break
}
timeSinceSent := now.Sub(packet.SendTime)
timeSinceSent := now.Sub(packet.sendTime)
if timeSinceSent > delayUntilLost {
lostPackets = append(lostPackets, el)
} else if h.lossTime.IsZero() {
@ -296,20 +322,22 @@ func (h *sentPacketHandler) detectLostPackets() {
}
func (h *sentPacketHandler) OnAlarm() {
now := time.Now()
// TODO(#497): TLP
if !h.handshakeComplete {
h.queueHandshakePacketsForRetransmission()
h.handshakeCount++
} else if !h.lossTime.IsZero() {
// Early retransmit or time loss detection
h.detectLostPackets()
h.detectLostPackets(now)
} else {
// RTO
h.retransmitOldestTwoPackets()
h.rtoCount++
}
h.updateLossDetectionAlarm()
h.updateLossDetectionAlarm(now)
}
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
@ -345,12 +373,11 @@ func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFra
}
func (h *sentPacketHandler) SendingAllowed() bool {
congestionLimited := h.bytesInFlight > h.congestion.GetCongestionWindow()
cwnd := h.congestion.GetCongestionWindow()
congestionLimited := h.bytesInFlight > cwnd
maxTrackedLimited := protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()) >= protocol.MaxTrackedSentPackets
if congestionLimited {
utils.Debugf("Congestion limited: bytes in flight %d, window %d",
h.bytesInFlight,
h.congestion.GetCongestionWindow())
utils.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
}
// Workaround for #555:
// Always allow sending of retransmissions. This should probably be limited

View File

@ -3,60 +3,15 @@ package ackhandler
import (
"time"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockCongestion struct {
argsOnPacketSent []interface{}
maybeExitSlowStart bool
onRetransmissionTimeout bool
getCongestionWindow bool
packetsAcked [][]interface{}
packetsLost [][]interface{}
}
func (m *mockCongestion) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration {
panic("not implemented")
}
func (m *mockCongestion) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool {
m.argsOnPacketSent = []interface{}{sentTime, bytesInFlight, packetNumber, bytes, isRetransmittable}
return false
}
func (m *mockCongestion) GetCongestionWindow() protocol.ByteCount {
m.getCongestionWindow = true
return protocol.DefaultTCPMSS
}
func (m *mockCongestion) MaybeExitSlowStart() {
m.maybeExitSlowStart = true
}
func (m *mockCongestion) OnRetransmissionTimeout(packetsRetransmitted bool) {
m.onRetransmissionTimeout = true
}
func (m *mockCongestion) RetransmissionDelay() time.Duration {
return defaultRTOTimeout
}
func (m *mockCongestion) SetNumEmulatedConnections(n int) { panic("not implemented") }
func (m *mockCongestion) OnConnectionMigration() { panic("not implemented") }
func (m *mockCongestion) SetSlowStartLargeReduction(enabled bool) { panic("not implemented") }
func (m *mockCongestion) OnPacketAcked(n protocol.PacketNumber, l protocol.ByteCount, bif protocol.ByteCount) {
m.packetsAcked = append(m.packetsAcked, []interface{}{n, l, bif})
}
func (m *mockCongestion) OnPacketLost(n protocol.PacketNumber, l protocol.ByteCount, bif protocol.ByteCount) {
m.packetsLost = append(m.packetsLost, []interface{}{n, l, bif})
}
func retransmittablePacket(num protocol.PacketNumber) *Packet {
return &Packet{
PacketNumber: num,
@ -143,7 +98,7 @@ var _ = Describe("SentPacketHandler", func() {
packet := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1}
err := handler.SentPacket(&packet)
Expect(err).ToNot(HaveOccurred())
Expect(handler.packetHistory.Front().Value.SendTime.Unix()).To(BeNumerically("~", time.Now().Unix(), 1))
Expect(handler.packetHistory.Front().Value.sendTime.Unix()).To(BeNumerically("~", time.Now().Unix(), 1))
})
It("does not store non-retransmittable packets", func() {
@ -553,9 +508,9 @@ var _ = Describe("SentPacketHandler", func() {
It("computes the RTT", func() {
now := time.Now()
// First, fake the sent times of the first, second and last packet
getPacketElement(1).Value.SendTime = now.Add(-10 * time.Minute)
getPacketElement(2).Value.SendTime = now.Add(-5 * time.Minute)
getPacketElement(6).Value.SendTime = now.Add(-1 * time.Minute)
getPacketElement(1).Value.sendTime = now.Add(-10 * time.Minute)
getPacketElement(2).Value.sendTime = now.Add(-5 * time.Minute)
getPacketElement(6).Value.sendTime = now.Add(-1 * time.Minute)
// Now, check that the proper times are used when calculating the deltas
err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1}, 1, protocol.EncryptionUnencrypted, time.Now())
Expect(err).NotTo(HaveOccurred())
@ -570,12 +525,50 @@ var _ = Describe("SentPacketHandler", func() {
It("uses the DelayTime in the ack frame", func() {
now := time.Now()
getPacketElement(1).Value.SendTime = now.Add(-10 * time.Minute)
getPacketElement(1).Value.sendTime = now.Add(-10 * time.Minute)
err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, DelayTime: 5 * time.Minute}, 1, protocol.EncryptionUnencrypted, time.Now())
Expect(err).NotTo(HaveOccurred())
Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second))
})
})
Context("determinining, which ACKs we have received an ACK for", func() {
BeforeEach(func() {
morePackets := []*Packet{
&Packet{PacketNumber: 13, Frames: []wire.Frame{&wire.AckFrame{LowestAcked: 80, LargestAcked: 100}, &streamFrame}, Length: 1},
&Packet{PacketNumber: 14, Frames: []wire.Frame{&wire.AckFrame{LowestAcked: 50, LargestAcked: 200}, &streamFrame}, Length: 1},
&Packet{PacketNumber: 15, Frames: []wire.Frame{&streamFrame}, Length: 1},
}
for _, packet := range morePackets {
err := handler.SentPacket(packet)
Expect(err).NotTo(HaveOccurred())
}
})
It("determines which ACK we have received an ACK for", func() {
err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 15, LowestAcked: 12}, 1, protocol.EncryptionUnencrypted, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201)))
})
It("doesn't do anything when the acked packet didn't contain an ACK", func() {
err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 13, LowestAcked: 13}, 1, protocol.EncryptionUnencrypted, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101)))
err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 15, LowestAcked: 15}, 2, protocol.EncryptionUnencrypted, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101)))
})
It("doesn't decrease the value", func() {
err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 14, LowestAcked: 14}, 1, protocol.EncryptionUnencrypted, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201)))
err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 13, LowestAcked: 13}, 2, protocol.EncryptionUnencrypted, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201)))
})
})
})
Context("Retransmission handling", func() {
@ -583,13 +576,13 @@ var _ = Describe("SentPacketHandler", func() {
BeforeEach(func() {
packets = []*Packet{
{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1},
{PacketNumber: 2, Frames: []wire.Frame{&streamFrame}, Length: 1},
{PacketNumber: 3, Frames: []wire.Frame{&streamFrame}, Length: 1},
{PacketNumber: 4, Frames: []wire.Frame{&streamFrame}, Length: 1},
{PacketNumber: 5, Frames: []wire.Frame{&streamFrame}, Length: 1},
{PacketNumber: 6, Frames: []wire.Frame{&streamFrame}, Length: 1},
{PacketNumber: 7, Frames: []wire.Frame{&streamFrame}, Length: 1},
{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionUnencrypted},
{PacketNumber: 2, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionUnencrypted},
{PacketNumber: 3, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionUnencrypted},
{PacketNumber: 4, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionSecure},
{PacketNumber: 5, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionSecure},
{PacketNumber: 6, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionForwardSecure},
{PacketNumber: 7, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionForwardSecure},
}
for _, packet := range packets {
handler.SentPacket(packet)
@ -597,7 +590,7 @@ var _ = Describe("SentPacketHandler", func() {
// Increase RTT, because the tests would be flaky otherwise
handler.rttStats.UpdateRTT(time.Minute, 0, time.Now())
// Ack a single packet so that we have non-RTO timings
handler.ReceivedAck(&wire.AckFrame{LargestAcked: 2, LowestAcked: 2}, 1, protocol.EncryptionUnencrypted, time.Now())
handler.ReceivedAck(&wire.AckFrame{LargestAcked: 2, LowestAcked: 2}, 1, protocol.EncryptionForwardSecure, time.Now())
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6)))
})
@ -606,7 +599,7 @@ var _ = Describe("SentPacketHandler", func() {
})
It("dequeues a packet for retransmission", func() {
getPacketElement(1).Value.SendTime = time.Now().Add(-time.Hour)
getPacketElement(1).Value.sendTime = time.Now().Add(-time.Hour)
handler.OnAlarm()
Expect(getPacketElement(1)).To(BeNil())
Expect(handler.retransmissionQueue).To(HaveLen(1))
@ -617,15 +610,33 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.DequeuePacketForRetransmission()).To(BeNil())
})
Context("StopWaitings", func() {
It("gets a StopWaitingFrame", func() {
It("deletes non forward-secure packets when the handshake completes", func() {
for i := protocol.PacketNumber(1); i <= 7; i++ {
if i == 2 { // packet 2 was already acked in BeforeEach
continue
}
handler.queuePacketForRetransmission(getPacketElement(i))
}
Expect(handler.retransmissionQueue).To(HaveLen(6))
handler.SetHandshakeComplete()
packet := handler.DequeuePacketForRetransmission()
Expect(packet).ToNot(BeNil())
Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(6)))
packet = handler.DequeuePacketForRetransmission()
Expect(packet).ToNot(BeNil())
Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(7)))
Expect(handler.DequeuePacketForRetransmission()).To(BeNil())
})
Context("STOP_WAITINGs", func() {
It("gets a STOP_WAITING frame", func() {
ack := wire.AckFrame{LargestAcked: 5, LowestAcked: 5}
err := handler.ReceivedAck(&ack, 2, protocol.EncryptionUnencrypted, time.Now())
err := handler.ReceivedAck(&ack, 2, protocol.EncryptionForwardSecure, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 6}))
})
It("gets a StopWaitingFrame after queueing a retransmission", func() {
It("gets a STOP_WAITING frame after queueing a retransmission", func() {
handler.queuePacketForRetransmission(getPacketElement(5))
Expect(handler.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 6}))
})
@ -662,7 +673,7 @@ var _ = Describe("SentPacketHandler", func() {
Expect(err).NotTo(HaveOccurred())
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2)))
handler.packetHistory.Front().Value.SendTime = time.Now().Add(-time.Hour)
handler.packetHistory.Front().Value.sendTime = time.Now().Add(-time.Hour)
handler.OnAlarm()
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(0)))
@ -670,15 +681,23 @@ var _ = Describe("SentPacketHandler", func() {
Context("congestion", func() {
var (
cong *mockCongestion
cong *mocks.MockSendAlgorithm
)
BeforeEach(func() {
cong = &mockCongestion{}
cong = mocks.NewMockSendAlgorithm(mockCtrl)
cong.EXPECT().RetransmissionDelay().AnyTimes()
handler.congestion = cong
})
It("should call OnSent", func() {
cong.EXPECT().OnPacketSent(
gomock.Any(),
protocol.ByteCount(42),
protocol.PacketNumber(1),
protocol.ByteCount(42),
true,
)
p := &Packet{
PacketNumber: 1,
Length: 42,
@ -686,62 +705,60 @@ var _ = Describe("SentPacketHandler", func() {
}
err := handler.SentPacket(p)
Expect(err).NotTo(HaveOccurred())
Expect(cong.argsOnPacketSent[1]).To(Equal(protocol.ByteCount(42)))
Expect(cong.argsOnPacketSent[2]).To(Equal(protocol.PacketNumber(1)))
Expect(cong.argsOnPacketSent[3]).To(Equal(protocol.ByteCount(42)))
Expect(cong.argsOnPacketSent[4]).To(BeTrue())
})
It("should call MaybeExitSlowStart and OnPacketAcked", func() {
cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
cong.EXPECT().MaybeExitSlowStart()
cong.EXPECT().OnPacketAcked(
protocol.PacketNumber(1),
protocol.ByteCount(1),
protocol.ByteCount(1),
)
handler.SentPacket(retransmittablePacket(1))
handler.SentPacket(retransmittablePacket(2))
err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionForwardSecure, time.Now())
Expect(err).NotTo(HaveOccurred())
Expect(cong.maybeExitSlowStart).To(BeTrue())
Expect(cong.packetsAcked).To(BeEquivalentTo([][]interface{}{
{protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(1)},
}))
Expect(cong.packetsLost).To(BeEmpty())
})
It("should call MaybeExitSlowStart and OnPacketLost", func() {
cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3)
cong.EXPECT().OnRetransmissionTimeout(true).Times(2)
cong.EXPECT().OnPacketLost(
protocol.PacketNumber(1),
protocol.ByteCount(1),
protocol.ByteCount(2),
)
cong.EXPECT().OnPacketLost(
protocol.PacketNumber(2),
protocol.ByteCount(1),
protocol.ByteCount(1),
)
handler.SentPacket(retransmittablePacket(1))
handler.SentPacket(retransmittablePacket(2))
handler.SentPacket(retransmittablePacket(3))
handler.OnAlarm() // RTO, meaning 2 lost packets
Expect(cong.maybeExitSlowStart).To(BeFalse())
Expect(cong.onRetransmissionTimeout).To(BeTrue())
Expect(cong.packetsAcked).To(BeEmpty())
Expect(cong.packetsLost).To(BeEquivalentTo([][]interface{}{
{protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)},
{protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(1)},
}))
})
It("allows or denies sending based on congestion", func() {
Expect(handler.retransmissionQueue).To(BeEmpty())
handler.bytesInFlight = 100
cong.EXPECT().GetCongestionWindow().Return(protocol.MaxByteCount)
Expect(handler.SendingAllowed()).To(BeTrue())
err := handler.SentPacket(&Packet{
PacketNumber: 1,
Frames: []wire.Frame{&wire.PingFrame{}},
Length: protocol.DefaultTCPMSS + 1,
})
Expect(err).NotTo(HaveOccurred())
cong.EXPECT().GetCongestionWindow().Return(protocol.ByteCount(0))
Expect(handler.SendingAllowed()).To(BeFalse())
})
It("allows or denies sending based on the number of tracked packets", func() {
cong.EXPECT().GetCongestionWindow().Return(protocol.MaxByteCount).AnyTimes()
Expect(handler.SendingAllowed()).To(BeTrue())
handler.retransmissionQueue = make([]*Packet, protocol.MaxTrackedSentPackets)
Expect(handler.SendingAllowed()).To(BeFalse())
})
It("allows sending if there are retransmisisons outstanding", func() {
err := handler.SentPacket(&Packet{
PacketNumber: 1,
Frames: []wire.Frame{&wire.PingFrame{}},
Length: protocol.DefaultTCPMSS + 1,
})
Expect(err).NotTo(HaveOccurred())
handler.bytesInFlight = 100
cong.EXPECT().GetCongestionWindow().Return(protocol.ByteCount(0)).AnyTimes()
Expect(handler.SendingAllowed()).To(BeFalse())
handler.retransmissionQueue = []*Packet{nil}
Expect(handler.SendingAllowed()).To(BeTrue())
@ -799,7 +816,7 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.lossTime.Sub(time.Now())).To(BeNumerically("~", time.Hour*9/8, time.Minute))
Expect(handler.GetAlarmTimeout().Sub(time.Now())).To(BeNumerically("~", time.Hour*9/8, time.Minute))
handler.packetHistory.Front().Value.SendTime = time.Now().Add(-2 * time.Hour)
handler.packetHistory.Front().Value.sendTime = time.Now().Add(-2 * time.Hour)
handler.OnAlarm()
Expect(handler.DequeuePacketForRetransmission()).NotTo(BeNil())
})
@ -843,7 +860,7 @@ var _ = Describe("SentPacketHandler", func() {
err = handler.SentPacket(handshakePacket(4))
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionSecure, time.Now().Add(time.Hour))
err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionSecure, time.Now())
Expect(err).NotTo(HaveOccurred())
Expect(handler.lossTime.IsZero()).To(BeTrue())
handshakeTimeout := handler.computeHandshakeTimeout()

View File

@ -60,33 +60,15 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error)
return Dial(udpConn, udpAddr, addr, tlsConf, config)
}
// DialAddrNonFWSecure establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address.
func DialAddrNonFWSecure(
addr string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
}
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI.
func DialNonFWSecure(
func Dial(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
) (Session, error) {
connID, err := generateConnectionID()
if err != nil {
return nil, err
@ -115,31 +97,11 @@ func DialNonFWSecure(
}
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
go c.listen()
if err := c.dial(); err != nil {
return nil, err
}
return c.session.(NonFWSession), nil
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI.
func Dial(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
if err != nil {
return nil, err
}
if err := sess.WaitUntilHandshakeComplete(); err != nil {
return nil, err
}
return sess, nil
return c.session, nil
}
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
@ -199,6 +161,7 @@ func (c *client) dialGQUIC() error {
if err := c.createNewGQUICSession(); err != nil {
return err
}
go c.listen()
return c.establishSecureConnection()
}
@ -224,6 +187,7 @@ func (c *client) dialTLS() error {
if err := c.createNewTLSSession(eh.GetPeerParams(), c.version); err != nil {
return err
}
go c.listen()
if err := c.establishSecureConnection(); err != nil {
if err != handshake.ErrCloseSessionForRetry {
return err
@ -267,14 +231,8 @@ func (c *client) establishSecureConnection() error {
select {
case <-errorChan:
return runErr
case ev := <-c.session.handshakeStatus():
if ev.err != nil {
return ev.err
}
if !c.version.UsesTLS() && ev.encLevel != protocol.EncryptionSecure {
return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel)
}
return nil
case err := <-c.session.handshakeStatus():
return err
}
}

View File

@ -5,6 +5,7 @@ import (
"crypto/tls"
"errors"
"net"
"os"
"sync/atomic"
"time"
@ -100,57 +101,7 @@ var _ = Describe("Client", func() {
generateConnectionID = origGenerateConnectionID
})
It("dials non-forward-secure", func() {
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
s, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
close(dialed)
}()
Consistently(dialed).ShouldNot(BeClosed())
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
Eventually(dialed).Should(BeClosed())
})
It("dials a non-forward-secure address", func() {
serverAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
Expect(err).ToNot(HaveOccurred())
server, err := net.ListenUDP("udp", serverAddr)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
for {
_, clientAddr, err := server.ReadFromUDP(make([]byte, 200))
if err != nil {
return
}
_, err = server.WriteToUDP(acceptClientVersionPacket(cl.connectionID), clientAddr)
Expect(err).ToNot(HaveOccurred())
}
}()
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
s, err := DialAddrNonFWSecure(server.LocalAddr().String(), nil, config)
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
close(dialed)
}()
Consistently(dialed).ShouldNot(BeClosed())
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
Eventually(dialed).Should(BeClosed())
server.Close()
Eventually(done).Should(BeClosed())
})
It("Dial only returns after the handshake is complete", func() {
It("returns after the handshake is complete", func() {
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
dialed := make(chan struct{})
go func() {
@ -160,13 +111,14 @@ var _ = Describe("Client", func() {
Expect(s).ToNot(BeNil())
close(dialed)
}()
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
Consistently(dialed).ShouldNot(BeClosed())
close(sess.handshakeComplete)
close(sess.handshakeChan)
Eventually(dialed).Should(BeClosed())
})
It("resolves the address", func() {
if os.Getenv("APPVEYOR") == "True" {
Skip("This test is flaky on AppVeyor.")
}
closeErr := errors.New("peer doesn't reply")
remoteAddrChan := make(chan string)
newClientSession = func(
@ -245,22 +197,7 @@ var _ = Describe("Client", func() {
Expect(err).To(MatchError(testErr))
close(done)
}()
sess.handshakeChan <- handshakeEvent{err: testErr}
Eventually(done).Should(BeClosed())
})
It("returns an error that occurs while waiting for the handshake to complete", func() {
testErr := errors.New("late handshake error")
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).To(MatchError(testErr))
close(done)
}()
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
sess.handshakeComplete <- testErr
sess.handshakeChan <- testErr
Eventually(done).Should(BeClosed())
})
@ -305,7 +242,7 @@ var _ = Describe("Client", func() {
) (packetHandler, error) {
return nil, testErr
}
_, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config)
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).To(MatchError(testErr))
})
@ -331,7 +268,7 @@ var _ = Describe("Client", func() {
Expect(newVersion).ToNot(Equal(cl.version))
Expect(config.Versions).To(ContainElement(newVersion))
sessionChan := make(chan *mockSession)
handshakeChan := make(chan handshakeEvent)
handshakeChan := make(chan error)
newClientSession = func(
_ connection,
_ string,
@ -382,7 +319,7 @@ var _ = Describe("Client", func() {
Expect(negotiatedVersions).To(ContainElement(newVersion))
Expect(initialVersion).To(Equal(actualInitialVersion))
handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
close(handshakeChan)
Eventually(established).Should(BeClosed())
})

View File

@ -0,0 +1,41 @@
package quic
import (
"io"
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type cryptoStreamI interface {
StreamID() protocol.StreamID
io.Reader
io.Writer
handleStreamFrame(*wire.StreamFrame) error
popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
closeForShutdown(error)
setReadOffset(protocol.ByteCount)
// methods needed for flow control
getWindowUpdate() protocol.ByteCount
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
}
type cryptoStream struct {
*stream
}
var _ cryptoStreamI = &cryptoStream{}
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI {
str := newStream(version.CryptoStreamID(), sender, flowController, version)
return &cryptoStream{str}
}
// SetReadOffset sets the read offset.
// It is only needed for the crypto stream.
// It must not be called concurrently with any other stream methods, especially Read and Write.
func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) {
s.receiveStream.readOffset = offset
s.receiveStream.frameQueue.readPosition = offset
}

View File

@ -0,0 +1,26 @@
package quic
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Crypto Stream", func() {
var (
str *cryptoStream
mockSender *MockStreamSender
)
BeforeEach(func() {
mockSender = NewMockStreamSender(mockCtrl)
str = newCryptoStream(mockSender, nil, protocol.VersionWhatever).(*cryptoStream)
})
It("sets the read offset", func() {
str.setReadOffset(0x42)
Expect(str.receiveStream.readOffset).To(Equal(protocol.ByteCount(0x42)))
Expect(str.receiveStream.frameQueue.readPosition).To(Equal(protocol.ByteCount(0x42)))
})
})

View File

@ -97,45 +97,44 @@ func (c *client) handleHeaderStream() {
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, c.headerStream)
var lastStream protocol.StreamID
var err error
for err == nil {
err = c.readResponse(h2framer, decoder)
}
utils.Debugf("Error handling header stream: %s", err)
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
// stop all running request
close(c.headerErrored)
}
for {
frame, err := h2framer.ReadFrame()
if err != nil {
c.headerErr = qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
break
}
lastStream = protocol.StreamID(frame.Header().StreamID)
hframe, ok := frame.(*http2.HeadersFrame)
if !ok {
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")
break
}
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
if err != nil {
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields")
break
}
c.mutex.RLock()
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
c.mutex.RUnlock()
if !ok {
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream))
break
}
rsp, err := responseFromHeaders(mhframe)
if err != nil {
c.headerErr = qerr.Error(qerr.InternalError, err.Error())
}
responseChan <- rsp
func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
frame, err := h2framer.ReadFrame()
if err != nil {
return err
}
hframe, ok := frame.(*http2.HeadersFrame)
if !ok {
return errors.New("not a headers frame")
}
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
if err != nil {
return fmt.Errorf("cannot read header fields: %s", err.Error())
}
// stop all running request
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error())
close(c.headerErrored)
c.mutex.RLock()
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
c.mutex.RUnlock()
if !ok {
return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
}
rsp, err := responseFromHeaders(mhframe)
if err != nil {
return err
}
responseChan <- rsp
return nil
}
// Roundtrip executes a request and returns a response

View File

@ -188,41 +188,31 @@ var _ = Describe("Client", func() {
close(done)
})
It("closes the quic client when encountering an error on the header stream", func(done Done) {
It("closes the quic client when encountering an error on the header stream", func() {
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
var doReturned bool
done := make(chan struct{})
go func() {
defer GinkgoRecover()
var err error
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(client.headerErr))
Expect(rsp).To(BeNil())
doReturned = true
close(done)
}()
Eventually(func() bool { return doReturned }).Should(BeTrue())
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")))
Eventually(done).Should(BeClosed())
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr))
close(done)
}, 2)
})
It("returns subsequent request if there was an error on the header stream before", func(done Done) {
expectedErr := qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
It("returns subsequent request if there was an error on the header stream before", func() {
session.streamsToOpen = []quic.Stream{headerStream, dataStream, newMockStream(7)}
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
var firstReqReturned bool
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(request)
Expect(err).To(MatchError(expectedErr))
firstReqReturned = true
}()
Eventually(func() bool { return firstReqReturned }).Should(BeTrue())
// now that the first request failed due to an error on the header stream, try another request
_, err := client.RoundTrip(request)
Expect(err).To(MatchError(expectedErr))
close(done)
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
// now that the first request failed due to an error on the header stream, try another request
_, nextErr := client.RoundTrip(request)
Expect(nextErr).To(MatchError(err))
})
It("blocks if no stream is available", func() {
@ -479,16 +469,9 @@ var _ = Describe("Client", func() {
It("errors if the H2 frame is not a HeadersFrame", func() {
h2framer.WritePing(true, [8]byte{0, 0, 0, 0, 0, 0, 0, 0})
var handlerReturned bool
go func() {
client.handleHeaderStream()
handlerReturned = true
}()
client.handleHeaderStream()
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")))
Eventually(func() bool { return handlerReturned }).Should(BeTrue())
})
It("errors if it can't read the HPACK encoded header fields", func() {
@ -497,16 +480,26 @@ var _ = Describe("Client", func() {
EndHeaders: true,
BlockFragment: []byte("invalid HPACK data"),
})
var handlerReturned bool
go func() {
client.handleHeaderStream()
handlerReturned = true
}()
client.handleHeaderStream()
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields")))
Eventually(func() bool { return handlerReturned }).Should(BeTrue())
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("cannot read header fields"))
})
It("errors if the stream cannot be found", func() {
var headers bytes.Buffer
enc := hpack.NewEncoder(&headers)
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
err := h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: 1337,
EndHeaders: true,
BlockFragment: headers.Bytes(),
})
Expect(err).ToNot(HaveOccurred())
client.handleHeaderStream()
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("response channel for stream 1337 not found"))
})
})
})

View File

@ -11,6 +11,7 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
@ -29,6 +30,8 @@ type mockStream struct {
ctxCancel context.CancelFunc
}
var _ quic.Stream = &mockStream{}
func newMockStream(id protocol.StreamID) *mockStream {
s := &mockStream{
id: id,
@ -39,7 +42,8 @@ func newMockStream(id protocol.StreamID) *mockStream {
}
func (s *mockStream) Close() error { s.closed = true; s.ctxCancel(); return nil }
func (s *mockStream) Reset(error) { s.reset = true }
func (s *mockStream) CancelRead(quic.ErrorCode) error { s.reset = true; return nil }
func (s *mockStream) CancelWrite(quic.ErrorCode) error { panic("not implemented") }
func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true; s.ctxCancel() }
func (s mockStream) StreamID() protocol.StreamID { return s.id }
func (s *mockStream) Context() context.Context { return s.ctx }

View File

@ -50,6 +50,7 @@ type Server struct {
listenerMutex sync.Mutex
listener quic.Listener
closed bool
supportedVersionsAsString string
}
@ -88,6 +89,10 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
return errors.New("use of h2quic.Server without http.Server")
}
s.listenerMutex.Lock()
if s.closed {
s.listenerMutex.Unlock()
return errors.New("Server is already closed")
}
if s.listener != nil {
s.listenerMutex.Unlock()
return errors.New("ListenAndServe may only be called once")
@ -223,7 +228,8 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
}
if responseWriter.dataStream != nil {
if !streamEnded && !reqBody.requestRead {
responseWriter.dataStream.Reset(nil)
// in gQUIC, the error code doesn't matter, so just use 0 here
responseWriter.dataStream.CancelRead(0)
}
responseWriter.dataStream.Close()
}
@ -241,6 +247,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
func (s *Server) Close() error {
s.listenerMutex.Lock()
defer s.listenerMutex.Unlock()
s.closed = true
if s.listener != nil {
err := s.listener.Close()
s.listener = nil

View File

@ -70,6 +70,7 @@ func (s *mockSession) RemoteAddr() net.Addr {
func (s *mockSession) Context() context.Context {
return s.ctx
}
func (s *mockSession) ConnectionState() quic.ConnectionState { panic("not implemented") }
var _ = Describe("H2 server", func() {
var (
@ -410,6 +411,13 @@ var _ = Describe("H2 server", func() {
Expect(err).NotTo(HaveOccurred())
})
It("errors when ListenAndServer is called after Close", func() {
serv := &Server{Server: &http.Server{}}
Expect(serv.Close()).To(Succeed())
err := serv.ListenAndServe()
Expect(err).To(MatchError("Server is already closed"))
})
Context("ListenAndServe", func() {
BeforeEach(func() {
s.Server.Addr = "localhost:0"

View File

@ -19,20 +19,42 @@ type VersionNumber = protocol.VersionNumber
// A Cookie can be used to verify the ownership of the client address.
type Cookie = handshake.Cookie
// ConnectionState records basic details about the QUIC connection.
type ConnectionState = handshake.ConnectionState
// An ErrorCode is an application-defined error code.
type ErrorCode = protocol.ApplicationErrorCode
// Stream is the interface implemented by QUIC streams
type Stream interface {
// StreamID returns the stream ID.
StreamID() StreamID
// Read reads data from the stream.
// Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
io.Reader
// Write writes data to the stream.
// Write can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
io.Writer
// Close closes the write-direction of the stream.
// Future calls to Write are not permitted after calling Close.
// It must not be called concurrently with Write.
// It must not be called after calling CancelWrite.
io.Closer
StreamID() StreamID
// Reset closes the stream with an error.
Reset(error)
// CancelWrite aborts sending on this stream.
// It must not be called after Close.
// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
// Write will unblock immediately, and future calls to Write will fail.
CancelWrite(ErrorCode) error
// CancelRead aborts receiving on this stream.
// It will ask the peer to stop transmitting stream data.
// Read will unblock immediately, and future Read calls will fail.
CancelRead(ErrorCode) error
// The context is canceled as soon as the write-side of the stream is closed.
// This happens when Close() is called, or when the stream is reset (either locally or remotely).
// Warning: This API should not be considered stable and might change soon.
@ -53,6 +75,41 @@ type Stream interface {
SetDeadline(t time.Time) error
}
// A ReceiveStream is a unidirectional Receive Stream.
type ReceiveStream interface {
// see Stream.StreamID
StreamID() StreamID
// see Stream.Read
io.Reader
// see Stream.CancelRead
CancelRead(ErrorCode) error
// see Stream.SetReadDealine
SetReadDeadline(t time.Time) error
}
// A SendStream is a unidirectional Send Stream.
type SendStream interface {
// see Stream.StreamID
StreamID() StreamID
// see Stream.Write
io.Writer
// see Stream.Close
io.Closer
// see Stream.CancelWrite
CancelWrite(ErrorCode) error
// see Stream.Context
Context() context.Context
// see Stream.SetWriteDeadline
SetWriteDeadline(t time.Time) error
}
// StreamError is returned by Read and Write when the peer cancels the stream.
type StreamError interface {
error
Canceled() bool
ErrorCode() ErrorCode
}
// A Session is a QUIC connection between two peers.
type Session interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
@ -74,13 +131,9 @@ type Session interface {
// The context is cancelled when the session is closed.
// Warning: This API should not be considered stable and might change soon.
Context() context.Context
}
// A NonFWSession is a QUIC connection between two peers half-way through the handshake.
// The communication is encrypted, but not yet forward secure.
type NonFWSession interface {
Session
WaitUntilHandshakeComplete() error
// ConnectionState returns basic details about the QUIC connection.
// Warning: This API should not be considered stable and might change soon.
ConnectionState() ConnectionState
}
// Config contains all configuration data needed for a QUIC server or client.

View File

@ -18,6 +18,7 @@ type CertManager interface {
GetLeafCertHash() (uint64, error)
VerifyServerProof(proof, chlo, serverConfigData []byte) bool
Verify(hostname string) error
GetChain() []*x509.Certificate
}
type certManager struct {
@ -54,6 +55,10 @@ func (c *certManager) SetData(data []byte) error {
return nil
}
func (c *certManager) GetChain() []*x509.Certificate {
return c.chain
}
func (c *certManager) GetCommonCertificateHashes() []byte {
return getCommonCertificateHashes()
}

View File

@ -10,35 +10,30 @@ import (
)
type baseFlowController struct {
mutex sync.RWMutex
rttStats *congestion.RTTStats
// for sending data
bytesSent protocol.ByteCount
sendWindow protocol.ByteCount
lastWindowUpdateTime time.Time
// for receiving data
mutex sync.RWMutex
bytesRead protocol.ByteCount
highestReceived protocol.ByteCount
receiveWindow protocol.ByteCount
receiveWindowSize protocol.ByteCount
maxReceiveWindowSize protocol.ByteCount
bytesRead protocol.ByteCount
highestReceived protocol.ByteCount
receiveWindow protocol.ByteCount
receiveWindowIncrement protocol.ByteCount
maxReceiveWindowIncrement protocol.ByteCount
epochStartTime time.Time
epochStartOffset protocol.ByteCount
rttStats *congestion.RTTStats
}
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.bytesSent += n
}
// UpdateSendWindow should be called after receiving a WindowUpdateFrame
// it returns true if the window was actually updated
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) {
c.mutex.Lock()
defer c.mutex.Unlock()
if offset > c.sendWindow {
c.sendWindow = offset
}
@ -57,52 +52,55 @@ func (c *baseFlowController) AddBytesRead(n protocol.ByteCount) {
defer c.mutex.Unlock()
// pretend we sent a WindowUpdate when reading the first byte
// this way auto-tuning of the window increment already works for the first WindowUpdate
// this way auto-tuning of the window size already works for the first WindowUpdate
if c.bytesRead == 0 {
c.lastWindowUpdateTime = time.Now()
c.startNewAutoTuningEpoch()
}
c.bytesRead += n
}
func (c *baseFlowController) hasWindowUpdate() bool {
bytesRemaining := c.receiveWindow - c.bytesRead
// update the window when more than the threshold was consumed
return bytesRemaining <= protocol.ByteCount((float64(c.receiveWindowSize) * float64((1 - protocol.WindowUpdateThreshold))))
}
// getWindowUpdate updates the receive window, if necessary
// it returns the new offset
func (c *baseFlowController) getWindowUpdate() protocol.ByteCount {
diff := c.receiveWindow - c.bytesRead
// update the window when more than half of it was already consumed
if diff >= (c.receiveWindowIncrement / 2) {
if !c.hasWindowUpdate() {
return 0
}
c.maybeAdjustWindowIncrement()
c.receiveWindow = c.bytesRead + c.receiveWindowIncrement
c.lastWindowUpdateTime = time.Now()
c.maybeAdjustWindowSize()
c.receiveWindow = c.bytesRead + c.receiveWindowSize
return c.receiveWindow
}
func (c *baseFlowController) IsBlocked() bool {
c.mutex.RLock()
defer c.mutex.RUnlock()
return c.sendWindowSize() == 0
}
// maybeAdjustWindowIncrement increases the receiveWindowIncrement if we're sending WindowUpdates too often
func (c *baseFlowController) maybeAdjustWindowIncrement() {
if c.lastWindowUpdateTime.IsZero() {
// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often.
// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing.
func (c *baseFlowController) maybeAdjustWindowSize() {
bytesReadInEpoch := c.bytesRead - c.epochStartOffset
// don't do anything if less than half the window has been consumed
if bytesReadInEpoch <= c.receiveWindowSize/2 {
return
}
rtt := c.rttStats.SmoothedRTT()
if rtt == 0 {
return
}
timeSinceLastWindowUpdate := time.Since(c.lastWindowUpdateTime)
// interval between the window updates is sufficiently large, no need to increase the increment
if timeSinceLastWindowUpdate >= 2*rtt {
return
fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize)
if time.Since(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) {
// window is consumed too fast, try to increase the window size
c.receiveWindowSize = utils.MinByteCount(2*c.receiveWindowSize, c.maxReceiveWindowSize)
}
c.receiveWindowIncrement = utils.MinByteCount(2*c.receiveWindowIncrement, c.maxReceiveWindowIncrement)
c.startNewAutoTuningEpoch()
}
func (c *baseFlowController) startNewAutoTuningEpoch() {
c.epochStartTime = time.Now()
c.epochStartOffset = c.bytesRead
}
func (c *baseFlowController) checkFlowControlViolation() bool {

View File

@ -1,6 +1,8 @@
package flowcontrol
import (
"os"
"strconv"
"time"
"github.com/lucas-clemente/quic-go/congestion"
@ -9,6 +11,16 @@ import (
. "github.com/onsi/gomega"
)
// on the CIs, the timing is a lot less precise, so scale every duration by this factor
func scaleDuration(t time.Duration) time.Duration {
scaleFactor := 1
if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set
scaleFactor = f
}
Expect(scaleFactor).ToNot(BeZero())
return time.Duration(scaleFactor) * t
}
var _ = Describe("Base Flow controller", func() {
var controller *baseFlowController
@ -49,22 +61,18 @@ var _ = Describe("Base Flow controller", func() {
controller.UpdateSendWindow(10)
Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(20)))
})
It("says when it's blocked", func() {
controller.UpdateSendWindow(100)
Expect(controller.IsBlocked()).To(BeFalse())
controller.AddBytesSent(100)
Expect(controller.IsBlocked()).To(BeTrue())
})
})
Context("receive flow control", func() {
var receiveWindow protocol.ByteCount = 10000
var receiveWindowIncrement protocol.ByteCount = 600
var (
receiveWindow protocol.ByteCount = 10000
receiveWindowSize protocol.ByteCount = 1000
)
BeforeEach(func() {
controller.bytesRead = receiveWindow - receiveWindowSize
controller.receiveWindow = receiveWindow
controller.receiveWindowIncrement = receiveWindowIncrement
controller.receiveWindowSize = receiveWindowSize
})
It("adds bytes read", func() {
@ -74,31 +82,30 @@ var _ = Describe("Base Flow controller", func() {
})
It("triggers a window update when necessary", func() {
controller.lastWindowUpdateTime = time.Now().Add(-time.Hour)
readPosition := receiveWindow - receiveWindowIncrement/2 + 1
bytesConsumed := float64(receiveWindowSize)*protocol.WindowUpdateThreshold + 1 // consumed 1 byte more than the threshold
bytesRemaining := receiveWindowSize - protocol.ByteCount(bytesConsumed)
readPosition := receiveWindow - bytesRemaining
controller.bytesRead = readPosition
offset := controller.getWindowUpdate()
Expect(offset).To(Equal(readPosition + receiveWindowIncrement))
Expect(controller.receiveWindow).To(Equal(readPosition + receiveWindowIncrement))
Expect(controller.lastWindowUpdateTime).To(BeTemporally("~", time.Now(), 20*time.Millisecond))
Expect(offset).To(Equal(readPosition + receiveWindowSize))
Expect(controller.receiveWindow).To(Equal(readPosition + receiveWindowSize))
})
It("doesn't trigger a window update when not necessary", func() {
lastWindowUpdateTime := time.Now().Add(-time.Hour)
controller.lastWindowUpdateTime = lastWindowUpdateTime
readPosition := receiveWindow - receiveWindow/2 - 1
bytesConsumed := float64(receiveWindowSize)*protocol.WindowUpdateThreshold - 1 // consumed 1 byte less than the threshold
bytesRemaining := receiveWindowSize - protocol.ByteCount(bytesConsumed)
readPosition := receiveWindow - bytesRemaining
controller.bytesRead = readPosition
offset := controller.getWindowUpdate()
Expect(offset).To(BeZero())
Expect(controller.lastWindowUpdateTime).To(Equal(lastWindowUpdateTime))
})
Context("receive window increment auto-tuning", func() {
var oldIncrement protocol.ByteCount
Context("receive window size auto-tuning", func() {
var oldWindowSize protocol.ByteCount
BeforeEach(func() {
oldIncrement = controller.receiveWindowIncrement
controller.maxReceiveWindowIncrement = 3000
oldWindowSize = controller.receiveWindowSize
controller.maxReceiveWindowSize = 5000
})
// update the congestion such that it returns a given value for the smoothed RTT
@ -107,72 +114,98 @@ var _ = Describe("Base Flow controller", func() {
Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked
}
It("doesn't increase the increment for a new stream", func() {
controller.maybeAdjustWindowIncrement()
Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement))
It("doesn't increase the window size for a new stream", func() {
controller.maybeAdjustWindowSize()
Expect(controller.receiveWindowSize).To(Equal(oldWindowSize))
})
It("doesn't increase the increment when no RTT estimate is available", func() {
It("doesn't increase the window size when no RTT estimate is available", func() {
setRtt(0)
controller.lastWindowUpdateTime = time.Now()
controller.maybeAdjustWindowIncrement()
Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement))
controller.startNewAutoTuningEpoch()
controller.AddBytesRead(400)
offset := controller.getWindowUpdate()
Expect(offset).ToNot(BeZero()) // make sure a window update is sent
Expect(controller.receiveWindowSize).To(Equal(oldWindowSize))
})
It("increases the increment when the last WindowUpdate was sent less than two RTTs ago", func() {
setRtt(20 * time.Millisecond)
controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond)
controller.maybeAdjustWindowIncrement()
Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement))
})
It("doesn't increase the increase increment when the last WindowUpdate was sent more than two RTTs ago", func() {
setRtt(20 * time.Millisecond)
controller.lastWindowUpdateTime = time.Now().Add(-45 * time.Millisecond)
controller.maybeAdjustWindowIncrement()
Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement))
})
It("doesn't increase the increment to a value higher than the maxReceiveWindowIncrement", func() {
setRtt(20 * time.Millisecond)
controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond)
controller.maybeAdjustWindowIncrement()
Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement)) // 1200
// because the lastWindowUpdateTime is updated by MaybeTriggerWindowUpdate(), we can just call maybeAdjustWindowIncrement() multiple times and get an increase of the increment every time
controller.maybeAdjustWindowIncrement()
Expect(controller.receiveWindowIncrement).To(Equal(2 * 2 * oldIncrement)) // 2400
controller.maybeAdjustWindowIncrement()
Expect(controller.receiveWindowIncrement).To(Equal(controller.maxReceiveWindowIncrement)) // 3000
controller.maybeAdjustWindowIncrement()
Expect(controller.receiveWindowIncrement).To(Equal(controller.maxReceiveWindowIncrement)) // 3000
})
It("returns the new increment when updating the window", func() {
setRtt(20 * time.Millisecond)
controller.AddBytesRead(9900) // receive window is 10000
controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond)
It("increases the window size if read so fast that the window would be consumed in less than 4 RTTs", func() {
bytesRead := controller.bytesRead
rtt := scaleDuration(20 * time.Millisecond)
setRtt(rtt)
// consume more than 2/3 of the window...
dataRead := receiveWindowSize*2/3 + 1
// ... in 4*2/3 of the RTT
controller.epochStartOffset = controller.bytesRead
controller.epochStartTime = time.Now().Add(-rtt * 4 * 2 / 3)
controller.AddBytesRead(dataRead)
offset := controller.getWindowUpdate()
Expect(offset).ToNot(BeZero())
newIncrement := controller.receiveWindowIncrement
Expect(newIncrement).To(Equal(2 * oldIncrement))
Expect(offset).To(Equal(protocol.ByteCount(9900 + newIncrement)))
// check that the window size was increased
newWindowSize := controller.receiveWindowSize
Expect(newWindowSize).To(Equal(2 * oldWindowSize))
// check that the new window size was used to increase the offset
Expect(offset).To(Equal(protocol.ByteCount(bytesRead + dataRead + newWindowSize)))
})
It("increases the increment sent in the first WindowUpdate, if data is read fast enough", func() {
setRtt(20 * time.Millisecond)
controller.AddBytesRead(9900)
It("doesn't increase the window size if data is read so fast that the window would be consumed in less than 4 RTTs, but less than half the window has been read", func() {
// this test only makes sense if a window update is triggered before half of the window has been consumed
Expect(protocol.WindowUpdateThreshold).To(BeNumerically(">", 1/3))
bytesRead := controller.bytesRead
rtt := scaleDuration(20 * time.Millisecond)
setRtt(rtt)
// consume more than 2/3 of the window...
dataRead := receiveWindowSize*1/3 + 1
// ... in 4*2/3 of the RTT
controller.epochStartOffset = controller.bytesRead
controller.epochStartTime = time.Now().Add(-rtt * 4 * 1 / 3)
controller.AddBytesRead(dataRead)
offset := controller.getWindowUpdate()
Expect(offset).ToNot(BeZero())
Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement))
// check that the window size was not increased
newWindowSize := controller.receiveWindowSize
Expect(newWindowSize).To(Equal(oldWindowSize))
// check that the new window size was used to increase the offset
Expect(offset).To(Equal(protocol.ByteCount(bytesRead + dataRead + newWindowSize)))
})
It("doesn't increamse the increment sent in the first WindowUpdate, if data is read slowly", func() {
setRtt(5 * time.Millisecond)
controller.AddBytesRead(9900)
time.Sleep(15 * time.Millisecond) // more than 2x RTT
It("doesn't increase the window size if read too slowly", func() {
bytesRead := controller.bytesRead
rtt := scaleDuration(20 * time.Millisecond)
setRtt(rtt)
// consume less than 2/3 of the window...
dataRead := receiveWindowSize*2/3 - 1
// ... in 4*2/3 of the RTT
controller.epochStartOffset = controller.bytesRead
controller.epochStartTime = time.Now().Add(-rtt * 4 * 2 / 3)
controller.AddBytesRead(dataRead)
offset := controller.getWindowUpdate()
Expect(offset).ToNot(BeZero())
Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement))
// check that the window size was not increased
Expect(controller.receiveWindowSize).To(Equal(oldWindowSize))
// check that the new window size was used to increase the offset
Expect(offset).To(Equal(protocol.ByteCount(bytesRead + dataRead + oldWindowSize)))
})
It("doesn't increase the window size to a value higher than the maxReceiveWindowSize", func() {
resetEpoch := func() {
// make sure the next call to maybeAdjustWindowSize will increase the window
controller.epochStartTime = time.Now().Add(-time.Millisecond)
controller.epochStartOffset = controller.bytesRead
controller.AddBytesRead(controller.receiveWindowSize/2 + 1)
}
setRtt(scaleDuration(20 * time.Millisecond))
resetEpoch()
controller.maybeAdjustWindowSize()
Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize)) // 2000
// because the lastWindowUpdateTime is updated by MaybeTriggerWindowUpdate(), we can just call maybeAdjustWindowSize() multiple times and get an increase of the window size every time
resetEpoch()
controller.maybeAdjustWindowSize()
Expect(controller.receiveWindowSize).To(Equal(2 * 2 * oldWindowSize)) // 4000
resetEpoch()
controller.maybeAdjustWindowSize()
Expect(controller.receiveWindowSize).To(Equal(controller.maxReceiveWindowSize)) // 5000
controller.maybeAdjustWindowSize()
Expect(controller.receiveWindowSize).To(Equal(controller.maxReceiveWindowSize)) // 5000
})
})
})

View File

@ -2,7 +2,6 @@ package flowcontrol
import (
"fmt"
"time"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
@ -11,6 +10,7 @@ import (
)
type connectionFlowController struct {
lastBlockedAt protocol.ByteCount
baseFlowController
}
@ -25,21 +25,29 @@ func NewConnectionFlowController(
) ConnectionFlowController {
return &connectionFlowController{
baseFlowController: baseFlowController{
rttStats: rttStats,
receiveWindow: receiveWindow,
receiveWindowIncrement: receiveWindow,
maxReceiveWindowIncrement: maxReceiveWindow,
rttStats: rttStats,
receiveWindow: receiveWindow,
receiveWindowSize: receiveWindow,
maxReceiveWindowSize: maxReceiveWindow,
},
}
}
func (c *connectionFlowController) SendWindowSize() protocol.ByteCount {
c.mutex.RLock()
defer c.mutex.RUnlock()
return c.baseFlowController.sendWindowSize()
}
// IsNewlyBlocked says if it is newly blocked by flow control.
// For every offset, it only returns true once.
// If it is blocked, the offset is returned.
func (c *connectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
return false, 0
}
c.lastBlockedAt = c.sendWindow
return true, c.sendWindow
}
// IncrementHighestReceived adds an increment to the highestReceived value
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error {
c.mutex.Lock()
@ -54,24 +62,22 @@ func (c *connectionFlowController) IncrementHighestReceived(increment protocol.B
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
c.mutex.Lock()
defer c.mutex.Unlock()
oldWindowIncrement := c.receiveWindowIncrement
oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate()
if oldWindowIncrement < c.receiveWindowIncrement {
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10))
if oldWindowSize < c.receiveWindowSize {
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
}
c.mutex.Unlock()
return offset
}
// EnsureMinimumWindowIncrement sets a minimum window increment
// EnsureMinimumWindowSize sets a minimum window size
// it should make sure that the connection-level window is increased when a stream-level window grows
func (c *connectionFlowController) EnsureMinimumWindowIncrement(inc protocol.ByteCount) {
func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) {
c.mutex.Lock()
defer c.mutex.Unlock()
if inc > c.receiveWindowIncrement {
c.receiveWindowIncrement = utils.MinByteCount(inc, c.maxReceiveWindowIncrement)
c.lastWindowUpdateTime = time.Time{} // disables autotuning for the next window update
if inc > c.receiveWindowSize {
c.receiveWindowSize = utils.MinByteCount(inc, c.maxReceiveWindowSize)
c.startNewAutoTuningEpoch()
}
c.mutex.Unlock()
}

View File

@ -32,12 +32,12 @@ var _ = Describe("Connection Flow controller", func() {
fc := NewConnectionFlowController(receiveWindow, maxReceiveWindow, rttStats).(*connectionFlowController)
Expect(fc.receiveWindow).To(Equal(receiveWindow))
Expect(fc.maxReceiveWindowIncrement).To(Equal(maxReceiveWindow))
Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow))
})
})
Context("receive flow control", func() {
It("increases the highestReceived by a given increment", func() {
It("increases the highestReceived by a given window size", func() {
controller.highestReceived = 1337
controller.IncrementHighestReceived(123)
Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1337 + 123)))
@ -46,64 +46,96 @@ var _ = Describe("Connection Flow controller", func() {
Context("getting window updates", func() {
BeforeEach(func() {
controller.receiveWindow = 100
controller.receiveWindowIncrement = 60
controller.maxReceiveWindowIncrement = 1000
controller.receiveWindowSize = 60
controller.maxReceiveWindowSize = 1000
controller.bytesRead = 100 - 60
})
It("gets a window update", func() {
controller.AddBytesRead(80)
windowSize := controller.receiveWindowSize
oldOffset := controller.bytesRead
dataRead := windowSize/2 - 1 // make sure not to trigger auto-tuning
controller.AddBytesRead(dataRead)
offset := controller.GetWindowUpdate()
Expect(offset).To(Equal(protocol.ByteCount(80 + 60)))
Expect(offset).To(Equal(protocol.ByteCount(oldOffset + dataRead + 60)))
})
It("autotunes the window", func() {
controller.AddBytesRead(80)
setRtt(20 * time.Millisecond)
controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond)
oldOffset := controller.bytesRead
oldWindowSize := controller.receiveWindowSize
rtt := scaleDuration(20 * time.Millisecond)
setRtt(rtt)
controller.epochStartTime = time.Now().Add(-time.Millisecond)
controller.epochStartOffset = oldOffset
dataRead := oldWindowSize/2 + 1
controller.AddBytesRead(dataRead)
offset := controller.GetWindowUpdate()
Expect(offset).To(Equal(protocol.ByteCount(80 + 2*60)))
newWindowSize := controller.receiveWindowSize
Expect(newWindowSize).To(Equal(2 * oldWindowSize))
Expect(offset).To(Equal(protocol.ByteCount(oldOffset + dataRead + newWindowSize)))
})
})
})
Context("setting the minimum increment", func() {
Context("send flow control", func() {
It("says when it's blocked", func() {
controller.UpdateSendWindow(100)
Expect(controller.IsNewlyBlocked()).To(BeFalse())
controller.AddBytesSent(100)
blocked, offset := controller.IsNewlyBlocked()
Expect(blocked).To(BeTrue())
Expect(offset).To(Equal(protocol.ByteCount(100)))
})
It("doesn't say that it's newly blocked multiple times for the same offset", func() {
controller.UpdateSendWindow(100)
controller.AddBytesSent(100)
newlyBlocked, offset := controller.IsNewlyBlocked()
Expect(newlyBlocked).To(BeTrue())
Expect(offset).To(Equal(protocol.ByteCount(100)))
newlyBlocked, _ = controller.IsNewlyBlocked()
Expect(newlyBlocked).To(BeFalse())
controller.UpdateSendWindow(150)
controller.AddBytesSent(150)
newlyBlocked, offset = controller.IsNewlyBlocked()
Expect(newlyBlocked).To(BeTrue())
})
})
Context("setting the minimum window size", func() {
var (
oldIncrement protocol.ByteCount
receiveWindow protocol.ByteCount = 10000
receiveWindowIncrement protocol.ByteCount = 600
oldWindowSize protocol.ByteCount
receiveWindow protocol.ByteCount = 10000
receiveWindowSize protocol.ByteCount = 1000
)
BeforeEach(func() {
controller.bytesRead = receiveWindowSize - receiveWindowSize
controller.receiveWindow = receiveWindow
controller.receiveWindowIncrement = receiveWindowIncrement
oldIncrement = controller.receiveWindowIncrement
controller.maxReceiveWindowIncrement = 3000
controller.receiveWindowSize = receiveWindowSize
oldWindowSize = controller.receiveWindowSize
controller.maxReceiveWindowSize = 3000
})
It("sets the minimum window increment", func() {
controller.EnsureMinimumWindowIncrement(1000)
Expect(controller.receiveWindowIncrement).To(Equal(protocol.ByteCount(1000)))
It("sets the minimum window window size", func() {
controller.EnsureMinimumWindowSize(1800)
Expect(controller.receiveWindowSize).To(Equal(protocol.ByteCount(1800)))
})
It("doesn't reduce the window increment", func() {
controller.EnsureMinimumWindowIncrement(1)
Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement))
It("doesn't reduce the window window size", func() {
controller.EnsureMinimumWindowSize(1)
Expect(controller.receiveWindowSize).To(Equal(oldWindowSize))
})
It("doens't increase the increment beyond the maxReceiveWindowIncrement", func() {
max := controller.maxReceiveWindowIncrement
controller.EnsureMinimumWindowIncrement(2 * max)
Expect(controller.receiveWindowIncrement).To(Equal(max))
It("doens't increase the window size beyond the maxReceiveWindowSize", func() {
max := controller.maxReceiveWindowSize
controller.EnsureMinimumWindowSize(2 * max)
Expect(controller.receiveWindowSize).To(Equal(max))
})
It("doesn't auto-tune the window after the increment was increased", func() {
setRtt(20 * time.Millisecond)
controller.bytesRead = 9900 // receive window is 10000
controller.lastWindowUpdateTime = time.Now().Add(-20 * time.Millisecond)
controller.EnsureMinimumWindowIncrement(912)
offset := controller.getWindowUpdate()
Expect(controller.receiveWindowIncrement).To(Equal(protocol.ByteCount(912))) // no auto-tuning
Expect(offset).To(Equal(protocol.ByteCount(9900 + 912)))
It("starts a new epoch after the window size was increased", func() {
controller.EnsureMinimumWindowSize(1912)
Expect(controller.epochStartTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond))
})
})
})

View File

@ -5,7 +5,6 @@ import "github.com/lucas-clemente/quic-go/internal/protocol"
type flowController interface {
// for sending
SendWindowSize() protocol.ByteCount
IsBlocked() bool
UpdateSendWindow(protocol.ByteCount)
AddBytesSent(protocol.ByteCount)
// for receiving
@ -16,22 +15,28 @@ type flowController interface {
// A StreamFlowController is a flow controller for a QUIC stream.
type StreamFlowController interface {
flowController
// for sending
IsBlocked() (bool, protocol.ByteCount)
// for receiving
// UpdateHighestReceived should be called when a new highest offset is received
// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
// HasWindowUpdate says if it is necessary to update the window
HasWindowUpdate() bool
}
// The ConnectionFlowController is the flow controller for the connection.
type ConnectionFlowController interface {
flowController
// for sending
IsNewlyBlocked() (bool, protocol.ByteCount)
}
type connectionFlowControllerI interface {
ConnectionFlowController
// The following two methods are not supposed to be called from outside this packet, but are needed internally
// for sending
EnsureMinimumWindowIncrement(protocol.ByteCount)
EnsureMinimumWindowSize(protocol.ByteCount)
// for receiving
IncrementHighestReceived(protocol.ByteCount) error
}

View File

@ -37,11 +37,11 @@ func NewStreamFlowController(
contributesToConnection: contributesToConnection,
connection: cfc.(connectionFlowControllerI),
baseFlowController: baseFlowController{
rttStats: rttStats,
receiveWindow: receiveWindow,
receiveWindowIncrement: receiveWindow,
maxReceiveWindowIncrement: maxReceiveWindow,
sendWindow: initialSendWindow,
rttStats: rttStats,
receiveWindow: receiveWindow,
receiveWindowSize: receiveWindow,
maxReceiveWindowSize: maxReceiveWindow,
sendWindow: initialSendWindow,
},
}
}
@ -102,9 +102,6 @@ func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
}
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
c.mutex.Lock()
defer c.mutex.Unlock()
window := c.baseFlowController.sendWindowSize()
if c.contributesToConnection {
window = utils.MinByteCount(window, c.connection.SendWindowSize())
@ -112,22 +109,39 @@ func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
return window
}
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
c.mutex.Lock()
defer c.mutex.Unlock()
// IsBlocked says if it is blocked by stream-level flow control.
// If it is blocked, the offset is returned.
func (c *streamFlowController) IsBlocked() (bool, protocol.ByteCount) {
if c.sendWindowSize() != 0 {
return false, 0
}
return true, c.sendWindow
}
func (c *streamFlowController) HasWindowUpdate() bool {
c.mutex.Lock()
hasWindowUpdate := !c.receivedFinalOffset && c.hasWindowUpdate()
c.mutex.Unlock()
return hasWindowUpdate
}
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
// don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler
c.mutex.Lock()
// if we already received the final offset for this stream, the peer won't need any additional flow control credit
if c.receivedFinalOffset {
c.mutex.Unlock()
return 0
}
oldWindowIncrement := c.receiveWindowIncrement
oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate()
if c.receiveWindowIncrement > oldWindowIncrement { // auto-tuning enlarged the window increment
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10))
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
if c.contributesToConnection {
c.connection.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(c.receiveWindowIncrement) * protocol.ConnectionFlowControlMultiplier))
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
}
}
c.mutex.Unlock()
return offset
}

View File

@ -19,7 +19,7 @@ var _ = Describe("Stream Flow controller", func() {
streamID: 10,
connection: NewConnectionFlowController(1000, 1000, rttStats).(*connectionFlowController),
}
controller.maxReceiveWindowIncrement = 10000
controller.maxReceiveWindowSize = 10000
controller.rttStats = rttStats
})
@ -35,7 +35,7 @@ var _ = Describe("Stream Flow controller", func() {
fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, rttStats).(*streamFlowController)
Expect(fc.streamID).To(Equal(protocol.StreamID(5)))
Expect(fc.receiveWindow).To(Equal(receiveWindow))
Expect(fc.maxReceiveWindowIncrement).To(Equal(maxReceiveWindow))
Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow))
Expect(fc.sendWindow).To(Equal(sendWindow))
Expect(fc.contributesToConnection).To(BeTrue())
})
@ -44,11 +44,11 @@ var _ = Describe("Stream Flow controller", func() {
Context("receiving data", func() {
Context("registering received offsets", func() {
var receiveWindow protocol.ByteCount = 10000
var receiveWindowIncrement protocol.ByteCount = 600
var receiveWindowSize protocol.ByteCount = 600
BeforeEach(func() {
controller.receiveWindow = receiveWindow
controller.receiveWindowIncrement = receiveWindowIncrement
controller.receiveWindowSize = receiveWindowSize
})
It("updates the highestReceived", func() {
@ -157,7 +157,7 @@ var _ = Describe("Stream Flow controller", func() {
})
Context("generating window updates", func() {
var oldIncrement protocol.ByteCount
var oldWindowSize protocol.ByteCount
// update the congestion such that it returns a given value for the smoothed RTT
setRtt := func(t time.Duration) {
@ -167,37 +167,51 @@ var _ = Describe("Stream Flow controller", func() {
BeforeEach(func() {
controller.receiveWindow = 100
controller.receiveWindowIncrement = 60
controller.connection.(*connectionFlowController).receiveWindowIncrement = 120
oldIncrement = controller.receiveWindowIncrement
controller.receiveWindowSize = 60
controller.bytesRead = 100 - 60
controller.connection.(*connectionFlowController).receiveWindowSize = 120
oldWindowSize = controller.receiveWindowSize
})
It("tells if it has window updates", func() {
Expect(controller.HasWindowUpdate()).To(BeFalse())
controller.AddBytesRead(30)
Expect(controller.HasWindowUpdate()).To(BeTrue())
Expect(controller.GetWindowUpdate()).ToNot(BeZero())
Expect(controller.HasWindowUpdate()).To(BeFalse())
})
It("tells the connection flow controller when the window was autotuned", func() {
oldOffset := controller.bytesRead
controller.contributesToConnection = true
controller.AddBytesRead(75)
setRtt(20 * time.Millisecond)
controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond)
setRtt(scaleDuration(20 * time.Millisecond))
controller.epochStartOffset = oldOffset
controller.epochStartTime = time.Now().Add(-time.Millisecond)
controller.AddBytesRead(55)
offset := controller.GetWindowUpdate()
Expect(offset).To(Equal(protocol.ByteCount(75 + 2*60)))
Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement))
Expect(controller.connection.(*connectionFlowController).receiveWindowIncrement).To(Equal(protocol.ByteCount(float64(controller.receiveWindowIncrement) * protocol.ConnectionFlowControlMultiplier)))
Expect(offset).To(Equal(protocol.ByteCount(oldOffset + 55 + 2*oldWindowSize)))
Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize))
Expect(controller.connection.(*connectionFlowController).receiveWindowSize).To(Equal(protocol.ByteCount(float64(controller.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier)))
})
It("doesn't tell the connection flow controller if it doesn't contribute", func() {
oldOffset := controller.bytesRead
controller.contributesToConnection = false
controller.AddBytesRead(75)
setRtt(20 * time.Millisecond)
controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond)
setRtt(scaleDuration(20 * time.Millisecond))
controller.epochStartOffset = oldOffset
controller.epochStartTime = time.Now().Add(-time.Millisecond)
controller.AddBytesRead(55)
offset := controller.GetWindowUpdate()
Expect(offset).ToNot(BeZero())
Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement))
Expect(controller.connection.(*connectionFlowController).receiveWindowIncrement).To(Equal(protocol.ByteCount(120))) // unchanged
Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize))
Expect(controller.connection.(*connectionFlowController).receiveWindowSize).To(Equal(protocol.ByteCount(2 * oldWindowSize))) // unchanged
})
It("doesn't increase the window after a final offset was already received", func() {
controller.AddBytesRead(80)
controller.AddBytesRead(30)
err := controller.UpdateHighestReceived(90, true)
Expect(err).ToNot(HaveOccurred())
Expect(controller.HasWindowUpdate()).To(BeFalse())
offset := controller.GetWindowUpdate()
Expect(offset).To(BeZero())
})
@ -231,7 +245,8 @@ var _ = Describe("Stream Flow controller", func() {
controller.connection.UpdateSendWindow(50)
controller.UpdateSendWindow(100)
controller.AddBytesSent(50)
Expect(controller.connection.IsBlocked()).To(BeTrue())
blocked, _ := controller.connection.IsNewlyBlocked()
Expect(blocked).To(BeTrue())
Expect(controller.IsBlocked()).To(BeFalse())
})
})

View File

@ -51,8 +51,8 @@ type cryptoSetupClient struct {
secureAEAD crypto.AEAD
forwardSecureAEAD crypto.AEAD
paramsChan chan<- TransportParameters
aeadChanged chan<- protocol.EncryptionLevel
paramsChan chan<- TransportParameters
handshakeEvent chan<- struct{}
params *TransportParameters
}
@ -74,7 +74,7 @@ func NewCryptoSetupClient(
tlsConfig *tls.Config,
params *TransportParameters,
paramsChan chan<- TransportParameters,
aeadChanged chan<- protocol.EncryptionLevel,
handshakeEvent chan<- struct{},
initialVersion protocol.VersionNumber,
negotiatedVersions []protocol.VersionNumber,
) (CryptoSetup, error) {
@ -93,7 +93,7 @@ func NewCryptoSetupClient(
keyExchange: getEphermalKEX,
nullAEAD: nullAEAD,
paramsChan: paramsChan,
aeadChanged: aeadChanged,
handshakeEvent: handshakeEvent,
initialVersion: initialVersion,
negotiatedVersions: negotiatedVersions,
divNonceChan: make(chan []byte),
@ -159,8 +159,8 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
}
// blocks until the session has received the parameters
h.paramsChan <- *params
h.aeadChanged <- protocol.EncryptionForwardSecure
close(h.aeadChanged)
h.handshakeEvent <- struct{}{}
close(h.handshakeEvent)
default:
return qerr.InvalidCryptoMessageType
}
@ -381,6 +381,15 @@ func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) {
h.divNonceChan <- data
}
func (h *cryptoSetupClient) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()
return ConnectionState{
HandshakeComplete: h.forwardSecureAEAD != nil,
PeerCertificates: h.certManager.GetChain(),
}
}
func (h *cryptoSetupClient) sendCHLO() error {
h.clientHelloCounter++
if h.clientHelloCounter > protocol.MaxClientHellos {
@ -496,10 +505,8 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
if err != nil {
return err
}
h.aeadChanged <- protocol.EncryptionSecure
h.handshakeEvent <- struct{}{}
}
return nil
}

View File

@ -2,6 +2,7 @@ package handshake
import (
"bytes"
"crypto/x509"
"encoding/binary"
"errors"
"fmt"
@ -10,6 +11,7 @@ import (
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/mocks/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
. "github.com/onsi/ginkgo"
@ -34,6 +36,8 @@ type mockCertManager struct {
commonCertificateHashes []byte
chain []*x509.Certificate
leafCert []byte
leafCertHash uint64
leafCertHashError error
@ -45,6 +49,8 @@ type mockCertManager struct {
verifyCalled bool
}
var _ crypto.CertManager = &mockCertManager{}
func (m *mockCertManager) SetData(data []byte) error {
m.setDataCalledWith = data
return m.setDataError
@ -72,6 +78,10 @@ func (m *mockCertManager) Verify(hostname string) error {
return m.verifyError
}
func (m *mockCertManager) GetChain() []*x509.Certificate {
return m.chain
}
var _ = Describe("Client Crypto Setup", func() {
var (
cs *cryptoSetupClient
@ -79,7 +89,7 @@ var _ = Describe("Client Crypto Setup", func() {
stream *mockStream
keyDerivationCalledWith *keyDerivationValues
shloMap map[Tag][]byte
aeadChanged chan protocol.EncryptionLevel
handshakeEvent chan struct{}
paramsChan chan TransportParameters
)
@ -108,7 +118,7 @@ var _ = Describe("Client Crypto Setup", func() {
version := protocol.Version39
// use a buffered channel here, so that we can parse a SHLO without having to receive the TransportParameters to avoid blocking
paramsChan = make(chan TransportParameters, 1)
aeadChanged = make(chan protocol.EncryptionLevel, 2)
handshakeEvent = make(chan struct{}, 2)
csInt, err := NewCryptoSetupClient(
stream,
"hostname",
@ -117,7 +127,7 @@ var _ = Describe("Client Crypto Setup", func() {
nil,
&TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout},
paramsChan,
aeadChanged,
handshakeEvent,
protocol.Version39,
nil,
)
@ -130,10 +140,6 @@ var _ = Describe("Client Crypto Setup", func() {
cs.cryptoStream = stream
})
AfterEach(func() {
close(stream.unblockRead)
})
Context("Reading REJ", func() {
var tagMap map[Tag][]byte
@ -158,8 +164,17 @@ var _ = Describe("Client Crypto Setup", func() {
stk := []byte("foobar")
tagMap[TagSTK] = stk
HandshakeMessage{Tag: TagREJ, Data: tagMap}.Write(&stream.dataToRead)
go cs.HandleCryptoStream()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
close(done)
}()
Eventually(func() []byte { return cs.stk }).Should(Equal(stk))
// make the go routine return
stream.close()
Eventually(done).Should(BeClosed())
})
It("saves the proof", func() {
@ -380,22 +395,22 @@ var _ = Describe("Client Crypto Setup", func() {
cs.receivedSecurePacket = false
_, err := cs.handleSHLOMessage(shloMap)
Expect(err).To(MatchError(qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")))
Expect(aeadChanged).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed())
Expect(handshakeEvent).ToNot(Receive())
Expect(handshakeEvent).ToNot(BeClosed())
})
It("rejects SHLOs without a PUBS", func() {
delete(shloMap, TagPUBS)
_, err := cs.handleSHLOMessage(shloMap)
Expect(err).To(MatchError(qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")))
Expect(aeadChanged).ToNot(BeClosed())
Expect(handshakeEvent).ToNot(BeClosed())
})
It("rejects SHLOs without a version list", func() {
delete(shloMap, TagVER)
_, err := cs.handleSHLOMessage(shloMap)
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")))
Expect(aeadChanged).ToNot(BeClosed())
Expect(handshakeEvent).ToNot(BeClosed())
})
It("accepts a SHLO after a version negotiation", func() {
@ -430,28 +445,38 @@ var _ = Describe("Client Crypto Setup", func() {
Expect(params.IdleTimeout).To(Equal(13 * time.Second))
})
It("closes the aeadChanged when receiving an SHLO", func() {
It("closes the handshakeEvent chan when receiving an SHLO", func() {
HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
close(done)
}()
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionForwardSecure)))
Eventually(aeadChanged).Should(BeClosed())
Eventually(handshakeEvent).Should(Receive())
Eventually(handshakeEvent).Should(BeClosed())
// make the go routine return
stream.close()
Eventually(done).Should(BeClosed())
})
It("passes the transport parameters on the channel", func() {
shloMap[TagSFCW] = []byte{0x0d, 0x00, 0xdf, 0xba}
HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
close(done)
}()
var params TransportParameters
Eventually(paramsChan).Should(Receive(&params))
Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0xbadf000d)))
// make the go routine return
stream.close()
Eventually(done).Should(BeClosed())
})
It("errors if it can't read a connection parameter", func() {
@ -637,9 +662,9 @@ var _ = Describe("Client Crypto Setup", func() {
Expect(keyDerivationCalledWith.cert).To(Equal(certManager.leafCert))
Expect(keyDerivationCalledWith.divNonce).To(Equal(cs.diversificationNonce))
Expect(keyDerivationCalledWith.pers).To(Equal(protocol.PerspectiveClient))
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
Expect(aeadChanged).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed())
Expect(handshakeEvent).To(Receive())
Expect(handshakeEvent).ToNot(Receive())
Expect(handshakeEvent).ToNot(BeClosed())
})
It("uses the server nonce, if the server sent one", func() {
@ -649,51 +674,64 @@ var _ = Describe("Client Crypto Setup", func() {
Expect(err).ToNot(HaveOccurred())
Expect(cs.secureAEAD).ToNot(BeNil())
Expect(keyDerivationCalledWith.nonces).To(Equal(append(cs.nonc, cs.sno...)))
Expect(aeadChanged).To(Receive())
Expect(aeadChanged).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed())
Expect(handshakeEvent).To(Receive())
Expect(handshakeEvent).ToNot(Receive())
Expect(handshakeEvent).ToNot(BeClosed())
})
It("doesn't create a secureAEAD if the certificate is not yet verified, even if it has all necessary values", func() {
err := cs.maybeUpgradeCrypto()
Expect(err).ToNot(HaveOccurred())
Expect(cs.secureAEAD).To(BeNil())
Expect(aeadChanged).ToNot(Receive())
Expect(handshakeEvent).ToNot(Receive())
cs.serverVerified = true
// make sure we really had all necessary values before, and only serverVerified was missing
err = cs.maybeUpgradeCrypto()
Expect(err).ToNot(HaveOccurred())
Expect(cs.secureAEAD).ToNot(BeNil())
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
Expect(aeadChanged).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed())
Expect(handshakeEvent).To(Receive())
Expect(handshakeEvent).ToNot(Receive())
Expect(handshakeEvent).ToNot(BeClosed())
})
It("tries to escalate before reading a handshake message", func() {
Expect(cs.secureAEAD).To(BeNil())
cs.serverVerified = true
go cs.HandleCryptoStream()
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure)))
Expect(cs.secureAEAD).ToNot(BeNil())
Expect(aeadChanged).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed())
})
It("tries to escalate the crypto after receiving a diversification nonce", func(done Done) {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
cs.HandleCryptoStream()
Fail("HandleCryptoStream should not have returned")
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
close(done)
}()
Eventually(handshakeEvent).Should(Receive())
Expect(cs.secureAEAD).ToNot(BeNil())
Expect(handshakeEvent).ToNot(Receive())
Expect(handshakeEvent).ToNot(BeClosed())
// make the go routine return
stream.close()
Eventually(done).Should(BeClosed())
})
It("tries to escalate the crypto after receiving a diversification nonce", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
close(done)
}()
cs.diversificationNonce = nil
cs.serverVerified = true
Expect(cs.secureAEAD).To(BeNil())
cs.SetDiversificationNonce([]byte("div"))
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure)))
Eventually(handshakeEvent).Should(Receive())
Expect(cs.secureAEAD).ToNot(BeNil())
Expect(aeadChanged).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed())
close(done)
Expect(handshakeEvent).ToNot(Receive())
Expect(handshakeEvent).ToNot(BeClosed())
// make the go routine return
stream.close()
Eventually(done).Should(BeClosed())
})
Context("null encryption", func() {
@ -813,6 +851,22 @@ var _ = Describe("Client Crypto Setup", func() {
})
})
Context("reporting the connection state", func() {
It("reports the connection state before the handshake completes", func() {
chain := []*x509.Certificate{testdata.GetCertificate().Leaf}
certManager.chain = chain
state := cs.ConnectionState()
Expect(state.HandshakeComplete).To(BeFalse())
Expect(state.PeerCertificates).To(Equal(chain))
})
It("reports the connection state after the handshake completes", func() {
doSHLO()
state := cs.ConnectionState()
Expect(state.HandshakeComplete).To(BeTrue())
})
})
Context("forcing encryption levels", func() {
It("forces null encryption", func() {
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(4), []byte{}).Return([]byte("foobar unencrypted"))
@ -862,32 +916,51 @@ var _ = Describe("Client Crypto Setup", func() {
Context("Diversification Nonces", func() {
It("sets a diversification nonce", func() {
go cs.HandleCryptoStream()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
close(done)
}()
nonce := []byte("foobar")
cs.SetDiversificationNonce(nonce)
Eventually(func() []byte { return cs.diversificationNonce }).Should(Equal(nonce))
// make the go routine return
stream.close()
Eventually(done).Should(BeClosed())
})
It("doesn't do anything when called multiple times with the same nonce", func(done Done) {
go cs.HandleCryptoStream()
It("doesn't do anything when called multiple times with the same nonce", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
close(done)
}()
nonce := []byte("foobar")
cs.SetDiversificationNonce(nonce)
cs.SetDiversificationNonce(nonce)
Eventually(func() []byte { return cs.diversificationNonce }).Should(Equal(nonce))
close(done)
// make the go routine return
stream.close()
Eventually(done).Should(BeClosed())
})
It("rejects a different diversification nonce", func() {
var err error
done := make(chan struct{})
go func() {
err = cs.HandleCryptoStream()
defer GinkgoRecover()
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(errConflictingDiversificationNonces))
close(done)
}()
nonce1 := []byte("foobar")
nonce2 := []byte("raboof")
cs.SetDiversificationNonce(nonce1)
cs.SetDiversificationNonce(nonce2)
Eventually(func() error { return err }).Should(MatchError(errConflictingDiversificationNonces))
Eventually(done).Should(BeClosed())
})
})

View File

@ -23,6 +23,8 @@ type KeyExchangeFunction func() crypto.KeyExchange
// The CryptoSetupServer handles all things crypto for the Session
type cryptoSetupServer struct {
mutex sync.RWMutex
connID protocol.ConnectionID
remoteAddr net.Addr
scfg *ServerConfig
@ -42,7 +44,7 @@ type cryptoSetupServer struct {
receivedParams bool
paramsChan chan<- TransportParameters
aeadChanged chan<- protocol.EncryptionLevel
handshakeEvent chan<- struct{}
keyDerivation QuicCryptoKeyDerivationFunction
keyExchange KeyExchangeFunction
@ -51,7 +53,7 @@ type cryptoSetupServer struct {
params *TransportParameters
mutex sync.RWMutex
sni string // need to fill out the ConnectionState
}
var _ CryptoSetup = &cryptoSetupServer{}
@ -76,7 +78,7 @@ func NewCryptoSetup(
supportedVersions []protocol.VersionNumber,
acceptSTK func(net.Addr, *Cookie) bool,
paramsChan chan<- TransportParameters,
aeadChanged chan<- protocol.EncryptionLevel,
handshakeEvent chan<- struct{},
) (CryptoSetup, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
if err != nil {
@ -96,7 +98,7 @@ func NewCryptoSetup(
acceptSTKCallback: acceptSTK,
sentSHLO: make(chan struct{}),
paramsChan: paramsChan,
aeadChanged: aeadChanged,
handshakeEvent: handshakeEvent,
}, nil
}
@ -139,6 +141,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
if sni == "" {
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
}
h.sni = sni
// prevent version downgrade attacks
// see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/N-de9j63tCk for a discussion and examples
@ -182,7 +185,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
if _, err := h.cryptoStream.Write(reply); err != nil {
return false, err
}
h.aeadChanged <- protocol.EncryptionForwardSecure
h.handshakeEvent <- struct{}{}
close(h.sentSHLO)
return true, nil
}
@ -206,9 +209,9 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
if err == nil {
if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client
h.receivedForwardSecurePacket = true
// wait until protocol.EncryptionForwardSecure was sent on the aeadChan
// wait for the send on the handshakeEvent chan
<-h.sentSHLO
close(h.aeadChanged)
close(h.handshakeEvent)
}
return res, protocol.EncryptionForwardSecure, nil
}
@ -396,8 +399,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
if err != nil {
return nil, err
}
h.aeadChanged <- protocol.EncryptionSecure
h.handshakeEvent <- struct{}{}
// Generate a new curve instance to derive the forward secure key
var fsNonce bytes.Buffer
@ -454,6 +456,15 @@ func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) {
panic("not needed for cryptoSetupServer")
}
func (h *cryptoSetupServer) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()
return ConnectionState{
ServerName: h.sni,
HandshakeComplete: h.receivedForwardSecurePacket,
}
}
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
if len(nonce) != 32 {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")

View File

@ -4,6 +4,7 @@ import (
"bytes"
"encoding/binary"
"errors"
"io"
"net"
"time"
@ -63,35 +64,36 @@ func mockQuicCryptoKeyDerivation(forwardSecure bool, sharedSecret, nonces []byte
}
type mockStream struct {
unblockRead chan struct{} // close this chan to unblock Read
unblockRead chan struct{}
dataToRead bytes.Buffer
dataWritten bytes.Buffer
}
var _ io.ReadWriter = &mockStream{}
var errMockStreamClosing = errors.New("mock stream closing")
func newMockStream() *mockStream {
return &mockStream{unblockRead: make(chan struct{})}
}
// call Close to make Read return
func (s *mockStream) Read(p []byte) (int, error) {
n, _ := s.dataToRead.Read(p)
if n == 0 { // block if there's no data
<-s.unblockRead
return 0, errMockStreamClosing
}
return n, nil // never return an EOF
}
func (s *mockStream) ReadByte() (byte, error) {
return s.dataToRead.ReadByte()
}
func (s *mockStream) Write(p []byte) (int, error) {
return s.dataWritten.Write(p)
}
func (s *mockStream) Close() error { panic("not implemented") }
func (s *mockStream) Reset(error) { panic("not implemented") }
func (mockStream) CloseRemote(offset protocol.ByteCount) { panic("not implemented") }
func (s mockStream) StreamID() protocol.StreamID { panic("not implemented") }
func (s *mockStream) close() {
close(s.unblockRead)
}
type mockCookieProtector struct {
data []byte
@ -122,7 +124,7 @@ var _ = Describe("Server Crypto Setup", func() {
cs *cryptoSetupServer
stream *mockStream
paramsChan chan TransportParameters
aeadChanged chan protocol.EncryptionLevel
handshakeEvent chan struct{}
nonce32 []byte
versionTag []byte
validSTK []byte
@ -144,7 +146,7 @@ var _ = Describe("Server Crypto Setup", func() {
// use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking
paramsChan = make(chan TransportParameters, 1)
aeadChanged = make(chan protocol.EncryptionLevel, 2)
handshakeEvent = make(chan struct{}, 2)
stream = newMockStream()
kex = &mockKEX{}
signer = &mockSigner{}
@ -168,7 +170,7 @@ var _ = Describe("Server Crypto Setup", func() {
supportedVersions,
nil,
paramsChan,
aeadChanged,
handshakeEvent,
)
Expect(err).NotTo(HaveOccurred())
cs = csInt.(*cryptoSetupServer)
@ -183,10 +185,6 @@ var _ = Describe("Server Crypto Setup", func() {
cs.cryptoStream = stream
})
AfterEach(func() {
close(stream.unblockRead)
})
Context("diversification nonce", func() {
BeforeEach(func() {
cs.secureAEAD = mockcrypto.NewMockAEAD(mockCtrl)
@ -345,10 +343,10 @@ var _ = Describe("Server Crypto Setup", func() {
err := cs.HandleCryptoStream()
Expect(err).NotTo(HaveOccurred())
Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ"))
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
Expect(handshakeEvent).To(Receive()) // for the switch to secure
Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO"))
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure)))
Expect(aeadChanged).ToNot(BeClosed())
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
Expect(handshakeEvent).ToNot(BeClosed())
})
It("rejects client nonces that have the wrong length", func() {
@ -379,9 +377,9 @@ var _ = Describe("Server Crypto Setup", func() {
Expect(err).NotTo(HaveOccurred())
Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO"))
Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ"))
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure)))
Expect(aeadChanged).ToNot(BeClosed())
Expect(handshakeEvent).To(Receive()) // for the switch to secure
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
Expect(handshakeEvent).ToNot(BeClosed())
})
It("recognizes inchoate CHLOs missing SCID", func() {
@ -537,7 +535,7 @@ var _ = Describe("Server Crypto Setup", func() {
TagKEXS: kexs,
})
Expect(err).ToNot(HaveOccurred())
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
Expect(handshakeEvent).To(Receive()) // for the switch to secure
close(cs.sentSHLO)
}
@ -659,7 +657,26 @@ var _ = Describe("Server Crypto Setup", func() {
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(200), []byte{})
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 200, []byte{})
Expect(err).ToNot(HaveOccurred())
Expect(aeadChanged).To(BeClosed())
Expect(handshakeEvent).To(BeClosed())
})
})
Context("reporting the connection state", func() {
It("reports before the handshake completes", func() {
cs.sni = "server name"
state := cs.ConnectionState()
Expect(state.HandshakeComplete).To(BeFalse())
Expect(state.ServerName).To(Equal("server name"))
})
It("reports after the handshake completes", func() {
doCHLO()
// receive a forward secure packet
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(11), []byte{})
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 11, []byte{})
Expect(err).ToNot(HaveOccurred())
state := cs.ConnectionState()
Expect(state.HandshakeComplete).To(BeTrue())
})
})
@ -723,6 +740,7 @@ var _ = Describe("Server Crypto Setup", func() {
Expect(err).ToNot(HaveOccurred())
Expect(done).To(BeFalse())
Expect(stream.dataWritten.Bytes()).To(ContainSubstring(string(validSTK)))
Expect(cs.sni).To(Equal("foo"))
})
It("works with proper STK", func() {

View File

@ -26,9 +26,9 @@ type cryptoSetupTLS struct {
nullAEAD crypto.AEAD
aead crypto.AEAD
tls MintTLS
cryptoStream *CryptoStreamConn
aeadChanged chan<- protocol.EncryptionLevel
tls MintTLS
cryptoStream *CryptoStreamConn
handshakeEvent chan<- struct{}
}
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
@ -36,16 +36,16 @@ func NewCryptoSetupTLSServer(
tls MintTLS,
cryptoStream *CryptoStreamConn,
nullAEAD crypto.AEAD,
aeadChanged chan<- protocol.EncryptionLevel,
handshakeEvent chan<- struct{},
version protocol.VersionNumber,
) CryptoSetup {
return &cryptoSetupTLS{
tls: tls,
cryptoStream: cryptoStream,
nullAEAD: nullAEAD,
perspective: protocol.PerspectiveServer,
keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged,
tls: tls,
cryptoStream: cryptoStream,
nullAEAD: nullAEAD,
perspective: protocol.PerspectiveServer,
keyDerivation: crypto.DeriveAESKeys,
handshakeEvent: handshakeEvent,
}
}
@ -54,7 +54,7 @@ func NewCryptoSetupTLSClient(
cryptoStream io.ReadWriter,
connID protocol.ConnectionID,
hostname string,
aeadChanged chan<- protocol.EncryptionLevel,
handshakeEvent chan<- struct{},
tls MintTLS,
version protocol.VersionNumber,
) (CryptoSetup, error) {
@ -64,11 +64,11 @@ func NewCryptoSetupTLSClient(
}
return &cryptoSetupTLS{
perspective: protocol.PerspectiveClient,
tls: tls,
nullAEAD: nullAEAD,
keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged,
perspective: protocol.PerspectiveClient,
tls: tls,
nullAEAD: nullAEAD,
keyDerivation: crypto.DeriveAESKeys,
handshakeEvent: handshakeEvent,
}, nil
}
@ -102,9 +102,8 @@ handshakeLoop:
h.aead = aead
h.mutex.Unlock()
// signal to the outside world that the handshake completed
h.aeadChanged <- protocol.EncryptionForwardSecure
close(h.aeadChanged)
h.handshakeEvent <- struct{}{}
close(h.handshakeEvent)
return nil
}
@ -165,3 +164,14 @@ func (h *cryptoSetupTLS) DiversificationNonce() []byte {
func (h *cryptoSetupTLS) SetDiversificationNonce([]byte) {
panic("diversification nonce not needed for TLS")
}
func (h *cryptoSetupTLS) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()
mintConnState := h.tls.ConnectionState()
return ConnectionState{
// TODO: set the ServerName, once mint exports it
HandshakeComplete: h.aead != nil,
PeerCertificates: mintConnState.PeerCertificates,
}
}

View File

@ -20,17 +20,17 @@ func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, e
var _ = Describe("TLS Crypto Setup", func() {
var (
cs *cryptoSetupTLS
aeadChanged chan protocol.EncryptionLevel
cs *cryptoSetupTLS
handshakeEvent chan struct{}
)
BeforeEach(func() {
aeadChanged = make(chan protocol.EncryptionLevel, 2)
handshakeEvent = make(chan struct{}, 2)
cs = NewCryptoSetupTLSServer(
nil,
NewCryptoStreamConn(nil),
nil, // AEAD
aeadChanged,
handshakeEvent,
protocol.VersionTLS,
).(*cryptoSetupTLS)
cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl)
@ -51,8 +51,8 @@ var _ = Describe("TLS Crypto Setup", func() {
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure)))
Expect(aeadChanged).To(BeClosed())
Expect(handshakeEvent).To(Receive())
Expect(handshakeEvent).To(BeClosed())
})
It("handshakes until it is connected", func() {
@ -63,7 +63,30 @@ var _ = Describe("TLS Crypto Setup", func() {
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
Expect(aeadChanged).To(Receive())
Expect(handshakeEvent).To(Receive())
})
Context("reporting the handshake state", func() {
It("reports before the handshake compeletes", func() {
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{})
state := cs.ConnectionState()
Expect(state.HandshakeComplete).To(BeFalse())
Expect(state.PeerCertificates).To(BeNil())
})
It("reports after the handshake completes", func() {
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{})
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected)
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
state := cs.ConnectionState()
Expect(state.HandshakeComplete).To(BeTrue())
Expect(state.PeerCertificates).To(BeNil())
})
})
Context("escalating crypto", func() {
@ -180,17 +203,17 @@ var _ = Describe("TLS Crypto Setup", func() {
var _ = Describe("TLS Crypto Setup, for the client", func() {
var (
cs *cryptoSetupTLS
aeadChanged chan protocol.EncryptionLevel
cs *cryptoSetupTLS
handshakeEvent chan struct{}
)
BeforeEach(func() {
aeadChanged = make(chan protocol.EncryptionLevel, 2)
handshakeEvent = make(chan struct{})
csInt, err := NewCryptoSetupTLSClient(
nil,
0,
"quic.clemente.io",
aeadChanged,
handshakeEvent,
nil, // mintTLS
protocol.VersionTLS,
)

View File

@ -1,6 +1,7 @@
package handshake
import (
"crypto/x509"
"io"
"github.com/bifurcation/mint"
@ -29,6 +30,7 @@ type MintTLS interface {
// additional methods
Handshake() mint.Alert
State() mint.State
ConnectionState() mint.ConnectionState
SetCryptoStream(io.ReadWriter)
SetExtensionHandler(mint.AppExtensionHandler) error
@ -41,8 +43,17 @@ type CryptoSetup interface {
// TODO: clean up this interface
DiversificationNonce() []byte // only needed for cryptoSetupServer
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
ConnectionState() ConnectionState
GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
}
// ConnectionState records basic details about the QUIC connection.
// Warning: This API should not be considered stable and might change soon.
type ConnectionState struct {
HandshakeComplete bool // handshake is complete
ServerName string // server name requested by client, if any (server side only)
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer
}

View File

@ -24,6 +24,7 @@ type extensionHandlerClient struct {
var _ mint.AppExtensionHandler = &extensionHandlerClient{}
var _ TLSExtensionHandler = &extensionHandlerClient{}
// NewExtensionHandlerClient creates a new extension handler for the client.
func NewExtensionHandlerClient(
params *TransportParameters,
initialVersion protocol.VersionNumber,
@ -57,7 +58,10 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi
func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
ext := &tlsExtensionBody{}
found, _ := el.Find(ext)
found, err := el.Find(ext)
if err != nil {
return err
}
if hType != mint.HandshakeTypeEncryptedExtensions && hType != mint.HandshakeTypeNewSessionTicket {
if found {

View File

@ -39,7 +39,8 @@ var _ = Describe("TLS Extension Handler, for the client", func() {
Expect(err).ToNot(HaveOccurred())
Expect(el).To(HaveLen(1))
ext := &tlsExtensionBody{}
found := el.Find(ext)
found, err := el.Find(ext)
Expect(err).ToNot(HaveOccurred())
Expect(found).To(BeTrue())
chtp := &clientHelloTransportParameters{}
_, err = syntax.Unmarshal(ext.data, chtp)

View File

@ -24,6 +24,7 @@ type extensionHandlerServer struct {
var _ mint.AppExtensionHandler = &extensionHandlerServer{}
var _ TLSExtensionHandler = &extensionHandlerServer{}
// NewExtensionHandlerServer creates a new extension handler for the server
func NewExtensionHandlerServer(
params *TransportParameters,
supportedVersions []protocol.VersionNumber,
@ -66,7 +67,10 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi
func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
ext := &tlsExtensionBody{}
found, _ := el.Find(ext)
found, err := el.Find(ext)
if err != nil {
return err
}
if hType != mint.HandshakeTypeClientHello {
if found {

View File

@ -48,7 +48,8 @@ var _ = Describe("TLS Extension Handler, for the server", func() {
Expect(err).ToNot(HaveOccurred())
Expect(el).To(HaveLen(1))
ext := &tlsExtensionBody{}
found := el.Find(ext)
found, err := el.Find(ext)
Expect(err).ToNot(HaveOccurred())
Expect(found).To(BeTrue())
eetp := &encryptedExtensionsTransportParameters{}
_, err = syntax.Unmarshal(ext.data, eetp)

View File

@ -64,6 +64,9 @@ type ByteCount uint64
// MaxByteCount is the maximum value of a ByteCount
const MaxByteCount = ByteCount(1<<62 - 1)
// An ApplicationErrorCode is an application-defined error code.
type ApplicationErrorCode uint16
// MaxReceivePacketSize maximum packet size of any QUIC packet, based on
// ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header,
// UDP adds an additional 8 bytes. This is a total overhead of 48 bytes.

View File

@ -56,6 +56,9 @@ const DefaultMaxReceiveConnectionFlowControlWindowClient = 15 * (1 << 20) // 15
// This is the value that Chromium is using
const ConnectionFlowControlMultiplier = 1.5
// WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client
const WindowUpdateThreshold = 0.25
// MaxIncomingStreams is the maximum number of streams that a peer may open
const MaxIncomingStreams = 100
@ -122,3 +125,9 @@ const ClosedSessionDeleteTimeout = time.Minute
// NumCachedCertificates is the number of cached compressed certificate chains, each taking ~1K space
const NumCachedCertificates = 128
// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame.
// This avoids splitting up STREAM frames into small pieces, which has 2 advantages:
// 1. it reduces the framing overhead
// 2. it reduces the head-of-line blocking, when a packet is lost
const MinStreamFrameSize ByteCount = 128

View File

@ -139,7 +139,7 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error
}
// MinLength of a written frame
func (f *AckFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
func (f *AckFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
if !version.UsesIETFFrameFormat() {
return f.minLengthLegacy(version)
}
@ -157,7 +157,7 @@ func (f *AckFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount
length += utils.VarIntLen(uint64(f.LargestAcked - lowestInFirstRange))
if !f.HasMissingRanges() {
return length, nil
return length
}
var lowest protocol.PacketNumber
for i, ackRange := range f.AckRanges {
@ -169,7 +169,7 @@ func (f *AckFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount
length += utils.VarIntLen(uint64(ackRange.Last - ackRange.First))
lowest = ackRange.First
}
return length, nil
return length
}
// HasMissingRanges returns if this frame reports any missing packets

View File

@ -308,7 +308,7 @@ func (f *AckFrame) writeLegacy(b *bytes.Buffer, _ protocol.VersionNumber) error
return nil
}
func (f *AckFrame) minLengthLegacy(_ protocol.VersionNumber) (protocol.ByteCount, error) {
func (f *AckFrame) minLengthLegacy(_ protocol.VersionNumber) protocol.ByteCount {
length := protocol.ByteCount(1 + 2 + 1) // 1 TypeByte, 2 ACK delay time, 1 Num Timestamp
length += protocol.ByteCount(protocol.GetPacketNumberLength(f.LargestAcked))
@ -320,7 +320,7 @@ func (f *AckFrame) minLengthLegacy(_ protocol.VersionNumber) (protocol.ByteCount
length += missingSequenceNumberDeltaLen
}
// we don't write
return length, nil
return length
}
// numWritableNackRanges calculates the number of ACK blocks that are about to be written

View File

@ -4,17 +4,26 @@ import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A BlockedFrame is a BLOCKED frame
type BlockedFrame struct{}
type BlockedFrame struct {
Offset protocol.ByteCount
}
// ParseBlockedFrame parses a BLOCKED frame
func ParseBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*BlockedFrame, error) {
func ParseBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*BlockedFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
return &BlockedFrame{}, nil
offset, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
return &BlockedFrame{
Offset: protocol.ByteCount(offset),
}, nil
}
func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
@ -23,13 +32,14 @@ func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) er
}
typeByte := uint8(0x08)
b.WriteByte(typeByte)
utils.WriteVarInt(b, uint64(f.Offset))
return nil
}
// MinLength of a written frame
func (f *BlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
if !version.UsesIETFFrameFormat() { // writing this frame would result in a legacy BLOCKED being written, which is longer
return 1 + 4, nil
func (f *BlockedFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
if !version.UsesIETFFrameFormat() {
return 1 + 4
}
return 1, nil
return 1 + utils.VarIntLen(uint64(f.Offset))
}

View File

@ -2,8 +2,10 @@ package wire
import (
"bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
@ -12,30 +14,41 @@ import (
var _ = Describe("BLOCKED frame", func() {
Context("when parsing", func() {
It("accepts sample frame", func() {
b := bytes.NewReader([]byte{0x08})
_, err := ParseBlockedFrame(b, protocol.VersionWhatever)
data := []byte{0x08}
data = append(data, encodeVarInt(0x12345678)...)
b := bytes.NewReader(data)
frame, err := ParseBlockedFrame(b, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(frame.Offset).To(Equal(protocol.ByteCount(0x12345678)))
Expect(b.Len()).To(BeZero())
})
It("errors on EOFs", func() {
_, err := ParseBlockedFrame(bytes.NewReader(nil), protocol.VersionWhatever)
Expect(err).To(HaveOccurred())
data := []byte{0x08}
data = append(data, encodeVarInt(0x12345678)...)
_, err := ParseBlockedFrame(bytes.NewReader(data), versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
for i := range data {
_, err := ParseBlockedFrame(bytes.NewReader(data[:i]), versionIETFFrames)
Expect(err).To(MatchError(io.EOF))
}
})
})
Context("when writing", func() {
It("writes a sample frame", func() {
b := &bytes.Buffer{}
frame := BlockedFrame{}
frame := BlockedFrame{Offset: 0xdeadbeef}
err := frame.Write(b, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
Expect(b.Bytes()).To(Equal([]byte{0x08}))
expected := []byte{0x08}
expected = append(expected, encodeVarInt(0xdeadbeef)...)
Expect(b.Bytes()).To(Equal(expected))
})
It("has the correct min length", func() {
frame := BlockedFrame{}
Expect(frame.MinLength(versionIETFFrames)).To(Equal(protocol.ByteCount(1)))
frame := BlockedFrame{Offset: 0x12345}
Expect(frame.MinLength(versionIETFFrames)).To(Equal(1 + utils.VarIntLen(0x12345)))
})
})
})

View File

@ -68,11 +68,11 @@ func ParseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber)
}
// MinLength of a written frame
func (f *ConnectionCloseFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
func (f *ConnectionCloseFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
if version.UsesIETFFrameFormat() {
return 1 + 2 + utils.VarIntLen(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase)), nil
return 1 + 2 + utils.VarIntLen(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase))
}
return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase)), nil
return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase))
}
// Write writes an CONNECTION_CLOSE frame.

View File

@ -9,5 +9,5 @@ import (
// A Frame in QUIC
type Frame interface {
Write(b *bytes.Buffer, version protocol.VersionNumber) error
MinLength(version protocol.VersionNumber) (protocol.ByteCount, error)
MinLength(version protocol.VersionNumber) protocol.ByteCount
}

View File

@ -63,6 +63,6 @@ func (f *GoawayFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
}
// MinLength of a written frame
func (f *GoawayFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
return protocol.ByteCount(1 + 4 + 4 + 2 + len(f.ReasonPhrase)), nil
func (f *GoawayFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
return protocol.ByteCount(1 + 4 + 4 + 2 + len(f.ReasonPhrase))
}

View File

@ -43,9 +43,9 @@ func (f *MaxDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) er
}
// MinLength of a written frame
func (f *MaxDataFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
func (f *MaxDataFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
if !version.UsesIETFFrameFormat() { // writing this frame would result in a gQUIC WINDOW_UPDATE being written, which is longer
return 1 + 4 + 8, nil
return 1 + 4 + 8
}
return 1 + utils.VarIntLen(uint64(f.ByteOffset)), nil
return 1 + utils.VarIntLen(uint64(f.ByteOffset))
}

View File

@ -51,10 +51,10 @@ func (f *MaxStreamDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumb
}
// MinLength of a written frame
func (f *MaxStreamDataFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
func (f *MaxStreamDataFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
// writing this frame would result in a gQUIC WINDOW_UPDATE being written, which has a different length
if !version.UsesIETFFrameFormat() {
return 1 + 4 + 8, nil
return 1 + 4 + 8
}
return 1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.ByteOffset)), nil
return 1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.ByteOffset))
}

View File

@ -0,0 +1,37 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A MaxStreamIDFrame is a MAX_STREAM_ID frame
type MaxStreamIDFrame struct {
StreamID protocol.StreamID
}
// ParseMaxStreamIDFrame parses a MAX_STREAM_ID frame
func ParseMaxStreamIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamIDFrame, error) {
// read the Type byte
if _, err := r.ReadByte(); err != nil {
return nil, err
}
streamID, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
return &MaxStreamIDFrame{StreamID: protocol.StreamID(streamID)}, nil
}
func (f *MaxStreamIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
b.WriteByte(0x6)
utils.WriteVarInt(b, uint64(f.StreamID))
return nil
}
// MinLength of a written frame
func (f *MaxStreamIDFrame) MinLength(protocol.VersionNumber) protocol.ByteCount {
return 1 + utils.VarIntLen(uint64(f.StreamID))
}

View File

@ -0,0 +1,51 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("MAX_STREAM_ID frame", func() {
Context("parsing", func() {
It("accepts sample frame", func() {
data := []byte{0x6}
data = append(data, encodeVarInt(0xdecafbad)...)
b := bytes.NewReader(data)
f, err := ParseMaxStreamIDFrame(b, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
Expect(f.StreamID).To(Equal(protocol.StreamID(0xdecafbad)))
Expect(b.Len()).To(BeZero())
})
It("errors on EOFs", func() {
data := []byte{0x06}
data = append(data, encodeVarInt(0xdeadbeefcafe13)...)
_, err := ParseMaxStreamIDFrame(bytes.NewReader(data), protocol.VersionWhatever)
Expect(err).NotTo(HaveOccurred())
for i := range data {
_, err := ParseMaxStreamIDFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever)
Expect(err).To(HaveOccurred())
}
})
})
Context("writing", func() {
It("writes a sample frame", func() {
b := &bytes.Buffer{}
frame := MaxStreamIDFrame{StreamID: 0x12345678}
frame.Write(b, protocol.VersionWhatever)
expected := []byte{0x6}
expected = append(expected, encodeVarInt(0x12345678)...)
Expect(b.Bytes()).To(Equal(expected))
})
It("has the correct min length", func() {
frame := MaxStreamIDFrame{StreamID: 0x1337}
Expect(frame.MinLength(protocol.VersionWhatever)).To(Equal(1 + utils.VarIntLen(0x1337)))
})
})
})

View File

@ -28,6 +28,6 @@ func (f *PingFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error
}
// MinLength of a written frame
func (f *PingFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
return 1, nil
func (f *PingFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
return 1
}

View File

@ -7,10 +7,12 @@ import (
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A RstStreamFrame in QUIC
// A RstStreamFrame is a RST_STREAM frame in QUIC
type RstStreamFrame struct {
StreamID protocol.StreamID
ErrorCode uint32
StreamID protocol.StreamID
// The error code is a uint32 in gQUIC, but a uint16 in IETF QUIC.
// protocol.ApplicaitonErrorCode is a uint16, so larger values in gQUIC frames will be truncated.
ErrorCode protocol.ApplicationErrorCode
ByteOffset protocol.ByteCount
}
@ -21,7 +23,7 @@ func ParseRstStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*RstS
}
var streamID protocol.StreamID
var errorCode uint32
var errorCode uint16
var byteOffset protocol.ByteCount
if version.UsesIETFFrameFormat() {
sid, err := utils.ReadVarInt(r)
@ -29,11 +31,10 @@ func ParseRstStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*RstS
return nil, err
}
streamID = protocol.StreamID(sid)
ec, err := utils.BigEndian.ReadUint16(r)
errorCode, err = utils.BigEndian.ReadUint16(r)
if err != nil {
return nil, err
}
errorCode = uint32(ec)
bo, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
@ -54,12 +55,12 @@ func ParseRstStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*RstS
if err != nil {
return nil, err
}
errorCode = uint32(ec)
errorCode = uint16(ec)
}
return &RstStreamFrame{
StreamID: streamID,
ErrorCode: errorCode,
ErrorCode: protocol.ApplicationErrorCode(errorCode),
ByteOffset: byteOffset,
}, nil
}
@ -74,15 +75,15 @@ func (f *RstStreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber)
} else {
utils.BigEndian.WriteUint32(b, uint32(f.StreamID))
utils.BigEndian.WriteUint64(b, uint64(f.ByteOffset))
utils.BigEndian.WriteUint32(b, f.ErrorCode)
utils.BigEndian.WriteUint32(b, uint32(f.ErrorCode))
}
return nil
}
// MinLength of a written frame
func (f *RstStreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
func (f *RstStreamFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
if version.UsesIETFFrameFormat() {
return 1 + utils.VarIntLen(uint64(f.StreamID)) + 2 + utils.VarIntLen(uint64(f.ByteOffset)), nil
return 1 + utils.VarIntLen(uint64(f.StreamID)) + 2 + utils.VarIntLen(uint64(f.ByteOffset))
}
return 1 + 4 + 8 + 4, nil
return 1 + 4 + 8 + 4
}

View File

@ -22,7 +22,7 @@ var _ = Describe("RST_STREAM frame", func() {
Expect(err).ToNot(HaveOccurred())
Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef)))
Expect(frame.ByteOffset).To(Equal(protocol.ByteCount(0x987654321)))
Expect(frame.ErrorCode).To(Equal(uint32(0x1337)))
Expect(frame.ErrorCode).To(Equal(protocol.ApplicationErrorCode(0x1337)))
})
It("errors on EOFs", func() {
@ -44,13 +44,13 @@ var _ = Describe("RST_STREAM frame", func() {
b := bytes.NewReader([]byte{0x1,
0xde, 0xad, 0xbe, 0xef, // stream id
0x88, 0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11, // byte offset
0x34, 0x12, 0x37, 0x13, // error code
0x0, 0x0, 0xca, 0xfe, // error code
})
frame, err := ParseRstStreamFrame(b, versionBigEndian)
Expect(err).ToNot(HaveOccurred())
Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef)))
Expect(frame.ByteOffset).To(Equal(protocol.ByteCount(0x8877665544332211)))
Expect(frame.ErrorCode).To(Equal(uint32(0x34123713)))
Expect(frame.ErrorCode).To(Equal(protocol.ApplicationErrorCode(0xcafe)))
})
It("errors on EOFs", func() {
@ -103,7 +103,7 @@ var _ = Describe("RST_STREAM frame", func() {
frame := RstStreamFrame{
StreamID: 0x1337,
ByteOffset: 0x11223344decafbad,
ErrorCode: 0xdeadbeef,
ErrorCode: 0xcafe,
}
b := &bytes.Buffer{}
err := frame.Write(b, versionBigEndian)
@ -111,7 +111,7 @@ var _ = Describe("RST_STREAM frame", func() {
Expect(b.Bytes()).To(Equal([]byte{0x01,
0x0, 0x0, 0x13, 0x37, // stream id
0x11, 0x22, 0x33, 0x44, 0xde, 0xca, 0xfb, 0xad, // byte offset
0xde, 0xad, 0xbe, 0xef, // error code
0x0, 0x0, 0xca, 0xfe, // error code
}))
})

View File

@ -0,0 +1,47 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A StopSendingFrame is a STOP_SENDING frame
type StopSendingFrame struct {
StreamID protocol.StreamID
ErrorCode protocol.ApplicationErrorCode
}
// ParseStopSendingFrame parses a STOP_SENDING frame
func ParseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSendingFrame, error) {
if _, err := r.ReadByte(); err != nil { // read the TypeByte
return nil, err
}
streamID, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
errorCode, err := utils.BigEndian.ReadUint16(r)
if err != nil {
return nil, err
}
return &StopSendingFrame{
StreamID: protocol.StreamID(streamID),
ErrorCode: protocol.ApplicationErrorCode(errorCode),
}, nil
}
// MinLength of a written frame
func (f *StopSendingFrame) MinLength(_ protocol.VersionNumber) protocol.ByteCount {
return 1 + utils.VarIntLen(uint64(f.StreamID)) + 2
}
func (f *StopSendingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
b.WriteByte(0x0c)
utils.WriteVarInt(b, uint64(f.StreamID))
utils.BigEndian.WriteUint16(b, uint16(f.ErrorCode))
return nil
}

View File

@ -0,0 +1,63 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("STOP_SENDING frame", func() {
Context("when parsing", func() {
It("parses a sample frame", func() {
data := []byte{0x0c}
data = append(data, encodeVarInt(0xdecafbad)...) // stream ID
data = append(data, []byte{0x13, 0x37}...) // error code
b := bytes.NewReader(data)
frame, err := ParseStopSendingFrame(b, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdecafbad)))
Expect(frame.ErrorCode).To(Equal(protocol.ApplicationErrorCode(0x1337)))
Expect(b.Len()).To(BeZero())
})
It("errors on EOFs", func() {
data := []byte{0x0c}
data = append(data, encodeVarInt(0xdecafbad)...) // stream ID
data = append(data, []byte{0x13, 0x37}...) // error code
_, err := ParseStopSendingFrame(bytes.NewReader(data), versionIETFFrames)
Expect(err).NotTo(HaveOccurred())
for i := range data {
_, err := ParseStopSendingFrame(bytes.NewReader(data[:i]), versionIETFFrames)
Expect(err).To(HaveOccurred())
}
})
})
Context("when writing", func() {
It("writes", func() {
frame := &StopSendingFrame{
StreamID: 0xdeadbeefcafe,
ErrorCode: 0x10,
}
buf := &bytes.Buffer{}
err := frame.Write(buf, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
expected := []byte{0x0c}
expected = append(expected, encodeVarInt(0xdeadbeefcafe)...)
expected = append(expected, []byte{0x0, 0x10}...)
Expect(buf.Bytes()).To(Equal(expected))
})
It("has the correct min length", func() {
frame := &StopSendingFrame{
StreamID: 0xdeadbeef,
ErrorCode: 0x10,
}
Expect(frame.MinLength(versionIETFFrames)).To(Equal(1 + 2 + utils.VarIntLen(0xdeadbeef)))
})
})
})

View File

@ -22,7 +22,10 @@ var (
errPacketNumberLenNotSet = errors.New("StopWaitingFrame: PacketNumberLen not set")
)
func (f *StopWaitingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
func (f *StopWaitingFrame) Write(b *bytes.Buffer, v protocol.VersionNumber) error {
if v.UsesIETFFrameFormat() {
return errors.New("STOP_WAITING not defined in IETF QUIC")
}
// make sure the PacketNumber was set
if f.PacketNumber == protocol.PacketNumber(0) {
return errPacketNumberNotSet
@ -49,14 +52,8 @@ func (f *StopWaitingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) erro
}
// MinLength of a written frame
func (f *StopWaitingFrame) MinLength(_ protocol.VersionNumber) (protocol.ByteCount, error) {
minLength := protocol.ByteCount(1) // typeByte
if f.PacketNumberLen == protocol.PacketNumberLenInvalid {
return 0, errPacketNumberLenNotSet
}
minLength += protocol.ByteCount(f.PacketNumberLen)
return minLength, nil
func (f *StopWaitingFrame) MinLength(_ protocol.VersionNumber) protocol.ByteCount {
return 1 + protocol.ByteCount(f.PacketNumberLen)
}
// ParseStopWaitingFrame parses a StopWaiting frame

View File

@ -84,7 +84,7 @@ var _ = Describe("StopWaitingFrame", func() {
LeastUnacked: 10,
PacketNumberLen: protocol.PacketNumberLen1,
}
err := frame.Write(b, protocol.VersionWhatever)
err := frame.Write(b, versionBigEndian)
Expect(err).To(MatchError(errPacketNumberNotSet))
})
@ -94,7 +94,7 @@ var _ = Describe("StopWaitingFrame", func() {
LeastUnacked: 10,
PacketNumber: 13,
}
err := frame.Write(b, protocol.VersionWhatever)
err := frame.Write(b, versionBigEndian)
Expect(err).To(MatchError(errPacketNumberLenNotSet))
})
@ -105,10 +105,21 @@ var _ = Describe("StopWaitingFrame", func() {
PacketNumber: 5,
PacketNumberLen: protocol.PacketNumberLen1,
}
err := frame.Write(b, protocol.VersionWhatever)
err := frame.Write(b, versionBigEndian)
Expect(err).To(MatchError(errLeastUnackedHigherThanPacketNumber))
})
It("refuses to write for IETF QUIC", func() {
b := &bytes.Buffer{}
frame := &StopWaitingFrame{
LeastUnacked: 10,
PacketNumber: 13,
PacketNumberLen: protocol.PacketNumberLen6,
}
err := frame.Write(b, versionIETFFrames)
Expect(err).To(MatchError("STOP_WAITING not defined in IETF QUIC"))
})
Context("LeastUnackedDelta length", func() {
Context("in big endian", func() {
It("writes a 1-byte LeastUnackedDelta", func() {
@ -176,18 +187,10 @@ var _ = Describe("StopWaitingFrame", func() {
Expect(frame.MinLength(protocol.VersionWhatever)).To(Equal(protocol.ByteCount(length + 1)))
}
})
It("errors when packetNumberLen is not set", func() {
frame := &StopWaitingFrame{
LeastUnacked: 10,
}
_, err := frame.MinLength(0)
Expect(err).To(MatchError(errPacketNumberLenNotSet))
})
})
Context("self consistency", func() {
It("reads a stop waiting frame that it wrote", func() {
It("reads a STOP_WAITING frame that it wrote", func() {
packetNumber := protocol.PacketNumber(13)
frame := &StopWaitingFrame{
LeastUnacked: 10,
@ -195,9 +198,9 @@ var _ = Describe("StopWaitingFrame", func() {
PacketNumberLen: protocol.PacketNumberLen4,
}
b := &bytes.Buffer{}
err := frame.Write(b, protocol.VersionWhatever)
err := frame.Write(b, versionBigEndian)
Expect(err).ToNot(HaveOccurred())
readframe, err := ParseStopWaitingFrame(bytes.NewReader(b.Bytes()), packetNumber, protocol.PacketNumberLen4, protocol.VersionWhatever)
readframe, err := ParseStopWaitingFrame(bytes.NewReader(b.Bytes()), packetNumber, protocol.PacketNumberLen4, versionBigEndian)
Expect(err).ToNot(HaveOccurred())
Expect(readframe.LeastUnacked).To(Equal(frame.LeastUnacked))
})

View File

@ -10,10 +10,11 @@ import (
// A StreamBlockedFrame in QUIC
type StreamBlockedFrame struct {
StreamID protocol.StreamID
Offset protocol.ByteCount
}
// ParseStreamBlockedFrame parses a STREAM_BLOCKED frame
func ParseStreamBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamBlockedFrame, error) {
func ParseStreamBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamBlockedFrame, error) {
if _, err := r.ReadByte(); err != nil { // read the TypeByte
return nil, err
}
@ -21,7 +22,14 @@ func ParseStreamBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*
if err != nil {
return nil, err
}
return &StreamBlockedFrame{StreamID: protocol.StreamID(sid)}, nil
offset, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
return &StreamBlockedFrame{
StreamID: protocol.StreamID(sid),
Offset: protocol.ByteCount(offset),
}, nil
}
// Write writes a STREAM_BLOCKED frame
@ -31,13 +39,14 @@ func (f *StreamBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumb
}
b.WriteByte(0x09)
utils.WriteVarInt(b, uint64(f.StreamID))
utils.WriteVarInt(b, uint64(f.Offset))
return nil
}
// MinLength of a written frame
func (f *StreamBlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
func (f *StreamBlockedFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
if !version.UsesIETFFrameFormat() {
return 1 + 4, nil
return 1 + 4
}
return 1 + utils.VarIntLen(uint64(f.StreamID)), nil
return 1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.Offset))
}

View File

@ -14,17 +14,20 @@ var _ = Describe("STREAM_BLOCKED frame", func() {
Context("parsing", func() {
It("accepts sample frame", func() {
data := []byte{0x9}
data = append(data, encodeVarInt(0xdeadbeef)...)
data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID
data = append(data, encodeVarInt(0xdecafbad)...) // offset
b := bytes.NewReader(data)
frame, err := ParseStreamBlockedFrame(b, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef)))
Expect(frame.Offset).To(Equal(protocol.ByteCount(0xdecafbad)))
Expect(b.Len()).To(BeZero())
})
It("errors on EOFs", func() {
data := []byte{0x9}
data = append(data, encodeVarInt(0xdeadbeef)...)
data = append(data, encodeVarInt(0xc0010ff)...)
_, err := ParseStreamBlockedFrame(bytes.NewReader(data), versionIETFFrames)
Expect(err).NotTo(HaveOccurred())
for i := range data {
@ -38,19 +41,22 @@ var _ = Describe("STREAM_BLOCKED frame", func() {
It("has proper min length", func() {
f := &StreamBlockedFrame{
StreamID: 0x1337,
Offset: 0xdeadbeef,
}
Expect(f.MinLength(0)).To(Equal(1 + utils.VarIntLen(0x1337)))
Expect(f.MinLength(0)).To(Equal(1 + utils.VarIntLen(0x1337) + utils.VarIntLen(0xdeadbeef)))
})
It("writes a sample frame", func() {
b := &bytes.Buffer{}
f := &StreamBlockedFrame{
StreamID: 0xdecafbad,
Offset: 0x1337,
}
err := f.Write(b, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
expected := []byte{0x9}
expected = append(expected, encodeVarInt(uint64(f.StreamID))...)
expected = append(expected, encodeVarInt(uint64(f.Offset))...)
Expect(b.Bytes()).To(Equal(expected))
})
})

View File

@ -117,7 +117,7 @@ func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) err
// MinLength returns the length of the header of a StreamFrame
// the total length of the frame is frame.MinLength() + frame.DataLen()
func (f *StreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
func (f *StreamFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount {
if !version.UsesIETFFrameFormat() {
return f.minLengthLegacy(version)
}
@ -128,5 +128,5 @@ func (f *StreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCo
if f.DataLenPresent {
length += utils.VarIntLen(uint64(f.DataLen()))
}
return length, nil
return length
}

View File

@ -183,12 +183,12 @@ func (f *StreamFrame) getOffsetLength() protocol.ByteCount {
return 8
}
func (f *StreamFrame) minLengthLegacy(_ protocol.VersionNumber) (protocol.ByteCount, error) {
func (f *StreamFrame) minLengthLegacy(_ protocol.VersionNumber) protocol.ByteCount {
length := protocol.ByteCount(1) + protocol.ByteCount(f.calculateStreamIDLength()) + f.getOffsetLength()
if f.DataLenPresent {
length += 2
}
return length, nil
return length
}
// DataLen gives the length of data in bytes

View File

@ -210,7 +210,7 @@ var _ = Describe("STREAM frame (for gQUIC)", func() {
}
err := f.Write(b, versionBigEndian)
Expect(err).ToNot(HaveOccurred())
minLength, _ := f.MinLength(0)
minLength := f.MinLength(0)
Expect(b.Bytes()[0] & 0x20).To(Equal(uint8(0x20)))
Expect(b.Bytes()[minLength-2 : minLength]).To(Equal([]byte{0x13, 0x37}))
})
@ -229,9 +229,9 @@ var _ = Describe("STREAM frame (for gQUIC)", func() {
Expect(err).ToNot(HaveOccurred())
Expect(b.Bytes()[0] & 0x20).To(Equal(uint8(0)))
Expect(b.Bytes()[1 : b.Len()-dataLen]).ToNot(ContainSubstring(string([]byte{0x37, 0x13})))
minLength, _ := f.MinLength(versionBigEndian)
minLength := f.MinLength(versionBigEndian)
f.DataLenPresent = true
minLengthWithoutDataLen, _ := f.MinLength(versionBigEndian)
minLengthWithoutDataLen := f.MinLength(versionBigEndian)
Expect(minLength).To(Equal(minLengthWithoutDataLen - 2))
})
@ -242,7 +242,7 @@ var _ = Describe("STREAM frame (for gQUIC)", func() {
DataLenPresent: false,
Offset: 0xdeadbeef,
}
minLengthWithoutDataLen, _ := f.MinLength(versionBigEndian)
minLengthWithoutDataLen := f.MinLength(versionBigEndian)
f.DataLenPresent = true
Expect(f.MinLength(versionBigEndian)).To(Equal(minLengthWithoutDataLen + 2))
})

View File

@ -0,0 +1,37 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A StreamIDBlockedFrame is a STREAM_ID_BLOCKED frame
type StreamIDBlockedFrame struct {
StreamID protocol.StreamID
}
// ParseStreamIDBlockedFrame parses a STREAM_ID_BLOCKED frame
func ParseStreamIDBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamIDBlockedFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
streamID, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
return &StreamIDBlockedFrame{StreamID: protocol.StreamID(streamID)}, nil
}
func (f *StreamIDBlockedFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
typeByte := uint8(0x0a)
b.WriteByte(typeByte)
utils.WriteVarInt(b, uint64(f.StreamID))
return nil
}
// MinLength of a written frame
func (f *StreamIDBlockedFrame) MinLength(_ protocol.VersionNumber) protocol.ByteCount {
return 1 + utils.VarIntLen(uint64(f.StreamID))
}

View File

@ -0,0 +1,53 @@
package wire
import (
"bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("STREAM_ID_BLOCKED frame", func() {
Context("parsing", func() {
It("accepts sample frame", func() {
expected := []byte{0xa}
expected = append(expected, encodeVarInt(0xdecafbad)...)
b := bytes.NewReader(expected)
frame, err := ParseStreamIDBlockedFrame(b, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdecafbad)))
Expect(b.Len()).To(BeZero())
})
It("errors on EOFs", func() {
data := []byte{0xa}
data = append(data, encodeVarInt(0x12345678)...)
_, err := ParseStreamIDBlockedFrame(bytes.NewReader(data), versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
for i := range data {
_, err := ParseStreamIDBlockedFrame(bytes.NewReader(data[:i]), versionIETFFrames)
Expect(err).To(MatchError(io.EOF))
}
})
})
Context("writing", func() {
It("writes a sample frame", func() {
b := &bytes.Buffer{}
frame := StreamIDBlockedFrame{StreamID: 0xdeadbeefcafe}
err := frame.Write(b, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
expected := []byte{0xa}
expected = append(expected, encodeVarInt(0xdeadbeefcafe)...)
Expect(b.Bytes()).To(Equal(expected))
})
It("has the correct min length", func() {
frame := StreamIDBlockedFrame{StreamID: 0x123456}
Expect(frame.MinLength(0)).To(Equal(protocol.ByteCount(1) + utils.VarIntLen(0x123456)))
})
})
})

View File

@ -56,6 +56,10 @@ func (mc *mintController) State() mint.State {
return mc.conn.State().HandshakeState
}
func (mc *mintController) ConnectionState() mint.ConnectionState {
return mc.conn.State()
}
func (mc *mintController) SetCryptoStream(stream io.ReadWriter) {
mc.csc.SetStream(stream)
}
@ -73,6 +77,7 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf
},
}
if tlsConf != nil {
mconf.ServerName = tlsConf.ServerName
mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates))
for i, certChain := range tlsConf.Certificates {
mconf.Certificates[i] = &mint.Certificate{
@ -87,6 +92,13 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf
mconf.Certificates[i].Chain[j] = c
}
}
switch tlsConf.ClientAuth {
case tls.NoClientCert:
case tls.RequireAnyClientCert:
mconf.RequireClientAuth = true
default:
return nil, errors.New("mint currently only support ClientAuthType RequireAnyClientCert")
}
}
if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil {
return nil, err
@ -128,14 +140,14 @@ func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, versio
// packUnencryptedPacket provides a low-overhead way to pack a packet.
// It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available.
func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, sf *wire.StreamFrame, pers protocol.Perspective) ([]byte, error) {
func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, f wire.Frame, pers protocol.Perspective) ([]byte, error) {
raw := getPacketBuffer()
buffer := bytes.NewBuffer(raw)
if err := hdr.Write(buffer, pers, hdr.Version); err != nil {
return nil, err
}
payloadStartIndex := buffer.Len()
if err := sf.Write(buffer, hdr.Version); err != nil {
if err := f.Write(buffer, hdr.Version); err != nil {
return nil, err
}
raw = raw[0:buffer.Len()]
@ -144,7 +156,7 @@ func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, sf *wire.StreamFr
if utils.Debug() {
utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(raw), hdr.ConnectionID, protocol.EncryptionUnencrypted)
hdr.Log()
wire.LogFrame(sf, true)
wire.LogFrame(f, true)
}
return raw, nil
}

View File

@ -2,9 +2,11 @@ package quic
import (
"bytes"
"crypto/tls"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
@ -33,6 +35,45 @@ var _ = Describe("Packing and unpacking Initial packets", func() {
hdr.Raw = buf.Bytes()
})
Context("generating a mint.Config", func() {
It("sets non-blocking mode", func() {
mintConf, err := tlsToMintConfig(nil, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
Expect(mintConf.NonBlocking).To(BeTrue())
})
It("sets the server name", func() {
conf := &tls.Config{ServerName: "www.example.com"}
mintConf, err := tlsToMintConfig(conf, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
Expect(mintConf.ServerName).To(Equal("www.example.com"))
})
It("sets the certificate chain", func() {
tlsConf := testdata.GetTLSConfig()
mintConf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
Expect(mintConf.Certificates).ToNot(BeEmpty())
Expect(mintConf.Certificates).To(HaveLen(len(tlsConf.Certificates)))
})
It("requires client authentication", func() {
mintConf, err := tlsToMintConfig(nil, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
Expect(mintConf.RequireClientAuth).To(BeFalse())
conf := &tls.Config{ClientAuth: tls.RequireAnyClientCert}
mintConf, err = tlsToMintConfig(conf, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
Expect(mintConf.RequireClientAuth).To(BeTrue())
})
It("rejects unsupported client auth types", func() {
conf := &tls.Config{ClientAuth: tls.RequireAndVerifyClientCert}
_, err := tlsToMintConfig(conf, protocol.PerspectiveClient)
Expect(err).To(MatchError("mint currently only support ClientAuthType RequireAnyClientCert"))
})
})
Context("unpacking", func() {
packPacket := func(frames []wire.Frame) []byte {
buf := &bytes.Buffer{}

View File

@ -0,0 +1,141 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: CryptoStream)
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
wire "github.com/lucas-clemente/quic-go/internal/wire"
)
// MockCryptoStream is a mock of CryptoStream interface
type MockCryptoStream struct {
ctrl *gomock.Controller
recorder *MockCryptoStreamMockRecorder
}
// MockCryptoStreamMockRecorder is the mock recorder for MockCryptoStream
type MockCryptoStreamMockRecorder struct {
mock *MockCryptoStream
}
// NewMockCryptoStream creates a new mock instance
func NewMockCryptoStream(ctrl *gomock.Controller) *MockCryptoStream {
mock := &MockCryptoStream{ctrl: ctrl}
mock.recorder = &MockCryptoStreamMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockCryptoStream) EXPECT() *MockCryptoStreamMockRecorder {
return m.recorder
}
// Read mocks base method
func (m *MockCryptoStream) Read(arg0 []byte) (int, error) {
ret := m.ctrl.Call(m, "Read", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Read indicates an expected call of Read
func (mr *MockCryptoStreamMockRecorder) Read(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockCryptoStream)(nil).Read), arg0)
}
// StreamID mocks base method
func (m *MockCryptoStream) StreamID() protocol.StreamID {
ret := m.ctrl.Call(m, "StreamID")
ret0, _ := ret[0].(protocol.StreamID)
return ret0
}
// StreamID indicates an expected call of StreamID
func (mr *MockCryptoStreamMockRecorder) StreamID() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockCryptoStream)(nil).StreamID))
}
// Write mocks base method
func (m *MockCryptoStream) Write(arg0 []byte) (int, error) {
ret := m.ctrl.Call(m, "Write", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Write indicates an expected call of Write
func (mr *MockCryptoStreamMockRecorder) Write(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockCryptoStream)(nil).Write), arg0)
}
// closeForShutdown mocks base method
func (m *MockCryptoStream) closeForShutdown(arg0 error) {
m.ctrl.Call(m, "closeForShutdown", arg0)
}
// closeForShutdown indicates an expected call of closeForShutdown
func (mr *MockCryptoStreamMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockCryptoStream)(nil).closeForShutdown), arg0)
}
// getWindowUpdate mocks base method
func (m *MockCryptoStream) getWindowUpdate() protocol.ByteCount {
ret := m.ctrl.Call(m, "getWindowUpdate")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
// getWindowUpdate indicates an expected call of getWindowUpdate
func (mr *MockCryptoStreamMockRecorder) getWindowUpdate() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockCryptoStream)(nil).getWindowUpdate))
}
// handleMaxStreamDataFrame mocks base method
func (m *MockCryptoStream) handleMaxStreamDataFrame(arg0 *wire.MaxStreamDataFrame) {
m.ctrl.Call(m, "handleMaxStreamDataFrame", arg0)
}
// handleMaxStreamDataFrame indicates an expected call of handleMaxStreamDataFrame
func (mr *MockCryptoStreamMockRecorder) handleMaxStreamDataFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleMaxStreamDataFrame", reflect.TypeOf((*MockCryptoStream)(nil).handleMaxStreamDataFrame), arg0)
}
// handleStreamFrame mocks base method
func (m *MockCryptoStream) handleStreamFrame(arg0 *wire.StreamFrame) error {
ret := m.ctrl.Call(m, "handleStreamFrame", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// handleStreamFrame indicates an expected call of handleStreamFrame
func (mr *MockCryptoStreamMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockCryptoStream)(nil).handleStreamFrame), arg0)
}
// popStreamFrame mocks base method
func (m *MockCryptoStream) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) {
ret := m.ctrl.Call(m, "popStreamFrame", arg0)
ret0, _ := ret[0].(*wire.StreamFrame)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// popStreamFrame indicates an expected call of popStreamFrame
func (mr *MockCryptoStreamMockRecorder) popStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockCryptoStream)(nil).popStreamFrame), arg0)
}
// setReadOffset mocks base method
func (m *MockCryptoStream) setReadOffset(arg0 protocol.ByteCount) {
m.ctrl.Call(m, "setReadOffset", arg0)
}
// setReadOffset indicates an expected call of setReadOffset
func (mr *MockCryptoStreamMockRecorder) setReadOffset(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setReadOffset", reflect.TypeOf((*MockCryptoStream)(nil).setReadOffset), arg0)
}

View File

@ -0,0 +1,132 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: ReceiveStreamI)
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
wire "github.com/lucas-clemente/quic-go/internal/wire"
)
// MockReceiveStreamI is a mock of ReceiveStreamI interface
type MockReceiveStreamI struct {
ctrl *gomock.Controller
recorder *MockReceiveStreamIMockRecorder
}
// MockReceiveStreamIMockRecorder is the mock recorder for MockReceiveStreamI
type MockReceiveStreamIMockRecorder struct {
mock *MockReceiveStreamI
}
// NewMockReceiveStreamI creates a new mock instance
func NewMockReceiveStreamI(ctrl *gomock.Controller) *MockReceiveStreamI {
mock := &MockReceiveStreamI{ctrl: ctrl}
mock.recorder = &MockReceiveStreamIMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockReceiveStreamI) EXPECT() *MockReceiveStreamIMockRecorder {
return m.recorder
}
// CancelRead mocks base method
func (m *MockReceiveStreamI) CancelRead(arg0 protocol.ApplicationErrorCode) error {
ret := m.ctrl.Call(m, "CancelRead", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// CancelRead indicates an expected call of CancelRead
func (mr *MockReceiveStreamIMockRecorder) CancelRead(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockReceiveStreamI)(nil).CancelRead), arg0)
}
// Read mocks base method
func (m *MockReceiveStreamI) Read(arg0 []byte) (int, error) {
ret := m.ctrl.Call(m, "Read", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Read indicates an expected call of Read
func (mr *MockReceiveStreamIMockRecorder) Read(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockReceiveStreamI)(nil).Read), arg0)
}
// SetReadDeadline mocks base method
func (m *MockReceiveStreamI) SetReadDeadline(arg0 time.Time) error {
ret := m.ctrl.Call(m, "SetReadDeadline", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetReadDeadline indicates an expected call of SetReadDeadline
func (mr *MockReceiveStreamIMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockReceiveStreamI)(nil).SetReadDeadline), arg0)
}
// StreamID mocks base method
func (m *MockReceiveStreamI) StreamID() protocol.StreamID {
ret := m.ctrl.Call(m, "StreamID")
ret0, _ := ret[0].(protocol.StreamID)
return ret0
}
// StreamID indicates an expected call of StreamID
func (mr *MockReceiveStreamIMockRecorder) StreamID() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockReceiveStreamI)(nil).StreamID))
}
// closeForShutdown mocks base method
func (m *MockReceiveStreamI) closeForShutdown(arg0 error) {
m.ctrl.Call(m, "closeForShutdown", arg0)
}
// closeForShutdown indicates an expected call of closeForShutdown
func (mr *MockReceiveStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockReceiveStreamI)(nil).closeForShutdown), arg0)
}
// getWindowUpdate mocks base method
func (m *MockReceiveStreamI) getWindowUpdate() protocol.ByteCount {
ret := m.ctrl.Call(m, "getWindowUpdate")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
// getWindowUpdate indicates an expected call of getWindowUpdate
func (mr *MockReceiveStreamIMockRecorder) getWindowUpdate() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockReceiveStreamI)(nil).getWindowUpdate))
}
// handleRstStreamFrame mocks base method
func (m *MockReceiveStreamI) handleRstStreamFrame(arg0 *wire.RstStreamFrame) error {
ret := m.ctrl.Call(m, "handleRstStreamFrame", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// handleRstStreamFrame indicates an expected call of handleRstStreamFrame
func (mr *MockReceiveStreamIMockRecorder) handleRstStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleRstStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleRstStreamFrame), arg0)
}
// handleStreamFrame mocks base method
func (m *MockReceiveStreamI) handleStreamFrame(arg0 *wire.StreamFrame) error {
ret := m.ctrl.Call(m, "handleStreamFrame", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// handleStreamFrame indicates an expected call of handleStreamFrame
func (mr *MockReceiveStreamIMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleStreamFrame), arg0)
}

View File

@ -0,0 +1,154 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: SendStreamI)
// Package quic is a generated GoMock package.
package quic
import (
context "context"
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
wire "github.com/lucas-clemente/quic-go/internal/wire"
)
// MockSendStreamI is a mock of SendStreamI interface
type MockSendStreamI struct {
ctrl *gomock.Controller
recorder *MockSendStreamIMockRecorder
}
// MockSendStreamIMockRecorder is the mock recorder for MockSendStreamI
type MockSendStreamIMockRecorder struct {
mock *MockSendStreamI
}
// NewMockSendStreamI creates a new mock instance
func NewMockSendStreamI(ctrl *gomock.Controller) *MockSendStreamI {
mock := &MockSendStreamI{ctrl: ctrl}
mock.recorder = &MockSendStreamIMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockSendStreamI) EXPECT() *MockSendStreamIMockRecorder {
return m.recorder
}
// CancelWrite mocks base method
func (m *MockSendStreamI) CancelWrite(arg0 protocol.ApplicationErrorCode) error {
ret := m.ctrl.Call(m, "CancelWrite", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// CancelWrite indicates an expected call of CancelWrite
func (mr *MockSendStreamIMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockSendStreamI)(nil).CancelWrite), arg0)
}
// Close mocks base method
func (m *MockSendStreamI) Close() error {
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close
func (mr *MockSendStreamIMockRecorder) Close() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSendStreamI)(nil).Close))
}
// Context mocks base method
func (m *MockSendStreamI) Context() context.Context {
ret := m.ctrl.Call(m, "Context")
ret0, _ := ret[0].(context.Context)
return ret0
}
// Context indicates an expected call of Context
func (mr *MockSendStreamIMockRecorder) Context() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockSendStreamI)(nil).Context))
}
// SetWriteDeadline mocks base method
func (m *MockSendStreamI) SetWriteDeadline(arg0 time.Time) error {
ret := m.ctrl.Call(m, "SetWriteDeadline", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetWriteDeadline indicates an expected call of SetWriteDeadline
func (mr *MockSendStreamIMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockSendStreamI)(nil).SetWriteDeadline), arg0)
}
// StreamID mocks base method
func (m *MockSendStreamI) StreamID() protocol.StreamID {
ret := m.ctrl.Call(m, "StreamID")
ret0, _ := ret[0].(protocol.StreamID)
return ret0
}
// StreamID indicates an expected call of StreamID
func (mr *MockSendStreamIMockRecorder) StreamID() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockSendStreamI)(nil).StreamID))
}
// Write mocks base method
func (m *MockSendStreamI) Write(arg0 []byte) (int, error) {
ret := m.ctrl.Call(m, "Write", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Write indicates an expected call of Write
func (mr *MockSendStreamIMockRecorder) Write(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendStreamI)(nil).Write), arg0)
}
// closeForShutdown mocks base method
func (m *MockSendStreamI) closeForShutdown(arg0 error) {
m.ctrl.Call(m, "closeForShutdown", arg0)
}
// closeForShutdown indicates an expected call of closeForShutdown
func (mr *MockSendStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockSendStreamI)(nil).closeForShutdown), arg0)
}
// handleMaxStreamDataFrame mocks base method
func (m *MockSendStreamI) handleMaxStreamDataFrame(arg0 *wire.MaxStreamDataFrame) {
m.ctrl.Call(m, "handleMaxStreamDataFrame", arg0)
}
// handleMaxStreamDataFrame indicates an expected call of handleMaxStreamDataFrame
func (mr *MockSendStreamIMockRecorder) handleMaxStreamDataFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleMaxStreamDataFrame", reflect.TypeOf((*MockSendStreamI)(nil).handleMaxStreamDataFrame), arg0)
}
// handleStopSendingFrame mocks base method
func (m *MockSendStreamI) handleStopSendingFrame(arg0 *wire.StopSendingFrame) {
m.ctrl.Call(m, "handleStopSendingFrame", arg0)
}
// handleStopSendingFrame indicates an expected call of handleStopSendingFrame
func (mr *MockSendStreamIMockRecorder) handleStopSendingFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockSendStreamI)(nil).handleStopSendingFrame), arg0)
}
// popStreamFrame mocks base method
func (m *MockSendStreamI) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) {
ret := m.ctrl.Call(m, "popStreamFrame", arg0)
ret0, _ := ret[0].(*wire.StreamFrame)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// popStreamFrame indicates an expected call of popStreamFrame
func (mr *MockSendStreamIMockRecorder) popStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockSendStreamI)(nil).popStreamFrame), arg0)
}

View File

@ -0,0 +1,72 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: StreamFrameSource)
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
wire "github.com/lucas-clemente/quic-go/internal/wire"
)
// MockStreamFrameSource is a mock of StreamFrameSource interface
type MockStreamFrameSource struct {
ctrl *gomock.Controller
recorder *MockStreamFrameSourceMockRecorder
}
// MockStreamFrameSourceMockRecorder is the mock recorder for MockStreamFrameSource
type MockStreamFrameSourceMockRecorder struct {
mock *MockStreamFrameSource
}
// NewMockStreamFrameSource creates a new mock instance
func NewMockStreamFrameSource(ctrl *gomock.Controller) *MockStreamFrameSource {
mock := &MockStreamFrameSource{ctrl: ctrl}
mock.recorder = &MockStreamFrameSourceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockStreamFrameSource) EXPECT() *MockStreamFrameSourceMockRecorder {
return m.recorder
}
// HasCryptoStreamData mocks base method
func (m *MockStreamFrameSource) HasCryptoStreamData() bool {
ret := m.ctrl.Call(m, "HasCryptoStreamData")
ret0, _ := ret[0].(bool)
return ret0
}
// HasCryptoStreamData indicates an expected call of HasCryptoStreamData
func (mr *MockStreamFrameSourceMockRecorder) HasCryptoStreamData() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasCryptoStreamData", reflect.TypeOf((*MockStreamFrameSource)(nil).HasCryptoStreamData))
}
// PopCryptoStreamFrame mocks base method
func (m *MockStreamFrameSource) PopCryptoStreamFrame(arg0 protocol.ByteCount) *wire.StreamFrame {
ret := m.ctrl.Call(m, "PopCryptoStreamFrame", arg0)
ret0, _ := ret[0].(*wire.StreamFrame)
return ret0
}
// PopCryptoStreamFrame indicates an expected call of PopCryptoStreamFrame
func (mr *MockStreamFrameSourceMockRecorder) PopCryptoStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopCryptoStreamFrame", reflect.TypeOf((*MockStreamFrameSource)(nil).PopCryptoStreamFrame), arg0)
}
// PopStreamFrames mocks base method
func (m *MockStreamFrameSource) PopStreamFrames(arg0 protocol.ByteCount) []*wire.StreamFrame {
ret := m.ctrl.Call(m, "PopStreamFrames", arg0)
ret0, _ := ret[0].([]*wire.StreamFrame)
return ret0
}
// PopStreamFrames indicates an expected call of PopStreamFrames
func (mr *MockStreamFrameSourceMockRecorder) PopStreamFrames(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopStreamFrames", reflect.TypeOf((*MockStreamFrameSource)(nil).PopStreamFrames), arg0)
}

View File

@ -0,0 +1,61 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: StreamGetter)
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockStreamGetter is a mock of StreamGetter interface
type MockStreamGetter struct {
ctrl *gomock.Controller
recorder *MockStreamGetterMockRecorder
}
// MockStreamGetterMockRecorder is the mock recorder for MockStreamGetter
type MockStreamGetterMockRecorder struct {
mock *MockStreamGetter
}
// NewMockStreamGetter creates a new mock instance
func NewMockStreamGetter(ctrl *gomock.Controller) *MockStreamGetter {
mock := &MockStreamGetter{ctrl: ctrl}
mock.recorder = &MockStreamGetterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockStreamGetter) EXPECT() *MockStreamGetterMockRecorder {
return m.recorder
}
// GetOrOpenReceiveStream mocks base method
func (m *MockStreamGetter) GetOrOpenReceiveStream(arg0 protocol.StreamID) (receiveStreamI, error) {
ret := m.ctrl.Call(m, "GetOrOpenReceiveStream", arg0)
ret0, _ := ret[0].(receiveStreamI)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream
func (mr *MockStreamGetterMockRecorder) GetOrOpenReceiveStream(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenReceiveStream), arg0)
}
// GetOrOpenSendStream mocks base method
func (m *MockStreamGetter) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStreamI, error) {
ret := m.ctrl.Call(m, "GetOrOpenSendStream", arg0)
ret0, _ := ret[0].(sendStreamI)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream
func (mr *MockStreamGetterMockRecorder) GetOrOpenSendStream(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenSendStream), arg0)
}

View File

@ -0,0 +1,239 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: StreamI)
// Package quic is a generated GoMock package.
package quic
import (
context "context"
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
wire "github.com/lucas-clemente/quic-go/internal/wire"
)
// MockStreamI is a mock of StreamI interface
type MockStreamI struct {
ctrl *gomock.Controller
recorder *MockStreamIMockRecorder
}
// MockStreamIMockRecorder is the mock recorder for MockStreamI
type MockStreamIMockRecorder struct {
mock *MockStreamI
}
// NewMockStreamI creates a new mock instance
func NewMockStreamI(ctrl *gomock.Controller) *MockStreamI {
mock := &MockStreamI{ctrl: ctrl}
mock.recorder = &MockStreamIMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockStreamI) EXPECT() *MockStreamIMockRecorder {
return m.recorder
}
// CancelRead mocks base method
func (m *MockStreamI) CancelRead(arg0 protocol.ApplicationErrorCode) error {
ret := m.ctrl.Call(m, "CancelRead", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// CancelRead indicates an expected call of CancelRead
func (mr *MockStreamIMockRecorder) CancelRead(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockStreamI)(nil).CancelRead), arg0)
}
// CancelWrite mocks base method
func (m *MockStreamI) CancelWrite(arg0 protocol.ApplicationErrorCode) error {
ret := m.ctrl.Call(m, "CancelWrite", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// CancelWrite indicates an expected call of CancelWrite
func (mr *MockStreamIMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockStreamI)(nil).CancelWrite), arg0)
}
// Close mocks base method
func (m *MockStreamI) Close() error {
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close
func (mr *MockStreamIMockRecorder) Close() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStreamI)(nil).Close))
}
// Context mocks base method
func (m *MockStreamI) Context() context.Context {
ret := m.ctrl.Call(m, "Context")
ret0, _ := ret[0].(context.Context)
return ret0
}
// Context indicates an expected call of Context
func (mr *MockStreamIMockRecorder) Context() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockStreamI)(nil).Context))
}
// Read mocks base method
func (m *MockStreamI) Read(arg0 []byte) (int, error) {
ret := m.ctrl.Call(m, "Read", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Read indicates an expected call of Read
func (mr *MockStreamIMockRecorder) Read(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStreamI)(nil).Read), arg0)
}
// SetDeadline mocks base method
func (m *MockStreamI) SetDeadline(arg0 time.Time) error {
ret := m.ctrl.Call(m, "SetDeadline", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetDeadline indicates an expected call of SetDeadline
func (mr *MockStreamIMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockStreamI)(nil).SetDeadline), arg0)
}
// SetReadDeadline mocks base method
func (m *MockStreamI) SetReadDeadline(arg0 time.Time) error {
ret := m.ctrl.Call(m, "SetReadDeadline", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetReadDeadline indicates an expected call of SetReadDeadline
func (mr *MockStreamIMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockStreamI)(nil).SetReadDeadline), arg0)
}
// SetWriteDeadline mocks base method
func (m *MockStreamI) SetWriteDeadline(arg0 time.Time) error {
ret := m.ctrl.Call(m, "SetWriteDeadline", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetWriteDeadline indicates an expected call of SetWriteDeadline
func (mr *MockStreamIMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStreamI)(nil).SetWriteDeadline), arg0)
}
// StreamID mocks base method
func (m *MockStreamI) StreamID() protocol.StreamID {
ret := m.ctrl.Call(m, "StreamID")
ret0, _ := ret[0].(protocol.StreamID)
return ret0
}
// StreamID indicates an expected call of StreamID
func (mr *MockStreamIMockRecorder) StreamID() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockStreamI)(nil).StreamID))
}
// Write mocks base method
func (m *MockStreamI) Write(arg0 []byte) (int, error) {
ret := m.ctrl.Call(m, "Write", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Write indicates an expected call of Write
func (mr *MockStreamIMockRecorder) Write(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStreamI)(nil).Write), arg0)
}
// closeForShutdown mocks base method
func (m *MockStreamI) closeForShutdown(arg0 error) {
m.ctrl.Call(m, "closeForShutdown", arg0)
}
// closeForShutdown indicates an expected call of closeForShutdown
func (mr *MockStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockStreamI)(nil).closeForShutdown), arg0)
}
// getWindowUpdate mocks base method
func (m *MockStreamI) getWindowUpdate() protocol.ByteCount {
ret := m.ctrl.Call(m, "getWindowUpdate")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
// getWindowUpdate indicates an expected call of getWindowUpdate
func (mr *MockStreamIMockRecorder) getWindowUpdate() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockStreamI)(nil).getWindowUpdate))
}
// handleMaxStreamDataFrame mocks base method
func (m *MockStreamI) handleMaxStreamDataFrame(arg0 *wire.MaxStreamDataFrame) {
m.ctrl.Call(m, "handleMaxStreamDataFrame", arg0)
}
// handleMaxStreamDataFrame indicates an expected call of handleMaxStreamDataFrame
func (mr *MockStreamIMockRecorder) handleMaxStreamDataFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleMaxStreamDataFrame", reflect.TypeOf((*MockStreamI)(nil).handleMaxStreamDataFrame), arg0)
}
// handleRstStreamFrame mocks base method
func (m *MockStreamI) handleRstStreamFrame(arg0 *wire.RstStreamFrame) error {
ret := m.ctrl.Call(m, "handleRstStreamFrame", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// handleRstStreamFrame indicates an expected call of handleRstStreamFrame
func (mr *MockStreamIMockRecorder) handleRstStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleRstStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleRstStreamFrame), arg0)
}
// handleStopSendingFrame mocks base method
func (m *MockStreamI) handleStopSendingFrame(arg0 *wire.StopSendingFrame) {
m.ctrl.Call(m, "handleStopSendingFrame", arg0)
}
// handleStopSendingFrame indicates an expected call of handleStopSendingFrame
func (mr *MockStreamIMockRecorder) handleStopSendingFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockStreamI)(nil).handleStopSendingFrame), arg0)
}
// handleStreamFrame mocks base method
func (m *MockStreamI) handleStreamFrame(arg0 *wire.StreamFrame) error {
ret := m.ctrl.Call(m, "handleStreamFrame", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// handleStreamFrame indicates an expected call of handleStreamFrame
func (mr *MockStreamIMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleStreamFrame), arg0)
}
// popStreamFrame mocks base method
func (m *MockStreamI) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) {
ret := m.ctrl.Call(m, "popStreamFrame", arg0)
ret0, _ := ret[0].(*wire.StreamFrame)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// popStreamFrame indicates an expected call of popStreamFrame
func (mr *MockStreamIMockRecorder) popStreamFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockStreamI)(nil).popStreamFrame), arg0)
}

View File

@ -0,0 +1,146 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: StreamManager)
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockStreamManager is a mock of StreamManager interface
type MockStreamManager struct {
ctrl *gomock.Controller
recorder *MockStreamManagerMockRecorder
}
// MockStreamManagerMockRecorder is the mock recorder for MockStreamManager
type MockStreamManagerMockRecorder struct {
mock *MockStreamManager
}
// NewMockStreamManager creates a new mock instance
func NewMockStreamManager(ctrl *gomock.Controller) *MockStreamManager {
mock := &MockStreamManager{ctrl: ctrl}
mock.recorder = &MockStreamManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockStreamManager) EXPECT() *MockStreamManagerMockRecorder {
return m.recorder
}
// AcceptStream mocks base method
func (m *MockStreamManager) AcceptStream() (Stream, error) {
ret := m.ctrl.Call(m, "AcceptStream")
ret0, _ := ret[0].(Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AcceptStream indicates an expected call of AcceptStream
func (mr *MockStreamManagerMockRecorder) AcceptStream() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream))
}
// CloseWithError mocks base method
func (m *MockStreamManager) CloseWithError(arg0 error) {
m.ctrl.Call(m, "CloseWithError", arg0)
}
// CloseWithError indicates an expected call of CloseWithError
func (mr *MockStreamManagerMockRecorder) CloseWithError(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockStreamManager)(nil).CloseWithError), arg0)
}
// DeleteStream mocks base method
func (m *MockStreamManager) DeleteStream(arg0 protocol.StreamID) error {
ret := m.ctrl.Call(m, "DeleteStream", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteStream indicates an expected call of DeleteStream
func (mr *MockStreamManagerMockRecorder) DeleteStream(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStream", reflect.TypeOf((*MockStreamManager)(nil).DeleteStream), arg0)
}
// GetOrOpenReceiveStream mocks base method
func (m *MockStreamManager) GetOrOpenReceiveStream(arg0 protocol.StreamID) (receiveStreamI, error) {
ret := m.ctrl.Call(m, "GetOrOpenReceiveStream", arg0)
ret0, _ := ret[0].(receiveStreamI)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream
func (mr *MockStreamManagerMockRecorder) GetOrOpenReceiveStream(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenReceiveStream), arg0)
}
// GetOrOpenSendStream mocks base method
func (m *MockStreamManager) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStreamI, error) {
ret := m.ctrl.Call(m, "GetOrOpenSendStream", arg0)
ret0, _ := ret[0].(sendStreamI)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream
func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0)
}
// GetOrOpenStream mocks base method
func (m *MockStreamManager) GetOrOpenStream(arg0 protocol.StreamID) (streamI, error) {
ret := m.ctrl.Call(m, "GetOrOpenStream", arg0)
ret0, _ := ret[0].(streamI)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrOpenStream indicates an expected call of GetOrOpenStream
func (mr *MockStreamManagerMockRecorder) GetOrOpenStream(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenStream), arg0)
}
// OpenStream mocks base method
func (m *MockStreamManager) OpenStream() (Stream, error) {
ret := m.ctrl.Call(m, "OpenStream")
ret0, _ := ret[0].(Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenStream indicates an expected call of OpenStream
func (mr *MockStreamManagerMockRecorder) OpenStream() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockStreamManager)(nil).OpenStream))
}
// OpenStreamSync mocks base method
func (m *MockStreamManager) OpenStreamSync() (Stream, error) {
ret := m.ctrl.Call(m, "OpenStreamSync")
ret0, _ := ret[0].(Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenStreamSync indicates an expected call of OpenStreamSync
func (mr *MockStreamManagerMockRecorder) OpenStreamSync() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync))
}
// UpdateLimits mocks base method
func (m *MockStreamManager) UpdateLimits(arg0 *handshake.TransportParameters) {
m.ctrl.Call(m, "UpdateLimits", arg0)
}
// UpdateLimits indicates an expected call of UpdateLimits
func (mr *MockStreamManagerMockRecorder) UpdateLimits(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLimits", reflect.TypeOf((*MockStreamManager)(nil).UpdateLimits), arg0)
}

View File

@ -0,0 +1,76 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: StreamSender)
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
wire "github.com/lucas-clemente/quic-go/internal/wire"
)
// MockStreamSender is a mock of StreamSender interface
type MockStreamSender struct {
ctrl *gomock.Controller
recorder *MockStreamSenderMockRecorder
}
// MockStreamSenderMockRecorder is the mock recorder for MockStreamSender
type MockStreamSenderMockRecorder struct {
mock *MockStreamSender
}
// NewMockStreamSender creates a new mock instance
func NewMockStreamSender(ctrl *gomock.Controller) *MockStreamSender {
mock := &MockStreamSender{ctrl: ctrl}
mock.recorder = &MockStreamSenderMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockStreamSender) EXPECT() *MockStreamSenderMockRecorder {
return m.recorder
}
// onHasStreamData mocks base method
func (m *MockStreamSender) onHasStreamData(arg0 protocol.StreamID) {
m.ctrl.Call(m, "onHasStreamData", arg0)
}
// onHasStreamData indicates an expected call of onHasStreamData
func (mr *MockStreamSenderMockRecorder) onHasStreamData(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamData", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamData), arg0)
}
// onHasWindowUpdate mocks base method
func (m *MockStreamSender) onHasWindowUpdate(arg0 protocol.StreamID) {
m.ctrl.Call(m, "onHasWindowUpdate", arg0)
}
// onHasWindowUpdate indicates an expected call of onHasWindowUpdate
func (mr *MockStreamSenderMockRecorder) onHasWindowUpdate(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasWindowUpdate", reflect.TypeOf((*MockStreamSender)(nil).onHasWindowUpdate), arg0)
}
// onStreamCompleted mocks base method
func (m *MockStreamSender) onStreamCompleted(arg0 protocol.StreamID) {
m.ctrl.Call(m, "onStreamCompleted", arg0)
}
// onStreamCompleted indicates an expected call of onStreamCompleted
func (mr *MockStreamSenderMockRecorder) onStreamCompleted(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onStreamCompleted", reflect.TypeOf((*MockStreamSender)(nil).onStreamCompleted), arg0)
}
// queueControlFrame mocks base method
func (m *MockStreamSender) queueControlFrame(arg0 wire.Frame) {
m.ctrl.Call(m, "queueControlFrame", arg0)
}
// queueControlFrame indicates an expected call of queueControlFrame
func (mr *MockStreamSenderMockRecorder) queueControlFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "queueControlFrame", reflect.TypeOf((*MockStreamSender)(nil).queueControlFrame), arg0)
}

12
vendor/github.com/lucas-clemente/quic-go/mockgen.go generated vendored Normal file
View File

@ -0,0 +1,12 @@
package quic
//go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/lucas-clemente/quic-go streamI StreamI"
//go:generate sh -c "./mockgen_private.sh quic mock_receive_stream_internal_test.go github.com/lucas-clemente/quic-go receiveStreamI ReceiveStreamI"
//go:generate sh -c "./mockgen_private.sh quic mock_send_stream_internal_test.go github.com/lucas-clemente/quic-go sendStreamI SendStreamI"
//go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender StreamSender"
//go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter StreamGetter"
//go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource StreamFrameSource"
//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStreamI CryptoStream"
//go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager StreamManager"
//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_stream_getter_test.go mock_stream_manager_test.go"
//go:generate sh -c "goimports -w mock*_test.go"

19
vendor/github.com/lucas-clemente/quic-go/mockgen_private.sh generated vendored Executable file
View File

@ -0,0 +1,19 @@
#!/bin/bash
# Mockgen refuses to generate mocks private types.
# This script copies the quic package to a temporary directory, and adds an public alias for the private type.
# It then creates a mock for this public (alias) type.
TEMP_DIR=$(mktemp -d)
mkdir -p $TEMP_DIR/src/github.com/lucas-clemente/quic-go/
# copy all .go files to a temporary directory
# golang.org/x/crypto/curve25519/ uses Go compiler directives, which is confusing to mockgen
rsync -r --exclude 'vendor/golang.org/x/crypto/curve25519/' --include='*.go' --include '*/' --exclude '*' $GOPATH/src/github.com/lucas-clemente/quic-go/ $TEMP_DIR/src/github.com/lucas-clemente/quic-go/
echo "type $5 = $4" >> $TEMP_DIR/src/github.com/lucas-clemente/quic-go/interface.go
export GOPATH="$TEMP_DIR:$GOPATH"
mockgen -package $1 -self_package $1 -destination $2 $3 $5
rm -r "$TEMP_DIR"

View File

@ -4,6 +4,7 @@ import (
"bytes"
"errors"
"fmt"
"sync"
"github.com/lucas-clemente/quic-go/ackhandler"
"github.com/lucas-clemente/quic-go/internal/handshake"
@ -18,6 +19,12 @@ type packedPacket struct {
encryptionLevel protocol.EncryptionLevel
}
type streamFrameSource interface {
HasCryptoStreamData() bool
PopCryptoStreamFrame(protocol.ByteCount) *wire.StreamFrame
PopStreamFrames(protocol.ByteCount) []*wire.StreamFrame
}
type packetPacker struct {
connectionID protocol.ConnectionID
perspective protocol.Perspective
@ -25,20 +32,23 @@ type packetPacker struct {
cryptoSetup handshake.CryptoSetup
packetNumberGenerator *packetNumberGenerator
streamFramer *streamFramer
streams streamFrameSource
controlFrames []wire.Frame
stopWaiting *wire.StopWaitingFrame
ackFrame *wire.AckFrame
leastUnacked protocol.PacketNumber
omitConnectionID bool
hasSentPacket bool // has the packetPacker already sent a packet
controlFrameMutex sync.Mutex
controlFrames []wire.Frame
stopWaiting *wire.StopWaitingFrame
ackFrame *wire.AckFrame
leastUnacked protocol.PacketNumber
omitConnectionID bool
hasSentPacket bool // has the packetPacker already sent a packet
makeNextPacketRetransmittable bool
}
func newPacketPacker(connectionID protocol.ConnectionID,
initialPacketNumber protocol.PacketNumber,
cryptoSetup handshake.CryptoSetup,
streamFramer *streamFramer,
streamFramer streamFrameSource,
perspective protocol.Perspective,
version protocol.VersionNumber,
) *packetPacker {
@ -47,7 +57,7 @@ func newPacketPacker(connectionID protocol.ConnectionID,
connectionID: connectionID,
perspective: perspective,
version: version,
streamFramer: streamFramer,
streams: streamFramer,
packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength),
}
}
@ -73,7 +83,7 @@ func (p *packetPacker) PackAckPacket() (*packedPacket, error) {
encLevel, sealer := p.cryptoSetup.GetSealer()
header := p.getHeader(encLevel)
frames := []wire.Frame{p.ackFrame}
if p.stopWaiting != nil {
if p.stopWaiting != nil { // a STOP_WAITING will only be queued when using gQUIC
p.stopWaiting.PacketNumber = header.PacketNumber
p.stopWaiting.PacketNumberLen = header.PacketNumberLen
frames = append(frames, p.stopWaiting)
@ -98,14 +108,20 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (*
if err != nil {
return nil, err
}
if p.stopWaiting == nil {
return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame")
}
header := p.getHeader(packet.EncryptionLevel)
p.stopWaiting.PacketNumber = header.PacketNumber
p.stopWaiting.PacketNumberLen = header.PacketNumberLen
frames := append([]wire.Frame{p.stopWaiting}, packet.Frames...)
p.stopWaiting = nil
var frames []wire.Frame
if !p.version.UsesIETFFrameFormat() { // for gQUIC: pack a STOP_WAITING first
if p.stopWaiting == nil {
return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame")
}
swf := p.stopWaiting
swf.PacketNumber = header.PacketNumber
swf.PacketNumberLen = header.PacketNumberLen
p.stopWaiting = nil
frames = append([]wire.Frame{swf}, packet.Frames...)
} else {
frames = packet.Frames
}
raw, err := p.writeAndSealPacket(header, frames, sealer)
return &packedPacket{
header: header,
@ -118,7 +134,7 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (*
// PackPacket packs a new packet
// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
func (p *packetPacker) PackPacket() (*packedPacket, error) {
hasCryptoStreamFrame := p.streamFramer.HasCryptoStreamFrame()
hasCryptoStreamFrame := p.streams.HasCryptoStreamData()
// if this is the first packet to be send, make sure it contains stream data
if !p.hasSentPacket && !hasCryptoStreamFrame {
return nil, nil
@ -153,6 +169,15 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) {
if len(payloadFrames) == 1 && p.stopWaiting != nil {
return nil, nil
}
// check if this packet only contains an ACK and / or STOP_WAITING
if !ackhandler.HasRetransmittableFrames(payloadFrames) {
if p.makeNextPacketRetransmittable {
payloadFrames = append(payloadFrames, &wire.PingFrame{})
p.makeNextPacketRetransmittable = false
}
} else { // this packet already contains a retransmittable frame. No need to send a PING
p.makeNextPacketRetransmittable = false
}
p.stopWaiting = nil
p.ackFrame = nil
@ -176,7 +201,9 @@ func (p *packetPacker) packCryptoPacket() (*packedPacket, error) {
return nil, err
}
maxLen := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength
frames := []wire.Frame{p.streamFramer.PopCryptoStreamFrame(maxLen)}
sf := p.streams.PopCryptoStreamFrame(maxLen)
sf.DataLenPresent = false
frames := []wire.Frame{sf}
raw, err := p.writeAndSealPacket(header, frames, sealer)
if err != nil {
return nil, err
@ -197,29 +224,20 @@ func (p *packetPacker) composeNextPacket(
var payloadFrames []wire.Frame
// STOP_WAITING and ACK will always fit
if p.stopWaiting != nil {
payloadFrames = append(payloadFrames, p.stopWaiting)
l, err := p.stopWaiting.MinLength(p.version)
if err != nil {
return nil, err
}
if p.ackFrame != nil { // ACKs need to go first, so that the sentPacketHandler will recognize them
payloadFrames = append(payloadFrames, p.ackFrame)
l := p.ackFrame.MinLength(p.version)
payloadLength += l
}
if p.ackFrame != nil {
payloadFrames = append(payloadFrames, p.ackFrame)
l, err := p.ackFrame.MinLength(p.version)
if err != nil {
return nil, err
}
payloadLength += l
if p.stopWaiting != nil { // a STOP_WAITING will only be queued when using gQUIC
payloadFrames = append(payloadFrames, p.stopWaiting)
payloadLength += p.stopWaiting.MinLength(p.version)
}
p.controlFrameMutex.Lock()
for len(p.controlFrames) > 0 {
frame := p.controlFrames[len(p.controlFrames)-1]
minLength, err := frame.MinLength(p.version)
if err != nil {
return nil, err
}
minLength := frame.MinLength(p.version)
if payloadLength+minLength > maxFrameSize {
break
}
@ -227,6 +245,7 @@ func (p *packetPacker) composeNextPacket(
payloadLength += minLength
p.controlFrames = p.controlFrames[:len(p.controlFrames)-1]
}
p.controlFrameMutex.Unlock()
if payloadLength > maxFrameSize {
return nil, fmt.Errorf("Packet Packer BUG: packet payload (%d) too large (%d)", payloadLength, maxFrameSize)
@ -247,20 +266,14 @@ func (p *packetPacker) composeNextPacket(
maxFrameSize += 2
}
fs := p.streamFramer.PopStreamFrames(maxFrameSize - payloadLength)
fs := p.streams.PopStreamFrames(maxFrameSize - payloadLength)
if len(fs) != 0 {
fs[len(fs)-1].DataLenPresent = false
}
// TODO: Simplify
for _, f := range fs {
payloadFrames = append(payloadFrames, f)
}
for b := p.streamFramer.PopBlockedFrame(); b != nil; b = p.streamFramer.PopBlockedFrame() {
p.controlFrames = append(p.controlFrames, b)
}
return payloadFrames, nil
}
@ -271,7 +284,9 @@ func (p *packetPacker) QueueControlFrame(frame wire.Frame) {
case *wire.AckFrame:
p.ackFrame = f
default:
p.controlFrameMutex.Lock()
p.controlFrames = append(p.controlFrames, f)
p.controlFrameMutex.Unlock()
}
}
@ -377,3 +392,7 @@ func (p *packetPacker) SetLeastUnacked(leastUnacked protocol.PacketNumber) {
func (p *packetPacker) SetOmitConnectionID() {
p.omitConnectionID = true
}
func (p *packetPacker) MakeNextPacketRetransmittable() {
p.makeNextPacketRetransmittable = true
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"math"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/ackhandler"
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/handshake"
@ -49,29 +50,32 @@ func (m *mockCryptoSetup) GetSealerWithEncryptionLevel(protocol.EncryptionLevel)
}
func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.divNonce }
func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce = divNonce }
func (m *mockCryptoSetup) ConnectionState() ConnectionState { panic("not implemented") }
var _ = Describe("Packet packer", func() {
var (
packer *packetPacker
publicHeaderLen protocol.ByteCount
maxFrameSize protocol.ByteCount
streamFramer *streamFramer
cryptoStream *stream
packer *packetPacker
publicHeaderLen protocol.ByteCount
maxFrameSize protocol.ByteCount
cryptoStream cryptoStreamI
mockStreamFramer *MockStreamFrameSource
)
BeforeEach(func() {
version := versionGQUICFrames
cryptoStream = &stream{streamID: version.CryptoStreamID(), flowController: flowcontrol.NewStreamFlowController(version.CryptoStreamID(), false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil)}
streamsMap := newStreamsMap(nil, protocol.PerspectiveServer, versionGQUICFrames)
streamFramer = newStreamFramer(cryptoStream, streamsMap, nil, versionGQUICFrames)
mockSender := NewMockStreamSender(mockCtrl)
mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes()
cryptoStream = newCryptoStream(mockSender, flowcontrol.NewStreamFlowController(version.CryptoStreamID(), false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil), version)
mockStreamFramer = NewMockStreamFrameSource(mockCtrl)
packer = &packetPacker{
cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure},
connectionID: 0x1337,
packetNumberGenerator: newPacketNumberGenerator(1, protocol.SkipPacketAveragePeriodLength),
streamFramer: streamFramer,
perspective: protocol.PerspectiveServer,
}
packer = newPacketPacker(
0x1337,
1,
&mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure},
mockStreamFramer,
protocol.PerspectiveServer,
version,
)
publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number
maxFrameSize = protocol.MaxPacketSize - protocol.ByteCount((&mockSealer{}).Overhead()) - publicHeaderLen
packer.hasSentPacket = true
@ -79,33 +83,36 @@ var _ = Describe("Packet packer", func() {
})
It("returns nil when no packet is queued", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
p, err := packer.PackPacket()
Expect(p).To(BeNil())
Expect(err).ToNot(HaveOccurred())
})
It("packs single packets", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
f := &wire.StreamFrame{
StreamID: 5,
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
}
streamFramer.AddFrameForRetransmission(f)
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f})
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
b := &bytes.Buffer{}
f.Write(b, packer.version)
Expect(p.frames).To(HaveLen(1))
Expect(p.frames).To(Equal([]wire.Frame{f}))
Expect(p.raw).To(ContainSubstring(string(b.Bytes())))
})
It("stores the encryption level a packet was sealed with", func() {
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionForwardSecure
f := &wire.StreamFrame{
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{{
StreamID: 5,
Data: []byte("foobar"),
}
streamFramer.AddFrameForRetransmission(f)
}})
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionForwardSecure
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure))
@ -213,7 +220,7 @@ var _ = Describe("Packet packer", func() {
})
})
It("packs a ConnectionClose", func() {
It("packs a CONNECTION_CLOSE", func() {
ccf := wire.ConnectionCloseFrame{
ErrorCode: 0x1337,
ReasonPhrase: "foobar",
@ -224,23 +231,21 @@ var _ = Describe("Packet packer", func() {
Expect(p.frames[0]).To(Equal(&ccf))
})
It("doesn't send any other frames when sending a ConnectionClose", func() {
ccf := wire.ConnectionCloseFrame{
It("doesn't send any other frames when sending a CONNECTION_CLOSE", func() {
// expect no mockStreamFramer.PopStreamFrames
ccf := &wire.ConnectionCloseFrame{
ErrorCode: 0x1337,
ReasonPhrase: "foobar",
}
packer.controlFrames = []wire.Frame{&wire.MaxStreamDataFrame{StreamID: 37}}
streamFramer.AddFrameForRetransmission(&wire.StreamFrame{
StreamID: 5,
Data: []byte("foobar"),
})
p, err := packer.PackConnectionClose(&ccf)
p, err := packer.PackConnectionClose(ccf)
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0]).To(Equal(&ccf))
Expect(p.frames).To(Equal([]wire.Frame{ccf}))
})
It("packs only control frames", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
packer.QueueControlFrame(&wire.RstStreamFrame{})
packer.QueueControlFrame(&wire.MaxDataFrame{})
p, err := packer.PackPacket()
@ -251,6 +256,8 @@ var _ = Describe("Packet packer", func() {
})
It("increases the packet number", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2)
packer.QueueControlFrame(&wire.RstStreamFrame{})
p1, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
@ -263,6 +270,8 @@ var _ = Describe("Packet packer", func() {
})
It("packs a STOP_WAITING frame first", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
packer.packetNumberGenerator.next = 15
swf := &wire.StopWaitingFrame{LeastUnacked: 10}
packer.QueueControlFrame(&wire.RstStreamFrame{})
@ -275,6 +284,8 @@ var _ = Describe("Packet packer", func() {
})
It("sets the LeastUnackedDelta length of a STOP_WAITING frame", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number
packer.packetNumberGenerator.next = packetNumber
swf := &wire.StopWaitingFrame{LeastUnacked: packetNumber - 0x100}
@ -286,6 +297,8 @@ var _ = Describe("Packet packer", func() {
})
It("does not pack a packet containing only a STOP_WAITING frame", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
swf := &wire.StopWaitingFrame{LeastUnacked: 10}
packer.QueueControlFrame(swf)
p, err := packer.PackPacket()
@ -294,6 +307,8 @@ var _ = Describe("Packet packer", func() {
})
It("packs a packet if it has queued control frames, but no new control frames", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}}
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
@ -301,6 +316,7 @@ var _ = Describe("Packet packer", func() {
})
It("refuses to send a packet that doesn't contain crypto stream data, if it has never sent a packet before", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
packer.hasSentPacket = false
packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}}
p, err := packer.PackPacket()
@ -329,8 +345,7 @@ var _ = Describe("Packet packer", func() {
It("packs a lot of control frames into 2 packets if they don't fit into one", func() {
blockedFrame := &wire.BlockedFrame{}
minLength, _ := blockedFrame.MinLength(packer.version)
maxFramesPerPacket := int(maxFrameSize) / int(minLength)
maxFramesPerPacket := int(maxFrameSize) / int(blockedFrame.MinLength(packer.version))
var controlFrames []wire.Frame
for i := 0; i < maxFramesPerPacket+10; i++ {
controlFrames = append(controlFrames, blockedFrame)
@ -345,16 +360,17 @@ var _ = Describe("Packet packer", func() {
})
It("only increases the packet number when there is an actual packet to send", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
packer.packetNumberGenerator.nextToSkip = 1000
p, err := packer.PackPacket()
Expect(p).To(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(1)))
f := &wire.StreamFrame{
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{{
StreamID: 5,
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
}
streamFramer.AddFrameForRetransmission(f)
}})
p, err = packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
@ -362,320 +378,207 @@ var _ = Describe("Packet packer", func() {
Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(2)))
})
Context("STREAM Frame handling", func() {
It("adds a PING frame when it's supposed to send a retransmittable packet", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2)
packer.QueueControlFrame(&wire.AckFrame{})
packer.QueueControlFrame(&wire.StopWaitingFrame{})
packer.MakeNextPacketRetransmittable()
p, err := packer.PackPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(3))
Expect(p.frames).To(ContainElement(&wire.PingFrame{}))
// make sure the next packet doesn't contain another PING
packer.QueueControlFrame(&wire.AckFrame{})
p, err = packer.PackPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
})
It("waits until there's something to send before adding a PING frame", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2)
packer.MakeNextPacketRetransmittable()
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil())
packer.QueueControlFrame(&wire.AckFrame{})
p, err = packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(2))
Expect(p.frames).To(ContainElement(&wire.PingFrame{}))
})
It("doesn't send a PING if it already sent another retransmittable frame", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2)
packer.MakeNextPacketRetransmittable()
packer.QueueControlFrame(&wire.MaxDataFrame{})
p, err := packer.PackPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
packer.QueueControlFrame(&wire.AckFrame{})
p, err = packer.PackPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
})
Context("STREAM frame handling", func() {
It("does not splits a STREAM frame with maximum size, for gQUIC frames", func() {
f := &wire.StreamFrame{
Offset: 1,
StreamID: 5,
DataLenPresent: false,
}
minLength, _ := f.MinLength(packer.version)
maxStreamFrameDataLen := maxFrameSize - minLength
f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen))
streamFramer.AddFrameForRetransmission(f)
payloadFrames, err := packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(1))
Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(BeEmpty())
})
It("does not splits a STREAM frame with maximum size, for IETF draft style frame", func() {
packer.version = versionIETFFrames
streamFramer.version = versionIETFFrames
f := &wire.StreamFrame{
Offset: 1,
StreamID: 5,
DataLenPresent: true,
}
minLength, _ := f.MinLength(packer.version)
// for IETF draft style STREAM frames, we don't know the size of the DataLen, because it is a variable length integer
// in the general case, we therefore use a STREAM frame that is 1 byte smaller than the maximum size
maxStreamFrameDataLen := maxFrameSize - minLength - 1
f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen))
streamFramer.AddFrameForRetransmission(f)
payloadFrames, err := packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(1))
Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(BeEmpty())
})
It("correctly handles a STREAM frame with one byte less than maximum size", func() {
maxStreamFrameDataLen := maxFrameSize - (1 + 1 + 2) - 1
f1 := &wire.StreamFrame{
StreamID: 5,
Offset: 1,
Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)),
}
f2 := &wire.StreamFrame{
StreamID: 5,
Offset: 1,
Data: []byte("foobar"),
}
streamFramer.AddFrameForRetransmission(f1)
streamFramer.AddFrameForRetransmission(f2)
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize - 1)))
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
p, err = packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
})
It("packs multiple small STREAM frames into single packet", func() {
f1 := &wire.StreamFrame{
StreamID: 5,
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
}
f2 := &wire.StreamFrame{
StreamID: 5,
Data: []byte{0xBE, 0xEF, 0x13, 0x37},
}
f3 := &wire.StreamFrame{
StreamID: 3,
Data: []byte{0xCA, 0xFE},
}
streamFramer.AddFrameForRetransmission(f1)
streamFramer.AddFrameForRetransmission(f2)
streamFramer.AddFrameForRetransmission(f3)
p, err := packer.PackPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
b := &bytes.Buffer{}
f1.Write(b, 0)
f2.Write(b, 0)
f3.Write(b, 0)
Expect(p.frames).To(HaveLen(3))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue())
Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeTrue())
Expect(p.frames[2].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
Expect(p.raw).To(ContainSubstring(string(f1.Data)))
Expect(p.raw).To(ContainSubstring(string(f2.Data)))
Expect(p.raw).To(ContainSubstring(string(f3.Data)))
})
It("splits one STREAM frame larger than maximum size", func() {
f := &wire.StreamFrame{
StreamID: 7,
Offset: 1,
}
minLength, _ := f.MinLength(packer.version)
maxStreamFrameDataLen := maxFrameSize - minLength
f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+200)
streamFramer.AddFrameForRetransmission(f)
payloadFrames, err := packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(1))
Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
Expect(payloadFrames[0].(*wire.StreamFrame).Data).To(HaveLen(int(maxStreamFrameDataLen)))
payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(1))
Expect(payloadFrames[0].(*wire.StreamFrame).Data).To(HaveLen(200))
Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(BeEmpty())
})
It("packs 2 STREAM frames that are too big for one packet correctly", func() {
maxStreamFrameDataLen := maxFrameSize - (1 + 1 + 2)
f1 := &wire.StreamFrame{
StreamID: 5,
Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+100),
Offset: 1,
}
f2 := &wire.StreamFrame{
StreamID: 5,
Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+100),
Offset: 1,
}
streamFramer.AddFrameForRetransmission(f1)
streamFramer.AddFrameForRetransmission(f2)
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).DoAndReturn(func(maxSize protocol.ByteCount) []*wire.StreamFrame {
f := &wire.StreamFrame{
Offset: 1,
StreamID: 5,
DataLenPresent: true,
}
f.Data = bytes.Repeat([]byte{'f'}, int(maxSize-f.MinLength(packer.version)))
return []*wire.StreamFrame{f}
})
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
p, err = packer.PackPacket()
Expect(p.frames).To(HaveLen(2))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue())
Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
Expect(err).ToNot(HaveOccurred())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
p, err = packer.PackPacket()
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
p, err = packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil())
})
It("packs a packet that has the maximum packet size when given a large enough STREAM frame", func() {
f := &wire.StreamFrame{
StreamID: 5,
Offset: 1,
}
minLength, _ := f.MinLength(packer.version)
f.Data = bytes.Repeat([]byte{'f'}, int(maxFrameSize-minLength+1)) // + 1 since MinceLength is 1 bigger than the actual StreamFrame header
streamFramer.AddFrameForRetransmission(f)
It("does not splits a STREAM frame with maximum size, for IETF draft style frame", func() {
packer.version = versionIETFFrames
mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2)
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).DoAndReturn(func(maxSize protocol.ByteCount) []*wire.StreamFrame {
f := &wire.StreamFrame{
Offset: 1,
StreamID: 5,
DataLenPresent: true,
}
f.Data = bytes.Repeat([]byte{'f'}, int(maxSize-f.MinLength(packer.version)))
return []*wire.StreamFrame{f}
})
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
Expect(p.frames).To(HaveLen(1))
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
p, err = packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil())
})
It("splits a STREAM frame larger than the maximum size", func() {
f := &wire.StreamFrame{
StreamID: 5,
Offset: 1,
It("packs multiple small STREAM frames into single packet", func() {
f1 := &wire.StreamFrame{
StreamID: 5,
Data: []byte("frame 1"),
DataLenPresent: true,
}
minLength, _ := f.MinLength(packer.version)
f.Data = bytes.Repeat([]byte{'f'}, int(maxFrameSize-minLength+2)) // + 2 since MinceLength is 1 bigger than the actual StreamFrame header
streamFramer.AddFrameForRetransmission(f)
payloadFrames, err := packer.composeNextPacket(maxFrameSize, true)
f2 := &wire.StreamFrame{
StreamID: 5,
Data: []byte("frame 2"),
DataLenPresent: true,
}
f3 := &wire.StreamFrame{
StreamID: 3,
Data: []byte("frame 3"),
DataLenPresent: true,
}
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f1, f2, f3})
p, err := packer.PackPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(1))
payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(1))
Expect(p.frames).To(HaveLen(3))
Expect(p.frames[0].(*wire.StreamFrame).Data).To(Equal([]byte("frame 1")))
Expect(p.frames[1].(*wire.StreamFrame).Data).To(Equal([]byte("frame 2")))
Expect(p.frames[2].(*wire.StreamFrame).Data).To(Equal([]byte("frame 3")))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue())
Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeTrue())
Expect(p.frames[2].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
})
It("refuses to send unencrypted stream data on a data stream", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
// don't expect a call to mockStreamFramer.PopStreamFrames
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted
f := &wire.StreamFrame{
StreamID: 3,
Data: []byte("foobar"),
}
streamFramer.AddFrameForRetransmission(f)
p, err := packer.PackPacket()
Expect(err).NotTo(HaveOccurred())
Expect(p).To(BeNil())
})
It("sends non forward-secure data as the client", func() {
packer.perspective = protocol.PerspectiveClient
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure
f := &wire.StreamFrame{
StreamID: 5,
Data: []byte("foobar"),
}
streamFramer.AddFrameForRetransmission(f)
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f})
packer.perspective = protocol.PerspectiveClient
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure))
Expect(p.frames[0]).To(Equal(f))
Expect(p.frames).To(Equal([]wire.Frame{f}))
})
It("does not send non forward-secure data as the server", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
// don't expect a call to mockStreamFramer.PopStreamFrames
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure
f := &wire.StreamFrame{
StreamID: 5,
Data: []byte("foobar"),
}
streamFramer.AddFrameForRetransmission(f)
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil())
})
It("sends unencrypted stream data on the crypto stream", func() {
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted
cryptoStream.dataForWriting = []byte("foobar")
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted))
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0]).To(Equal(&wire.StreamFrame{
f := &wire.StreamFrame{
StreamID: packer.version.CryptoStreamID(),
Data: []byte("foobar"),
}))
}
mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true)
mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f)
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(Equal([]wire.Frame{f}))
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted))
})
It("sends encrypted stream data on the crypto stream", func() {
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionSecure
cryptoStream.dataForWriting = []byte("foobar")
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure))
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0]).To(Equal(&wire.StreamFrame{
f := &wire.StreamFrame{
StreamID: packer.version.CryptoStreamID(),
Data: []byte("foobar"),
}))
})
It("does not pack stream frames if not allowed", func() {
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted
packer.QueueControlFrame(&wire.AckFrame{})
streamFramer.AddFrameForRetransmission(&wire.StreamFrame{StreamID: 3, Data: []byte("foobar")})
}
mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true)
mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f)
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionSecure
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
Expect(func() { _ = p.frames[0].(*wire.AckFrame) }).NotTo(Panic())
Expect(p.frames).To(Equal([]wire.Frame{f}))
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure))
})
})
Context("BLOCKED frames", func() {
It("queues a BLOCKED frame", func() {
length := 100
streamFramer.blockedFrameQueue = []wire.Frame{&wire.StreamBlockedFrame{StreamID: 5}}
f := &wire.StreamFrame{
StreamID: 5,
Data: bytes.Repeat([]byte{'f'}, length),
}
streamFramer.AddFrameForRetransmission(f)
_, err := packer.composeNextPacket(maxFrameSize, true)
It("does not pack STREAM frames if not allowed", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
// don't expect a call to mockStreamFramer.PopStreamFrames
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted
ack := &wire.AckFrame{LargestAcked: 10}
packer.QueueControlFrame(ack)
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(packer.controlFrames[0]).To(Equal(&wire.StreamBlockedFrame{StreamID: 5}))
Expect(p.frames).To(Equal([]wire.Frame{ack}))
})
It("removes the dataLen attribute from the last StreamFrame, even if it queued a BLOCKED frame", func() {
length := 100
streamFramer.blockedFrameQueue = []wire.Frame{&wire.StreamBlockedFrame{StreamID: 5}}
f := &wire.StreamFrame{
StreamID: 5,
Data: bytes.Repeat([]byte{'f'}, length),
}
streamFramer.AddFrameForRetransmission(f)
p, err := packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(p).To(HaveLen(1))
Expect(p[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
})
It("packs a connection-level BlockedFrame", func() {
streamFramer.blockedFrameQueue = []wire.Frame{&wire.BlockedFrame{}}
f := &wire.StreamFrame{
StreamID: 5,
Data: []byte("foobar"),
}
streamFramer.AddFrameForRetransmission(f)
_, err := packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred())
Expect(packer.controlFrames[0]).To(Equal(&wire.BlockedFrame{}))
})
})
It("returns nil if we only have a single STOP_WAITING", func() {
packer.QueueControlFrame(&wire.StopWaitingFrame{})
p, err := packer.PackPacket()
Expect(err).NotTo(HaveOccurred())
Expect(p).To(BeNil())
})
It("packs a single ACK", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
ack := &wire.AckFrame{LargestAcked: 42}
packer.QueueControlFrame(ack)
p, err := packer.PackPacket()
@ -685,6 +588,8 @@ var _ = Describe("Packet packer", func() {
})
It("does not return nil if we only have a single ACK but request it to be sent", func() {
mockStreamFramer.EXPECT().HasCryptoStreamData()
mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any())
ack := &wire.AckFrame{}
packer.QueueControlFrame(ack)
p, err := packer.PackPacket()
@ -692,15 +597,6 @@ var _ = Describe("Packet packer", func() {
Expect(p).ToNot(BeNil())
})
It("queues a control frame to be sent in the next packet", func() {
msd := &wire.MaxStreamDataFrame{StreamID: 5}
packer.QueueControlFrame(msd)
p, err := packer.PackPacket()
Expect(err).NotTo(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0]).To(Equal(msd))
})
Context("retransmitting of handshake packets", func() {
swf := &wire.StopWaitingFrame{LeastUnacked: 1}
sf := &wire.StreamFrame{
@ -719,8 +615,19 @@ var _ = Describe("Packet packer", func() {
}
p, err := packer.PackHandshakeRetransmission(packet)
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(ContainElement(sf))
Expect(p.frames).To(ContainElement(swf))
Expect(p.frames).To(Equal([]wire.Frame{swf, sf}))
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted))
})
It("doesn't add a STOP_WAITING frame for IETF QUIC", func() {
packer.version = versionIETFFrames
packet := &ackhandler.Packet{
EncryptionLevel: protocol.EncryptionUnencrypted,
Frames: []wire.Frame{sf},
}
p, err := packer.PackHandshakeRetransmission(packet)
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(Equal([]wire.Frame{sf}))
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted))
})
@ -733,8 +640,7 @@ var _ = Describe("Packet packer", func() {
}
p, err := packer.PackHandshakeRetransmission(packet)
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(ContainElement(sf))
Expect(p.frames).To(ContainElement(swf))
Expect(p.frames).To(Equal([]wire.Frame{swf, sf}))
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure))
// a packet sent by the server with initial encryption contains the SHLO
// it needs to have a diversification nonce
@ -768,11 +674,16 @@ var _ = Describe("Packet packer", func() {
})
It("pads Initial packets to the required minimum packet size", func() {
f := &wire.StreamFrame{
StreamID: packer.version.CryptoStreamID(),
Data: []byte("foobar"),
}
mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true)
mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f)
packer.version = protocol.VersionTLS
packer.hasSentPacket = false
packer.perspective = protocol.PerspectiveClient
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted
cryptoStream.dataForWriting = []byte("foobar")
packet, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(packet.raw).To(HaveLen(protocol.MinInitialPacketSize))
@ -795,7 +706,7 @@ var _ = Describe("Packet packer", func() {
_, err := packer.PackHandshakeRetransmission(&ackhandler.Packet{
EncryptionLevel: protocol.EncryptionSecure,
})
Expect(err).To(MatchError("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame"))
Expect(err).To(MatchError("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame"))
})
})
@ -807,7 +718,7 @@ var _ = Describe("Packet packer", func() {
Expect(p.frames).To(Equal([]wire.Frame{&wire.AckFrame{DelayTime: math.MaxInt64}}))
})
It("packs ACK packets with SWFs", func() {
It("packs ACK packets with STOP_WAITING frames", func() {
packer.QueueControlFrame(&wire.AckFrame{})
packer.QueueControlFrame(&wire.StopWaitingFrame{})
p, err := packer.PackAckPacket()

View File

@ -107,21 +107,32 @@ func (u *packetUnpacker) parseIETFFrame(r *bytes.Reader, typeByte byte, hdr *wir
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
}
case 0x6:
// TODO(#964): remove STOP_WAITING frames
// TODO(#878): implement the MAX_STREAM_ID frame
frame, err = wire.ParseStopWaitingFrame(r, hdr.PacketNumber, hdr.PacketNumberLen, u.version)
frame, err = wire.ParseMaxStreamIDFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidStopWaitingData, err.Error())
err = qerr.Error(qerr.InvalidFrameData, err.Error())
}
case 0x7:
frame, err = wire.ParsePingFrame(r, u.version)
case 0x8:
frame, err = wire.ParseBlockedFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
}
case 0x9:
frame, err = wire.ParseStreamBlockedFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
}
case 0xa:
frame, err = wire.ParseStreamIDBlockedFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidFrameData, err.Error())
}
case 0xc:
frame, err = wire.ParseStopSendingFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidFrameData, err.Error())
}
case 0xe:
frame, err = wire.ParseAckFrame(r, u.version)
if err != nil {

View File

@ -102,7 +102,7 @@ var _ = Describe("Packet unpacker", func() {
f := &wire.RstStreamFrame{
StreamID: 0xdeadbeef,
ByteOffset: 0xdecafbad11223344,
ErrorCode: 0x13371234,
ErrorCode: 0x1337,
}
err := f.Write(buf, versionGQUICFrames)
Expect(err).ToNot(HaveOccurred())
@ -342,8 +342,19 @@ var _ = Describe("Packet unpacker", func() {
Expect(packet.frames).To(Equal([]wire.Frame{f}))
})
It("unpacks MAX_STREAM_ID frames", func() {
f := &wire.MaxStreamIDFrame{StreamID: 0x1337}
buf := &bytes.Buffer{}
err := f.Write(buf, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
setData(buf.Bytes())
packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]wire.Frame{f}))
})
It("unpacks connection-level BLOCKED frames", func() {
f := &wire.BlockedFrame{}
f := &wire.BlockedFrame{Offset: 0x1234}
buf := &bytes.Buffer{}
err := f.Write(buf, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
@ -354,7 +365,32 @@ var _ = Describe("Packet unpacker", func() {
})
It("unpacks stream-level BLOCKED frames", func() {
f := &wire.StreamBlockedFrame{StreamID: 0xdeadbeef}
f := &wire.StreamBlockedFrame{
StreamID: 0xdeadbeef,
Offset: 0xdead,
}
buf := &bytes.Buffer{}
err := f.Write(buf, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
setData(buf.Bytes())
packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]wire.Frame{f}))
})
It("unpacks STREAM_ID_BLOCKED frames", func() {
f := &wire.StreamIDBlockedFrame{StreamID: 0x1234567}
buf := &bytes.Buffer{}
err := f.Write(buf, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
setData(buf.Bytes())
packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]wire.Frame{f}))
})
It("unpacks STOP_SENDING frames", func() {
f := &wire.StopSendingFrame{StreamID: 0x42}
buf := &bytes.Buffer{}
err := f.Write(buf, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
@ -392,9 +428,13 @@ var _ = Describe("Packet unpacker", func() {
0x02: qerr.InvalidConnectionCloseData,
0x04: qerr.InvalidWindowUpdateData,
0x05: qerr.InvalidWindowUpdateData,
0x06: qerr.InvalidFrameData,
0x08: qerr.InvalidBlockedData,
0x09: qerr.InvalidBlockedData,
0x0a: qerr.InvalidFrameData,
0x0c: qerr.InvalidFrameData,
0x0e: qerr.InvalidAckData,
0x10: qerr.InvalidStreamData,
0xe: qerr.InvalidAckData,
} {
setData([]byte{b})
_, err := unpacker.Unpack(hdrBin, hdr, data)

View File

@ -1,8 +1,8 @@
// Code generated by "stringer -type=ErrorCode"; DO NOT EDIT
// Code generated by "stringer -type=ErrorCode"; DO NOT EDIT.
package qerr
import "fmt"
import "strconv"
const (
_ErrorCode_name_0 = "InternalErrorStreamDataAfterTerminationInvalidPacketHeaderInvalidFrameDataInvalidFecDataInvalidRstStreamDataInvalidConnectionCloseDataInvalidGoawayDataInvalidAckDataInvalidVersionNegotiationPacketInvalidPublicRstPacketDecryptionFailureEncryptionFailurePacketTooLarge"
@ -19,7 +19,6 @@ var (
_ErrorCode_index_2 = [...]uint16{0, 15, 37, 57, 75, 96, 112, 127, 147, 167, 191, 226, 250, 279, 309, 340, 366, 385, 410, 425, 445, 457, 475, 505, 530, 547}
_ErrorCode_index_3 = [...]uint16{0, 14, 29, 50, 65, 90, 119, 158, 184, 208, 231, 249, 279, 301, 322, 340, 366, 390, 425}
_ErrorCode_index_4 = [...]uint16{0, 16, 45, 78, 97, 114, 144, 169, 192, 215, 238, 256, 276, 292, 308, 346, 379, 410, 448, 459, 477, 498, 532}
_ErrorCode_index_5 = [...]uint8{0, 34}
)
func (i ErrorCode) String() string {
@ -42,6 +41,6 @@ func (i ErrorCode) String() string {
case i == 97:
return _ErrorCode_name_5
default:
return fmt.Sprintf("ErrorCode(%d)", i)
return "ErrorCode(" + strconv.FormatInt(int64(i), 10) + ")"
}
}

View File

@ -0,0 +1,284 @@
package quic
import (
"fmt"
"io"
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type receiveStreamI interface {
ReceiveStream
handleStreamFrame(*wire.StreamFrame) error
handleRstStreamFrame(*wire.RstStreamFrame) error
closeForShutdown(error)
getWindowUpdate() protocol.ByteCount
}
type receiveStream struct {
mutex sync.Mutex
streamID protocol.StreamID
sender streamSender
frameQueue *streamFrameSorter
readPosInFrame int
readOffset protocol.ByteCount
closeForShutdownErr error
cancelReadErr error
resetRemotelyErr StreamError
closedForShutdown bool // set when CloseForShutdown() is called
finRead bool // set once we read a frame with a FinBit
canceledRead bool // set when CancelRead() is called
resetRemotely bool // set when HandleRstStreamFrame() is called
readChan chan struct{}
readDeadline time.Time
flowController flowcontrol.StreamFlowController
version protocol.VersionNumber
}
var _ ReceiveStream = &receiveStream{}
var _ receiveStreamI = &receiveStream{}
func newReceiveStream(
streamID protocol.StreamID,
sender streamSender,
flowController flowcontrol.StreamFlowController,
) *receiveStream {
return &receiveStream{
streamID: streamID,
sender: sender,
flowController: flowController,
frameQueue: newStreamFrameSorter(),
readChan: make(chan struct{}, 1),
}
}
func (s *receiveStream) StreamID() protocol.StreamID {
return s.streamID
}
// Read implements io.Reader. It is not thread safe!
func (s *receiveStream) Read(p []byte) (int, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.finRead {
return 0, io.EOF
}
if s.canceledRead {
return 0, s.cancelReadErr
}
if s.resetRemotely {
return 0, s.resetRemotelyErr
}
if s.closedForShutdown {
return 0, s.closeForShutdownErr
}
bytesRead := 0
for bytesRead < len(p) {
frame := s.frameQueue.Head()
if frame == nil && bytesRead > 0 {
return bytesRead, s.closeForShutdownErr
}
for {
// Stop waiting on errors
if s.closedForShutdown {
return bytesRead, s.closeForShutdownErr
}
if s.canceledRead {
return bytesRead, s.cancelReadErr
}
if s.resetRemotely {
return bytesRead, s.resetRemotelyErr
}
deadline := s.readDeadline
if !deadline.IsZero() && !time.Now().Before(deadline) {
return bytesRead, errDeadline
}
if frame != nil {
s.readPosInFrame = int(s.readOffset - frame.Offset)
break
}
s.mutex.Unlock()
if deadline.IsZero() {
<-s.readChan
} else {
select {
case <-s.readChan:
case <-time.After(deadline.Sub(time.Now())):
}
}
s.mutex.Lock()
frame = s.frameQueue.Head()
}
if bytesRead > len(p) {
return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p))
}
if s.readPosInFrame > int(frame.DataLen()) {
return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, frame.DataLen())
}
s.mutex.Unlock()
copy(p[bytesRead:], frame.Data[s.readPosInFrame:])
m := utils.Min(len(p)-bytesRead, int(frame.DataLen())-s.readPosInFrame)
s.readPosInFrame += m
bytesRead += m
s.readOffset += protocol.ByteCount(m)
s.mutex.Lock()
// when a RST_STREAM was received, the was already informed about the final byteOffset for this stream
if !s.resetRemotely {
s.flowController.AddBytesRead(protocol.ByteCount(m))
}
// this call triggers the flow controller to increase the flow control window, if necessary
if s.flowController.HasWindowUpdate() {
s.sender.onHasWindowUpdate(s.streamID)
}
if s.readPosInFrame >= int(frame.DataLen()) {
s.frameQueue.Pop()
s.finRead = frame.FinBit
if frame.FinBit {
s.sender.onStreamCompleted(s.streamID)
return bytesRead, io.EOF
}
}
}
return bytesRead, nil
}
func (s *receiveStream) CancelRead(errorCode protocol.ApplicationErrorCode) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.finRead {
return nil
}
if s.canceledRead {
return nil
}
s.canceledRead = true
s.cancelReadErr = fmt.Errorf("Read on stream %d canceled with error code %d", s.streamID, errorCode)
s.signalRead()
if s.version.UsesIETFFrameFormat() {
s.sender.queueControlFrame(&wire.StopSendingFrame{
StreamID: s.streamID,
ErrorCode: errorCode,
})
}
return nil
}
func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error {
maxOffset := frame.Offset + frame.DataLen()
if err := s.flowController.UpdateHighestReceived(maxOffset, frame.FinBit); err != nil {
return err
}
s.mutex.Lock()
defer s.mutex.Unlock()
if err := s.frameQueue.Push(frame); err != nil && err != errDuplicateStreamData {
return err
}
s.signalRead()
return nil
}
func (s *receiveStream) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.closedForShutdown {
return nil
}
if err := s.flowController.UpdateHighestReceived(frame.ByteOffset, true); err != nil {
return err
}
// In gQUIC, error code 0 has a special meaning.
// The peer will reliably continue transmitting, but is not interested in reading from the stream.
// We should therefore just continue reading from the stream, until we encounter the FIN bit.
if !s.version.UsesIETFFrameFormat() && frame.ErrorCode == 0 {
return nil
}
// ignore duplicate RST_STREAM frames for this stream (after checking their final offset)
if s.resetRemotely {
return nil
}
s.resetRemotely = true
s.resetRemotelyErr = streamCanceledError{
errorCode: frame.ErrorCode,
error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode),
}
s.signalRead()
s.sender.onStreamCompleted(s.streamID)
return nil
}
func (s *receiveStream) CloseRemote(offset protocol.ByteCount) {
s.handleStreamFrame(&wire.StreamFrame{FinBit: true, Offset: offset})
}
func (s *receiveStream) onClose(offset protocol.ByteCount) {
if s.canceledRead && !s.version.UsesIETFFrameFormat() {
s.sender.queueControlFrame(&wire.RstStreamFrame{
StreamID: s.streamID,
ByteOffset: offset,
ErrorCode: 0,
})
}
}
func (s *receiveStream) SetReadDeadline(t time.Time) error {
s.mutex.Lock()
oldDeadline := s.readDeadline
s.readDeadline = t
s.mutex.Unlock()
// if the new deadline is before the currently set deadline, wake up Read()
if t.Before(oldDeadline) {
s.signalRead()
}
return nil
}
// CloseForShutdown closes a stream abruptly.
// It makes Read unblock (and return the error) immediately.
// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
func (s *receiveStream) closeForShutdown(err error) {
s.mutex.Lock()
s.closedForShutdown = true
s.closeForShutdownErr = err
s.mutex.Unlock()
s.signalRead()
}
func (s *receiveStream) getWindowUpdate() protocol.ByteCount {
return s.flowController.GetWindowUpdate()
}
// signalRead performs a non-blocking send on the readChan
func (s *receiveStream) signalRead() {
select {
case s.readChan <- struct{}{}:
default:
}
}

View File

@ -0,0 +1,648 @@
package quic
import (
"errors"
"io"
"runtime"
"time"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/gbytes"
)
var _ = Describe("Receive Stream", func() {
const streamID protocol.StreamID = 1337
var (
str *receiveStream
strWithTimeout io.Reader // str wrapped with gbytes.TimeoutReader
mockFC *mocks.MockStreamFlowController
mockSender *MockStreamSender
)
BeforeEach(func() {
mockSender = NewMockStreamSender(mockCtrl)
mockFC = mocks.NewMockStreamFlowController(mockCtrl)
str = newReceiveStream(streamID, mockSender, mockFC)
timeout := scaleDuration(250 * time.Millisecond)
strWithTimeout = gbytes.TimeoutReader(str, timeout)
})
It("gets stream id", func() {
Expect(str.StreamID()).To(Equal(protocol.StreamID(1337)))
})
Context("reading", func() {
It("reads a single STREAM frame", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4))
mockFC.EXPECT().HasWindowUpdate()
frame := wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
}
err := str.handleStreamFrame(&frame)
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
})
It("reads a single STREAM frame in multiple goes", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
mockFC.EXPECT().HasWindowUpdate().Times(2)
frame := wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
}
err := str.handleStreamFrame(&frame)
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 2)
n, err := strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(2))
Expect(b).To(Equal([]byte{0xDE, 0xAD}))
n, err = strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(2))
Expect(b).To(Equal([]byte{0xBE, 0xEF}))
})
It("reads all data available", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2)
mockFC.EXPECT().HasWindowUpdate().Times(2)
frame1 := wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD},
}
frame2 := wire.StreamFrame{
Offset: 2,
Data: []byte{0xBE, 0xEF},
}
err := str.handleStreamFrame(&frame1)
Expect(err).ToNot(HaveOccurred())
err = str.handleStreamFrame(&frame2)
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 6)
n, err := strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF, 0x00, 0x00}))
})
It("assembles multiple STREAM frames", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2)
mockFC.EXPECT().HasWindowUpdate().Times(2)
frame1 := wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD},
}
frame2 := wire.StreamFrame{
Offset: 2,
Data: []byte{0xBE, 0xEF},
}
err := str.handleStreamFrame(&frame1)
Expect(err).ToNot(HaveOccurred())
err = str.handleStreamFrame(&frame2)
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
})
It("waits until data is available", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
mockFC.EXPECT().HasWindowUpdate()
go func() {
defer GinkgoRecover()
frame := wire.StreamFrame{Data: []byte{0xDE, 0xAD}}
time.Sleep(10 * time.Millisecond)
err := str.handleStreamFrame(&frame)
Expect(err).ToNot(HaveOccurred())
}()
b := make([]byte, 2)
n, err := strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(2))
})
It("handles STREAM frames in wrong order", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2)
mockFC.EXPECT().HasWindowUpdate().Times(2)
frame1 := wire.StreamFrame{
Offset: 2,
Data: []byte{0xBE, 0xEF},
}
frame2 := wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD},
}
err := str.handleStreamFrame(&frame1)
Expect(err).ToNot(HaveOccurred())
err = str.handleStreamFrame(&frame2)
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
})
It("ignores duplicate STREAM frames", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2)
mockFC.EXPECT().HasWindowUpdate().Times(2)
frame1 := wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD},
}
frame2 := wire.StreamFrame{
Offset: 0,
Data: []byte{0x13, 0x37},
}
frame3 := wire.StreamFrame{
Offset: 2,
Data: []byte{0xBE, 0xEF},
}
err := str.handleStreamFrame(&frame1)
Expect(err).ToNot(HaveOccurred())
err = str.handleStreamFrame(&frame2)
Expect(err).ToNot(HaveOccurred())
err = str.handleStreamFrame(&frame3)
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
})
It("doesn't rejects a STREAM frames with an overlapping data range", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4))
mockFC.EXPECT().HasWindowUpdate().Times(2)
frame1 := wire.StreamFrame{
Offset: 0,
Data: []byte("foob"),
}
frame2 := wire.StreamFrame{
Offset: 2,
Data: []byte("obar"),
}
err := str.handleStreamFrame(&frame1)
Expect(err).ToNot(HaveOccurred())
err = str.handleStreamFrame(&frame2)
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 6)
n, err := strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
Expect(b).To(Equal([]byte("foobar")))
})
It("passes on errors from the streamFrameSorter", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), false)
err := str.handleStreamFrame(&wire.StreamFrame{StreamID: streamID}) // STREAM frame without data
Expect(err).To(MatchError(errEmptyStreamData))
})
It("calls the onHasWindowUpdate callback, when the a MAX_STREAM_DATA should be sent", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6))
mockFC.EXPECT().HasWindowUpdate().Return(true)
mockSender.EXPECT().onHasWindowUpdate(streamID)
frame1 := wire.StreamFrame{
Offset: 0,
Data: []byte("foobar"),
}
err := str.handleStreamFrame(&frame1)
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 6)
_, err = strWithTimeout.Read(b)
Expect(err).ToNot(HaveOccurred())
})
Context("deadlines", func() {
It("the deadline error has the right net.Error properties", func() {
Expect(errDeadline.Temporary()).To(BeTrue())
Expect(errDeadline.Timeout()).To(BeTrue())
Expect(errDeadline).To(MatchError("deadline exceeded"))
})
It("returns an error when Read is called after the deadline", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false).AnyTimes()
f := &wire.StreamFrame{Data: []byte("foobar")}
err := str.handleStreamFrame(f)
Expect(err).ToNot(HaveOccurred())
str.SetReadDeadline(time.Now().Add(-time.Second))
b := make([]byte, 6)
n, err := strWithTimeout.Read(b)
Expect(err).To(MatchError(errDeadline))
Expect(n).To(BeZero())
})
It("unblocks after the deadline", func() {
deadline := time.Now().Add(scaleDuration(50 * time.Millisecond))
str.SetReadDeadline(deadline)
b := make([]byte, 6)
n, err := strWithTimeout.Read(b)
Expect(err).To(MatchError(errDeadline))
Expect(n).To(BeZero())
Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(10*time.Millisecond)))
})
It("doesn't unblock if the deadline is changed before the first one expires", func() {
deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond))
deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond))
str.SetReadDeadline(deadline1)
go func() {
defer GinkgoRecover()
time.Sleep(scaleDuration(20 * time.Millisecond))
str.SetReadDeadline(deadline2)
// make sure that this was actually execute before the deadline expires
Expect(time.Now()).To(BeTemporally("<", deadline1))
}()
runtime.Gosched()
b := make([]byte, 10)
n, err := strWithTimeout.Read(b)
Expect(err).To(MatchError(errDeadline))
Expect(n).To(BeZero())
Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond)))
})
It("unblocks earlier, when a new deadline is set", func() {
deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond))
deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond))
go func() {
defer GinkgoRecover()
time.Sleep(scaleDuration(10 * time.Millisecond))
str.SetReadDeadline(deadline2)
// make sure that this was actually execute before the deadline expires
Expect(time.Now()).To(BeTemporally("<", deadline2))
}()
str.SetReadDeadline(deadline1)
runtime.Gosched()
b := make([]byte, 10)
_, err := strWithTimeout.Read(b)
Expect(err).To(MatchError(errDeadline))
Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(25*time.Millisecond)))
})
})
Context("closing", func() {
Context("with FIN bit", func() {
It("returns EOFs", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4))
mockFC.EXPECT().HasWindowUpdate()
str.handleStreamFrame(&wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
FinBit: true,
})
mockSender.EXPECT().onStreamCompleted(streamID)
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).To(MatchError(io.EOF))
Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
n, err = strWithTimeout.Read(b)
Expect(n).To(BeZero())
Expect(err).To(MatchError(io.EOF))
})
It("handles out-of-order frames", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2)
mockFC.EXPECT().HasWindowUpdate().Times(2)
frame1 := wire.StreamFrame{
Offset: 2,
Data: []byte{0xBE, 0xEF},
FinBit: true,
}
frame2 := wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD},
}
err := str.handleStreamFrame(&frame1)
Expect(err).ToNot(HaveOccurred())
err = str.handleStreamFrame(&frame2)
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().onStreamCompleted(streamID)
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).To(MatchError(io.EOF))
Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
n, err = strWithTimeout.Read(b)
Expect(n).To(BeZero())
Expect(err).To(MatchError(io.EOF))
})
It("returns EOFs with partial read", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), true)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
mockFC.EXPECT().HasWindowUpdate()
err := str.handleStreamFrame(&wire.StreamFrame{
Offset: 0,
Data: []byte{0xde, 0xad},
FinBit: true,
})
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().onStreamCompleted(streamID)
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).To(MatchError(io.EOF))
Expect(n).To(Equal(2))
Expect(b[:n]).To(Equal([]byte{0xde, 0xad}))
})
It("handles immediate FINs", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0))
mockFC.EXPECT().HasWindowUpdate()
err := str.handleStreamFrame(&wire.StreamFrame{
Offset: 0,
FinBit: true,
})
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().onStreamCompleted(streamID)
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(n).To(BeZero())
Expect(err).To(MatchError(io.EOF))
})
})
It("closes when CloseRemote is called", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0))
mockFC.EXPECT().HasWindowUpdate()
str.CloseRemote(0)
mockSender.EXPECT().onStreamCompleted(streamID)
b := make([]byte, 8)
n, err := strWithTimeout.Read(b)
Expect(n).To(BeZero())
Expect(err).To(MatchError(io.EOF))
})
})
Context("closing for shutdown", func() {
testErr := errors.New("test error")
It("immediately returns all reads", func() {
done := make(chan struct{})
b := make([]byte, 4)
go func() {
defer GinkgoRecover()
n, err := strWithTimeout.Read(b)
Expect(n).To(BeZero())
Expect(err).To(MatchError(testErr))
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
str.closeForShutdown(testErr)
Eventually(done).Should(BeClosed())
})
It("errors for all following reads", func() {
str.closeForShutdown(testErr)
b := make([]byte, 1)
n, err := strWithTimeout.Read(b)
Expect(n).To(BeZero())
Expect(err).To(MatchError(testErr))
})
})
})
Context("stream cancelations", func() {
Context("canceling read", func() {
It("unblocks Read", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := strWithTimeout.Read([]byte{0})
Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234"))
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
err := str.CancelRead(1234)
Expect(err).ToNot(HaveOccurred())
Eventually(done).Should(BeClosed())
})
It("doesn't allow further calls to Read", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
err := str.CancelRead(1234)
Expect(err).ToNot(HaveOccurred())
_, err = strWithTimeout.Read([]byte{0})
Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234"))
})
It("does nothing when CancelRead is called twice", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
err := str.CancelRead(1234)
Expect(err).ToNot(HaveOccurred())
err = str.CancelRead(2345)
Expect(err).ToNot(HaveOccurred())
_, err = strWithTimeout.Read([]byte{0})
Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234"))
})
It("doesn't send a RST_STREAM frame, if the FIN was already read", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6))
mockFC.EXPECT().HasWindowUpdate()
// no calls to mockSender.queueControlFrame
err := str.handleStreamFrame(&wire.StreamFrame{
StreamID: streamID,
Data: []byte("foobar"),
FinBit: true,
})
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().onStreamCompleted(streamID)
_, err = strWithTimeout.Read(make([]byte, 100))
Expect(err).To(MatchError(io.EOF))
err = str.CancelRead(1234)
Expect(err).ToNot(HaveOccurred())
})
It("queues a STOP_SENDING frame, for IETF QUIC", func() {
mockSender.EXPECT().queueControlFrame(&wire.StopSendingFrame{
StreamID: streamID,
ErrorCode: 1234,
})
err := str.CancelRead(1234)
Expect(err).ToNot(HaveOccurred())
})
It("doesn't queue a STOP_SENDING frame, for gQUIC", func() {
})
})
Context("receiving RST_STREAM frames", func() {
rst := &wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 42,
ErrorCode: 1234,
}
It("unblocks Read", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := strWithTimeout.Read([]byte{0})
Expect(err).To(MatchError("Stream 1337 was reset with error code 1234"))
Expect(err).To(BeAssignableToTypeOf(streamCanceledError{}))
Expect(err.(streamCanceledError).Canceled()).To(BeTrue())
Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(1234)))
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
mockSender.EXPECT().onStreamCompleted(streamID)
str.handleRstStreamFrame(rst)
Eventually(done).Should(BeClosed())
})
It("doesn't allow further calls to Read", func() {
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
err := str.handleRstStreamFrame(rst)
Expect(err).ToNot(HaveOccurred())
_, err = strWithTimeout.Read([]byte{0})
Expect(err).To(MatchError("Stream 1337 was reset with error code 1234"))
Expect(err).To(BeAssignableToTypeOf(streamCanceledError{}))
Expect(err.(streamCanceledError).Canceled()).To(BeTrue())
Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(1234)))
})
It("errors when receiving a RST_STREAM with an inconsistent offset", func() {
testErr := errors.New("already received a different final offset before")
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Return(testErr)
err := str.handleRstStreamFrame(rst)
Expect(err).To(MatchError(testErr))
})
It("ignores duplicate RST_STREAM frames", func() {
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2)
err := str.handleRstStreamFrame(rst)
Expect(err).ToNot(HaveOccurred())
err = str.handleRstStreamFrame(rst)
Expect(err).ToNot(HaveOccurred())
})
It("doesn't do anyting when it was closed for shutdown", func() {
str.closeForShutdown(nil)
err := str.handleRstStreamFrame(rst)
Expect(err).ToNot(HaveOccurred())
})
Context("for gQUIC", func() {
BeforeEach(func() {
str.version = versionGQUICFrames
})
It("unblocks Read when receiving a RST_STREAM frame with non-zero error code", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
readReturned := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := strWithTimeout.Read([]byte{0})
Expect(err).To(MatchError("Stream 1337 was reset with error code 1234"))
Expect(err).To(BeAssignableToTypeOf(streamCanceledError{}))
Expect(err.(streamCanceledError).Canceled()).To(BeTrue())
Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(1234)))
close(readReturned)
}()
Consistently(readReturned).ShouldNot(BeClosed())
mockSender.EXPECT().onStreamCompleted(streamID)
err := str.handleRstStreamFrame(rst)
Expect(err).ToNot(HaveOccurred())
Eventually(readReturned).Should(BeClosed())
})
It("continues reading until the end when receiving a RST_STREAM frame with error code 0", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true).Times(2)
gomock.InOrder(
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)),
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)),
mockSender.EXPECT().onStreamCompleted(streamID),
)
mockFC.EXPECT().HasWindowUpdate().Times(2)
readReturned := make(chan struct{})
go func() {
defer GinkgoRecover()
n, err := strWithTimeout.Read(make([]byte, 4))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4))
n, err = strWithTimeout.Read(make([]byte, 4))
Expect(err).To(MatchError(io.EOF))
Expect(n).To(Equal(2))
close(readReturned)
}()
Consistently(readReturned).ShouldNot(BeClosed())
err := str.handleStreamFrame(&wire.StreamFrame{
StreamID: streamID,
Data: []byte("foobar"),
FinBit: true,
})
Expect(err).ToNot(HaveOccurred())
err = str.handleRstStreamFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 6,
ErrorCode: 0,
})
Expect(err).ToNot(HaveOccurred())
Eventually(readReturned).Should(BeClosed())
})
})
})
})
Context("flow control", func() {
It("errors when a STREAM frame causes a flow control violation", func() {
testErr := errors.New("flow control violation")
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), false).Return(testErr)
frame := wire.StreamFrame{
Offset: 2,
Data: []byte("foobar"),
}
err := str.handleStreamFrame(&frame)
Expect(err).To(MatchError(testErr))
})
It("gets a window update", func() {
mockFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x100))
Expect(str.getWindowUpdate()).To(Equal(protocol.ByteCount(0x100)))
})
})
})

313
vendor/github.com/lucas-clemente/quic-go/send_stream.go generated vendored Normal file
View File

@ -0,0 +1,313 @@
package quic
import (
"context"
"fmt"
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type sendStreamI interface {
SendStream
handleStopSendingFrame(*wire.StopSendingFrame)
popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool)
closeForShutdown(error)
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
}
type sendStream struct {
mutex sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
streamID protocol.StreamID
sender streamSender
writeOffset protocol.ByteCount
cancelWriteErr error
closeForShutdownErr error
closedForShutdown bool // set when CloseForShutdown() is called
finishedWriting bool // set once Close() is called
canceledWrite bool // set when CancelWrite() is called, or a STOP_SENDING frame is received
finSent bool // set when a STREAM_FRAME with FIN bit has b
dataForWriting []byte
writeChan chan struct{}
writeDeadline time.Time
flowController flowcontrol.StreamFlowController
version protocol.VersionNumber
}
var _ SendStream = &sendStream{}
var _ sendStreamI = &sendStream{}
func newSendStream(
streamID protocol.StreamID,
sender streamSender,
flowController flowcontrol.StreamFlowController,
version protocol.VersionNumber,
) *sendStream {
s := &sendStream{
streamID: streamID,
sender: sender,
flowController: flowController,
writeChan: make(chan struct{}, 1),
version: version,
}
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
return s
}
func (s *sendStream) StreamID() protocol.StreamID {
return s.streamID // same for receiveStream and sendStream
}
func (s *sendStream) Write(p []byte) (int, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.finishedWriting {
return 0, fmt.Errorf("write on closed stream %d", s.streamID)
}
if s.canceledWrite {
return 0, s.cancelWriteErr
}
if s.closeForShutdownErr != nil {
return 0, s.closeForShutdownErr
}
if !s.writeDeadline.IsZero() && !time.Now().Before(s.writeDeadline) {
return 0, errDeadline
}
if len(p) == 0 {
return 0, nil
}
s.dataForWriting = make([]byte, len(p))
copy(s.dataForWriting, p)
s.sender.onHasStreamData(s.streamID)
var bytesWritten int
var err error
for {
bytesWritten = len(p) - len(s.dataForWriting)
deadline := s.writeDeadline
if !deadline.IsZero() && !time.Now().Before(deadline) {
s.dataForWriting = nil
err = errDeadline
break
}
if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown {
break
}
s.mutex.Unlock()
if deadline.IsZero() {
<-s.writeChan
} else {
select {
case <-s.writeChan:
case <-time.After(deadline.Sub(time.Now())):
}
}
s.mutex.Lock()
}
if s.closeForShutdownErr != nil {
err = s.closeForShutdownErr
} else if s.cancelWriteErr != nil {
err = s.cancelWriteErr
}
return bytesWritten, err
}
// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
// maxBytes is the maximum length this frame (including frame header) will have.
func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.closeForShutdownErr != nil {
return nil, false
}
frame := &wire.StreamFrame{
StreamID: s.streamID,
Offset: s.writeOffset,
DataLenPresent: true,
}
frameLen := frame.MinLength(s.version)
if frameLen >= maxBytes { // a STREAM frame must have at least one byte of data
return nil, s.dataForWriting != nil
}
frame.Data, frame.FinBit = s.getDataForWriting(maxBytes - frameLen)
if len(frame.Data) == 0 && !frame.FinBit {
// this can happen if:
// - popStreamFrame is called but there's no data for writing
// - there's data for writing, but the stream is stream-level flow control blocked
// - there's data for writing, but the stream is connection-level flow control blocked
if s.dataForWriting == nil {
return nil, false
}
isBlocked, _ := s.flowController.IsBlocked()
return nil, !isBlocked
}
if frame.FinBit {
s.finSent = true
s.sender.onStreamCompleted(s.streamID)
} else if s.streamID != s.version.CryptoStreamID() { // TODO(#657): Flow control for the crypto stream
if isBlocked, offset := s.flowController.IsBlocked(); isBlocked {
s.sender.queueControlFrame(&wire.StreamBlockedFrame{
StreamID: s.streamID,
Offset: offset,
})
return frame, false
}
}
return frame, s.dataForWriting != nil
}
func (s *sendStream) getDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) {
if s.dataForWriting == nil {
return nil, s.finishedWriting && !s.finSent
}
// TODO(#657): Flow control for the crypto stream
if s.streamID != s.version.CryptoStreamID() {
maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize())
}
if maxBytes == 0 {
return nil, false
}
var ret []byte
if protocol.ByteCount(len(s.dataForWriting)) > maxBytes {
ret = s.dataForWriting[:maxBytes]
s.dataForWriting = s.dataForWriting[maxBytes:]
} else {
ret = s.dataForWriting
s.dataForWriting = nil
s.signalWrite()
}
s.writeOffset += protocol.ByteCount(len(ret))
s.flowController.AddBytesSent(protocol.ByteCount(len(ret)))
return ret, s.finishedWriting && s.dataForWriting == nil && !s.finSent
}
func (s *sendStream) Close() error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.canceledWrite {
return fmt.Errorf("Close called for canceled stream %d", s.streamID)
}
s.finishedWriting = true
s.sender.onHasStreamData(s.streamID) // need to send the FIN
s.ctxCancel()
return nil
}
func (s *sendStream) CancelWrite(errorCode protocol.ApplicationErrorCode) error {
s.mutex.Lock()
defer s.mutex.Unlock()
return s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode))
}
// must be called after locking the mutex
func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, writeErr error) error {
if s.canceledWrite {
return nil
}
if s.finishedWriting {
return fmt.Errorf("CancelWrite for closed stream %d", s.streamID)
}
s.canceledWrite = true
s.cancelWriteErr = writeErr
s.signalWrite()
s.sender.queueControlFrame(&wire.RstStreamFrame{
StreamID: s.streamID,
ByteOffset: s.writeOffset,
ErrorCode: errorCode,
})
// TODO(#991): cancel retransmissions for this stream
s.ctxCancel()
s.sender.onStreamCompleted(s.streamID)
return nil
}
func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.handleStopSendingFrameImpl(frame)
}
func (s *sendStream) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) {
s.flowController.UpdateSendWindow(frame.ByteOffset)
s.mutex.Lock()
if s.dataForWriting != nil {
s.sender.onHasStreamData(s.streamID)
}
s.mutex.Unlock()
}
// must be called after locking the mutex
func (s *sendStream) handleStopSendingFrameImpl(frame *wire.StopSendingFrame) {
writeErr := streamCanceledError{
errorCode: frame.ErrorCode,
error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode),
}
errorCode := errorCodeStopping
if !s.version.UsesIETFFrameFormat() {
errorCode = errorCodeStoppingGQUIC
}
s.cancelWriteImpl(errorCode, writeErr)
}
func (s *sendStream) Context() context.Context {
return s.ctx
}
func (s *sendStream) SetWriteDeadline(t time.Time) error {
s.mutex.Lock()
oldDeadline := s.writeDeadline
s.writeDeadline = t
s.mutex.Unlock()
if t.Before(oldDeadline) {
s.signalWrite()
}
return nil
}
// CloseForShutdown closes a stream abruptly.
// It makes Write unblock (and return the error) immediately.
// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
func (s *sendStream) closeForShutdown(err error) {
s.mutex.Lock()
s.closedForShutdown = true
s.closeForShutdownErr = err
s.mutex.Unlock()
s.signalWrite()
s.ctxCancel()
}
func (s *sendStream) getWriteOffset() protocol.ByteCount {
return s.writeOffset
}
// signalWrite performs a non-blocking send on the writeChan
func (s *sendStream) signalWrite() {
select {
case s.writeChan <- struct{}{}:
default:
}
}

View File

@ -0,0 +1,636 @@
package quic
import (
"bytes"
"errors"
"io"
"runtime"
"time"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/gbytes"
)
var _ = Describe("Send Stream", func() {
const streamID protocol.StreamID = 1337
var (
str *sendStream
strWithTimeout io.Writer // str wrapped with gbytes.TimeoutWriter
mockFC *mocks.MockStreamFlowController
mockSender *MockStreamSender
)
BeforeEach(func() {
mockSender = NewMockStreamSender(mockCtrl)
mockFC = mocks.NewMockStreamFlowController(mockCtrl)
str = newSendStream(streamID, mockSender, mockFC, protocol.VersionWhatever)
timeout := scaleDuration(250 * time.Millisecond)
strWithTimeout = gbytes.TimeoutWriter(str, timeout)
})
waitForWrite := func() {
EventuallyWithOffset(0, func() []byte {
str.mutex.Lock()
data := str.dataForWriting
str.mutex.Unlock()
return data
}).ShouldNot(BeEmpty())
}
It("gets stream id", func() {
Expect(str.StreamID()).To(Equal(protocol.StreamID(1337)))
})
Context("writing", func() {
It("writes and gets all data at once", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6))
mockFC.EXPECT().IsBlocked()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
n, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
close(done)
}()
waitForWrite()
f, _ := str.popStreamFrame(1000)
Expect(f.Data).To(Equal([]byte("foobar")))
Expect(f.FinBit).To(BeFalse())
Expect(f.Offset).To(BeZero())
Expect(f.DataLenPresent).To(BeTrue())
Expect(str.writeOffset).To(Equal(protocol.ByteCount(6)))
Expect(str.dataForWriting).To(BeNil())
Eventually(done).Should(BeClosed())
})
It("writes and gets data in two turns", func() {
mockSender.EXPECT().onHasStreamData(streamID)
frameHeaderLen := protocol.ByteCount(4)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2)
mockFC.EXPECT().AddBytesSent(gomock.Any() /* protocol.ByteCount(3)*/).Times(2)
mockFC.EXPECT().IsBlocked().Times(2)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
n, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
close(done)
}()
waitForWrite()
f, _ := str.popStreamFrame(3 + frameHeaderLen)
Expect(f.Data).To(Equal([]byte("foo")))
Expect(f.FinBit).To(BeFalse())
Expect(f.Offset).To(BeZero())
Expect(f.DataLenPresent).To(BeTrue())
f, _ = str.popStreamFrame(100)
Expect(f.Data).To(Equal([]byte("bar")))
Expect(f.FinBit).To(BeFalse())
Expect(f.Offset).To(Equal(protocol.ByteCount(3)))
Expect(f.DataLenPresent).To(BeTrue())
Expect(str.popStreamFrame(1000)).To(BeNil())
Eventually(done).Should(BeClosed())
})
It("popStreamFrame returns nil if no data is available", func() {
frame, hasMoreData := str.popStreamFrame(1000)
Expect(frame).To(BeNil())
Expect(hasMoreData).To(BeFalse())
})
It("says if it has more data for writing", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2)
mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2)
mockFC.EXPECT().IsBlocked().Times(2)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
n, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 100))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(100))
close(done)
}()
waitForWrite()
frame, hasMoreData := str.popStreamFrame(50)
Expect(hasMoreData).To(BeTrue())
frame, hasMoreData = str.popStreamFrame(1000)
Expect(frame).ToNot(BeNil())
Expect(hasMoreData).To(BeFalse())
frame, _ = str.popStreamFrame(1000)
Expect(frame).To(BeNil())
Eventually(done).Should(BeClosed())
})
It("copies the slice while writing", func() {
mockSender.EXPECT().onHasStreamData(streamID)
frameHeaderSize := protocol.ByteCount(4)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2)
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(1))
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(2))
mockFC.EXPECT().IsBlocked().Times(2)
s := []byte("foo")
done := make(chan struct{})
go func() {
defer GinkgoRecover()
n, err := strWithTimeout.Write(s)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
close(done)
}()
waitForWrite()
frame, _ := str.popStreamFrame(frameHeaderSize + 1)
Expect(frame.Data).To(Equal([]byte("f")))
s[1] = 'e'
f, _ := str.popStreamFrame(100)
Expect(f).ToNot(BeNil())
Expect(f.Data).To(Equal([]byte("oo")))
Eventually(done).Should(BeClosed())
})
It("returns when given a nil input", func() {
n, err := strWithTimeout.Write(nil)
Expect(n).To(BeZero())
Expect(err).ToNot(HaveOccurred())
})
It("returns when given an empty slice", func() {
n, err := strWithTimeout.Write([]byte(""))
Expect(n).To(BeZero())
Expect(err).ToNot(HaveOccurred())
})
It("cancels the context when Close is called", func() {
mockSender.EXPECT().onHasStreamData(streamID)
Expect(str.Context().Done()).ToNot(BeClosed())
str.Close()
Expect(str.Context().Done()).To(BeClosed())
})
Context("flow control blocking", func() {
It("returns nil when it is blocked", func() {
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(0))
mockFC.EXPECT().IsBlocked().Return(true, protocol.ByteCount(10))
mockSender.EXPECT().onHasStreamData(streamID)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
close(done)
}()
waitForWrite()
f, hasMoreData := str.popStreamFrame(1000)
Expect(f).To(BeNil())
Expect(hasMoreData).To(BeFalse())
// make the Write go routine return
str.closeForShutdown(nil)
Eventually(done).Should(BeClosed())
})
It("queues a BLOCKED frame if the stream is flow control blocked", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().queueControlFrame(&wire.StreamBlockedFrame{
StreamID: streamID,
Offset: 10,
})
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6))
// don't use offset 6 here, to make sure the BLOCKED frame contains the number returned by the flow controller
mockFC.EXPECT().IsBlocked().Return(true, protocol.ByteCount(10))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
close(done)
}()
waitForWrite()
f, hasMoreData := str.popStreamFrame(1000)
Expect(f).ToNot(BeNil())
Expect(hasMoreData).To(BeFalse())
Eventually(done).Should(BeClosed())
})
It("says that it doesn't have any more data, when it is flow control blocked", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
mockFC.EXPECT().AddBytesSent(gomock.Any())
mockFC.EXPECT().IsBlocked().Return(true, protocol.ByteCount(10))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := str.Write(bytes.Repeat([]byte{0}, 100))
Expect(err).ToNot(HaveOccurred())
close(done)
}()
waitForWrite()
f, hasMoreData := str.popStreamFrame(50)
Expect(f).ToNot(BeNil())
Expect(hasMoreData).To(BeFalse())
// make the Write go routine return
str.closeForShutdown(nil)
Eventually(done).Should(BeClosed())
})
It("doesn't queue a BLOCKED frame if the stream is flow control blocked, but the frame popped has the FIN bit set", func() {
mockSender.EXPECT().onHasStreamData(streamID).Times(2) // once for the Write, once for the Close
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6))
// don't EXPECT a call to mockFC.IsBlocked
// don't EXPECT a call to mockSender.queueControlFrame
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
close(done)
}()
waitForWrite()
Expect(str.Close()).To(Succeed())
f, hasMoreData := str.popStreamFrame(1000)
Expect(hasMoreData).To(BeFalse())
Expect(f).ToNot(BeNil())
Expect(f.FinBit).To(BeTrue())
Eventually(done).Should(BeClosed())
})
})
Context("deadlines", func() {
It("returns an error when Write is called after the deadline", func() {
str.SetWriteDeadline(time.Now().Add(-time.Second))
n, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).To(MatchError(errDeadline))
Expect(n).To(BeZero())
})
It("unblocks after the deadline", func() {
mockSender.EXPECT().onHasStreamData(streamID)
deadline := time.Now().Add(scaleDuration(50 * time.Millisecond))
str.SetWriteDeadline(deadline)
n, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).To(MatchError(errDeadline))
Expect(n).To(BeZero())
Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond)))
})
It("returns the number of bytes written, when the deadline expires", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(10000)).AnyTimes()
mockFC.EXPECT().AddBytesSent(gomock.Any())
mockFC.EXPECT().IsBlocked()
deadline := time.Now().Add(scaleDuration(50 * time.Millisecond))
str.SetWriteDeadline(deadline)
var n int
writeReturned := make(chan struct{})
go func() {
defer GinkgoRecover()
var err error
n, err = strWithTimeout.Write(bytes.Repeat([]byte{0}, 100))
Expect(err).To(MatchError(errDeadline))
Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond)))
close(writeReturned)
}()
waitForWrite()
frame, hasMoreData := str.popStreamFrame(50)
Expect(frame).ToNot(BeNil())
Expect(hasMoreData).To(BeTrue())
Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed())
Expect(n).To(BeEquivalentTo(frame.DataLen()))
})
It("doesn't pop any data after the deadline expired", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(10000)).AnyTimes()
mockFC.EXPECT().AddBytesSent(gomock.Any())
mockFC.EXPECT().IsBlocked()
deadline := time.Now().Add(scaleDuration(50 * time.Millisecond))
str.SetWriteDeadline(deadline)
writeReturned := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 100))
Expect(err).To(MatchError(errDeadline))
close(writeReturned)
}()
waitForWrite()
frame, hasMoreData := str.popStreamFrame(50)
Expect(frame).ToNot(BeNil())
Expect(hasMoreData).To(BeTrue())
Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed())
frame, hasMoreData = str.popStreamFrame(50)
Expect(frame).To(BeNil())
Expect(hasMoreData).To(BeFalse())
})
It("doesn't unblock if the deadline is changed before the first one expires", func() {
mockSender.EXPECT().onHasStreamData(streamID)
deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond))
deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond))
str.SetWriteDeadline(deadline1)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
time.Sleep(scaleDuration(20 * time.Millisecond))
str.SetWriteDeadline(deadline2)
// make sure that this was actually execute before the deadline expires
Expect(time.Now()).To(BeTemporally("<", deadline1))
close(done)
}()
runtime.Gosched()
n, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).To(MatchError(errDeadline))
Expect(n).To(BeZero())
Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond)))
Eventually(done).Should(BeClosed())
})
It("unblocks earlier, when a new deadline is set", func() {
mockSender.EXPECT().onHasStreamData(streamID)
deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond))
deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
time.Sleep(scaleDuration(10 * time.Millisecond))
str.SetWriteDeadline(deadline2)
// make sure that this was actually execute before the deadline expires
Expect(time.Now()).To(BeTemporally("<", deadline2))
close(done)
}()
str.SetWriteDeadline(deadline1)
runtime.Gosched()
_, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).To(MatchError(errDeadline))
Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond)))
Eventually(done).Should(BeClosed())
})
})
Context("closing", func() {
It("doesn't allow writes after it has been closed", func() {
mockSender.EXPECT().onHasStreamData(streamID)
str.Close()
_, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).To(MatchError("write on closed stream 1337"))
})
It("allows FIN", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onStreamCompleted(streamID)
str.Close()
f, hasMoreData := str.popStreamFrame(1000)
Expect(f).ToNot(BeNil())
Expect(f.Data).To(BeEmpty())
Expect(f.FinBit).To(BeTrue())
Expect(hasMoreData).To(BeFalse())
})
It("doesn't send a FIN when there's still data", func() {
mockSender.EXPECT().onHasStreamData(streamID)
frameHeaderLen := protocol.ByteCount(4)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2)
mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2)
mockFC.EXPECT().IsBlocked()
str.dataForWriting = []byte("foobar")
Expect(str.Close()).To(Succeed())
f, _ := str.popStreamFrame(3 + frameHeaderLen)
Expect(f).ToNot(BeNil())
Expect(f.Data).To(Equal([]byte("foo")))
Expect(f.FinBit).To(BeFalse())
mockSender.EXPECT().onStreamCompleted(streamID)
f, _ = str.popStreamFrame(100)
Expect(f.Data).To(Equal([]byte("bar")))
Expect(f.FinBit).To(BeTrue())
})
It("doesn't allow FIN after it is closed for shutdown", func() {
str.closeForShutdown(errors.New("test"))
f, hasMoreData := str.popStreamFrame(1000)
Expect(f).To(BeNil())
Expect(hasMoreData).To(BeFalse())
})
It("doesn't allow FIN twice", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onStreamCompleted(streamID)
str.Close()
f, _ := str.popStreamFrame(1000)
Expect(f).ToNot(BeNil())
Expect(f.Data).To(BeEmpty())
Expect(f.FinBit).To(BeTrue())
f, hasMoreData := str.popStreamFrame(1000)
Expect(f).To(BeNil())
Expect(hasMoreData).To(BeFalse())
})
})
Context("closing for shutdown", func() {
testErr := errors.New("test")
It("returns errors when the stream is cancelled", func() {
str.closeForShutdown(testErr)
n, err := strWithTimeout.Write([]byte("foo"))
Expect(n).To(BeZero())
Expect(err).To(MatchError(testErr))
})
It("doesn't get data for writing if an error occurred", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
mockFC.EXPECT().AddBytesSent(gomock.Any())
mockFC.EXPECT().IsBlocked()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 500))
Expect(err).To(MatchError(testErr))
close(done)
}()
waitForWrite()
frame, hasMoreData := str.popStreamFrame(50) // get a STREAM frame containing some data, but not all
Expect(frame).ToNot(BeNil())
Expect(hasMoreData).To(BeTrue())
str.closeForShutdown(testErr)
frame, hasMoreData = str.popStreamFrame(1000)
Expect(frame).To(BeNil())
Expect(hasMoreData).To(BeFalse())
Eventually(done).Should(BeClosed())
})
It("cancels the context", func() {
Expect(str.Context().Done()).ToNot(BeClosed())
str.closeForShutdown(testErr)
Expect(str.Context().Done()).To(BeClosed())
})
})
})
Context("handling MAX_STREAM_DATA frames", func() {
It("informs the flow controller", func() {
mockFC.EXPECT().UpdateSendWindow(protocol.ByteCount(0x1337))
str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{
StreamID: streamID,
ByteOffset: 0x1337,
})
})
It("says when it has data for sending", func() {
mockFC.EXPECT().UpdateSendWindow(gomock.Any())
mockSender.EXPECT().onHasStreamData(streamID).Times(2) // once for Write, once for the MAX_STREAM_DATA frame
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
close(done)
}()
waitForWrite()
str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{
StreamID: streamID,
ByteOffset: 42,
})
// make sure the Write go routine returns
str.closeForShutdown(nil)
Eventually(done).Should(BeClosed())
})
})
Context("stream cancelations", func() {
Context("canceling writing", func() {
It("queues a RST_STREAM frame", func() {
mockSender.EXPECT().queueControlFrame(&wire.RstStreamFrame{
StreamID: streamID,
ByteOffset: 1234,
ErrorCode: 9876,
})
mockSender.EXPECT().onStreamCompleted(streamID)
str.writeOffset = 1234
err := str.CancelWrite(9876)
Expect(err).ToNot(HaveOccurred())
})
It("unblocks Write", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onStreamCompleted(streamID)
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount)
mockFC.EXPECT().AddBytesSent(gomock.Any())
mockFC.EXPECT().IsBlocked()
writeReturned := make(chan struct{})
var n int
go func() {
defer GinkgoRecover()
var err error
n, err = strWithTimeout.Write(bytes.Repeat([]byte{0}, 100))
Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234"))
close(writeReturned)
}()
waitForWrite()
frame, _ := str.popStreamFrame(50)
Expect(frame).ToNot(BeNil())
err := str.CancelWrite(1234)
Expect(err).ToNot(HaveOccurred())
Eventually(writeReturned).Should(BeClosed())
Expect(n).To(BeEquivalentTo(frame.DataLen()))
})
It("cancels the context", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
Expect(str.Context().Done()).ToNot(BeClosed())
str.CancelWrite(1234)
Expect(str.Context().Done()).To(BeClosed())
})
It("doesn't allow further calls to Write", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
err := str.CancelWrite(1234)
Expect(err).ToNot(HaveOccurred())
_, err = strWithTimeout.Write([]byte("foobar"))
Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234"))
})
It("only cancels once", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
err := str.CancelWrite(1234)
Expect(err).ToNot(HaveOccurred())
err = str.CancelWrite(4321)
Expect(err).ToNot(HaveOccurred())
})
It("doesn't cancel when the stream was already closed", func() {
mockSender.EXPECT().onHasStreamData(streamID)
err := str.Close()
Expect(err).ToNot(HaveOccurred())
err = str.CancelWrite(123)
Expect(err).To(MatchError("CancelWrite for closed stream 1337"))
})
})
Context("receiving STOP_SENDING frames", func() {
It("queues a RST_STREAM frames with error code Stopping", func() {
mockSender.EXPECT().queueControlFrame(&wire.RstStreamFrame{
StreamID: streamID,
ErrorCode: errorCodeStopping,
})
mockSender.EXPECT().onStreamCompleted(streamID)
str.handleStopSendingFrame(&wire.StopSendingFrame{
StreamID: streamID,
ErrorCode: 101,
})
})
It("unblocks Write", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().queueControlFrame(gomock.Any())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := str.Write([]byte("foobar"))
Expect(err).To(MatchError("Stream 1337 was reset with error code 123"))
Expect(err).To(BeAssignableToTypeOf(streamCanceledError{}))
Expect(err.(streamCanceledError).Canceled()).To(BeTrue())
Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(123)))
close(done)
}()
waitForWrite()
mockSender.EXPECT().onStreamCompleted(streamID)
str.handleStopSendingFrame(&wire.StopSendingFrame{
StreamID: streamID,
ErrorCode: 123,
})
Eventually(done).Should(BeClosed())
})
It("doesn't allow further calls to Write", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
str.handleStopSendingFrame(&wire.StopSendingFrame{
StreamID: streamID,
ErrorCode: 123,
})
_, err := str.Write([]byte("foobar"))
Expect(err).To(MatchError("Stream 1337 was reset with error code 123"))
Expect(err).To(BeAssignableToTypeOf(streamCanceledError{}))
Expect(err.(streamCanceledError).Canceled()).To(BeTrue())
Expect(err.(streamCanceledError).ErrorCode()).To(Equal(protocol.ApplicationErrorCode(123)))
})
})
})
})

View File

@ -19,8 +19,8 @@ import (
// packetHandler handles packets
type packetHandler interface {
Session
getCryptoStream() cryptoStream
handshakeStatus() <-chan handshakeEvent
getCryptoStream() cryptoStreamI
handshakeStatus() <-chan error
handlePacket(*receivedPacket)
GetVersion() protocol.VersionNumber
run() error
@ -40,15 +40,17 @@ type server struct {
certChain crypto.CertChain
scfg *handshake.ServerConfig
sessions map[protocol.ConnectionID]packetHandler
sessionsMutex sync.RWMutex
deleteClosedSessionsAfter time.Duration
sessionsMutex sync.RWMutex
sessions map[protocol.ConnectionID]packetHandler
closed bool
serverError error
sessionQueue chan Session
errorChan chan struct{}
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, error)
// set as members, so they can be set in the tests
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, error)
deleteClosedSessionsAfter time.Duration
}
var _ Listener = &server{}
@ -240,6 +242,12 @@ func (s *server) Accept() (Session, error) {
// Close the server
func (s *server) Close() error {
s.sessionsMutex.Lock()
if s.closed {
s.sessionsMutex.Unlock()
return nil
}
s.closed = true
var wg sync.WaitGroup
for _, session := range s.sessions {
if session != nil {
@ -254,10 +262,9 @@ func (s *server) Close() error {
s.sessionsMutex.Unlock()
wg.Wait()
if s.conn == nil {
return nil
}
return s.conn.Close()
err := s.conn.Close()
<-s.errorChan // wait for serve() to return
return err
}
// Addr returns the server's network address
@ -384,14 +391,8 @@ func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.C
}()
go func() {
for {
ev := <-session.handshakeStatus()
if ev.err != nil {
return
}
if ev.encLevel == protocol.EncryptionForwardSecure {
break
}
if err := <-session.handshakeStatus(); err != nil {
return
}
s.sessionQueue <- session
}()

View File

@ -22,14 +22,13 @@ import (
)
type mockSession struct {
connectionID protocol.ConnectionID
packetCount int
closed bool
closeReason error
closedRemote bool
stopRunLoop chan struct{} // run returns as soon as this channel receives a value
handshakeChan chan handshakeEvent
handshakeComplete chan error // for WaitUntilHandshakeComplete
connectionID protocol.ConnectionID
packetCount int
closed bool
closeReason error
closedRemote bool
stopRunLoop chan struct{} // run returns as soon as this channel receives a value
handshakeChan chan error
}
func (s *mockSession) handlePacket(*receivedPacket) {
@ -40,9 +39,6 @@ func (s *mockSession) run() error {
<-s.stopRunLoop
return s.closeReason
}
func (s *mockSession) WaitUntilHandshakeComplete() error {
return <-s.handshakeComplete
}
func (s *mockSession) Close(e error) error {
if s.closed {
return nil
@ -59,19 +55,19 @@ func (s *mockSession) closeRemote(e error) {
close(s.stopRunLoop)
}
func (s *mockSession) OpenStream() (Stream, error) {
return &stream{streamID: 1337}, nil
return &stream{}, nil
}
func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") }
func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") }
func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") }
func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") }
func (*mockSession) Context() context.Context { panic("not implemented") }
func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever }
func (s *mockSession) handshakeStatus() <-chan handshakeEvent { return s.handshakeChan }
func (*mockSession) getCryptoStream() cryptoStream { panic("not implemented") }
func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") }
func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") }
func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") }
func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") }
func (*mockSession) Context() context.Context { panic("not implemented") }
func (*mockSession) ConnectionState() ConnectionState { panic("not implemented") }
func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever }
func (s *mockSession) handshakeStatus() <-chan error { return s.handshakeChan }
func (*mockSession) getCryptoStream() cryptoStreamI { panic("not implemented") }
var _ Session = &mockSession{}
var _ NonFWSession = &mockSession{}
func newMockSession(
_ connection,
@ -82,10 +78,9 @@ func newMockSession(
_ *Config,
) (packetHandler, error) {
s := mockSession{
connectionID: connectionID,
handshakeChan: make(chan handshakeEvent),
handshakeComplete: make(chan error),
stopRunLoop: make(chan struct{}),
connectionID: connectionID,
handshakeChan: make(chan error),
stopRunLoop: make(chan struct{}),
}
return &s, nil
}
@ -155,9 +150,8 @@ var _ = Describe("Server", func() {
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
sess := serv.sessions[connID].(*mockSession)
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
Consistently(func() Session { return acceptedSess }).Should(BeNil())
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionForwardSecure}
close(sess.handshakeChan)
Eventually(func() Session { return acceptedSess }).Should(Equal(sess))
close(done)
}, 0.5)
@ -173,7 +167,7 @@ var _ = Describe("Server", func() {
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
sess := serv.sessions[connID].(*mockSession)
sess.handshakeChan <- handshakeEvent{err: errors.New("handshake failed")}
sess.handshakeChan <- errors.New("handshake failed")
Consistently(func() bool { return accepted }).Should(BeFalse())
close(done)
})
@ -222,6 +216,7 @@ var _ = Describe("Server", func() {
})
It("closes sessions and the connection when Close is called", func() {
go serv.serve()
session, _ := newMockSession(nil, 0, 0, nil, nil, nil)
serv.sessions[1] = session
err := serv.Close()

View File

@ -12,6 +12,7 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
type nullAEAD struct {
@ -98,6 +99,26 @@ func (s *serverTLS) newMintConnImpl(bc *handshake.CryptoStreamConn, v protocol.V
return tls, extHandler.GetPeerParams(), nil
}
func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Header, aead crypto.AEAD, closeErr error) error {
ccf := &wire.ConnectionCloseFrame{
ErrorCode: qerr.HandshakeFailed,
ReasonPhrase: closeErr.Error(),
}
replyHdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
ConnectionID: clientHdr.ConnectionID, // echo the client's connection ID
PacketNumber: 1, // random packet number
Version: clientHdr.Version,
}
data, err := packUnencryptedPacket(aead, replyHdr, ccf, protocol.PerspectiveServer)
if err != nil {
return err
}
_, err = s.conn.WriteTo(data, remoteAddr)
return err
}
func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, error) {
if len(hdr.Raw)+len(data) < protocol.MinInitialPacketSize {
return nil, errors.New("dropping too small Initial packet")
@ -110,19 +131,30 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat
}
// unpack packet and check stream frame contents
version := hdr.Version
aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.ConnectionID, version)
aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.ConnectionID, hdr.Version)
if err != nil {
return nil, err
}
frame, err := unpackInitialPacket(aead, hdr, data, version)
frame, err := unpackInitialPacket(aead, hdr, data, hdr.Version)
if err != nil {
utils.Debugf("Error unpacking initial packet: %s", err)
return nil, nil
}
sess, err := s.handleUnpackedInitial(remoteAddr, hdr, frame, aead)
if err != nil {
if ccerr := s.sendConnectionClose(remoteAddr, hdr, aead, err); ccerr != nil {
utils.Debugf("Error sending CONNECTION_CLOSE: ", ccerr)
}
return nil, err
}
return sess, nil
}
func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (packetHandler, error) {
version := hdr.Version
bc := handshake.NewCryptoStreamConn(remoteAddr)
bc.AddDataForReading(frame.Data)
tls, paramsChan, err := s.newMintConn(bc, hdr.Version)
tls, paramsChan, err := s.newMintConn(bc, version)
if err != nil {
return nil, err
}
@ -176,7 +208,7 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat
return nil, err
}
cs := sess.getCryptoStream()
cs.SetReadOffset(frame.DataLen())
cs.setReadOffset(frame.DataLen())
bc.SetStream(cs)
return sess, nil
}

View File

@ -4,15 +4,16 @@ import (
"bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/mocks/handshake"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/mocks/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
@ -65,6 +66,18 @@ var _ = Describe("Stateless TLS handling", func() {
return hdr, data
}
unpackPacket := func(data []byte) (*wire.Header, []byte) {
r := bytes.NewReader(conn.dataWritten.Bytes())
hdr, err := wire.ParseHeaderSentByServer(r, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
hdr.Raw = data[:len(data)-r.Len()]
aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, hdr.ConnectionID, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
payload, err := aead.Open(nil, data[len(data)-r.Len():], hdr.PacketNumber, hdr.Raw)
Expect(err).ToNot(HaveOccurred())
return hdr, payload
}
It("sends a version negotiation packet if it doesn't support the version", func() {
server.HandleInitial(nil, &wire.Header{Version: 0x1337}, bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize))
Expect(conn.dataWritten.Len()).ToNot(BeZero())
@ -124,4 +137,20 @@ var _ = Describe("Stateless TLS handling", func() {
Eventually(sessionChan).Should(Receive())
Eventually(done).Should(BeClosed())
})
It("sends a CONNECTION_CLOSE, if mint returns an error", func() {
mintTLS.EXPECT().Handshake().Return(mint.AlertAccessDenied)
extHandler.EXPECT().GetPeerParams()
hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")})
server.HandleInitial(nil, hdr, data)
// the Handshake packet is written by the session
Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty())
// unpack the packet to check that it actually contains a CONNECTION_CLOSE
hdr, data = unpackPacket(conn.dataWritten.Bytes())
Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake))
ccf, err := wire.ParseConnectionCloseFrame(bytes.NewReader(data), protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
Expect(ccf.ErrorCode).To(Equal(qerr.HandshakeFailed))
Expect(ccf.ReasonPhrase).To(Equal(mint.AlertAccessDenied.String()))
})
})

View File

@ -4,7 +4,6 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"
"time"
@ -24,6 +23,23 @@ type unpacker interface {
Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error)
}
type streamGetter interface {
GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error)
GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error)
}
type streamManager interface {
GetOrOpenStream(protocol.StreamID) (streamI, error)
GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error)
GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error)
OpenStream() (Stream, error)
OpenStreamSync() (Stream, error)
AcceptStream() (Stream, error)
DeleteStream(protocol.StreamID) error
UpdateLimits(*handshake.TransportParameters)
CloseWithError(error)
}
type receivedPacket struct {
remoteAddr net.Addr
header *wire.Header
@ -36,11 +52,6 @@ var (
newCryptoSetupClient = handshake.NewCryptoSetupClient
)
type handshakeEvent struct {
encLevel protocol.EncryptionLevel
err error
}
type closeError struct {
err error
remote bool
@ -55,16 +66,16 @@ type session struct {
conn connection
streamsMap *streamsMap
cryptoStream cryptoStream
streamsMap streamManager
cryptoStream cryptoStreamI
rttStats *congestion.RTTStats
sentPacketHandler ackhandler.SentPacketHandler
receivedPacketHandler ackhandler.ReceivedPacketHandler
streamFramer *streamFramer
connFlowController flowcontrol.ConnectionFlowController
windowUpdateQueue *windowUpdateQueue
connFlowController flowcontrol.ConnectionFlowController
unpacker unpacker
packer *packetPacker
@ -87,17 +98,14 @@ type session struct {
// this channel is passed to the CryptoSetup and receives the transport parameters, as soon as the peer sends them
paramsChan <-chan handshake.TransportParameters
// this channel is passed to the CryptoSetup and receives the current encryption level
// it is closed as soon as the handshake is complete
aeadChanged <-chan protocol.EncryptionLevel
// the handshakeEvent channel is passed to the CryptoSetup.
// It receives when it makes sense to try decrypting undecryptable packets.
handshakeEvent <-chan struct{}
// handshakeChan is returned by handshakeStatus.
// It receives any error that might occur during the handshake.
// It is closed when the handshake is complete.
handshakeChan chan error
handshakeComplete bool
// will be closed as soon as the handshake completes, and receive any error that might occur until then
// it is used to block WaitUntilHandshakeComplete()
handshakeCompleteChan chan error
// handshakeChan receives handshake events and is closed as soon the handshake completes
// the receiving end of this channel is passed to the creator of the session
// it receives at most 3 handshake events: 2 when the encryption level changes, and one error
handshakeChan chan handshakeEvent
lastRcvdPacketNumber protocol.PacketNumber
// Used to calculate the next packet number from the truncated wire
@ -116,6 +124,7 @@ type session struct {
}
var _ Session = &session{}
var _ streamSender = &session{}
// newSession makes a new session
func newSession(
@ -127,15 +136,15 @@ func newSession(
config *Config,
) (packetHandler, error) {
paramsChan := make(chan handshake.TransportParameters)
aeadChanged := make(chan protocol.EncryptionLevel, 2)
handshakeEvent := make(chan struct{}, 1)
s := &session{
conn: conn,
connectionID: connectionID,
perspective: protocol.PerspectiveServer,
version: v,
config: config,
aeadChanged: aeadChanged,
paramsChan: paramsChan,
conn: conn,
connectionID: connectionID,
perspective: protocol.PerspectiveServer,
version: v,
config: config,
handshakeEvent: handshakeEvent,
paramsChan: paramsChan,
}
s.preSetup()
transportParams := &handshake.TransportParameters{
@ -154,7 +163,7 @@ func newSession(
s.config.Versions,
s.config.AcceptCookie,
paramsChan,
aeadChanged,
handshakeEvent,
)
if err != nil {
return nil, err
@ -175,15 +184,15 @@ var newClientSession = func(
negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton
) (packetHandler, error) {
paramsChan := make(chan handshake.TransportParameters)
aeadChanged := make(chan protocol.EncryptionLevel, 2)
handshakeEvent := make(chan struct{}, 1)
s := &session{
conn: conn,
connectionID: connectionID,
perspective: protocol.PerspectiveClient,
version: v,
config: config,
aeadChanged: aeadChanged,
paramsChan: paramsChan,
conn: conn,
connectionID: connectionID,
perspective: protocol.PerspectiveClient,
version: v,
config: config,
handshakeEvent: handshakeEvent,
paramsChan: paramsChan,
}
s.preSetup()
transportParams := &handshake.TransportParameters{
@ -201,7 +210,7 @@ var newClientSession = func(
tlsConf,
transportParams,
paramsChan,
aeadChanged,
handshakeEvent,
initialVersion,
negotiatedVersions,
)
@ -223,21 +232,21 @@ func newTLSServerSession(
peerParams *handshake.TransportParameters,
v protocol.VersionNumber,
) (packetHandler, error) {
aeadChanged := make(chan protocol.EncryptionLevel, 2)
handshakeEvent := make(chan struct{}, 1)
s := &session{
conn: conn,
config: config,
connectionID: connectionID,
perspective: protocol.PerspectiveServer,
version: v,
aeadChanged: aeadChanged,
conn: conn,
config: config,
connectionID: connectionID,
perspective: protocol.PerspectiveServer,
version: v,
handshakeEvent: handshakeEvent,
}
s.preSetup()
s.cryptoSetup = handshake.NewCryptoSetupTLSServer(
tls,
cryptoStreamConn,
nullAEAD,
aeadChanged,
handshakeEvent,
v,
)
if err := s.postSetup(initialPacketNumber); err != nil {
@ -260,15 +269,15 @@ var newTLSClientSession = func(
paramsChan <-chan handshake.TransportParameters,
initialPacketNumber protocol.PacketNumber,
) (packetHandler, error) {
aeadChanged := make(chan protocol.EncryptionLevel, 2)
handshakeEvent := make(chan struct{}, 1)
s := &session{
conn: conn,
config: config,
connectionID: connectionID,
perspective: protocol.PerspectiveClient,
version: v,
aeadChanged: aeadChanged,
paramsChan: paramsChan,
conn: conn,
config: config,
connectionID: connectionID,
perspective: protocol.PerspectiveClient,
version: v,
handshakeEvent: handshakeEvent,
paramsChan: paramsChan,
}
s.preSetup()
tls.SetCryptoStream(s.cryptoStream)
@ -276,7 +285,7 @@ var newTLSClientSession = func(
s.cryptoStream,
s.connectionID,
hostname,
aeadChanged,
handshakeEvent,
tls,
v,
)
@ -294,12 +303,11 @@ func (s *session) preSetup() {
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
s.rttStats,
)
s.cryptoStream = s.newStream(s.version.CryptoStreamID()).(cryptoStream)
s.cryptoStream = s.newCryptoStream()
}
func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error {
s.handshakeChan = make(chan handshakeEvent, 3)
s.handshakeCompleteChan = make(chan error, 1)
s.handshakeChan = make(chan error, 1)
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1)
@ -314,9 +322,12 @@ func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error {
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version)
s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.version)
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.connFlowController, s.version)
if s.version.UsesTLS() {
s.streamsMap = newStreamsMap(s.newStream, s.perspective)
} else {
s.streamsMap = newStreamsMapLegacy(s.newStream, s.perspective)
}
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version)
s.packer = newPacketPacker(s.connectionID,
initialPacketNumber,
s.cryptoSetup,
@ -324,6 +335,7 @@ func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error {
s.perspective,
s.version,
)
s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.cryptoStream, s.packer.QueueControlFrame)
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
return nil
}
@ -339,7 +351,7 @@ func (s *session) run() error {
}()
var closeErr closeError
aeadChanged := s.aeadChanged
handshakeEvent := s.handshakeEvent
runLoop:
for {
@ -377,16 +389,20 @@ runLoop:
putPacketBuffer(p.header.Raw)
case p := <-s.paramsChan:
s.processTransportParameters(&p)
case l, ok := <-aeadChanged:
case _, ok := <-handshakeEvent:
if !ok { // the aeadChanged chan was closed. This means that the handshake is completed.
s.handshakeComplete = true
aeadChanged = nil // prevent this case from ever being selected again
handshakeEvent = nil // prevent this case from ever being selected again
s.sentPacketHandler.SetHandshakeComplete()
if !s.version.UsesTLS() && s.perspective == protocol.PerspectiveClient {
// In gQUIC, there's no equivalent to the Finished message in TLS
// The server knows that the handshake is complete when it receives the first forward-secure packet sent by the client.
// We need to make sure that the client actually sends such a packet.
s.packer.QueueControlFrame(&wire.PingFrame{})
}
close(s.handshakeChan)
close(s.handshakeCompleteChan)
} else {
s.tryDecryptingQueuedPackets()
s.handshakeChan <- handshakeEvent{encLevel: l}
}
}
@ -403,9 +419,26 @@ runLoop:
s.keepAlivePingSent = true
}
if err := s.sendPacket(); err != nil {
s.closeLocal(err)
sendingAllowed := s.sentPacketHandler.SendingAllowed()
if !sendingAllowed { // if congestion limited, at least try sending an ACK frame
if err := s.maybeSendAckOnlyPacket(); err != nil {
s.closeLocal(err)
}
} else {
// repeatedly try sending until we don't have any more data, or run out of the congestion window
for sendingAllowed {
sentPacket, err := s.sendPacket()
if err != nil {
s.closeLocal(err)
break
}
if !sentPacket {
break
}
sendingAllowed = s.sentPacketHandler.SendingAllowed()
}
}
if !s.receivedTooManyUndecrytablePacketsTime.IsZero() && s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout).Before(now) && len(s.undecryptablePackets) != 0 {
s.closeLocal(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received"))
}
@ -415,17 +448,12 @@ runLoop:
if s.handshakeComplete && now.Sub(s.lastNetworkActivityTime) >= s.config.IdleTimeout {
s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
}
if err := s.streamsMap.DeleteClosedStreams(); err != nil {
s.closeLocal(err)
}
}
// only send the error the handshakeChan when the handshake is not completed yet
// otherwise this chan will already be closed
if !s.handshakeComplete {
s.handshakeCompleteChan <- closeErr.err
s.handshakeChan <- handshakeEvent{err: closeErr.err}
s.handshakeChan <- closeErr.err
}
s.handleCloseError(closeErr)
return closeErr.err
@ -435,6 +463,10 @@ func (s *session) Context() context.Context {
return s.ctx
}
func (s *session) ConnectionState() ConnectionState {
return s.cryptoSetup.ConnectionState()
}
func (s *session) maybeResetTimer() {
var deadline time.Time
if s.config.KeepAlive && s.handshakeComplete && !s.keepAlivePingSent {
@ -504,7 +536,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber)
isRetransmittable := ackhandler.HasRetransmittableFrames(packet.frames)
if err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, isRetransmittable); err != nil {
if err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, p.rcvTime, isRetransmittable); err != nil {
return err
}
@ -524,8 +556,7 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve
s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase))
case *wire.GoawayFrame:
err = errors.New("unimplemented: handling GOAWAY frames")
case *wire.StopWaitingFrame:
s.receivedPacketHandler.IgnoreBelow(frame.LeastUnacked)
case *wire.StopWaitingFrame: // ignore STOP_WAITINGs
case *wire.RstStreamFrame:
err = s.handleRstStreamFrame(frame)
case *wire.MaxDataFrame:
@ -534,6 +565,8 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve
err = s.handleMaxStreamDataFrame(frame)
case *wire.BlockedFrame:
case *wire.StreamBlockedFrame:
case *wire.StopSendingFrame:
err = s.handleStopSendingFrame(frame)
case *wire.PingFrame:
default:
return errors.New("Session BUG: unexpected frame type")
@ -563,9 +596,12 @@ func (s *session) handlePacket(p *receivedPacket) {
func (s *session) handleStreamFrame(frame *wire.StreamFrame) error {
if frame.StreamID == s.version.CryptoStreamID() {
return s.cryptoStream.AddStreamFrame(frame)
if frame.FinBit {
return errors.New("Received STREAM frame with FIN bit for the crypto stream")
}
return s.cryptoStream.handleStreamFrame(frame)
}
str, err := s.streamsMap.GetOrOpenStream(frame.StreamID)
str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID)
if err != nil {
return err
}
@ -574,7 +610,7 @@ func (s *session) handleStreamFrame(frame *wire.StreamFrame) error {
// ignore this StreamFrame
return nil
}
return str.AddStreamFrame(frame)
return str.handleStreamFrame(frame)
}
func (s *session) handleMaxDataFrame(frame *wire.MaxDataFrame) {
@ -582,7 +618,11 @@ func (s *session) handleMaxDataFrame(frame *wire.MaxDataFrame) {
}
func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error {
str, err := s.streamsMap.GetOrOpenStream(frame.StreamID)
if frame.StreamID == s.version.CryptoStreamID() {
s.cryptoStream.handleMaxStreamDataFrame(frame)
return nil
}
str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID)
if err != nil {
return err
}
@ -590,12 +630,15 @@ func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error
// stream is closed and already garbage collected
return nil
}
str.UpdateSendWindow(frame.ByteOffset)
str.handleMaxStreamDataFrame(frame)
return nil
}
func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
str, err := s.streamsMap.GetOrOpenStream(frame.StreamID)
if frame.StreamID == s.version.CryptoStreamID() {
return errors.New("Received RST_STREAM frame for the crypto stream")
}
str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID)
if err != nil {
return err
}
@ -603,11 +646,31 @@ func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
// stream is closed and already garbage collected
return nil
}
return str.RegisterRemoteError(fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode), frame.ByteOffset)
return str.handleRstStreamFrame(frame)
}
func (s *session) handleStopSendingFrame(frame *wire.StopSendingFrame) error {
if frame.StreamID == s.version.CryptoStreamID() {
return errors.New("Received a STOP_SENDING frame for the crypto stream")
}
str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID)
if err != nil {
return err
}
if str == nil {
// stream is closed and already garbage collected
return nil
}
str.handleStopSendingFrame(frame)
return nil
}
func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error {
return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime)
if err := s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime); err != nil {
return err
}
s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked())
return nil
}
func (s *session) closeLocal(e error) {
@ -647,7 +710,7 @@ func (s *session) handleCloseError(closeErr closeError) error {
utils.Errorf("Closing session with error: %s", closeErr.err.Error())
}
s.cryptoStream.Cancel(quicErr)
s.cryptoStream.closeForShutdown(quicErr)
s.streamsMap.CloseWithError(quicErr)
if closeErr.err == errCloseSessionForNewVersion || closeErr.err == handshake.ErrCloseSessionForRetry {
@ -669,111 +732,105 @@ func (s *session) handleCloseError(closeErr closeError) error {
func (s *session) processTransportParameters(params *handshake.TransportParameters) {
s.peerParams = params
s.streamsMap.UpdateMaxStreamLimit(params.MaxStreams)
s.streamsMap.UpdateLimits(params)
if params.OmitConnectionID {
s.packer.SetOmitConnectionID()
}
s.connFlowController.UpdateSendWindow(params.ConnectionFlowControlWindow)
s.streamsMap.Range(func(str streamI) {
str.UpdateSendWindow(params.StreamFlowControlWindow)
})
// the crypto stream is the only open stream at this moment
// so we don't need to update stream flow control windows
}
func (s *session) sendPacket() error {
func (s *session) maybeSendAckOnlyPacket() error {
ack := s.receivedPacketHandler.GetAckFrame()
if ack == nil {
return nil
}
s.packer.QueueControlFrame(ack)
if !s.version.UsesIETFFrameFormat() { // for gQUIC, maybe add a STOP_WAITING
if swf := s.sentPacketHandler.GetStopWaitingFrame(false); swf != nil {
s.packer.QueueControlFrame(swf)
}
}
packet, err := s.packer.PackAckPacket()
if err != nil {
return err
}
return s.sendPackedPacket(packet)
}
func (s *session) sendPacket() (bool, error) {
s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked())
// Get MAX_DATA and MAX_STREAM_DATA frames
// this call triggers the flow controller to increase the flow control windows, if necessary
windowUpdates := s.getWindowUpdates()
for _, f := range windowUpdates {
s.packer.QueueControlFrame(f)
if offset := s.connFlowController.GetWindowUpdate(); offset != 0 {
s.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: offset})
}
if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked {
s.packer.QueueControlFrame(&wire.BlockedFrame{Offset: offset})
}
s.windowUpdateQueue.QueueAll()
ack := s.receivedPacketHandler.GetAckFrame()
if ack != nil {
s.packer.QueueControlFrame(ack)
}
// Repeatedly try sending until we don't have any more data, or run out of the congestion window
// check for retransmissions first
for {
if !s.sentPacketHandler.SendingAllowed() {
if ack == nil {
return nil
}
// If we aren't allowed to send, at least try sending an ACK frame
swf := s.sentPacketHandler.GetStopWaitingFrame(false)
if swf != nil {
s.packer.QueueControlFrame(swf)
}
packet, err := s.packer.PackAckPacket()
if err != nil {
return err
}
return s.sendPackedPacket(packet)
retransmitPacket := s.sentPacketHandler.DequeuePacketForRetransmission()
if retransmitPacket == nil {
break
}
// check for retransmissions first
for {
retransmitPacket := s.sentPacketHandler.DequeuePacketForRetransmission()
if retransmitPacket == nil {
break
}
if retransmitPacket.EncryptionLevel != protocol.EncryptionForwardSecure {
if s.handshakeComplete {
// Don't retransmit handshake packets when the handshake is complete
continue
}
utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber)
// retransmit handshake packets
if retransmitPacket.EncryptionLevel != protocol.EncryptionForwardSecure {
utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber)
if !s.version.UsesIETFFrameFormat() {
s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true))
packet, err := s.packer.PackHandshakeRetransmission(retransmitPacket)
if err != nil {
return err
}
if err = s.sendPackedPacket(packet); err != nil {
return err
}
} else {
utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber)
// resend the frames that were in the packet
for _, frame := range retransmitPacket.GetFramesForRetransmission() {
// TODO: only retransmit WINDOW_UPDATEs if they actually enlarge the window
switch f := frame.(type) {
case *wire.StreamFrame:
s.streamFramer.AddFrameForRetransmission(f)
default:
s.packer.QueueControlFrame(frame)
}
}
}
packet, err := s.packer.PackHandshakeRetransmission(retransmitPacket)
if err != nil {
return false, err
}
if err := s.sendPackedPacket(packet); err != nil {
return false, err
}
return true, nil
}
hasRetransmission := s.streamFramer.HasFramesForRetransmission()
if ack != nil || hasRetransmission {
swf := s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission)
if swf != nil {
s.packer.QueueControlFrame(swf)
// queue all retransmittable frames sent in forward-secure packets
utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber)
// resend the frames that were in the packet
for _, frame := range retransmitPacket.GetFramesForRetransmission() {
// TODO: only retransmit WINDOW_UPDATEs if they actually enlarge the window
switch f := frame.(type) {
case *wire.StreamFrame:
s.streamFramer.AddFrameForRetransmission(f)
default:
s.packer.QueueControlFrame(frame)
}
}
// add a retransmittable frame
if s.sentPacketHandler.ShouldSendRetransmittablePacket() {
s.packer.QueueControlFrame(&wire.PingFrame{})
}
packet, err := s.packer.PackPacket()
if err != nil || packet == nil {
return err
}
if err = s.sendPackedPacket(packet); err != nil {
return err
}
// send every window update twice
for _, f := range windowUpdates {
s.packer.QueueControlFrame(f)
}
windowUpdates = nil
ack = nil
}
hasRetransmission := s.streamFramer.HasFramesForRetransmission()
if !s.version.UsesIETFFrameFormat() && (ack != nil || hasRetransmission) {
if swf := s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission); swf != nil {
s.packer.QueueControlFrame(swf)
}
}
// add a retransmittable frame
if s.sentPacketHandler.ShouldSendRetransmittablePacket() {
s.packer.MakeNextPacketRetransmittable()
}
packet, err := s.packer.PackPacket()
if err != nil || packet == nil {
return false, err
}
if err := s.sendPackedPacket(packet); err != nil {
return false, err
}
return true, nil
}
func (s *session) sendPackedPacket(packet *packedPacket) error {
@ -824,7 +881,7 @@ func (s *session) GetOrOpenStream(id protocol.StreamID) (Stream, error) {
return str, err
}
// make sure to return an actual nil value here, not an Stream with value nil
return nil, err
return str, err
}
// AcceptStream returns the next stream openend by the peer
@ -841,18 +898,6 @@ func (s *session) OpenStreamSync() (Stream, error) {
return s.streamsMap.OpenStreamSync()
}
func (s *session) WaitUntilHandshakeComplete() error {
return <-s.handshakeCompleteChan
}
func (s *session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) {
s.packer.QueueControlFrame(&wire.RstStreamFrame{
StreamID: id,
ByteOffset: offset,
})
s.scheduleSending()
}
func (s *session) newStream(id protocol.StreamID) streamI {
var initialSendWindow protocol.ByteCount
if s.peerParams != nil {
@ -867,7 +912,21 @@ func (s *session) newStream(id protocol.StreamID) streamI {
initialSendWindow,
s.rttStats,
)
return newStream(id, s.scheduleSending, s.queueResetStreamFrame, flowController, s.version)
return newStream(id, s, flowController, s.version)
}
func (s *session) newCryptoStream() cryptoStreamI {
id := s.version.CryptoStreamID()
flowController := flowcontrol.NewStreamFlowController(
id,
s.version.StreamContributesToConnectionFlowControl(id),
s.connFlowController,
protocol.ReceiveStreamFlowControlWindow,
protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow),
0,
s.rttStats,
)
return newCryptoStream(s, flowController, s.version)
}
func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error {
@ -908,22 +967,25 @@ func (s *session) tryDecryptingQueuedPackets() {
s.undecryptablePackets = s.undecryptablePackets[:0]
}
func (s *session) getWindowUpdates() []wire.Frame {
var res []wire.Frame
s.streamsMap.Range(func(str streamI) {
if offset := str.GetWindowUpdate(); offset != 0 {
res = append(res, &wire.MaxStreamDataFrame{
StreamID: str.StreamID(),
ByteOffset: offset,
})
}
})
if offset := s.connFlowController.GetWindowUpdate(); offset != 0 {
res = append(res, &wire.MaxDataFrame{
ByteOffset: offset,
})
func (s *session) queueControlFrame(f wire.Frame) {
s.packer.QueueControlFrame(f)
s.scheduleSending()
}
func (s *session) onHasWindowUpdate(id protocol.StreamID) {
s.windowUpdateQueue.Add(id)
s.scheduleSending()
}
func (s *session) onHasStreamData(id protocol.StreamID) {
s.streamFramer.AddActiveStream(id)
s.scheduleSending()
}
func (s *session) onStreamCompleted(id protocol.StreamID) {
if err := s.streamsMap.DeleteStream(id); err != nil {
s.Close(err)
}
return res
}
func (s *session) LocalAddr() net.Addr {
@ -935,11 +997,11 @@ func (s *session) RemoteAddr() net.Addr {
return s.conn.RemoteAddr()
}
func (s *session) handshakeStatus() <-chan handshakeEvent {
func (s *session) handshakeStatus() <-chan error {
return s.handshakeChan
}
func (s *session) getCryptoStream() cryptoStream {
func (s *session) getCryptoStream() cryptoStreamI {
return s.cryptoStream
}

File diff suppressed because it is too large Load Diff

View File

@ -1,88 +1,85 @@
package quic
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
const (
errorCodeStopping protocol.ApplicationErrorCode = 0
errorCodeStoppingGQUIC protocol.ApplicationErrorCode = 7
)
// The streamSender is notified by the stream about various events.
type streamSender interface {
queueControlFrame(wire.Frame)
onHasWindowUpdate(protocol.StreamID)
onHasStreamData(protocol.StreamID)
onStreamCompleted(protocol.StreamID)
}
// Each of the both stream halves gets its own uniStreamSender.
// This is necessary in order to keep track when both halves have been completed.
type uniStreamSender struct {
streamSender
onStreamCompletedImpl func()
}
func (s *uniStreamSender) queueControlFrame(f wire.Frame) {
s.streamSender.queueControlFrame(f)
}
func (s *uniStreamSender) onHasWindowUpdate(id protocol.StreamID) {
s.streamSender.onHasWindowUpdate(id)
}
func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) {
s.streamSender.onHasStreamData(id)
}
func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) {
s.onStreamCompletedImpl()
}
var _ streamSender = &uniStreamSender{}
type streamI interface {
Stream
AddStreamFrame(*wire.StreamFrame) error
RegisterRemoteError(error, protocol.ByteCount) error
HasDataForWriting() bool
GetDataForWriting(maxBytes protocol.ByteCount) (data []byte, shouldSendFin bool)
GetWriteOffset() protocol.ByteCount
Finished() bool
Cancel(error)
// methods needed for flow control
GetWindowUpdate() protocol.ByteCount
UpdateSendWindow(protocol.ByteCount)
IsFlowControlBlocked() bool
closeForShutdown(error)
// for receiving
handleStreamFrame(*wire.StreamFrame) error
handleRstStreamFrame(*wire.RstStreamFrame) error
getWindowUpdate() protocol.ByteCount
// for sending
handleStopSendingFrame(*wire.StopSendingFrame)
popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool)
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
}
type cryptoStream interface {
streamI
SetReadOffset(protocol.ByteCount)
}
var _ receiveStreamI = (streamI)(nil)
var _ sendStreamI = (streamI)(nil)
// A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface
//
// Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually.
type stream struct {
mutex sync.Mutex
receiveStream
sendStream
ctx context.Context
ctxCancel context.CancelFunc
completedMutex sync.Mutex
sender streamSender
receiveStreamCompleted bool
sendStreamCompleted bool
streamID protocol.StreamID
onData func()
// onReset is a callback that should send a RST_STREAM
onReset func(protocol.StreamID, protocol.ByteCount)
readPosInFrame int
writeOffset protocol.ByteCount
readOffset protocol.ByteCount
// Once set, the errors must not be changed!
err error
// cancelled is set when Cancel() is called
cancelled utils.AtomicBool
// finishedReading is set once we read a frame with a FinBit
finishedReading utils.AtomicBool
// finisedWriting is set once Close() is called
finishedWriting utils.AtomicBool
// resetLocally is set if Reset() is called
resetLocally utils.AtomicBool
// resetRemotely is set if RegisterRemoteError() is called
resetRemotely utils.AtomicBool
frameQueue *streamFrameSorter
readChan chan struct{}
readDeadline time.Time
dataForWriting []byte
finSent utils.AtomicBool
rstSent utils.AtomicBool
writeChan chan struct{}
writeDeadline time.Time
flowController flowcontrol.StreamFlowController
version protocol.VersionNumber
version protocol.VersionNumber
}
var _ Stream = &stream{}
var _ streamI = &stream{}
type deadlineError struct{}
@ -92,290 +89,58 @@ func (deadlineError) Timeout() bool { return true }
var errDeadline net.Error = &deadlineError{}
type streamCanceledError struct {
error
errorCode protocol.ApplicationErrorCode
}
func (streamCanceledError) Canceled() bool { return true }
func (e streamCanceledError) ErrorCode() protocol.ApplicationErrorCode { return e.errorCode }
var _ StreamError = &streamCanceledError{}
// newStream creates a new Stream
func newStream(StreamID protocol.StreamID,
onData func(),
onReset func(protocol.StreamID, protocol.ByteCount),
func newStream(streamID protocol.StreamID,
sender streamSender,
flowController flowcontrol.StreamFlowController,
version protocol.VersionNumber,
) *stream {
s := &stream{
onData: onData,
onReset: onReset,
streamID: StreamID,
flowController: flowController,
frameQueue: newStreamFrameSorter(),
readChan: make(chan struct{}, 1),
writeChan: make(chan struct{}, 1),
version: version,
s := &stream{sender: sender}
senderForSendStream := &uniStreamSender{
streamSender: sender,
onStreamCompletedImpl: func() {
s.completedMutex.Lock()
s.sendStreamCompleted = true
s.checkIfCompleted()
s.completedMutex.Unlock()
},
}
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
s.sendStream = *newSendStream(streamID, senderForSendStream, flowController, version)
senderForReceiveStream := &uniStreamSender{
streamSender: sender,
onStreamCompletedImpl: func() {
s.completedMutex.Lock()
s.receiveStreamCompleted = true
s.checkIfCompleted()
s.completedMutex.Unlock()
},
}
s.receiveStream = *newReceiveStream(streamID, senderForReceiveStream, flowController)
return s
}
// Read implements io.Reader. It is not thread safe!
func (s *stream) Read(p []byte) (int, error) {
s.mutex.Lock()
err := s.err
s.mutex.Unlock()
if s.cancelled.Get() || s.resetLocally.Get() {
return 0, err
}
if s.finishedReading.Get() {
return 0, io.EOF
}
bytesRead := 0
for bytesRead < len(p) {
s.mutex.Lock()
frame := s.frameQueue.Head()
if frame == nil && bytesRead > 0 {
err = s.err
s.mutex.Unlock()
return bytesRead, err
}
var err error
for {
// Stop waiting on errors
if s.resetLocally.Get() || s.cancelled.Get() {
err = s.err
break
}
deadline := s.readDeadline
if !deadline.IsZero() && !time.Now().Before(deadline) {
err = errDeadline
break
}
if frame != nil {
s.readPosInFrame = int(s.readOffset - frame.Offset)
break
}
s.mutex.Unlock()
if deadline.IsZero() {
<-s.readChan
} else {
select {
case <-s.readChan:
case <-time.After(deadline.Sub(time.Now())):
}
}
s.mutex.Lock()
frame = s.frameQueue.Head()
}
s.mutex.Unlock()
if err != nil {
return bytesRead, err
}
m := utils.Min(len(p)-bytesRead, int(frame.DataLen())-s.readPosInFrame)
if bytesRead > len(p) {
return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p))
}
if s.readPosInFrame > int(frame.DataLen()) {
return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, frame.DataLen())
}
copy(p[bytesRead:], frame.Data[s.readPosInFrame:])
s.readPosInFrame += m
bytesRead += m
s.readOffset += protocol.ByteCount(m)
// when a RST_STREAM was received, the was already informed about the final byteOffset for this stream
if !s.resetRemotely.Get() {
s.flowController.AddBytesRead(protocol.ByteCount(m))
}
s.onData() // so that a possible WINDOW_UPDATE is sent
if s.readPosInFrame >= int(frame.DataLen()) {
fin := frame.FinBit
s.mutex.Lock()
s.frameQueue.Pop()
s.mutex.Unlock()
if fin {
s.finishedReading.Set(true)
return bytesRead, io.EOF
}
}
}
return bytesRead, nil
// need to define StreamID() here, since both receiveStream and readStream have a StreamID()
func (s *stream) StreamID() protocol.StreamID {
// the result is same for receiveStream and sendStream
return s.sendStream.StreamID()
}
func (s *stream) Write(p []byte) (int, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.resetLocally.Get() || s.err != nil {
return 0, s.err
}
if s.finishedWriting.Get() {
return 0, fmt.Errorf("write on closed stream %d", s.streamID)
}
if len(p) == 0 {
return 0, nil
}
s.dataForWriting = make([]byte, len(p))
copy(s.dataForWriting, p)
s.onData()
var err error
for {
deadline := s.writeDeadline
if !deadline.IsZero() && !time.Now().Before(deadline) {
err = errDeadline
break
}
if s.dataForWriting == nil || s.err != nil {
break
}
s.mutex.Unlock()
if deadline.IsZero() {
<-s.writeChan
} else {
select {
case <-s.writeChan:
case <-time.After(deadline.Sub(time.Now())):
}
}
s.mutex.Lock()
}
if err != nil {
return 0, err
}
if s.err != nil {
return len(p) - len(s.dataForWriting), s.err
}
return len(p), nil
}
func (s *stream) GetWriteOffset() protocol.ByteCount {
return s.writeOffset
}
// HasDataForWriting says if there's stream available to be dequeued for writing
func (s *stream) HasDataForWriting() bool {
s.mutex.Lock()
hasData := s.err == nil && // nothing should be sent if an error occurred
(len(s.dataForWriting) > 0 || // there is data queued for sending
s.finishedWriting.Get() && !s.finSent.Get()) // if there is no data, but writing finished and the FIN hasn't been sent yet
s.mutex.Unlock()
return hasData
}
func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) {
data, shouldSendFin := s.getDataForWritingImpl(maxBytes)
if shouldSendFin {
s.finSent.Set(true)
}
return data, shouldSendFin
}
func (s *stream) getDataForWritingImpl(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.err != nil || s.dataForWriting == nil {
return nil, s.finishedWriting.Get() && !s.finSent.Get()
}
// TODO(#657): Flow control for the crypto stream
if s.streamID != s.version.CryptoStreamID() {
maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize())
}
if maxBytes == 0 {
return nil, false
}
var ret []byte
if protocol.ByteCount(len(s.dataForWriting)) > maxBytes {
ret = s.dataForWriting[:maxBytes]
s.dataForWriting = s.dataForWriting[maxBytes:]
} else {
ret = s.dataForWriting
s.dataForWriting = nil
s.signalWrite()
}
s.writeOffset += protocol.ByteCount(len(ret))
s.flowController.AddBytesSent(protocol.ByteCount(len(ret)))
return ret, s.finishedWriting.Get() && s.dataForWriting == nil && !s.finSent.Get()
}
// Close implements io.Closer
func (s *stream) Close() error {
s.finishedWriting.Set(true)
s.ctxCancel()
s.onData()
return nil
}
func (s *stream) shouldSendReset() bool {
if s.rstSent.Get() {
return false
}
return (s.resetLocally.Get() || s.resetRemotely.Get()) && !s.finishedWriteAndSentFin()
}
// AddStreamFrame adds a new stream frame
func (s *stream) AddStreamFrame(frame *wire.StreamFrame) error {
maxOffset := frame.Offset + frame.DataLen()
if err := s.flowController.UpdateHighestReceived(maxOffset, frame.FinBit); err != nil {
if err := s.sendStream.Close(); err != nil {
return err
}
s.mutex.Lock()
defer s.mutex.Unlock()
if err := s.frameQueue.Push(frame); err != nil && err != errDuplicateStreamData {
return err
}
s.signalRead()
return nil
}
// signalRead performs a non-blocking send on the readChan
func (s *stream) signalRead() {
select {
case s.readChan <- struct{}{}:
default:
}
}
// signalRead performs a non-blocking send on the writeChan
func (s *stream) signalWrite() {
select {
case s.writeChan <- struct{}{}:
default:
}
}
func (s *stream) SetReadDeadline(t time.Time) error {
s.mutex.Lock()
oldDeadline := s.readDeadline
s.readDeadline = t
s.mutex.Unlock()
// if the new deadline is before the currently set deadline, wake up Read()
if t.Before(oldDeadline) {
s.signalRead()
}
return nil
}
func (s *stream) SetWriteDeadline(t time.Time) error {
s.mutex.Lock()
oldDeadline := s.writeDeadline
s.writeDeadline = t
s.mutex.Unlock()
if t.Before(oldDeadline) {
s.signalWrite()
}
// in gQUIC, we need to send a RST_STREAM with the final offset if CancelRead() was called
s.receiveStream.onClose(s.sendStream.getWriteOffset())
return nil
}
@ -385,107 +150,31 @@ func (s *stream) SetDeadline(t time.Time) error {
return nil
}
// CloseRemote makes the stream receive a "virtual" FIN stream frame at a given offset
func (s *stream) CloseRemote(offset protocol.ByteCount) {
s.AddStreamFrame(&wire.StreamFrame{FinBit: true, Offset: offset})
// CloseForShutdown closes a stream abruptly.
// It makes Read and Write unblock (and return the error) immediately.
// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
func (s *stream) closeForShutdown(err error) {
s.sendStream.closeForShutdown(err)
s.receiveStream.closeForShutdown(err)
}
// Cancel is called by session to indicate that an error occurred
// The stream should will be closed immediately
func (s *stream) Cancel(err error) {
s.mutex.Lock()
s.cancelled.Set(true)
s.ctxCancel()
// errors must not be changed!
if s.err == nil {
s.err = err
s.signalRead()
s.signalWrite()
}
s.mutex.Unlock()
}
// resets the stream locally
func (s *stream) Reset(err error) {
if s.resetLocally.Get() {
return
}
s.mutex.Lock()
s.resetLocally.Set(true)
s.ctxCancel()
// errors must not be changed!
if s.err == nil {
s.err = err
s.signalRead()
s.signalWrite()
}
if s.shouldSendReset() {
s.onReset(s.streamID, s.writeOffset)
s.rstSent.Set(true)
}
s.mutex.Unlock()
}
// resets the stream remotely
func (s *stream) RegisterRemoteError(err error, offset protocol.ByteCount) error {
if s.resetRemotely.Get() {
return nil
}
s.mutex.Lock()
s.resetRemotely.Set(true)
s.ctxCancel()
// errors must not be changed!
if s.err == nil {
s.err = err
s.signalWrite()
}
if err := s.flowController.UpdateHighestReceived(offset, true); err != nil {
func (s *stream) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
if err := s.receiveStream.handleRstStreamFrame(frame); err != nil {
return err
}
if s.shouldSendReset() {
s.onReset(s.streamID, s.writeOffset)
s.rstSent.Set(true)
if !s.version.UsesIETFFrameFormat() {
s.handleStopSendingFrame(&wire.StopSendingFrame{
StreamID: s.StreamID(),
ErrorCode: frame.ErrorCode,
})
}
s.mutex.Unlock()
return nil
}
func (s *stream) finishedWriteAndSentFin() bool {
return s.finishedWriting.Get() && s.finSent.Get()
}
func (s *stream) Finished() bool {
return s.cancelled.Get() ||
(s.finishedReading.Get() && s.finishedWriteAndSentFin()) ||
(s.resetRemotely.Get() && s.rstSent.Get()) ||
(s.finishedReading.Get() && s.rstSent.Get()) ||
(s.finishedWriteAndSentFin() && s.resetRemotely.Get())
}
func (s *stream) Context() context.Context {
return s.ctx
}
func (s *stream) StreamID() protocol.StreamID {
return s.streamID
}
func (s *stream) UpdateSendWindow(n protocol.ByteCount) {
s.flowController.UpdateSendWindow(n)
}
func (s *stream) IsFlowControlBlocked() bool {
return s.flowController.IsBlocked()
}
func (s *stream) GetWindowUpdate() protocol.ByteCount {
return s.flowController.GetWindowUpdate()
}
// SetReadOffset sets the read offset.
// It is only needed for the crypto stream.
// It must not be called concurrently with any other stream methods, especially Read and Write.
func (s *stream) SetReadOffset(offset protocol.ByteCount) {
s.readOffset = offset
s.frameQueue.readPosition = offset
// checkIfCompleted is called from the uniStreamSender, when one of the stream halves is completed.
// It makes sure that the onStreamCompleted callback is only called if both receive and send side have completed.
func (s *stream) checkIfCompleted() {
if s.sendStreamCompleted && s.receiveStreamCompleted {
s.sender.onStreamCompleted(s.StreamID())
}
}

View File

@ -1,33 +1,35 @@
package quic
import (
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type streamFramer struct {
streamsMap *streamsMap
cryptoStream streamI
streamGetter streamGetter
cryptoStream cryptoStreamI
version protocol.VersionNumber
connFlowController flowcontrol.ConnectionFlowController
retransmissionQueue []*wire.StreamFrame
blockedFrameQueue []wire.Frame
streamQueueMutex sync.Mutex
activeStreams map[protocol.StreamID]struct{}
streamQueue []protocol.StreamID
hasCryptoStreamData bool
}
func newStreamFramer(
cryptoStream streamI,
streamsMap *streamsMap,
cfc flowcontrol.ConnectionFlowController,
cryptoStream cryptoStreamI,
streamGetter streamGetter,
v protocol.VersionNumber,
) *streamFramer {
return &streamFramer{
streamsMap: streamsMap,
cryptoStream: cryptoStream,
connFlowController: cfc,
version: v,
streamGetter: streamGetter,
cryptoStream: cryptoStream,
activeStreams: make(map[protocol.StreamID]struct{}),
version: v,
}
}
@ -35,114 +37,101 @@ func (f *streamFramer) AddFrameForRetransmission(frame *wire.StreamFrame) {
f.retransmissionQueue = append(f.retransmissionQueue, frame)
}
func (f *streamFramer) AddActiveStream(id protocol.StreamID) {
if id == f.version.CryptoStreamID() { // the crypto stream is handled separately
f.streamQueueMutex.Lock()
f.hasCryptoStreamData = true
f.streamQueueMutex.Unlock()
return
}
f.streamQueueMutex.Lock()
if _, ok := f.activeStreams[id]; !ok {
f.streamQueue = append(f.streamQueue, id)
f.activeStreams[id] = struct{}{}
}
f.streamQueueMutex.Unlock()
}
func (f *streamFramer) PopStreamFrames(maxLen protocol.ByteCount) []*wire.StreamFrame {
fs, currentLen := f.maybePopFramesForRetransmission(maxLen)
return append(fs, f.maybePopNormalFrames(maxLen-currentLen)...)
}
func (f *streamFramer) PopBlockedFrame() wire.Frame {
if len(f.blockedFrameQueue) == 0 {
return nil
}
frame := f.blockedFrameQueue[0]
f.blockedFrameQueue = f.blockedFrameQueue[1:]
return frame
}
func (f *streamFramer) HasFramesForRetransmission() bool {
return len(f.retransmissionQueue) > 0
}
func (f *streamFramer) HasCryptoStreamFrame() bool {
return f.cryptoStream.HasDataForWriting()
func (f *streamFramer) HasCryptoStreamData() bool {
f.streamQueueMutex.Lock()
hasCryptoStreamData := f.hasCryptoStreamData
f.streamQueueMutex.Unlock()
return hasCryptoStreamData
}
// TODO(lclemente): This is somewhat duplicate with the normal path for generating frames.
func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.StreamFrame {
if !f.HasCryptoStreamFrame() {
return nil
}
frame := &wire.StreamFrame{
StreamID: f.cryptoStream.StreamID(),
Offset: f.cryptoStream.GetWriteOffset(),
}
frameHeaderBytes, _ := frame.MinLength(f.version) // can never error
frame.Data, frame.FinBit = f.cryptoStream.GetDataForWriting(maxLen - frameHeaderBytes)
f.streamQueueMutex.Lock()
frame, hasMoreData := f.cryptoStream.popStreamFrame(maxLen)
f.hasCryptoStreamData = hasMoreData
f.streamQueueMutex.Unlock()
return frame
}
func (f *streamFramer) maybePopFramesForRetransmission(maxLen protocol.ByteCount) (res []*wire.StreamFrame, currentLen protocol.ByteCount) {
func (f *streamFramer) maybePopFramesForRetransmission(maxTotalLen protocol.ByteCount) (res []*wire.StreamFrame, currentLen protocol.ByteCount) {
for len(f.retransmissionQueue) > 0 {
frame := f.retransmissionQueue[0]
frame.DataLenPresent = true
frameHeaderLen, _ := frame.MinLength(f.version) // can never error
if currentLen+frameHeaderLen >= maxLen {
frameHeaderLen := frame.MinLength(f.version) // can never error
maxLen := maxTotalLen - currentLen
if frameHeaderLen+frame.DataLen() > maxLen && maxLen < protocol.MinStreamFrameSize {
break
}
currentLen += frameHeaderLen
splitFrame := maybeSplitOffFrame(frame, maxLen-currentLen)
splitFrame := maybeSplitOffFrame(frame, maxLen-frameHeaderLen)
if splitFrame != nil { // StreamFrame was split
res = append(res, splitFrame)
currentLen += splitFrame.DataLen()
currentLen += frameHeaderLen + splitFrame.DataLen()
break
}
f.retransmissionQueue = f.retransmissionQueue[1:]
res = append(res, frame)
currentLen += frame.DataLen()
currentLen += frameHeaderLen + frame.DataLen()
}
return
}
func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []*wire.StreamFrame) {
frame := &wire.StreamFrame{DataLenPresent: true}
func (f *streamFramer) maybePopNormalFrames(maxTotalLen protocol.ByteCount) []*wire.StreamFrame {
var currentLen protocol.ByteCount
fn := func(s streamI) (bool, error) {
if s == nil {
return true, nil
var frames []*wire.StreamFrame
f.streamQueueMutex.Lock()
// pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet
numActiveStreams := len(f.streamQueue)
for i := 0; i < numActiveStreams; i++ {
if maxTotalLen-currentLen < protocol.MinStreamFrameSize {
break
}
frame.StreamID = s.StreamID()
frame.Offset = s.GetWriteOffset()
// not perfect, but thread-safe since writeOffset is only written when getting data
frameHeaderBytes, _ := frame.MinLength(f.version) // can never error
if currentLen+frameHeaderBytes > maxBytes {
return false, nil // theoretically, we could find another stream that fits, but this is quite unlikely, so we stop here
id := f.streamQueue[0]
f.streamQueue = f.streamQueue[1:]
str, err := f.streamGetter.GetOrOpenSendStream(id)
if err != nil { // can happen if the stream completed after it said it had data
delete(f.activeStreams, id)
continue
}
maxLen := maxBytes - currentLen - frameHeaderBytes
if s.HasDataForWriting() {
frame.Data, frame.FinBit = s.GetDataForWriting(maxLen)
frame, hasMoreData := str.popStreamFrame(maxTotalLen - currentLen)
if hasMoreData { // put the stream back in the queue (at the end)
f.streamQueue = append(f.streamQueue, id)
} else { // no more data to send. Stream is not active any more
delete(f.activeStreams, id)
}
if len(frame.Data) == 0 && !frame.FinBit {
return true, nil
if frame == nil { // can happen if the receiveStream was canceled after it said it had data
continue
}
// Finally, check if we are now FC blocked and should queue a BLOCKED frame
if !frame.FinBit && s.IsFlowControlBlocked() {
f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.StreamBlockedFrame{StreamID: s.StreamID()})
}
if f.connFlowController.IsBlocked() {
f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.BlockedFrame{})
}
res = append(res, frame)
currentLen += frameHeaderBytes + frame.DataLen()
if currentLen == maxBytes {
return false, nil
}
frame = &wire.StreamFrame{DataLenPresent: true}
return true, nil
frames = append(frames, frame)
currentLen += frame.MinLength(f.version) + frame.DataLen()
}
f.streamsMap.RoundRobinIterate(fn)
return
f.streamQueueMutex.Unlock()
return frames
}
// maybeSplitOffFrame removes the first n bytes and returns them as a separate frame. If n >= len(frame), nil is returned and nothing is modified.

View File

@ -2,9 +2,9 @@ package quic
import (
"bytes"
"errors"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
@ -21,12 +21,13 @@ var _ = Describe("Stream Framer", func() {
var (
retransmittedFrame1, retransmittedFrame2 *wire.StreamFrame
framer *streamFramer
streamsMap *streamsMap
stream1, stream2 *mocks.MockStreamI
connFC *mocks.MockConnectionFlowController
cryptoStream *MockCryptoStream
stream1, stream2 *MockSendStreamI
streamGetter *MockStreamGetter
)
BeforeEach(func() {
streamGetter = NewMockStreamGetter(mockCtrl)
retransmittedFrame1 = &wire.StreamFrame{
StreamID: 5,
Data: []byte{0x13, 0x37},
@ -36,25 +37,14 @@ var _ = Describe("Stream Framer", func() {
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
}
stream1 = mocks.NewMockStreamI(mockCtrl)
stream1 = NewMockSendStreamI(mockCtrl)
stream1.EXPECT().StreamID().Return(protocol.StreamID(5)).AnyTimes()
stream2 = mocks.NewMockStreamI(mockCtrl)
stream2 = NewMockSendStreamI(mockCtrl)
stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes()
streamsMap = newStreamsMap(nil, protocol.PerspectiveServer, versionGQUICFrames)
streamsMap.putStream(stream1)
streamsMap.putStream(stream2)
connFC = mocks.NewMockConnectionFlowController(mockCtrl)
framer = newStreamFramer(nil, streamsMap, connFC, versionGQUICFrames)
cryptoStream = NewMockCryptoStream(mockCtrl)
framer = newStreamFramer(cryptoStream, streamGetter, versionGQUICFrames)
})
setNoData := func(str *mocks.MockStreamI) {
str.EXPECT().HasDataForWriting().Return(false).AnyTimes()
str.EXPECT().GetDataForWriting(gomock.Any()).Return(nil, false).AnyTimes()
str.EXPECT().GetWriteOffset().AnyTimes()
}
It("says if it has retransmissions", func() {
Expect(framer.HasFramesForRetransmission()).To(BeFalse())
framer.AddFrameForRetransmission(retransmittedFrame1)
@ -62,119 +52,220 @@ var _ = Describe("Stream Framer", func() {
})
It("sets the DataLenPresent for dequeued retransmitted frames", func() {
setNoData(stream1)
setNoData(stream2)
framer.AddFrameForRetransmission(retransmittedFrame1)
fs := framer.PopStreamFrames(protocol.MaxByteCount)
Expect(fs).To(HaveLen(1))
Expect(fs[0].DataLenPresent).To(BeTrue())
})
It("sets the DataLenPresent for dequeued normal frames", func() {
connFC.EXPECT().IsBlocked()
setNoData(stream2)
stream1.EXPECT().GetWriteOffset()
stream1.EXPECT().HasDataForWriting().Return(true)
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false)
stream1.EXPECT().IsFlowControlBlocked()
fs := framer.PopStreamFrames(protocol.MaxByteCount)
Expect(fs).To(HaveLen(1))
Expect(fs[0].DataLenPresent).To(BeTrue())
Context("handling the crypto stream", func() {
It("says if it has crypto stream data", func() {
Expect(framer.HasCryptoStreamData()).To(BeFalse())
framer.AddActiveStream(framer.version.CryptoStreamID())
Expect(framer.HasCryptoStreamData()).To(BeTrue())
})
It("says that it doesn't have crypto stream data after popping all data", func() {
streamID := framer.version.CryptoStreamID()
f := &wire.StreamFrame{
StreamID: streamID,
Data: []byte("foobar"),
}
cryptoStream.EXPECT().popStreamFrame(protocol.ByteCount(1000)).Return(f, false)
framer.AddActiveStream(streamID)
Expect(framer.PopCryptoStreamFrame(1000)).To(Equal(f))
Expect(framer.HasCryptoStreamData()).To(BeFalse())
})
It("says that it has more crypto stream data if not all data was popped", func() {
streamID := framer.version.CryptoStreamID()
f := &wire.StreamFrame{
StreamID: streamID,
Data: []byte("foobar"),
}
cryptoStream.EXPECT().popStreamFrame(protocol.ByteCount(1000)).Return(f, true)
framer.AddActiveStream(streamID)
Expect(framer.PopCryptoStreamFrame(1000)).To(Equal(f))
Expect(framer.HasCryptoStreamData()).To(BeTrue())
})
})
Context("Popping", func() {
BeforeEach(func() {
// nothing is blocked here
connFC.EXPECT().IsBlocked().AnyTimes()
stream1.EXPECT().IsFlowControlBlocked().Return(false).AnyTimes()
stream2.EXPECT().IsFlowControlBlocked().Return(false).AnyTimes()
})
It("returns nil when popping an empty framer", func() {
setNoData(stream1)
setNoData(stream2)
Expect(framer.PopStreamFrames(1000)).To(BeEmpty())
})
It("pops frames for retransmission", func() {
setNoData(stream1)
setNoData(stream2)
framer.AddFrameForRetransmission(retransmittedFrame1)
framer.AddFrameForRetransmission(retransmittedFrame2)
fs := framer.PopStreamFrames(1000)
Expect(fs).To(HaveLen(2))
Expect(fs[0]).To(Equal(retransmittedFrame1))
Expect(fs[1]).To(Equal(retransmittedFrame2))
Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1, retransmittedFrame2}))
// make sure the frames are actually removed, and not returned a second time
Expect(framer.PopStreamFrames(1000)).To(BeEmpty())
})
It("returns normal frames", func() {
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false)
stream1.EXPECT().HasDataForWriting().Return(true)
stream1.EXPECT().GetWriteOffset()
setNoData(stream2)
fs := framer.PopStreamFrames(1000)
Expect(fs).To(HaveLen(1))
Expect(fs[0].StreamID).To(Equal(stream1.StreamID()))
Expect(fs[0].Data).To(Equal([]byte("foobar")))
Expect(fs[0].FinBit).To(BeFalse())
})
It("returns multiple normal frames", func() {
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false)
stream1.EXPECT().HasDataForWriting().Return(true)
stream1.EXPECT().GetWriteOffset()
stream2.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobaz"), false)
stream2.EXPECT().HasDataForWriting().Return(true)
stream2.EXPECT().GetWriteOffset()
fs := framer.PopStreamFrames(1000)
Expect(fs).To(HaveLen(2))
// Swap if we dequeued in other order
if fs[0].StreamID != stream1.StreamID() {
fs[0], fs[1] = fs[1], fs[0]
}
Expect(fs[0].StreamID).To(Equal(stream1.StreamID()))
Expect(fs[0].Data).To(Equal([]byte("foobar")))
Expect(fs[1].StreamID).To(Equal(stream2.StreamID()))
Expect(fs[1].Data).To(Equal([]byte("foobaz")))
})
It("returns retransmission frames before normal frames", func() {
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false)
stream1.EXPECT().HasDataForWriting().Return(true)
stream1.EXPECT().GetWriteOffset()
setNoData(stream2)
framer.AddFrameForRetransmission(retransmittedFrame1)
fs := framer.PopStreamFrames(1000)
Expect(fs).To(HaveLen(2))
Expect(fs[0]).To(Equal(retransmittedFrame1))
Expect(fs[1].StreamID).To(Equal(stream1.StreamID()))
})
It("does not pop empty frames", func() {
stream1.EXPECT().HasDataForWriting().Return(false)
stream1.EXPECT().GetWriteOffset()
setNoData(stream2)
fs := framer.PopStreamFrames(5)
It("doesn't pop frames for retransmission, if the size would be smaller than the minimum STREAM frame size", func() {
framer.AddFrameForRetransmission(&wire.StreamFrame{
StreamID: id1,
Data: bytes.Repeat([]byte{'a'}, int(protocol.MinStreamFrameSize)),
})
fs := framer.PopStreamFrames(protocol.MinStreamFrameSize - 1)
Expect(fs).To(BeEmpty())
})
It("uses the round-robin scheduling", func() {
streamFrameHeaderLen := protocol.ByteCount(4)
stream1.EXPECT().GetDataForWriting(10-streamFrameHeaderLen).Return(bytes.Repeat([]byte("f"), int(10-streamFrameHeaderLen)), false)
stream1.EXPECT().HasDataForWriting().Return(true)
stream1.EXPECT().GetWriteOffset()
stream2.EXPECT().GetDataForWriting(protocol.ByteCount(10-streamFrameHeaderLen)).Return(bytes.Repeat([]byte("e"), int(10-streamFrameHeaderLen)), false)
stream2.EXPECT().HasDataForWriting().Return(true)
stream2.EXPECT().GetWriteOffset()
fs := framer.PopStreamFrames(10)
Expect(fs).To(HaveLen(1))
// it doesn't matter here if this data is from stream1 or from stream2...
firstStreamID := fs[0].StreamID
fs = framer.PopStreamFrames(10)
Expect(fs).To(HaveLen(1))
// ... but the data popped this time has to be from the other stream
Expect(fs[0].StreamID).ToNot(Equal(firstStreamID))
It("pops frames for retransmission, even if the remaining space in the packet is too small, if the frame doesn't need to be split", func() {
framer.AddFrameForRetransmission(retransmittedFrame1)
fs := framer.PopStreamFrames(protocol.MinStreamFrameSize - 1)
Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1}))
})
It("pops frames for retransmission, if the remaining size is the miniumum STREAM frame size", func() {
framer.AddFrameForRetransmission(retransmittedFrame1)
fs := framer.PopStreamFrames(protocol.MinStreamFrameSize)
Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1}))
})
It("returns normal frames", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
f := &wire.StreamFrame{
StreamID: id1,
Data: []byte("foobar"),
Offset: 42,
}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false)
framer.AddActiveStream(id1)
fs := framer.PopStreamFrames(1000)
Expect(fs).To(Equal([]*wire.StreamFrame{f}))
})
It("skips a stream that was reported active, but was completed shortly after", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(nil, errors.New("stream was already deleted"))
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
f := &wire.StreamFrame{
StreamID: id2,
Data: []byte("foobar"),
}
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f, false)
framer.AddActiveStream(id1)
framer.AddActiveStream(id2)
Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f}))
})
It("skips a stream that was reported active, but doesn't have any data", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
f := &wire.StreamFrame{
StreamID: id2,
Data: []byte("foobar"),
}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(nil, false)
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f, false)
framer.AddActiveStream(id1)
framer.AddActiveStream(id2)
Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f}))
})
It("pops from a stream multiple times, if it has enough data", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2)
f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
f2 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, true)
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false)
framer.AddActiveStream(id1) // only add it once
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f1}))
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f2}))
// no further calls to popStreamFrame, after popStreamFrame said there's no more data
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(BeNil())
})
It("re-queues a stream at the end, if it has enough data", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2)
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
f11 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
f12 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")}
f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f11, true)
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f12, false)
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false)
framer.AddActiveStream(id1) // only add it once
framer.AddActiveStream(id2)
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f11})) // first a frame from stream 1
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f2})) // then a frame from stream 2
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f12})) // then another frame from stream 1
})
It("only dequeues data from each stream once per packet", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")}
// both streams have more data, and will be re-queued
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, true)
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, true)
framer.AddActiveStream(id1)
framer.AddActiveStream(id2)
Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f1, f2}))
})
It("returns multiple normal frames in the order they were reported active", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil)
f1 := &wire.StreamFrame{Data: []byte("foobar")}
f2 := &wire.StreamFrame{Data: []byte("foobaz")}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, false)
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false)
framer.AddActiveStream(id2)
framer.AddActiveStream(id1)
Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f2, f1}))
})
It("only asks a stream for data once, even if it was reported active multiple times", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
f := &wire.StreamFrame{Data: []byte("foobar")}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) // only one call to this function
framer.AddActiveStream(id1)
framer.AddActiveStream(id1)
Expect(framer.PopStreamFrames(1000)).To(HaveLen(1))
})
It("returns retransmission frames before normal frames", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
framer.AddActiveStream(id1)
f1 := &wire.StreamFrame{Data: []byte("foobar")}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, false)
framer.AddFrameForRetransmission(retransmittedFrame1)
fs := framer.PopStreamFrames(1000)
Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1, f1}))
})
It("does not pop empty frames", func() {
fs := framer.PopStreamFrames(500)
Expect(fs).To(BeEmpty())
})
It("pops frames that have the minimum size", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
stream1.EXPECT().popStreamFrame(protocol.MinStreamFrameSize).Return(&wire.StreamFrame{Data: []byte("foobar")}, false)
framer.AddActiveStream(id1)
framer.PopStreamFrames(protocol.MinStreamFrameSize)
})
It("does not pop frames smaller than the mimimum size", func() {
// don't expect a call to PopStreamFrame()
framer.PopStreamFrames(protocol.MinStreamFrameSize - 1)
})
It("stops iterating when the remaining size is smaller than the minimum STREAM frame size", func() {
streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil)
// pop a frame such that the remaining size is one byte less than the minimum STREAM frame size
f := &wire.StreamFrame{
StreamID: id1,
Data: bytes.Repeat([]byte("f"), int(500-protocol.MinStreamFrameSize)),
}
stream1.EXPECT().popStreamFrame(protocol.ByteCount(500)).Return(f, false)
framer.AddActiveStream(id1)
fs := framer.PopStreamFrames(500)
Expect(fs).To(Equal([]*wire.StreamFrame{f}))
})
Context("splitting of frames", func() {
@ -212,139 +303,28 @@ var _ = Describe("Stream Framer", func() {
})
It("splits a frame", func() {
setNoData(stream1)
setNoData(stream2)
framer.AddFrameForRetransmission(retransmittedFrame2)
origlen := retransmittedFrame2.DataLen()
fs := framer.PopStreamFrames(6)
frame := &wire.StreamFrame{Data: bytes.Repeat([]byte{0}, 600)}
framer.AddFrameForRetransmission(frame)
fs := framer.PopStreamFrames(500)
Expect(fs).To(HaveLen(1))
minLength, _ := fs[0].MinLength(framer.version)
Expect(minLength + fs[0].DataLen()).To(Equal(protocol.ByteCount(6)))
Expect(framer.retransmissionQueue[0].Data).To(HaveLen(int(origlen - fs[0].DataLen())))
minLength := fs[0].MinLength(framer.version)
Expect(minLength + fs[0].DataLen()).To(Equal(protocol.ByteCount(500)))
Expect(framer.retransmissionQueue[0].Data).To(HaveLen(int(600 - fs[0].DataLen())))
Expect(framer.retransmissionQueue[0].Offset).To(Equal(fs[0].DataLen()))
})
It("never returns an empty stream frame", func() {
// this one frame will be split off from again and again in this test. Therefore, it has to be large enough (checked again at the end)
origFrame := &wire.StreamFrame{
StreamID: 5,
Offset: 1,
FinBit: false,
Data: bytes.Repeat([]byte{'f'}, 30*30),
}
framer.AddFrameForRetransmission(origFrame)
minFrameDataLen := protocol.MaxPacketSize
for i := 0; i < 30; i++ {
frames, currentLen := framer.maybePopFramesForRetransmission(protocol.ByteCount(i))
if len(frames) == 0 {
Expect(currentLen).To(BeZero())
} else {
Expect(frames).To(HaveLen(1))
Expect(currentLen).ToNot(BeZero())
dataLen := frames[0].DataLen()
Expect(dataLen).ToNot(BeZero())
if dataLen < minFrameDataLen {
minFrameDataLen = dataLen
}
}
}
Expect(framer.retransmissionQueue).To(HaveLen(1)) // check that origFrame was large enough for this test and didn't get used up completely
Expect(minFrameDataLen).To(Equal(protocol.ByteCount(1)))
})
It("only removes a frame from the framer after returning all split parts", func() {
setNoData(stream1)
setNoData(stream2)
framer.AddFrameForRetransmission(retransmittedFrame2)
fs := framer.PopStreamFrames(6)
frameHeaderLen := protocol.ByteCount(4)
frame := &wire.StreamFrame{Data: bytes.Repeat([]byte{0}, int(501-frameHeaderLen))}
framer.AddFrameForRetransmission(frame)
fs := framer.PopStreamFrames(500)
Expect(fs).To(HaveLen(1))
Expect(framer.retransmissionQueue).ToNot(BeEmpty())
fs = framer.PopStreamFrames(1000)
fs = framer.PopStreamFrames(500)
Expect(fs).To(HaveLen(1))
Expect(fs[0].DataLen()).To(BeEquivalentTo(1))
Expect(framer.retransmissionQueue).To(BeEmpty())
})
})
Context("sending FINs", func() {
It("sends FINs when streams are closed", func() {
offset := protocol.ByteCount(42)
stream1.EXPECT().HasDataForWriting().Return(true)
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return(nil, true)
stream1.EXPECT().GetWriteOffset().Return(offset)
setNoData(stream2)
fs := framer.PopStreamFrames(1000)
Expect(fs).To(HaveLen(1))
Expect(fs[0].StreamID).To(Equal(stream1.StreamID()))
Expect(fs[0].Offset).To(Equal(offset))
Expect(fs[0].FinBit).To(BeTrue())
Expect(fs[0].Data).To(BeEmpty())
})
It("bundles FINs with data", func() {
offset := protocol.ByteCount(42)
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), true)
stream1.EXPECT().HasDataForWriting().Return(true)
stream1.EXPECT().GetWriteOffset().Return(offset)
setNoData(stream2)
fs := framer.PopStreamFrames(1000)
Expect(fs).To(HaveLen(1))
Expect(fs[0].StreamID).To(Equal(stream1.StreamID()))
Expect(fs[0].Data).To(Equal([]byte("foobar")))
Expect(fs[0].FinBit).To(BeTrue())
})
})
})
Context("BLOCKED frames", func() {
It("Pop returns nil if no frame is queued", func() {
Expect(framer.PopBlockedFrame()).To(BeNil())
})
It("queues and pops BLOCKED frames for individually blocked streams", func() {
connFC.EXPECT().IsBlocked()
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false)
stream1.EXPECT().HasDataForWriting().Return(true)
stream1.EXPECT().GetWriteOffset()
stream1.EXPECT().IsFlowControlBlocked().Return(true)
setNoData(stream2)
frames := framer.PopStreamFrames(1000)
Expect(frames).To(HaveLen(1))
f := framer.PopBlockedFrame()
Expect(f).To(BeAssignableToTypeOf(&wire.StreamBlockedFrame{}))
bf := f.(*wire.StreamBlockedFrame)
Expect(bf.StreamID).To(Equal(stream1.StreamID()))
Expect(framer.PopBlockedFrame()).To(BeNil())
})
It("does not queue a stream-level BLOCKED frame after sending the FinBit frame", func() {
connFC.EXPECT().IsBlocked()
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo"), true)
stream1.EXPECT().HasDataForWriting().Return(true)
stream1.EXPECT().GetWriteOffset()
setNoData(stream2)
frames := framer.PopStreamFrames(1000)
Expect(frames).To(HaveLen(1))
Expect(frames[0].FinBit).To(BeTrue())
Expect(frames[0].DataLen()).To(Equal(protocol.ByteCount(3)))
blockedFrame := framer.PopBlockedFrame()
Expect(blockedFrame).To(BeNil())
})
It("queues and pops BLOCKED frames for connection blocked streams", func() {
connFC.EXPECT().IsBlocked().Return(true)
stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo"), false)
stream1.EXPECT().HasDataForWriting().Return(true)
stream1.EXPECT().GetWriteOffset()
stream1.EXPECT().IsFlowControlBlocked().Return(false)
setNoData(stream2)
framer.PopStreamFrames(1000)
f := framer.PopBlockedFrame()
Expect(f).To(BeAssignableToTypeOf(&wire.BlockedFrame{}))
Expect(framer.PopBlockedFrame()).To(BeNil())
})
})
})

File diff suppressed because it is too large Load Diff

View File

@ -5,8 +5,9 @@ import (
"fmt"
"sync"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
@ -16,11 +17,8 @@ type streamsMap struct {
perspective protocol.Perspective
streams map[protocol.StreamID]streamI
// needed for round-robin scheduling
openStreams []protocol.StreamID
roundRobinIndex int
nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
nextStreamToOpen protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
highestStreamOpenedByPeer protocol.StreamID
nextStreamOrErrCond sync.Cond
openStreamOrErrCond sync.Cond
@ -29,47 +27,32 @@ type streamsMap struct {
nextStreamToAccept protocol.StreamID
newStream newStreamLambda
numOutgoingStreams uint32
numIncomingStreams uint32
maxIncomingStreams uint32
maxOutgoingStreams uint32
}
type streamLambda func(streamI) (bool, error)
var _ streamManager = &streamsMap{}
type newStreamLambda func(protocol.StreamID) streamI
var errMapAccess = errors.New("streamsMap: Error accessing the streams map")
func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver protocol.VersionNumber) *streamsMap {
// add some tolerance to the maximum incoming streams value
maxStreams := uint32(protocol.MaxIncomingStreams)
maxIncomingStreams := utils.MaxUint32(
maxStreams+protocol.MaxStreamsMinimumIncrement,
uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)),
)
func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective) streamManager {
sm := streamsMap{
perspective: pers,
streams: make(map[protocol.StreamID]streamI),
openStreams: make([]protocol.StreamID, 0),
newStream: newStream,
maxIncomingStreams: maxIncomingStreams,
perspective: pers,
streams: make(map[protocol.StreamID]streamI),
newStream: newStream,
}
sm.nextStreamOrErrCond.L = &sm.mutex
sm.openStreamOrErrCond.L = &sm.mutex
nextOddStream := protocol.StreamID(1)
if ver.CryptoStreamID() == protocol.StreamID(1) {
nextOddStream = 3
}
if pers == protocol.PerspectiveClient {
sm.nextStream = nextOddStream
sm.nextStreamToAccept = 2
nextClientInitiatedStream := protocol.StreamID(1)
nextServerInitiatedStream := protocol.StreamID(2)
if pers == protocol.PerspectiveServer {
sm.nextStreamToOpen = nextServerInitiatedStream
sm.nextStreamToAccept = nextClientInitiatedStream
} else {
sm.nextStream = 2
sm.nextStreamToAccept = nextOddStream
sm.nextStreamToOpen = nextClientInitiatedStream
sm.nextStreamToAccept = nextServerInitiatedStream
}
return &sm
}
@ -81,6 +64,23 @@ func (m *streamsMap) streamInitiatedBy(id protocol.StreamID) protocol.Perspectiv
return protocol.PerspectiveClient
}
func (m *streamsMap) nextStreamID(id protocol.StreamID) protocol.StreamID {
if m.perspective == protocol.PerspectiveServer && id == 0 {
return 1
}
return id + 2
}
func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
// every bidirectional stream is also a receive stream
return m.GetOrOpenStream(id)
}
func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
// every bidirectional stream is also a send stream
return m.GetOrOpenStream(id)
}
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used.
func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
@ -88,7 +88,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
s, ok := m.streams[id]
m.mutex.RUnlock()
if ok {
return s, nil // s may be nil
return s, nil
}
// ... we don't have an existing stream
@ -101,7 +101,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
}
if m.perspective == m.streamInitiatedBy(id) {
if id <= m.nextStream { // this is a stream opened by us. Must have been closed already
if id <= m.nextStreamToOpen { // this is a stream opened by us. Must have been closed already
return nil, nil
}
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id))
@ -110,14 +110,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
return nil, nil
}
// sid is the next stream that will be opened
sid := m.highestStreamOpenedByPeer + 2
// if there is no stream opened yet, and this is the server, stream 1 should be openend
if sid == 2 && m.perspective == protocol.PerspectiveServer {
sid = 1
}
for ; sid <= id; sid += 2 {
for sid := m.nextStreamID(m.highestStreamOpenedByPeer); sid <= id; sid = m.nextStreamID(sid) {
if _, err := m.openRemoteStream(sid); err != nil {
return nil, err
}
@ -128,38 +121,26 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
}
func (m *streamsMap) openRemoteStream(id protocol.StreamID) (streamI, error) {
if m.numIncomingStreams >= m.maxIncomingStreams {
return nil, qerr.TooManyOpenStreams
}
if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer))
}
m.numIncomingStreams++
if id > m.highestStreamOpenedByPeer {
m.highestStreamOpenedByPeer = id
}
s := m.newStream(id)
m.putStream(s)
return s, nil
}
func (m *streamsMap) openStreamImpl() (streamI, error) {
id := m.nextStream
if m.numOutgoingStreams >= m.maxOutgoingStreams {
return nil, qerr.TooManyOpenStreams
}
m.numOutgoingStreams++
m.nextStream += 2
s := m.newStream(id)
s := m.newStream(m.nextStreamToOpen)
m.putStream(s)
m.nextStreamToOpen = m.nextStreamID(m.nextStreamToOpen)
return s, nil
}
// OpenStream opens the next available stream
func (m *streamsMap) OpenStream() (streamI, error) {
func (m *streamsMap) OpenStream() (Stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@ -169,7 +150,7 @@ func (m *streamsMap) OpenStream() (streamI, error) {
return m.openStreamImpl()
}
func (m *streamsMap) OpenStreamSync() (streamI, error) {
func (m *streamsMap) OpenStreamSync() (Stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@ -190,7 +171,7 @@ func (m *streamsMap) OpenStreamSync() (streamI, error) {
// AcceptStream returns the next stream opened by the peer
// it blocks until a new stream is opened
func (m *streamsMap) AcceptStream() (streamI, error) {
func (m *streamsMap) AcceptStream() (Stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var str streamI
@ -209,104 +190,24 @@ func (m *streamsMap) AcceptStream() (streamI, error) {
return str, nil
}
func (m *streamsMap) DeleteClosedStreams() error {
func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
m.mutex.Lock()
defer m.mutex.Unlock()
var numDeletedStreams int
// for every closed stream, the streamID is replaced by 0 in the openStreams slice
for i, streamID := range m.openStreams {
str, ok := m.streams[streamID]
if !ok {
return errMapAccess
}
if !str.Finished() {
continue
}
numDeletedStreams++
m.openStreams[i] = 0
if m.streamInitiatedBy(streamID) == m.perspective {
m.numOutgoingStreams--
} else {
m.numIncomingStreams--
}
delete(m.streams, streamID)
_, ok := m.streams[id]
if !ok {
return errMapAccess
}
if numDeletedStreams == 0 {
return nil
}
// remove all 0s (representing closed streams) from the openStreams slice
// and adjust the roundRobinIndex
var j int
for i, id := range m.openStreams {
if i != j {
m.openStreams[j] = m.openStreams[i]
}
if id != 0 {
j++
} else if j < m.roundRobinIndex {
m.roundRobinIndex--
}
}
m.openStreams = m.openStreams[:len(m.openStreams)-numDeletedStreams]
delete(m.streams, id)
m.openStreamOrErrCond.Signal()
return nil
}
// RoundRobinIterate executes the streamLambda for every open stream, until the streamLambda returns false
// It uses a round-robin-like scheduling to ensure that every stream is considered fairly
// It prioritizes the the header-stream (StreamID 3)
func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
m.mutex.Lock()
defer m.mutex.Unlock()
numStreams := len(m.streams)
startIndex := m.roundRobinIndex
for i := 0; i < numStreams; i++ {
streamID := m.openStreams[(i+startIndex)%numStreams]
cont, err := m.iterateFunc(streamID, fn)
if err != nil {
return err
}
m.roundRobinIndex = (m.roundRobinIndex + 1) % numStreams
if !cont {
break
}
}
return nil
}
// Range executes a callback for all streams, in pseudo-random order
func (m *streamsMap) Range(cb func(s streamI)) {
m.mutex.RLock()
defer m.mutex.RUnlock()
for _, s := range m.streams {
if s != nil {
cb(s)
}
}
}
func (m *streamsMap) iterateFunc(streamID protocol.StreamID, fn streamLambda) (bool, error) {
str, ok := m.streams[streamID]
if !ok {
return true, errMapAccess
}
return fn(str)
}
func (m *streamsMap) putStream(s streamI) error {
id := s.StreamID()
if _, ok := m.streams[id]; ok {
return fmt.Errorf("a stream with ID %d already exists", id)
}
m.streams[id] = s
m.openStreams = append(m.openStreams, id)
return nil
}
@ -316,14 +217,20 @@ func (m *streamsMap) CloseWithError(err error) {
m.closeErr = err
m.nextStreamOrErrCond.Broadcast()
m.openStreamOrErrCond.Broadcast()
for _, s := range m.openStreams {
m.streams[s].Cancel(err)
for _, s := range m.streams {
s.closeForShutdown(err)
}
}
func (m *streamsMap) UpdateMaxStreamLimit(limit uint32) {
// TODO(#952): this won't be needed when gQUIC supports stateless handshakes
func (m *streamsMap) UpdateLimits(params *handshake.TransportParameters) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.maxOutgoingStreams = limit
for id, str := range m.streams {
str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{
StreamID: id,
ByteOffset: params.StreamFlowControlWindow,
})
}
m.mutex.Unlock()
m.openStreamOrErrCond.Broadcast()
}

View File

@ -0,0 +1,257 @@
package quic
import (
"fmt"
"sync"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
type streamsMapLegacy struct {
mutex sync.RWMutex
perspective protocol.Perspective
streams map[protocol.StreamID]streamI
nextStreamToOpen protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
highestStreamOpenedByPeer protocol.StreamID
nextStreamOrErrCond sync.Cond
openStreamOrErrCond sync.Cond
closeErr error
nextStreamToAccept protocol.StreamID
newStream newStreamLambda
numOutgoingStreams uint32
numIncomingStreams uint32
maxIncomingStreams uint32
maxOutgoingStreams uint32
}
var _ streamManager = &streamsMapLegacy{}
func newStreamsMapLegacy(newStream newStreamLambda, pers protocol.Perspective) streamManager {
// add some tolerance to the maximum incoming streams value
maxStreams := uint32(protocol.MaxIncomingStreams)
maxIncomingStreams := utils.MaxUint32(
maxStreams+protocol.MaxStreamsMinimumIncrement,
uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)),
)
sm := streamsMapLegacy{
perspective: pers,
streams: make(map[protocol.StreamID]streamI),
newStream: newStream,
maxIncomingStreams: maxIncomingStreams,
}
sm.nextStreamOrErrCond.L = &sm.mutex
sm.openStreamOrErrCond.L = &sm.mutex
nextServerInitiatedStream := protocol.StreamID(2)
nextClientInitiatedStream := protocol.StreamID(3)
if pers == protocol.PerspectiveServer {
sm.highestStreamOpenedByPeer = 1
}
if pers == protocol.PerspectiveServer {
sm.nextStreamToOpen = nextServerInitiatedStream
sm.nextStreamToAccept = nextClientInitiatedStream
} else {
sm.nextStreamToOpen = nextClientInitiatedStream
sm.nextStreamToAccept = nextServerInitiatedStream
}
return &sm
}
// getStreamPerspective says which side should initiate a stream
func (m *streamsMapLegacy) streamInitiatedBy(id protocol.StreamID) protocol.Perspective {
if id%2 == 0 {
return protocol.PerspectiveServer
}
return protocol.PerspectiveClient
}
func (m *streamsMapLegacy) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
// every bidirectional stream is also a receive stream
return m.GetOrOpenStream(id)
}
func (m *streamsMapLegacy) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
// every bidirectional stream is also a send stream
return m.GetOrOpenStream(id)
}
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used.
func (m *streamsMapLegacy) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
m.mutex.RLock()
s, ok := m.streams[id]
m.mutex.RUnlock()
if ok {
return s, nil
}
// ... we don't have an existing stream
m.mutex.Lock()
defer m.mutex.Unlock()
// We need to check whether another invocation has already created a stream (between RUnlock() and Lock()).
s, ok = m.streams[id]
if ok {
return s, nil
}
if m.perspective == m.streamInitiatedBy(id) {
if id <= m.nextStreamToOpen { // this is a stream opened by us. Must have been closed already
return nil, nil
}
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id))
}
if id <= m.highestStreamOpenedByPeer { // this is a peer-initiated stream that doesn't exist anymore. Must have been closed already
return nil, nil
}
for sid := m.highestStreamOpenedByPeer + 2; sid <= id; sid += 2 {
if _, err := m.openRemoteStream(sid); err != nil {
return nil, err
}
}
m.nextStreamOrErrCond.Broadcast()
return m.streams[id], nil
}
func (m *streamsMapLegacy) openRemoteStream(id protocol.StreamID) (streamI, error) {
if m.numIncomingStreams >= m.maxIncomingStreams {
return nil, qerr.TooManyOpenStreams
}
if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer))
}
m.numIncomingStreams++
if id > m.highestStreamOpenedByPeer {
m.highestStreamOpenedByPeer = id
}
s := m.newStream(id)
m.putStream(s)
return s, nil
}
func (m *streamsMapLegacy) openStreamImpl() (streamI, error) {
if m.numOutgoingStreams >= m.maxOutgoingStreams {
return nil, qerr.TooManyOpenStreams
}
m.numOutgoingStreams++
s := m.newStream(m.nextStreamToOpen)
m.putStream(s)
m.nextStreamToOpen += 2
return s, nil
}
// OpenStream opens the next available stream
func (m *streamsMapLegacy) OpenStream() (Stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.closeErr != nil {
return nil, m.closeErr
}
return m.openStreamImpl()
}
func (m *streamsMapLegacy) OpenStreamSync() (Stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
for {
if m.closeErr != nil {
return nil, m.closeErr
}
str, err := m.openStreamImpl()
if err == nil {
return str, err
}
if err != nil && err != qerr.TooManyOpenStreams {
return nil, err
}
m.openStreamOrErrCond.Wait()
}
}
// AcceptStream returns the next stream opened by the peer
// it blocks until a new stream is opened
func (m *streamsMapLegacy) AcceptStream() (Stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var str streamI
for {
var ok bool
if m.closeErr != nil {
return nil, m.closeErr
}
str, ok = m.streams[m.nextStreamToAccept]
if ok {
break
}
m.nextStreamOrErrCond.Wait()
}
m.nextStreamToAccept += 2
return str, nil
}
func (m *streamsMapLegacy) DeleteStream(id protocol.StreamID) error {
m.mutex.Lock()
defer m.mutex.Unlock()
_, ok := m.streams[id]
if !ok {
return errMapAccess
}
delete(m.streams, id)
if m.streamInitiatedBy(id) == m.perspective {
m.numOutgoingStreams--
} else {
m.numIncomingStreams--
}
m.openStreamOrErrCond.Signal()
return nil
}
func (m *streamsMapLegacy) putStream(s streamI) error {
id := s.StreamID()
if _, ok := m.streams[id]; ok {
return fmt.Errorf("a stream with ID %d already exists", id)
}
m.streams[id] = s
return nil
}
func (m *streamsMapLegacy) CloseWithError(err error) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.closeErr = err
m.nextStreamOrErrCond.Broadcast()
m.openStreamOrErrCond.Broadcast()
for _, s := range m.streams {
s.closeForShutdown(err)
}
}
// TODO(#952): this won't be needed when gQUIC supports stateless handshakes
func (m *streamsMapLegacy) UpdateLimits(params *handshake.TransportParameters) {
m.mutex.Lock()
m.maxOutgoingStreams = params.MaxStreams
for id, str := range m.streams {
str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{
StreamID: id,
ByteOffset: params.StreamFlowControlWindow,
})
}
m.mutex.Unlock()
m.openStreamOrErrCond.Broadcast()
}

Some files were not shown because too many files have changed in this diff Show More