202 lines
5.5 KiB
Go
202 lines
5.5 KiB
Go
|
package h2quic
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"fmt"
|
||
|
"net/http"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
|
||
|
"golang.org/x/net/http2"
|
||
|
"golang.org/x/net/http2/hpack"
|
||
|
"golang.org/x/net/lex/httplex"
|
||
|
|
||
|
quic "github.com/lucas-clemente/quic-go"
|
||
|
"github.com/lucas-clemente/quic-go/protocol"
|
||
|
"github.com/lucas-clemente/quic-go/utils"
|
||
|
)
|
||
|
|
||
|
type requestWriter struct {
|
||
|
mutex sync.Mutex
|
||
|
headerStream quic.Stream
|
||
|
|
||
|
henc *hpack.Encoder
|
||
|
hbuf bytes.Buffer // HPACK encoder writes into this
|
||
|
}
|
||
|
|
||
|
const defaultUserAgent = "quic-go"
|
||
|
|
||
|
func newRequestWriter(headerStream quic.Stream) *requestWriter {
|
||
|
rw := &requestWriter{
|
||
|
headerStream: headerStream,
|
||
|
}
|
||
|
rw.henc = hpack.NewEncoder(&rw.hbuf)
|
||
|
return rw
|
||
|
}
|
||
|
|
||
|
func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.StreamID, endStream, requestGzip bool) error {
|
||
|
// TODO: add support for trailers
|
||
|
// TODO: add support for gzip compression
|
||
|
// TODO: write continuation frames, if the header frame is too long
|
||
|
|
||
|
w.mutex.Lock()
|
||
|
defer w.mutex.Unlock()
|
||
|
|
||
|
w.encodeHeaders(req, requestGzip, "", actualContentLength(req))
|
||
|
h2framer := http2.NewFramer(w.headerStream, nil)
|
||
|
return h2framer.WriteHeaders(http2.HeadersFrameParam{
|
||
|
StreamID: uint32(dataStreamID),
|
||
|
EndHeaders: true,
|
||
|
EndStream: endStream,
|
||
|
BlockFragment: w.hbuf.Bytes(),
|
||
|
Priority: http2.PriorityParam{Weight: 0xff},
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// the rest of this files is copied from http2.Transport
|
||
|
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
|
||
|
w.hbuf.Reset()
|
||
|
|
||
|
host := req.Host
|
||
|
if host == "" {
|
||
|
host = req.URL.Host
|
||
|
}
|
||
|
host, err := httplex.PunycodeHostPort(host)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
var path string
|
||
|
if req.Method != "CONNECT" {
|
||
|
path = req.URL.RequestURI()
|
||
|
if !validPseudoPath(path) {
|
||
|
orig := path
|
||
|
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
|
||
|
if !validPseudoPath(path) {
|
||
|
if req.URL.Opaque != "" {
|
||
|
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
|
||
|
} else {
|
||
|
return nil, fmt.Errorf("invalid request :path %q", orig)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Check for any invalid headers and return an error before we
|
||
|
// potentially pollute our hpack state. (We want to be able to
|
||
|
// continue to reuse the hpack encoder for future requests)
|
||
|
for k, vv := range req.Header {
|
||
|
if !httplex.ValidHeaderFieldName(k) {
|
||
|
return nil, fmt.Errorf("invalid HTTP header name %q", k)
|
||
|
}
|
||
|
for _, v := range vv {
|
||
|
if !httplex.ValidHeaderFieldValue(v) {
|
||
|
return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// 8.1.2.3 Request Pseudo-Header Fields
|
||
|
// The :path pseudo-header field includes the path and query parts of the
|
||
|
// target URI (the path-absolute production and optionally a '?' character
|
||
|
// followed by the query production (see Sections 3.3 and 3.4 of
|
||
|
// [RFC3986]).
|
||
|
w.writeHeader(":authority", host)
|
||
|
w.writeHeader(":method", req.Method)
|
||
|
if req.Method != "CONNECT" {
|
||
|
w.writeHeader(":path", path)
|
||
|
w.writeHeader(":scheme", req.URL.Scheme)
|
||
|
}
|
||
|
if trailers != "" {
|
||
|
w.writeHeader("trailer", trailers)
|
||
|
}
|
||
|
|
||
|
var didUA bool
|
||
|
for k, vv := range req.Header {
|
||
|
lowKey := strings.ToLower(k)
|
||
|
switch lowKey {
|
||
|
case "host", "content-length":
|
||
|
// Host is :authority, already sent.
|
||
|
// Content-Length is automatic, set below.
|
||
|
continue
|
||
|
case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive":
|
||
|
// Per 8.1.2.2 Connection-Specific Header
|
||
|
// Fields, don't send connection-specific
|
||
|
// fields. We have already checked if any
|
||
|
// are error-worthy so just ignore the rest.
|
||
|
continue
|
||
|
case "user-agent":
|
||
|
// Match Go's http1 behavior: at most one
|
||
|
// User-Agent. If set to nil or empty string,
|
||
|
// then omit it. Otherwise if not mentioned,
|
||
|
// include the default (below).
|
||
|
didUA = true
|
||
|
if len(vv) < 1 {
|
||
|
continue
|
||
|
}
|
||
|
vv = vv[:1]
|
||
|
if vv[0] == "" {
|
||
|
continue
|
||
|
}
|
||
|
}
|
||
|
for _, v := range vv {
|
||
|
w.writeHeader(lowKey, v)
|
||
|
}
|
||
|
}
|
||
|
if shouldSendReqContentLength(req.Method, contentLength) {
|
||
|
w.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
|
||
|
}
|
||
|
if addGzipHeader {
|
||
|
w.writeHeader("accept-encoding", "gzip")
|
||
|
}
|
||
|
if !didUA {
|
||
|
w.writeHeader("user-agent", defaultUserAgent)
|
||
|
}
|
||
|
return w.hbuf.Bytes(), nil
|
||
|
}
|
||
|
|
||
|
func (w *requestWriter) writeHeader(name, value string) {
|
||
|
utils.Debugf("http2: Transport encoding header %q = %q", name, value)
|
||
|
w.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
|
||
|
}
|
||
|
|
||
|
// shouldSendReqContentLength reports whether the http2.Transport should send
|
||
|
// a "content-length" request header. This logic is basically a copy of the net/http
|
||
|
// transferWriter.shouldSendContentLength.
|
||
|
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
|
||
|
// -1 means unknown.
|
||
|
func shouldSendReqContentLength(method string, contentLength int64) bool {
|
||
|
if contentLength > 0 {
|
||
|
return true
|
||
|
}
|
||
|
if contentLength < 0 {
|
||
|
return false
|
||
|
}
|
||
|
// For zero bodies, whether we send a content-length depends on the method.
|
||
|
// It also kinda doesn't matter for http2 either way, with END_STREAM.
|
||
|
switch method {
|
||
|
case "POST", "PUT", "PATCH":
|
||
|
return true
|
||
|
default:
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func validPseudoPath(v string) bool {
|
||
|
return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*"
|
||
|
}
|
||
|
|
||
|
// actualContentLength returns a sanitized version of
|
||
|
// req.ContentLength, where 0 actually means zero (not unknown) and -1
|
||
|
// means unknown.
|
||
|
func actualContentLength(req *http.Request) int64 {
|
||
|
if req.Body == nil {
|
||
|
return 0
|
||
|
}
|
||
|
if req.ContentLength != 0 {
|
||
|
return req.ContentLength
|
||
|
}
|
||
|
return -1
|
||
|
}
|