route/vendor/github.com/lucas-clemente/quic-go/h2quic/client_test.go

507 lines
18 KiB
Go

package h2quic
import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"errors"
"io"
"net/http"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Client", func() {
var (
client *client
session *mockSession
headerStream *mockStream
req *http.Request
origDialAddr = dialAddr
)
BeforeEach(func() {
origDialAddr = dialAddr
hostname := "quic.clemente.io:1337"
client = newClient(hostname, nil, &roundTripperOpts{}, nil)
Expect(client.hostname).To(Equal(hostname))
session = &mockSession{}
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
client.session = session
headerStream = newMockStream(3)
client.headerStream = headerStream
client.requestWriter = newRequestWriter(headerStream)
var err error
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
dialAddr = origDialAddr
})
It("saves the TLS config", func() {
tlsConf := &tls.Config{InsecureSkipVerify: true}
client = newClient("", tlsConf, &roundTripperOpts{}, nil)
Expect(client.tlsConf).To(Equal(tlsConf))
})
It("saves the QUIC config", func() {
quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond}
client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf)
Expect(client.config).To(Equal(quicConf))
})
It("uses the default QUIC config if none is give", func() {
client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil)
Expect(client.config).ToNot(BeNil())
Expect(client.config).To(Equal(defaultQuicConfig))
})
It("adds the port to the hostname, if none is given", func() {
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil)
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
})
It("dials", func(done Done) {
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
close(headerStream.unblockRead)
go client.RoundTrip(req)
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
close(done)
}, 2)
It("errors when dialing fails", func() {
testErr := errors.New("handshake error")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return nil, testErr
}
_, err := client.RoundTrip(req)
Expect(err).To(MatchError(testErr))
})
It("errors if it can't open a stream", func() {
testErr := errors.New("you shall not pass")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
session.streamOpenErr = testErr
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
_, err := client.RoundTrip(req)
Expect(err).To(MatchError(testErr))
})
It("returns a request when dial fails", func() {
testErr := errors.New("dial error")
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return nil, testErr
}
request, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
var doErr error
go func() {
_, doErr = client.RoundTrip(request)
}()
_, err = client.RoundTrip(request)
Expect(err).To(MatchError(testErr))
Eventually(func() error { return doErr }).Should(MatchError(testErr))
})
Context("Doing requests", func() {
var request *http.Request
var dataStream *mockStream
getRequest := func(data []byte) *http2.MetaHeadersFrame {
r := bytes.NewReader(data)
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, r)
frame, err := h2framer.ReadFrame()
Expect(err).ToNot(HaveOccurred())
mhframe := &http2.MetaHeadersFrame{HeadersFrame: frame.(*http2.HeadersFrame)}
mhframe.Fields, err = decoder.DecodeFull(mhframe.HeadersFrame.HeaderBlockFragment())
Expect(err).ToNot(HaveOccurred())
return mhframe
}
getHeaderFields := func(f *http2.MetaHeadersFrame) map[string]string {
fields := make(map[string]string)
for _, hf := range f.Fields {
fields[hf.Name] = hf.Value
}
return fields
}
BeforeEach(func() {
var err error
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
dataStream = newMockStream(5)
session.streamsToOpen = []quic.Stream{headerStream, dataStream}
request, err = http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
})
It("does a request", func(done Done) {
var doRsp *http.Response
var doErr error
var doReturned bool
go func() {
doRsp, doErr = client.RoundTrip(request)
doReturned = true
}()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
Eventually(func() map[protocol.StreamID]chan *http.Response { return client.responses }).Should(HaveKey(protocol.StreamID(5)))
rsp := &http.Response{
Status: "418 I'm a teapot",
StatusCode: 418,
}
Expect(client.responses[5]).ToNot(BeClosed())
Expect(client.headerErrored).ToNot(BeClosed())
client.responses[5] <- rsp
Eventually(func() bool { return doReturned }).Should(BeTrue())
Expect(doErr).ToNot(HaveOccurred())
Expect(doRsp).To(Equal(rsp))
Expect(doRsp.Body).To(Equal(dataStream))
Expect(doRsp.ContentLength).To(BeEquivalentTo(-1))
Expect(doRsp.Request).To(Equal(request))
close(done)
})
It("closes the quic client when encountering an error on the header stream", func() {
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(client.headerErr))
Expect(rsp).To(BeNil())
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr))
})
It("returns subsequent request if there was an error on the header stream before", func() {
session.streamsToOpen = []quic.Stream{headerStream, dataStream, newMockStream(7)}
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
_, err := client.RoundTrip(request)
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
// now that the first request failed due to an error on the header stream, try another request
_, nextErr := client.RoundTrip(request)
Expect(nextErr).To(MatchError(err))
})
It("blocks if no stream is available", func() {
session.streamsToOpen = []quic.Stream{headerStream}
session.blockOpenStreamSync = true
var doReturned bool
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
doReturned = true
}()
go client.handleHeaderStream()
Consistently(func() bool { return doReturned }).Should(BeFalse())
})
Context("validating the address", func() {
It("refuses to do requests for the wrong host", func() {
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTrip(req)
Expect(err).To(MatchError("h2quic Client BUG: RoundTrip called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
})
It("refuses to do plain HTTP requests", func() {
req, err := http.NewRequest("https", "http://quic.clemente.io:1337/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTrip(req)
Expect(err).To(MatchError("quic http2: unsupported scheme"))
})
It("adds the port for request URLs without one", func(done Done) {
var err error
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil)
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
var doErr error
var doReturned bool
// the client.RoundTrip will block, because the encryption level is still set to Unencrypted
go func() {
_, doErr = client.RoundTrip(req)
doReturned = true
}()
Consistently(doReturned).Should(BeFalse())
Expect(doErr).ToNot(HaveOccurred())
close(done)
})
})
It("sets the EndStream header for requests without a body", func() {
go func() { client.RoundTrip(request) }()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil())
mhf := getRequest(headerStream.dataWritten.Bytes())
Expect(mhf.HeadersFrame.StreamEnded()).To(BeTrue())
})
It("sets the EndStream header to false for requests with a body", func() {
request.Body = &mockBody{}
go func() { client.RoundTrip(request) }()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil())
mhf := getRequest(headerStream.dataWritten.Bytes())
Expect(mhf.HeadersFrame.StreamEnded()).To(BeFalse())
})
Context("requests containing a Body", func() {
var requestBody []byte
var response *http.Response
BeforeEach(func() {
requestBody = []byte("request body")
body := &mockBody{}
body.SetData(requestBody)
request.Body = body
response = &http.Response{
StatusCode: 200,
Header: http.Header{"Content-Length": []string{"1000"}},
}
// fake a handshake
client.dialOnce.Do(func() {})
session.streamsToOpen = []quic.Stream{dataStream}
})
It("sends a request", func() {
var doRsp *http.Response
var doErr error
var doReturned bool
go func() {
defer GinkgoRecover()
doRsp, doErr = client.RoundTrip(request)
Expect(doErr).ToNot(HaveOccurred())
doReturned = true
}()
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
client.responses[5] <- response
Eventually(func() bool { return doReturned }).Should(BeTrue())
Expect(dataStream.dataWritten.Bytes()).To(Equal(requestBody))
Expect(dataStream.closed).To(BeTrue())
Expect(request.Body.(*mockBody).closed).To(BeTrue())
Expect(doRsp).To(Equal(response))
})
It("returns the error that occurred when reading the body", func() {
testErr := errors.New("testErr")
request.Body.(*mockBody).readErr = testErr
var doRsp *http.Response
var doErr error
var doReturned bool
go func() {
doRsp, doErr = client.RoundTrip(request)
doReturned = true
}()
Eventually(func() bool { return doReturned }).Should(BeTrue())
Expect(doErr).To(MatchError(testErr))
Expect(doRsp).To(BeNil())
Expect(request.Body.(*mockBody).closed).To(BeTrue())
})
It("returns the error that occurred when closing the body", func() {
testErr := errors.New("testErr")
request.Body.(*mockBody).closeErr = testErr
var doRsp *http.Response
var doErr error
var doReturned bool
go func() {
doRsp, doErr = client.RoundTrip(request)
doReturned = true
}()
Eventually(func() bool { return doReturned }).Should(BeTrue())
Expect(doErr).To(MatchError(testErr))
Expect(doRsp).To(BeNil())
Expect(request.Body.(*mockBody).closed).To(BeTrue())
})
})
Context("gzip compression", func() {
var gzippedData []byte // a gzipped foobar
var response *http.Response
BeforeEach(func() {
var b bytes.Buffer
w := gzip.NewWriter(&b)
w.Write([]byte("foobar"))
w.Close()
gzippedData = b.Bytes()
response = &http.Response{
StatusCode: 200,
Header: http.Header{"Content-Length": []string{"1000"}},
}
})
It("adds the gzip header to requests", func(done Done) {
var doRsp *http.Response
var doErr error
go func() { doRsp, doErr = client.RoundTrip(request) }()
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
dataStream.dataToRead.Write(gzippedData)
response.Header.Add("Content-Encoding", "gzip")
client.responses[5] <- response
Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil())
Expect(doErr).ToNot(HaveOccurred())
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
Expect(doRsp.ContentLength).To(BeEquivalentTo(-1))
Expect(doRsp.Header.Get("Content-Encoding")).To(BeEmpty())
Expect(doRsp.Header.Get("Content-Length")).To(BeEmpty())
close(dataStream.unblockRead)
data := make([]byte, 6)
_, err := io.ReadFull(doRsp.Body, data)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal([]byte("foobar")))
close(done)
}, 2)
It("doesn't add gzip if the header disable it", func() {
client.opts.DisableCompression = true
var doErr error
go func() { _, doErr = client.RoundTrip(request) }()
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
Expect(doErr).ToNot(HaveOccurred())
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
Expect(headers).ToNot(HaveKey("accept-encoding"))
})
It("only decompresses the response if the response contains the right content-encoding header", func() {
var doRsp *http.Response
var doErr error
go func() { doRsp, doErr = client.RoundTrip(request) }()
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
dataStream.dataToRead.Write([]byte("not gzipped"))
client.responses[5] <- response
Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil())
Expect(doErr).ToNot(HaveOccurred())
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
data := make([]byte, 11)
doRsp.Body.Read(data)
Expect(doRsp.ContentLength).ToNot(BeEquivalentTo(-1))
Expect(data).To(Equal([]byte("not gzipped")))
})
It("doesn't add the gzip header for requests that have the accept-enconding set", func() {
request.Header.Add("accept-encoding", "gzip")
var doRsp *http.Response
var doErr error
go func() { doRsp, doErr = client.RoundTrip(request) }()
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
dataStream.dataToRead.Write([]byte("gzipped data"))
client.responses[5] <- response
Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil())
Expect(doErr).ToNot(HaveOccurred())
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
data := make([]byte, 12)
doRsp.Body.Read(data)
Expect(doRsp.ContentLength).ToNot(BeEquivalentTo(-1))
Expect(data).To(Equal([]byte("gzipped data")))
})
})
Context("handling the header stream", func() {
var h2framer *http2.Framer
BeforeEach(func() {
h2framer = http2.NewFramer(&headerStream.dataToRead, nil)
client.responses[23] = make(chan *http.Response)
})
It("reads header values from a response", func() {
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
data := []byte{0x48, 0x03, 0x33, 0x30, 0x32, 0x58, 0x07, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x61, 0x1d, 0x4d, 0x6f, 0x6e, 0x2c, 0x20, 0x32, 0x31, 0x20, 0x4f, 0x63, 0x74, 0x20, 0x32, 0x30, 0x31, 0x33, 0x20, 0x32, 0x30, 0x3a, 0x31, 0x33, 0x3a, 0x32, 0x31, 0x20, 0x47, 0x4d, 0x54, 0x6e, 0x17, 0x68, 0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d}
headerStream.dataToRead.Write([]byte{0x0, 0x0, byte(len(data)), 0x1, 0x5, 0x0, 0x0, 0x0, 23})
headerStream.dataToRead.Write(data)
go client.handleHeaderStream()
var rsp *http.Response
Eventually(client.responses[23]).Should(Receive(&rsp))
Expect(rsp).ToNot(BeNil())
Expect(rsp.Proto).To(Equal("HTTP/2.0"))
Expect(rsp.ProtoMajor).To(BeEquivalentTo(2))
Expect(rsp.StatusCode).To(BeEquivalentTo(302))
Expect(rsp.Status).To(Equal("302 Found"))
Expect(rsp.Header).To(HaveKeyWithValue("Location", []string{"https://www.example.com"}))
Expect(rsp.Header).To(HaveKeyWithValue("Cache-Control", []string{"private"}))
})
It("errors if the H2 frame is not a HeadersFrame", func() {
h2framer.WritePing(true, [8]byte{0, 0, 0, 0, 0, 0, 0, 0})
client.handleHeaderStream()
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")))
})
It("errors if it can't read the HPACK encoded header fields", func() {
h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: 23,
EndHeaders: true,
BlockFragment: []byte("invalid HPACK data"),
})
client.handleHeaderStream()
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("cannot read header fields"))
})
It("errors if the stream cannot be found", func() {
var headers bytes.Buffer
enc := hpack.NewEncoder(&headers)
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
err := h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: 1337,
EndHeaders: true,
BlockFragment: headers.Bytes(),
})
Expect(err).ToNot(HaveOccurred())
client.handleHeaderStream()
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("response channel for stream 1337 not found"))
})
})
})
})