From b9a78fa3bfb03eb6d8ad6a0bd22dbf22a0b459ce Mon Sep 17 00:00:00 2001 From: Christine Dodrill Date: Sun, 21 Jan 2018 09:07:21 -0800 Subject: [PATCH] cmd/routed: fix header getting --- cmd/routed/common.go | 3 ++- cmd/routed/server.go | 2 +- internal/middleware/headers.go | 35 ++++++++++++++++++++++++++++++++++ internal/middleware/twirp.go | 28 --------------------------- 4 files changed, 38 insertions(+), 30 deletions(-) create mode 100644 internal/middleware/headers.go delete mode 100644 internal/middleware/twirp.go diff --git a/cmd/routed/common.go b/cmd/routed/common.go index 63dd424..30379bc 100644 --- a/cmd/routed/common.go +++ b/cmd/routed/common.go @@ -7,6 +7,7 @@ import ( "time" "git.xeserv.us/xena/route/internal/database" + "git.xeserv.us/xena/route/internal/middleware" "github.com/Xe/ln" "github.com/twitchtv/twirp" "golang.org/x/net/trace" @@ -44,7 +45,7 @@ func (s *Server) makeTwirpHooks() *twirp.ServerHooks { "twirp_service": svc, }) - hdr, ok := twirp.HTTPRequestHeaders(ctx) + hdr, ok := middleware.GetHeaders(ctx) if !ok { return ctx, errors.New("can't get request headers") } diff --git a/cmd/routed/server.go b/cmd/routed/server.go index ff116f9..d440dd5 100644 --- a/cmd/routed/server.go +++ b/cmd/routed/server.go @@ -160,7 +160,7 @@ func New(cfg Config) (*Server, error) { hs := &http.Server{ TLSConfig: tc, Addr: cfg.GRPCAddr, - Handler: middleware.Twirp(middleware.Trace("twirp-https")(mux)), + Handler: middleware.SaveHeaders(middleware.Trace("twirp-https")(mux)), } go hs.ListenAndServeTLS("", "") diff --git a/internal/middleware/headers.go b/internal/middleware/headers.go new file mode 100644 index 0000000..5465835 --- /dev/null +++ b/internal/middleware/headers.go @@ -0,0 +1,35 @@ +package middleware + +import ( + "context" + "net/http" +) + +type headerKey int + +const hdrKey headerKey = iota + +// GetHeaders fetches http headers from the request context. +func GetHeaders(ctx context.Context) (http.Header, bool) { + h, ok := ctx.Value(hdrKey).(http.Header) + if !ok { + return http.Header{}, false + } + + return h, true +} + +// SaveHeaders adds the needed values to the request context for twirp services. +func SaveHeaders(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + if ctx == nil { + panic("context is nil") + } + + ctx = context.WithValue(ctx, hdrKey, r.Header) + + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} diff --git a/internal/middleware/twirp.go b/internal/middleware/twirp.go deleted file mode 100644 index d6a3169..0000000 --- a/internal/middleware/twirp.go +++ /dev/null @@ -1,28 +0,0 @@ -package middleware - -import ( - "context" - "net/http" - - "github.com/Xe/ln" - "github.com/twitchtv/twirp" -) - -// Twirp adds the needed values to the request context for twirp services. -func Twirp(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - if ctx == nil { - panic("context is nil") - } - - ctx, err := twirp.WithHTTPRequestHeaders(ctx, r.Header) - if err != nil { - ln.Error(context.Background(), err, ln.Action("can't get request headers")) - http.Error(w, err.Error(), http.StatusInternalServerError) - } - - next.ServeHTTP(w, r.WithContext(ctx)) - }) -}