507 lines
18 KiB
Go
507 lines
18 KiB
Go
package h2quic
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/gzip"
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
|
|
"golang.org/x/net/http2"
|
|
"golang.org/x/net/http2/hpack"
|
|
|
|
quic "github.com/lucas-clemente/quic-go"
|
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
"github.com/lucas-clemente/quic-go/qerr"
|
|
|
|
"time"
|
|
|
|
. "github.com/onsi/ginkgo"
|
|
. "github.com/onsi/gomega"
|
|
)
|
|
|
|
var _ = Describe("Client", func() {
|
|
var (
|
|
client *client
|
|
session *mockSession
|
|
headerStream *mockStream
|
|
req *http.Request
|
|
origDialAddr = dialAddr
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
origDialAddr = dialAddr
|
|
hostname := "quic.clemente.io:1337"
|
|
client = newClient(hostname, nil, &roundTripperOpts{}, nil)
|
|
Expect(client.hostname).To(Equal(hostname))
|
|
session = &mockSession{}
|
|
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
|
|
client.session = session
|
|
|
|
headerStream = newMockStream(3)
|
|
client.headerStream = headerStream
|
|
client.requestWriter = newRequestWriter(headerStream)
|
|
var err error
|
|
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
})
|
|
|
|
AfterEach(func() {
|
|
dialAddr = origDialAddr
|
|
})
|
|
|
|
It("saves the TLS config", func() {
|
|
tlsConf := &tls.Config{InsecureSkipVerify: true}
|
|
client = newClient("", tlsConf, &roundTripperOpts{}, nil)
|
|
Expect(client.tlsConf).To(Equal(tlsConf))
|
|
})
|
|
|
|
It("saves the QUIC config", func() {
|
|
quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond}
|
|
client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf)
|
|
Expect(client.config).To(Equal(quicConf))
|
|
})
|
|
|
|
It("uses the default QUIC config if none is give", func() {
|
|
client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil)
|
|
Expect(client.config).ToNot(BeNil())
|
|
Expect(client.config).To(Equal(defaultQuicConfig))
|
|
})
|
|
|
|
It("adds the port to the hostname, if none is given", func() {
|
|
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil)
|
|
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
|
|
})
|
|
|
|
It("dials", func(done Done) {
|
|
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
|
|
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
|
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
|
return session, nil
|
|
}
|
|
close(headerStream.unblockRead)
|
|
go client.RoundTrip(req)
|
|
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
|
|
close(done)
|
|
}, 2)
|
|
|
|
It("errors when dialing fails", func() {
|
|
testErr := errors.New("handshake error")
|
|
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
|
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
|
return nil, testErr
|
|
}
|
|
_, err := client.RoundTrip(req)
|
|
Expect(err).To(MatchError(testErr))
|
|
})
|
|
|
|
It("errors if it can't open a stream", func() {
|
|
testErr := errors.New("you shall not pass")
|
|
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
|
|
session.streamOpenErr = testErr
|
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
|
return session, nil
|
|
}
|
|
_, err := client.RoundTrip(req)
|
|
Expect(err).To(MatchError(testErr))
|
|
})
|
|
|
|
It("returns a request when dial fails", func() {
|
|
testErr := errors.New("dial error")
|
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
|
return nil, testErr
|
|
}
|
|
request, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
var doErr error
|
|
go func() {
|
|
_, doErr = client.RoundTrip(request)
|
|
}()
|
|
_, err = client.RoundTrip(request)
|
|
Expect(err).To(MatchError(testErr))
|
|
Eventually(func() error { return doErr }).Should(MatchError(testErr))
|
|
})
|
|
|
|
Context("Doing requests", func() {
|
|
var request *http.Request
|
|
var dataStream *mockStream
|
|
|
|
getRequest := func(data []byte) *http2.MetaHeadersFrame {
|
|
r := bytes.NewReader(data)
|
|
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
|
h2framer := http2.NewFramer(nil, r)
|
|
frame, err := h2framer.ReadFrame()
|
|
Expect(err).ToNot(HaveOccurred())
|
|
mhframe := &http2.MetaHeadersFrame{HeadersFrame: frame.(*http2.HeadersFrame)}
|
|
mhframe.Fields, err = decoder.DecodeFull(mhframe.HeadersFrame.HeaderBlockFragment())
|
|
Expect(err).ToNot(HaveOccurred())
|
|
return mhframe
|
|
}
|
|
|
|
getHeaderFields := func(f *http2.MetaHeadersFrame) map[string]string {
|
|
fields := make(map[string]string)
|
|
for _, hf := range f.Fields {
|
|
fields[hf.Name] = hf.Value
|
|
}
|
|
return fields
|
|
}
|
|
|
|
BeforeEach(func() {
|
|
var err error
|
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
|
return session, nil
|
|
}
|
|
dataStream = newMockStream(5)
|
|
session.streamsToOpen = []quic.Stream{headerStream, dataStream}
|
|
request, err = http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
})
|
|
|
|
It("does a request", func(done Done) {
|
|
var doRsp *http.Response
|
|
var doErr error
|
|
var doReturned bool
|
|
go func() {
|
|
doRsp, doErr = client.RoundTrip(request)
|
|
doReturned = true
|
|
}()
|
|
|
|
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
|
|
Eventually(func() map[protocol.StreamID]chan *http.Response { return client.responses }).Should(HaveKey(protocol.StreamID(5)))
|
|
rsp := &http.Response{
|
|
Status: "418 I'm a teapot",
|
|
StatusCode: 418,
|
|
}
|
|
Expect(client.responses[5]).ToNot(BeClosed())
|
|
Expect(client.headerErrored).ToNot(BeClosed())
|
|
client.responses[5] <- rsp
|
|
Eventually(func() bool { return doReturned }).Should(BeTrue())
|
|
Expect(doErr).ToNot(HaveOccurred())
|
|
Expect(doRsp).To(Equal(rsp))
|
|
Expect(doRsp.Body).To(Equal(dataStream))
|
|
Expect(doRsp.ContentLength).To(BeEquivalentTo(-1))
|
|
Expect(doRsp.Request).To(Equal(request))
|
|
|
|
close(done)
|
|
})
|
|
|
|
It("closes the quic client when encountering an error on the header stream", func() {
|
|
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
rsp, err := client.RoundTrip(request)
|
|
Expect(err).To(MatchError(client.headerErr))
|
|
Expect(rsp).To(BeNil())
|
|
close(done)
|
|
}()
|
|
|
|
Eventually(done).Should(BeClosed())
|
|
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
|
|
Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr))
|
|
})
|
|
|
|
It("returns subsequent request if there was an error on the header stream before", func() {
|
|
session.streamsToOpen = []quic.Stream{headerStream, dataStream, newMockStream(7)}
|
|
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
|
|
_, err := client.RoundTrip(request)
|
|
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
|
|
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
|
|
// now that the first request failed due to an error on the header stream, try another request
|
|
_, nextErr := client.RoundTrip(request)
|
|
Expect(nextErr).To(MatchError(err))
|
|
})
|
|
|
|
It("blocks if no stream is available", func() {
|
|
session.streamsToOpen = []quic.Stream{headerStream}
|
|
session.blockOpenStreamSync = true
|
|
var doReturned bool
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
_, err := client.RoundTrip(request)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
doReturned = true
|
|
}()
|
|
go client.handleHeaderStream()
|
|
|
|
Consistently(func() bool { return doReturned }).Should(BeFalse())
|
|
})
|
|
|
|
Context("validating the address", func() {
|
|
It("refuses to do requests for the wrong host", func() {
|
|
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
_, err = client.RoundTrip(req)
|
|
Expect(err).To(MatchError("h2quic Client BUG: RoundTrip called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
|
|
})
|
|
|
|
It("refuses to do plain HTTP requests", func() {
|
|
req, err := http.NewRequest("https", "http://quic.clemente.io:1337/foobar.html", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
_, err = client.RoundTrip(req)
|
|
Expect(err).To(MatchError("quic http2: unsupported scheme"))
|
|
})
|
|
|
|
It("adds the port for request URLs without one", func(done Done) {
|
|
var err error
|
|
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil)
|
|
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
var doErr error
|
|
var doReturned bool
|
|
// the client.RoundTrip will block, because the encryption level is still set to Unencrypted
|
|
go func() {
|
|
_, doErr = client.RoundTrip(req)
|
|
doReturned = true
|
|
}()
|
|
|
|
Consistently(doReturned).Should(BeFalse())
|
|
Expect(doErr).ToNot(HaveOccurred())
|
|
close(done)
|
|
})
|
|
})
|
|
|
|
It("sets the EndStream header for requests without a body", func() {
|
|
go func() { client.RoundTrip(request) }()
|
|
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil())
|
|
mhf := getRequest(headerStream.dataWritten.Bytes())
|
|
Expect(mhf.HeadersFrame.StreamEnded()).To(BeTrue())
|
|
})
|
|
|
|
It("sets the EndStream header to false for requests with a body", func() {
|
|
request.Body = &mockBody{}
|
|
go func() { client.RoundTrip(request) }()
|
|
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil())
|
|
mhf := getRequest(headerStream.dataWritten.Bytes())
|
|
Expect(mhf.HeadersFrame.StreamEnded()).To(BeFalse())
|
|
})
|
|
|
|
Context("requests containing a Body", func() {
|
|
var requestBody []byte
|
|
var response *http.Response
|
|
|
|
BeforeEach(func() {
|
|
requestBody = []byte("request body")
|
|
body := &mockBody{}
|
|
body.SetData(requestBody)
|
|
request.Body = body
|
|
response = &http.Response{
|
|
StatusCode: 200,
|
|
Header: http.Header{"Content-Length": []string{"1000"}},
|
|
}
|
|
// fake a handshake
|
|
client.dialOnce.Do(func() {})
|
|
session.streamsToOpen = []quic.Stream{dataStream}
|
|
})
|
|
|
|
It("sends a request", func() {
|
|
var doRsp *http.Response
|
|
var doErr error
|
|
var doReturned bool
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
doRsp, doErr = client.RoundTrip(request)
|
|
Expect(doErr).ToNot(HaveOccurred())
|
|
doReturned = true
|
|
}()
|
|
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
|
|
client.responses[5] <- response
|
|
Eventually(func() bool { return doReturned }).Should(BeTrue())
|
|
Expect(dataStream.dataWritten.Bytes()).To(Equal(requestBody))
|
|
Expect(dataStream.closed).To(BeTrue())
|
|
Expect(request.Body.(*mockBody).closed).To(BeTrue())
|
|
Expect(doRsp).To(Equal(response))
|
|
})
|
|
|
|
It("returns the error that occurred when reading the body", func() {
|
|
testErr := errors.New("testErr")
|
|
request.Body.(*mockBody).readErr = testErr
|
|
|
|
var doRsp *http.Response
|
|
var doErr error
|
|
var doReturned bool
|
|
go func() {
|
|
doRsp, doErr = client.RoundTrip(request)
|
|
doReturned = true
|
|
}()
|
|
Eventually(func() bool { return doReturned }).Should(BeTrue())
|
|
Expect(doErr).To(MatchError(testErr))
|
|
Expect(doRsp).To(BeNil())
|
|
Expect(request.Body.(*mockBody).closed).To(BeTrue())
|
|
})
|
|
|
|
It("returns the error that occurred when closing the body", func() {
|
|
testErr := errors.New("testErr")
|
|
request.Body.(*mockBody).closeErr = testErr
|
|
|
|
var doRsp *http.Response
|
|
var doErr error
|
|
var doReturned bool
|
|
go func() {
|
|
doRsp, doErr = client.RoundTrip(request)
|
|
doReturned = true
|
|
}()
|
|
Eventually(func() bool { return doReturned }).Should(BeTrue())
|
|
Expect(doErr).To(MatchError(testErr))
|
|
Expect(doRsp).To(BeNil())
|
|
Expect(request.Body.(*mockBody).closed).To(BeTrue())
|
|
})
|
|
})
|
|
|
|
Context("gzip compression", func() {
|
|
var gzippedData []byte // a gzipped foobar
|
|
var response *http.Response
|
|
|
|
BeforeEach(func() {
|
|
var b bytes.Buffer
|
|
w := gzip.NewWriter(&b)
|
|
w.Write([]byte("foobar"))
|
|
w.Close()
|
|
gzippedData = b.Bytes()
|
|
response = &http.Response{
|
|
StatusCode: 200,
|
|
Header: http.Header{"Content-Length": []string{"1000"}},
|
|
}
|
|
})
|
|
|
|
It("adds the gzip header to requests", func(done Done) {
|
|
var doRsp *http.Response
|
|
var doErr error
|
|
go func() { doRsp, doErr = client.RoundTrip(request) }()
|
|
|
|
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
|
|
dataStream.dataToRead.Write(gzippedData)
|
|
response.Header.Add("Content-Encoding", "gzip")
|
|
client.responses[5] <- response
|
|
Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil())
|
|
Expect(doErr).ToNot(HaveOccurred())
|
|
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
|
|
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
|
Expect(doRsp.ContentLength).To(BeEquivalentTo(-1))
|
|
Expect(doRsp.Header.Get("Content-Encoding")).To(BeEmpty())
|
|
Expect(doRsp.Header.Get("Content-Length")).To(BeEmpty())
|
|
close(dataStream.unblockRead)
|
|
data := make([]byte, 6)
|
|
_, err := io.ReadFull(doRsp.Body, data)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(data).To(Equal([]byte("foobar")))
|
|
close(done)
|
|
}, 2)
|
|
|
|
It("doesn't add gzip if the header disable it", func() {
|
|
client.opts.DisableCompression = true
|
|
var doErr error
|
|
go func() { _, doErr = client.RoundTrip(request) }()
|
|
|
|
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
|
|
Expect(doErr).ToNot(HaveOccurred())
|
|
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
|
|
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
|
|
Expect(headers).ToNot(HaveKey("accept-encoding"))
|
|
})
|
|
|
|
It("only decompresses the response if the response contains the right content-encoding header", func() {
|
|
var doRsp *http.Response
|
|
var doErr error
|
|
go func() { doRsp, doErr = client.RoundTrip(request) }()
|
|
|
|
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
|
|
dataStream.dataToRead.Write([]byte("not gzipped"))
|
|
client.responses[5] <- response
|
|
Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil())
|
|
Expect(doErr).ToNot(HaveOccurred())
|
|
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
|
|
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
|
data := make([]byte, 11)
|
|
doRsp.Body.Read(data)
|
|
Expect(doRsp.ContentLength).ToNot(BeEquivalentTo(-1))
|
|
Expect(data).To(Equal([]byte("not gzipped")))
|
|
})
|
|
|
|
It("doesn't add the gzip header for requests that have the accept-enconding set", func() {
|
|
request.Header.Add("accept-encoding", "gzip")
|
|
var doRsp *http.Response
|
|
var doErr error
|
|
go func() { doRsp, doErr = client.RoundTrip(request) }()
|
|
|
|
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
|
|
dataStream.dataToRead.Write([]byte("gzipped data"))
|
|
client.responses[5] <- response
|
|
Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil())
|
|
Expect(doErr).ToNot(HaveOccurred())
|
|
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
|
|
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
|
data := make([]byte, 12)
|
|
doRsp.Body.Read(data)
|
|
Expect(doRsp.ContentLength).ToNot(BeEquivalentTo(-1))
|
|
Expect(data).To(Equal([]byte("gzipped data")))
|
|
})
|
|
})
|
|
|
|
Context("handling the header stream", func() {
|
|
var h2framer *http2.Framer
|
|
|
|
BeforeEach(func() {
|
|
h2framer = http2.NewFramer(&headerStream.dataToRead, nil)
|
|
client.responses[23] = make(chan *http.Response)
|
|
})
|
|
|
|
It("reads header values from a response", func() {
|
|
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
data := []byte{0x48, 0x03, 0x33, 0x30, 0x32, 0x58, 0x07, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x61, 0x1d, 0x4d, 0x6f, 0x6e, 0x2c, 0x20, 0x32, 0x31, 0x20, 0x4f, 0x63, 0x74, 0x20, 0x32, 0x30, 0x31, 0x33, 0x20, 0x32, 0x30, 0x3a, 0x31, 0x33, 0x3a, 0x32, 0x31, 0x20, 0x47, 0x4d, 0x54, 0x6e, 0x17, 0x68, 0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d}
|
|
headerStream.dataToRead.Write([]byte{0x0, 0x0, byte(len(data)), 0x1, 0x5, 0x0, 0x0, 0x0, 23})
|
|
headerStream.dataToRead.Write(data)
|
|
go client.handleHeaderStream()
|
|
var rsp *http.Response
|
|
Eventually(client.responses[23]).Should(Receive(&rsp))
|
|
Expect(rsp).ToNot(BeNil())
|
|
Expect(rsp.Proto).To(Equal("HTTP/2.0"))
|
|
Expect(rsp.ProtoMajor).To(BeEquivalentTo(2))
|
|
Expect(rsp.StatusCode).To(BeEquivalentTo(302))
|
|
Expect(rsp.Status).To(Equal("302 Found"))
|
|
Expect(rsp.Header).To(HaveKeyWithValue("Location", []string{"https://www.example.com"}))
|
|
Expect(rsp.Header).To(HaveKeyWithValue("Cache-Control", []string{"private"}))
|
|
})
|
|
|
|
It("errors if the H2 frame is not a HeadersFrame", func() {
|
|
h2framer.WritePing(true, [8]byte{0, 0, 0, 0, 0, 0, 0, 0})
|
|
client.handleHeaderStream()
|
|
Eventually(client.headerErrored).Should(BeClosed())
|
|
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")))
|
|
})
|
|
|
|
It("errors if it can't read the HPACK encoded header fields", func() {
|
|
h2framer.WriteHeaders(http2.HeadersFrameParam{
|
|
StreamID: 23,
|
|
EndHeaders: true,
|
|
BlockFragment: []byte("invalid HPACK data"),
|
|
})
|
|
client.handleHeaderStream()
|
|
Eventually(client.headerErrored).Should(BeClosed())
|
|
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
|
|
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("cannot read header fields"))
|
|
})
|
|
|
|
It("errors if the stream cannot be found", func() {
|
|
var headers bytes.Buffer
|
|
enc := hpack.NewEncoder(&headers)
|
|
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
|
|
err := h2framer.WriteHeaders(http2.HeadersFrameParam{
|
|
StreamID: 1337,
|
|
EndHeaders: true,
|
|
BlockFragment: headers.Bytes(),
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
client.handleHeaderStream()
|
|
Eventually(client.headerErrored).Should(BeClosed())
|
|
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
|
|
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("response channel for stream 1337 not found"))
|
|
})
|
|
})
|
|
})
|
|
})
|