identicond/vendor/gopkg.in/chi.v3/mux_test.go

1663 lines
45 KiB
Go

package chi
import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"os"
"sync"
"testing"
"time"
)
func TestMuxBasic(t *testing.T) {
var count uint64
countermw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count++
next.ServeHTTP(w, r)
})
}
usermw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = context.WithValue(ctx, ctxKey{"user"}, "peter")
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
exmw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), ctxKey{"ex"}, "a")
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
logbuf := bytes.NewBufferString("")
logmsg := "logmw test"
logmw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logbuf.WriteString(logmsg)
next.ServeHTTP(w, r)
})
}
cxindex := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
user := ctx.Value(ctxKey{"user"}).(string)
w.WriteHeader(200)
w.Write([]byte(fmt.Sprintf("hi %s", user)))
}
ping := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("."))
}
headPing := func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Ping", "1")
w.WriteHeader(200)
}
createPing := func(w http.ResponseWriter, r *http.Request) {
// create ....
w.WriteHeader(201)
}
pingAll := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("ping all"))
}
pingAll2 := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("ping all2"))
}
pingOne := func(w http.ResponseWriter, r *http.Request) {
idParam := URLParam(r, "id")
w.WriteHeader(200)
w.Write([]byte(fmt.Sprintf("ping one id: %s", idParam)))
}
pingWoop := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("woop." + URLParam(r, "iidd")))
}
catchAll := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("catchall"))
}
m := NewRouter()
m.Use(countermw)
m.Use(usermw)
m.Use(exmw)
m.Use(logmw)
m.Get("/", cxindex)
m.Method("GET", "/ping", http.HandlerFunc(ping))
m.MethodFunc("GET", "/pingall", pingAll)
m.MethodFunc("get", "/ping/all", pingAll)
m.Get("/ping/all2", pingAll2)
m.Head("/ping", headPing)
m.Post("/ping", createPing)
m.Get("/ping/{id}", pingWoop)
m.Get("/ping/{id}", pingOne) // expected to overwrite to pingOne handler
m.Get("/ping/{iidd}/woop", pingWoop)
m.HandleFunc("/admin/*", catchAll)
// m.Post("/admin/*", catchAll)
ts := httptest.NewServer(m)
defer ts.Close()
// GET /
if _, body := testRequest(t, ts, "GET", "/", nil); body != "hi peter" {
t.Fatalf(body)
}
tlogmsg, _ := logbuf.ReadString(0)
if tlogmsg != logmsg {
t.Error("expecting log message from middleware:", logmsg)
}
// GET /ping
if _, body := testRequest(t, ts, "GET", "/ping", nil); body != "." {
t.Fatalf(body)
}
// GET /pingall
if _, body := testRequest(t, ts, "GET", "/pingall", nil); body != "ping all" {
t.Fatalf(body)
}
// GET /ping/all
if _, body := testRequest(t, ts, "GET", "/ping/all", nil); body != "ping all" {
t.Fatalf(body)
}
// GET /ping/all2
if _, body := testRequest(t, ts, "GET", "/ping/all2", nil); body != "ping all2" {
t.Fatalf(body)
}
// GET /ping/123
if _, body := testRequest(t, ts, "GET", "/ping/123", nil); body != "ping one id: 123" {
t.Fatalf(body)
}
// GET /ping/allan
if _, body := testRequest(t, ts, "GET", "/ping/allan", nil); body != "ping one id: allan" {
t.Fatalf(body)
}
// GET /ping/1/woop
if _, body := testRequest(t, ts, "GET", "/ping/1/woop", nil); body != "woop.1" {
t.Fatalf(body)
}
// HEAD /ping
resp, err := http.Head(ts.URL + "/ping")
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != 200 {
t.Error("head failed, should be 200")
}
if resp.Header.Get("X-Ping") == "" {
t.Error("expecting X-Ping header")
}
// GET /admin/catch-this
if _, body := testRequest(t, ts, "GET", "/admin/catch-thazzzzz", nil); body != "catchall" {
t.Fatalf(body)
}
// POST /admin/catch-this
resp, err = http.Post(ts.URL+"/admin/casdfsadfs", "text/plain", bytes.NewReader([]byte{}))
if err != nil {
t.Fatal(err)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Error("POST failed, should be 200")
}
if string(body) != "catchall" {
t.Error("expecting response body: 'catchall'")
}
// Custom http method DIE /ping/1/woop
if resp, body := testRequest(t, ts, "DIE", "/ping/1/woop", nil); body != "" || resp.StatusCode != 405 {
t.Fatalf(fmt.Sprintf("expecting 405 status and empty body, got %d '%s'", resp.StatusCode, body))
}
}
func TestMuxMounts(t *testing.T) {
r := NewRouter()
r.Get("/{hash}", func(w http.ResponseWriter, r *http.Request) {
v := URLParam(r, "hash")
w.Write([]byte(fmt.Sprintf("/%s", v)))
})
r.Route("/{hash}/share", func(r Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
v := URLParam(r, "hash")
w.Write([]byte(fmt.Sprintf("/%s/share", v)))
})
r.Get("/{network}", func(w http.ResponseWriter, r *http.Request) {
v := URLParam(r, "hash")
n := URLParam(r, "network")
w.Write([]byte(fmt.Sprintf("/%s/share/%s", v, n)))
})
})
m := NewRouter()
m.Mount("/sharing", r)
ts := httptest.NewServer(m)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/sharing/aBc", nil); body != "/aBc" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/sharing/aBc/share", nil); body != "/aBc/share" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/sharing/aBc/share/twitter", nil); body != "/aBc/share/twitter" {
t.Fatalf(body)
}
}
func TestMuxPlain(t *testing.T) {
r := NewRouter()
r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("bye"))
})
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte("nothing here"))
})
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "nothing here" {
t.Fatalf(body)
}
}
func TestMuxEmptyRoutes(t *testing.T) {
mux := NewRouter()
apiRouter := NewRouter()
// oops, we forgot to declare any route handlers
mux.Handle("/api*", apiRouter)
if _, body := testHandler(t, mux, "GET", "/", nil); body != "404 page not found\n" {
t.Fatalf(body)
}
func() {
defer func() {
if r := recover(); r != nil {
if r != `chi: attempting to route to a mux with no handlers.` {
t.Fatalf("expecting empty route panic")
}
}
}()
_, body := testHandler(t, mux, "GET", "/api", nil)
t.Fatalf("oops, we are expecting a panic instead of getting resp: %s", body)
}()
func() {
defer func() {
if r := recover(); r != nil {
if r != `chi: attempting to route to a mux with no handlers.` {
t.Fatalf("expecting empty route panic")
}
}
}()
_, body := testHandler(t, mux, "GET", "/api/abc", nil)
t.Fatalf("oops, we are expecting a panic instead of getting resp: %s", body)
}()
}
// Test a mux that routes a trailing slash, see also middleware/strip_test.go
// for an example of using a middleware to handle trailing slashes.
func TestMuxTrailingSlash(t *testing.T) {
r := NewRouter()
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte("nothing here"))
})
subRoutes := NewRouter()
indexHandler := func(w http.ResponseWriter, r *http.Request) {
accountID := URLParam(r, "accountID")
w.Write([]byte(accountID))
}
subRoutes.Get("/", indexHandler)
r.Mount("/accounts/{accountID}", subRoutes)
r.Get("/accounts/{accountID}/", indexHandler)
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/accounts/admin", nil); body != "admin" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/accounts/admin/", nil); body != "admin" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "nothing here" {
t.Fatalf(body)
}
}
func TestMuxNestedNotFound(t *testing.T) {
r := NewRouter()
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), ctxKey{"mw"}, "mw"))
next.ServeHTTP(w, r)
})
})
r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("bye"))
})
r.With(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), ctxKey{"with"}, "with"))
next.ServeHTTP(w, r)
})
}).NotFound(func(w http.ResponseWriter, r *http.Request) {
chkMw := r.Context().Value(ctxKey{"mw"}).(string)
chkWith := r.Context().Value(ctxKey{"with"}).(string)
w.WriteHeader(404)
w.Write([]byte(fmt.Sprintf("root 404 %s %s", chkMw, chkWith)))
})
sr1 := NewRouter()
sr1.Get("/sub", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("sub"))
})
sr1.Group(func(sr1 Router) {
sr1.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), ctxKey{"mw2"}, "mw2"))
next.ServeHTTP(w, r)
})
})
sr1.NotFound(func(w http.ResponseWriter, r *http.Request) {
chkMw2 := r.Context().Value(ctxKey{"mw2"}).(string)
w.WriteHeader(404)
w.Write([]byte(fmt.Sprintf("sub 404 %s", chkMw2)))
})
})
sr2 := NewRouter()
sr2.Get("/sub", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("sub2"))
})
r.Mount("/admin1", sr1)
r.Mount("/admin2", sr2)
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "root 404 mw with" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/admin1/sub", nil); body != "sub" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/admin1/nope", nil); body != "sub 404 mw2" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/admin2/sub", nil); body != "sub2" {
t.Fatalf(body)
}
// Not found pages should bubble up to the root.
if _, body := testRequest(t, ts, "GET", "/admin2/nope", nil); body != "root 404 mw with" {
t.Fatalf(body)
}
}
func TestMuxNestedMethodNotAllowed(t *testing.T) {
r := NewRouter()
r.Get("/root", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("root"))
})
r.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(405)
w.Write([]byte("root 405"))
})
sr1 := NewRouter()
sr1.Get("/sub1", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("sub1"))
})
sr1.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(405)
w.Write([]byte("sub1 405"))
})
sr2 := NewRouter()
sr2.Get("/sub2", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("sub2"))
})
r.Mount("/prefix1", sr1)
r.Mount("/prefix2", sr2)
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/root", nil); body != "root" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "PUT", "/root", nil); body != "root 405" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/prefix1/sub1", nil); body != "sub1" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "PUT", "/prefix1/sub1", nil); body != "sub1 405" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/prefix2/sub2", nil); body != "sub2" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "PUT", "/prefix2/sub2", nil); body != "root 405" {
t.Fatalf(body)
}
}
func TestMuxComplicatedNotFound(t *testing.T) {
// sub router with groups
sub := NewRouter()
sub.Route("/resource", func(r Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("private get"))
})
})
// Root router with groups
r := NewRouter()
r.Get("/auth", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("auth get"))
})
r.Route("/public", func(r Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("public get"))
})
})
r.Mount("/private", sub)
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("custom not-found"))
})
ts := httptest.NewServer(r)
defer ts.Close()
// check that we didn't break correct routes
if _, body := testRequest(t, ts, "GET", "/auth", nil); body != "auth get" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/public", nil); body != "public get" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/public/", nil); body != "public get" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/private/resource", nil); body != "private get" {
t.Fatalf(body)
}
// check custom not-found on all levels
if _, body := testRequest(t, ts, "GET", "/nope", nil); body != "custom not-found" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/public/nope", nil); body != "custom not-found" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/private/nope", nil); body != "custom not-found" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/private/resource/nope", nil); body != "custom not-found" {
t.Fatalf(body)
}
// check custom not-found on trailing slash routes
if _, body := testRequest(t, ts, "GET", "/auth/", nil); body != "custom not-found" {
t.Fatalf(body)
}
}
func TestMuxWith(t *testing.T) {
var cmwInit1, cmwHandler1 uint64
var cmwInit2, cmwHandler2 uint64
mw1 := func(next http.Handler) http.Handler {
cmwInit1++
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cmwHandler1++
r = r.WithContext(context.WithValue(r.Context(), ctxKey{"inline1"}, "yes"))
next.ServeHTTP(w, r)
})
}
mw2 := func(next http.Handler) http.Handler {
cmwInit2++
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cmwHandler2++
r = r.WithContext(context.WithValue(r.Context(), ctxKey{"inline2"}, "yes"))
next.ServeHTTP(w, r)
})
}
r := NewRouter()
r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("bye"))
})
r.With(mw1).With(mw2).Get("/inline", func(w http.ResponseWriter, r *http.Request) {
v1 := r.Context().Value(ctxKey{"inline1"}).(string)
v2 := r.Context().Value(ctxKey{"inline2"}).(string)
w.Write([]byte(fmt.Sprintf("inline %s %s", v1, v2)))
})
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/inline", nil); body != "inline yes yes" {
t.Fatalf(body)
}
if cmwInit1 != 1 {
t.Fatalf("expecting cmwInit1 to be 1, got %d", cmwInit1)
}
if cmwHandler1 != 1 {
t.Fatalf("expecting cmwHandler1 to be 1, got %d", cmwHandler1)
}
if cmwInit2 != 1 {
t.Fatalf("expecting cmwInit2 to be 1, got %d", cmwInit2)
}
if cmwHandler2 != 1 {
t.Fatalf("expecting cmwHandler2 to be 1, got %d", cmwHandler2)
}
}
func TestRouterFromMuxWith(t *testing.T) {
t.Parallel()
r := NewRouter()
with := r.With(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
})
})
with.Get("/with_middleware", func(w http.ResponseWriter, r *http.Request) {})
ts := httptest.NewServer(with)
defer ts.Close()
// Without the fix this test was committed with, this causes a panic.
testRequest(t, ts, http.MethodGet, "/with_middleware", nil)
}
func TestMuxMiddlewareStack(t *testing.T) {
var stdmwInit, stdmwHandler uint64
stdmw := func(next http.Handler) http.Handler {
stdmwInit++
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
stdmwHandler++
next.ServeHTTP(w, r)
})
}
_ = stdmw
var ctxmwInit, ctxmwHandler uint64
ctxmw := func(next http.Handler) http.Handler {
ctxmwInit++
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxmwHandler++
ctx := r.Context()
ctx = context.WithValue(ctx, ctxKey{"count.ctxmwHandler"}, ctxmwHandler)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
var inCtxmwInit, inCtxmwHandler uint64
inCtxmw := func(next http.Handler) http.Handler {
inCtxmwInit++
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
inCtxmwHandler++
next.ServeHTTP(w, r)
})
}
r := NewRouter()
r.Use(stdmw)
r.Use(ctxmw)
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/ping" {
w.Write([]byte("pong"))
return
}
next.ServeHTTP(w, r)
})
})
var handlerCount uint64
r.With(inCtxmw).Get("/", func(w http.ResponseWriter, r *http.Request) {
handlerCount++
ctx := r.Context()
ctxmwHandlerCount := ctx.Value(ctxKey{"count.ctxmwHandler"}).(uint64)
w.Write([]byte(fmt.Sprintf("inits:%d reqs:%d ctxValue:%d", ctxmwInit, handlerCount, ctxmwHandlerCount)))
})
r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("wooot"))
})
ts := httptest.NewServer(r)
defer ts.Close()
testRequest(t, ts, "GET", "/", nil)
testRequest(t, ts, "GET", "/", nil)
var body string
_, body = testRequest(t, ts, "GET", "/", nil)
if body != "inits:1 reqs:3 ctxValue:3" {
t.Fatalf("got: '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/ping", nil)
if body != "pong" {
t.Fatalf("got: '%s'", body)
}
}
func TestMuxRouteGroups(t *testing.T) {
var stdmwInit, stdmwHandler uint64
stdmw := func(next http.Handler) http.Handler {
stdmwInit++
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
stdmwHandler++
next.ServeHTTP(w, r)
})
}
var stdmwInit2, stdmwHandler2 uint64
stdmw2 := func(next http.Handler) http.Handler {
stdmwInit2++
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
stdmwHandler2++
next.ServeHTTP(w, r)
})
}
r := NewRouter()
r.Group(func(r Router) {
r.Use(stdmw)
r.Get("/group", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("root group"))
})
})
r.Group(func(r Router) {
r.Use(stdmw2)
r.Get("/group2", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("root group2"))
})
})
ts := httptest.NewServer(r)
defer ts.Close()
// GET /group
_, body := testRequest(t, ts, "GET", "/group", nil)
if body != "root group" {
t.Fatalf("got: '%s'", body)
}
if stdmwInit != 1 || stdmwHandler != 1 {
t.Logf("stdmw counters failed, should be 1:1, got %d:%d", stdmwInit, stdmwHandler)
}
// GET /group2
_, body = testRequest(t, ts, "GET", "/group2", nil)
if body != "root group2" {
t.Fatalf("got: '%s'", body)
}
if stdmwInit2 != 1 || stdmwHandler2 != 1 {
t.Fatalf("stdmw2 counters failed, should be 1:1, got %d:%d", stdmwInit2, stdmwHandler2)
}
}
func TestMuxBig(t *testing.T) {
r := bigMux()
ts := httptest.NewServer(r)
defer ts.Close()
var body, expected string
_, body = testRequest(t, ts, "GET", "/favicon.ico", nil)
if body != "fav" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/hubs/4/view", nil)
if body != "/hubs/4/view reqid:1 session:anonymous" {
t.Fatalf("got '%v'", body)
}
_, body = testRequest(t, ts, "GET", "/hubs/4/view/index.html", nil)
if body != "/hubs/4/view/index.html reqid:1 session:anonymous" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "POST", "/hubs/ethereumhub/view/index.html", nil)
if body != "/hubs/ethereumhub/view/index.html reqid:1 session:anonymous" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/", nil)
if body != "/ reqid:1 session:elvis" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/suggestions", nil)
if body != "/suggestions reqid:1 session:elvis" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/woot/444/hiiii", nil)
if body != "/woot/444/hiiii" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/hubs/123", nil)
expected = "/hubs/123 reqid:1 session:elvis"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
_, body = testRequest(t, ts, "GET", "/hubs/123/touch", nil)
if body != "/hubs/123/touch reqid:1 session:elvis" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/hubs/123/webhooks", nil)
if body != "/hubs/123/webhooks reqid:1 session:elvis" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/hubs/123/posts", nil)
if body != "/hubs/123/posts reqid:1 session:elvis" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/folders", nil)
if body != "404 page not found\n" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/folders/", nil)
if body != "/folders/ reqid:1 session:elvis" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/folders/public", nil)
if body != "/folders/public reqid:1 session:elvis" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/folders/nothing", nil)
if body != "404 page not found\n" {
t.Fatalf("got '%s'", body)
}
}
func bigMux() Router {
var r, sr1, sr2, sr3, sr4, sr5, sr6 *Mux
r = NewRouter()
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), ctxKey{"requestID"}, "1")
next.ServeHTTP(w, r.WithContext(ctx))
})
})
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
})
})
r.Group(func(r Router) {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), ctxKey{"session.user"}, "anonymous")
next.ServeHTTP(w, r.WithContext(ctx))
})
})
r.Get("/favicon.ico", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("fav"))
})
r.Get("/hubs/{hubID}/view", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/hubs/%s/view reqid:%s session:%s", URLParam(r, "hubID"),
ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
r.Get("/hubs/{hubID}/view/*", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/hubs/%s/view/%s reqid:%s session:%s", URLParamFromCtx(ctx, "hubID"),
URLParam(r, "*"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
r.Post("/hubs/{hubSlug}/view/*", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/hubs/%s/view/%s reqid:%s session:%s", URLParamFromCtx(ctx, "hubSlug"),
URLParam(r, "*"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
})
r.Group(func(r Router) {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), ctxKey{"session.user"}, "elvis")
next.ServeHTTP(w, r.WithContext(ctx))
})
})
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/ reqid:%s session:%s", ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
r.Get("/suggestions", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/suggestions reqid:%s session:%s", ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
r.Get("/woot/{wootID}/*", func(w http.ResponseWriter, r *http.Request) {
s := fmt.Sprintf("/woot/%s/%s", URLParam(r, "wootID"), URLParam(r, "*"))
w.Write([]byte(s))
})
r.Route("/hubs", func(r Router) {
sr1 = r.(*Mux)
r.Route("/{hubID}", func(r Router) {
sr2 = r.(*Mux)
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/hubs/%s reqid:%s session:%s",
URLParam(r, "hubID"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
r.Get("/touch", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/hubs/%s/touch reqid:%s session:%s", URLParam(r, "hubID"),
ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
sr3 = NewRouter()
sr3.Get("/", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/hubs/%s/webhooks reqid:%s session:%s", URLParam(r, "hubID"),
ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
sr3.Route("/{webhookID}", func(r Router) {
sr4 = r.(*Mux)
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/hubs/%s/webhooks/%s reqid:%s session:%s", URLParam(r, "hubID"),
URLParam(r, "webhookID"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
})
r.Mount("/webhooks", Chain(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), ctxKey{"hook"}, true)))
})
}).Handler(sr3))
r.Route("/posts", func(r Router) {
sr5 = r.(*Mux)
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/hubs/%s/posts reqid:%s session:%s", URLParam(r, "hubID"),
ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
})
})
})
r.Route("/folders/", func(r Router) {
sr6 = r.(*Mux)
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/folders/ reqid:%s session:%s",
ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
r.Get("/public", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/folders/public reqid:%s session:%s",
ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
})
})
return r
}
func TestMuxSubroutesBasic(t *testing.T) {
hIndex := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("index"))
})
hArticlesList := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("articles-list"))
})
hSearchArticles := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("search-articles"))
})
hGetArticle := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf("get-article:%s", URLParam(r, "id"))))
})
hSyncArticle := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf("sync-article:%s", URLParam(r, "id"))))
})
r := NewRouter()
var rr1, rr2 *Mux
r.Get("/", hIndex)
r.Route("/articles", func(r Router) {
rr1 = r.(*Mux)
r.Get("/", hArticlesList)
r.Get("/search", hSearchArticles)
r.Route("/{id}", func(r Router) {
rr2 = r.(*Mux)
r.Get("/", hGetArticle)
r.Get("/sync", hSyncArticle)
})
})
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
// debugPrintTree(0, 0, r.tree, 0)
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
// debugPrintTree(0, 0, rr1.tree, 0)
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
// debugPrintTree(0, 0, rr2.tree, 0)
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
ts := httptest.NewServer(r)
defer ts.Close()
var body, expected string
_, body = testRequest(t, ts, "GET", "/", nil)
expected = "index"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
_, body = testRequest(t, ts, "GET", "/articles", nil)
expected = "articles-list"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
_, body = testRequest(t, ts, "GET", "/articles/search", nil)
expected = "search-articles"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
_, body = testRequest(t, ts, "GET", "/articles/123", nil)
expected = "get-article:123"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
_, body = testRequest(t, ts, "GET", "/articles/123/sync", nil)
expected = "sync-article:123"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
}
func TestMuxSubroutes(t *testing.T) {
hHubView1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hub1"))
})
hHubView2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hub2"))
})
hHubView3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hub3"))
})
hAccountView1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("account1"))
})
hAccountView2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("account2"))
})
r := NewRouter()
r.Get("/hubs/{hubID}/view", hHubView1)
r.Get("/hubs/{hubID}/view/*", hHubView2)
sr := NewRouter()
sr.Get("/", hHubView3)
r.Mount("/hubs/{hubID}/users", sr)
r.Get("/hubs/{hubID}/users/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hub3 override"))
})
sr3 := NewRouter()
sr3.Get("/", hAccountView1)
sr3.Get("/hi", hAccountView2)
var sr2 *Mux
r.Route("/accounts/{accountID}", func(r Router) {
sr2 = r.(*Mux)
// r.Get("/", hAccountView1)
r.Mount("/", sr3)
})
// This is the same as the r.Route() call mounted on sr2
// sr2 := NewRouter()
// sr2.Mount("/", sr3)
// r.Mount("/accounts/{accountID}", sr2)
ts := httptest.NewServer(r)
defer ts.Close()
var body, expected string
_, body = testRequest(t, ts, "GET", "/hubs/123/view", nil)
expected = "hub1"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
_, body = testRequest(t, ts, "GET", "/hubs/123/view/index.html", nil)
expected = "hub2"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
_, body = testRequest(t, ts, "GET", "/hubs/123/users", nil)
expected = "hub3"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
_, body = testRequest(t, ts, "GET", "/hubs/123/users/", nil)
expected = "hub3 override"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
_, body = testRequest(t, ts, "GET", "/accounts/44", nil)
expected = "account1"
if body != expected {
t.Fatalf("request:%s expected:%s got:%s", "GET /accounts/44", expected, body)
}
_, body = testRequest(t, ts, "GET", "/accounts/44/hi", nil)
expected = "account2"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
// Test that we're building the routingPatterns properly
router := r
req, _ := http.NewRequest("GET", "/accounts/44/hi", nil)
rctx := NewRouteContext()
req = req.WithContext(context.WithValue(req.Context(), RouteCtxKey, rctx))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
body = string(w.Body.Bytes())
expected = "account2"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
routePatterns := rctx.RoutePatterns
if len(rctx.RoutePatterns) != 3 {
t.Fatalf("expected 3 routing patterns, got:%d", len(rctx.RoutePatterns))
}
expected = "/accounts/{accountID}/*"
if routePatterns[0] != expected {
t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[0])
}
expected = "/*"
if routePatterns[1] != expected {
t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[1])
}
expected = "/hi"
if routePatterns[2] != expected {
t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[2])
}
}
func TestSingleHandler(t *testing.T) {
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
name := URLParam(r, "name")
w.Write([]byte("hi " + name))
})
r, _ := http.NewRequest("GET", "/", nil)
rctx := NewRouteContext()
r = r.WithContext(context.WithValue(r.Context(), RouteCtxKey, rctx))
rctx.URLParams.Add("name", "joe")
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
body := string(w.Body.Bytes())
expected := "hi joe"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
}
// TODO: a Router wrapper test..
//
// type ACLMux struct {
// *Mux
// XX string
// }
//
// func NewACLMux() *ACLMux {
// return &ACLMux{Mux: NewRouter(), XX: "hihi"}
// }
//
// // TODO: this should be supported...
// func TestWoot(t *testing.T) {
// var r Router = NewRouter()
//
// var r2 Router = NewACLMux() //NewRouter()
// r2.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
// w.Write([]byte("hi"))
// })
//
// r.Mount("/", r2)
// }
func TestServeHTTPExistingContext(t *testing.T) {
r := NewRouter()
r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
s, _ := r.Context().Value(ctxKey{"testCtx"}).(string)
w.Write([]byte(s))
})
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
s, _ := r.Context().Value(ctxKey{"testCtx"}).(string)
w.WriteHeader(404)
w.Write([]byte(s))
})
testcases := []struct {
Method string
Path string
Ctx context.Context
ExpectedStatus int
ExpectedBody string
}{
{
Method: "GET",
Path: "/hi",
Ctx: context.WithValue(context.Background(), ctxKey{"testCtx"}, "hi ctx"),
ExpectedStatus: 200,
ExpectedBody: "hi ctx",
},
{
Method: "GET",
Path: "/hello",
Ctx: context.WithValue(context.Background(), ctxKey{"testCtx"}, "nothing here ctx"),
ExpectedStatus: 404,
ExpectedBody: "nothing here ctx",
},
}
for _, tc := range testcases {
resp := httptest.NewRecorder()
req, err := http.NewRequest(tc.Method, tc.Path, nil)
if err != nil {
t.Fatalf("%v", err)
}
req = req.WithContext(tc.Ctx)
r.ServeHTTP(resp, req)
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("%v", err)
}
if resp.Code != tc.ExpectedStatus {
t.Fatalf("%v != %v", tc.ExpectedStatus, resp.Code)
}
if string(b) != tc.ExpectedBody {
t.Fatalf("%s != %s", tc.ExpectedBody, b)
}
}
}
func TestNestedGroups(t *testing.T) {
handlerPrintCounter := func(w http.ResponseWriter, r *http.Request) {
counter, _ := r.Context().Value(ctxKey{"counter"}).(int)
w.Write([]byte(fmt.Sprintf("%v", counter)))
}
mwIncreaseCounter := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
counter, _ := ctx.Value(ctxKey{"counter"}).(int)
counter++
ctx = context.WithValue(ctx, ctxKey{"counter"}, counter)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// Each route represents value of its counter (number of applied middlewares).
r := NewRouter() // counter == 0
r.Get("/0", handlerPrintCounter)
r.Group(func(r Router) {
r.Use(mwIncreaseCounter) // counter == 1
r.Get("/1", handlerPrintCounter)
// r.Handle(GET, "/2", Chain(mwIncreaseCounter).HandlerFunc(handlerPrintCounter))
r.With(mwIncreaseCounter).Get("/2", handlerPrintCounter)
r.Group(func(r Router) {
r.Use(mwIncreaseCounter, mwIncreaseCounter) // counter == 3
r.Get("/3", handlerPrintCounter)
})
r.Route("/", func(r Router) {
r.Use(mwIncreaseCounter, mwIncreaseCounter) // counter == 3
// r.Handle(GET, "/4", Chain(mwIncreaseCounter).HandlerFunc(handlerPrintCounter))
r.With(mwIncreaseCounter).Get("/4", handlerPrintCounter)
r.Group(func(r Router) {
r.Use(mwIncreaseCounter, mwIncreaseCounter) // counter == 5
r.Get("/5", handlerPrintCounter)
// r.Handle(GET, "/6", Chain(mwIncreaseCounter).HandlerFunc(handlerPrintCounter))
r.With(mwIncreaseCounter).Get("/6", handlerPrintCounter)
})
})
})
ts := httptest.NewServer(r)
defer ts.Close()
for _, route := range []string{"0", "1", "2", "3", "4", "5", "6"} {
if _, body := testRequest(t, ts, "GET", "/"+route, nil); body != route {
t.Errorf("expected %v, got %v", route, body)
}
}
}
func TestMiddlewarePanicOnLateUse(t *testing.T) {
handler := func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello\n"))
}
mw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
})
}
defer func() {
if recover() == nil {
t.Error("expected panic()")
}
}()
r := NewRouter()
r.Get("/", handler)
r.Use(mw) // Too late to apply middleware, we're expecting panic().
}
func TestMountingExistingPath(t *testing.T) {
handler := func(w http.ResponseWriter, r *http.Request) {}
defer func() {
if recover() == nil {
t.Error("expected panic()")
}
}()
r := NewRouter()
r.Get("/", handler)
r.Mount("/hi", http.HandlerFunc(handler))
r.Mount("/hi", http.HandlerFunc(handler))
}
func TestMountingSimilarPattern(t *testing.T) {
r := NewRouter()
r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("bye"))
})
r2 := NewRouter()
r2.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("foobar"))
})
r3 := NewRouter()
r3.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("foo"))
})
r.Mount("/foobar", r2)
r.Mount("/foo", r3)
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
t.Fatalf(body)
}
}
func TestMuxEmptyParams(t *testing.T) {
r := NewRouter()
r.Get(`/users/{x}/{y}/{z}`, func(w http.ResponseWriter, r *http.Request) {
x := URLParam(r, "x")
y := URLParam(r, "y")
z := URLParam(r, "z")
w.Write([]byte(fmt.Sprintf("%s-%s-%s", x, y, z)))
})
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/users/a/b/c", nil); body != "a-b-c" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/users///c", nil); body != "--c" {
t.Fatalf(body)
}
}
func TestMuxMissingParams(t *testing.T) {
r := NewRouter()
r.Get(`/user/{userId:\d+}`, func(w http.ResponseWriter, r *http.Request) {
userID := URLParam(r, "userId")
w.Write([]byte(fmt.Sprintf("userId = '%s'", userID)))
})
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte("nothing here"))
})
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/user/123", nil); body != "userId = '123'" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/user/", nil); body != "nothing here" {
t.Fatalf(body)
}
}
func TestMuxContextIsThreadSafe(t *testing.T) {
router := NewRouter()
router.Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 1*time.Millisecond)
defer cancel()
<-ctx.Done()
})
wg := sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 10000; j++ {
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/ok", nil)
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(r.Context())
r = r.WithContext(ctx)
go func() {
cancel()
}()
router.ServeHTTP(w, r)
}
}()
}
wg.Wait()
}
func TestEscapedURLParams(t *testing.T) {
m := NewRouter()
m.Get("/api/{identifier}/{region}/{size}/{rotation}/*", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
rctx := RouteContext(r.Context())
if rctx == nil {
t.Error("no context")
return
}
identifier := URLParam(r, "identifier")
if identifier != "http:%2f%2fexample.com%2fimage.png" {
t.Errorf("identifier path parameter incorrect %s", identifier)
return
}
region := URLParam(r, "region")
if region != "full" {
t.Errorf("region path parameter incorrect %s", region)
return
}
size := URLParam(r, "size")
if size != "max" {
t.Errorf("size path parameter incorrect %s", size)
return
}
rotation := URLParam(r, "rotation")
if rotation != "0" {
t.Errorf("rotation path parameter incorrect %s", rotation)
return
}
w.Write([]byte("success"))
})
ts := httptest.NewServer(m)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/api/http:%2f%2fexample.com%2fimage.png/full/max/0/color.png", nil); body != "success" {
t.Fatalf(body)
}
}
func TestMuxMatch(t *testing.T) {
r := NewRouter()
r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test", "yes")
w.Write([]byte("bye"))
})
r.Route("/articles", func(r Router) {
r.Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
id := URLParam(r, "id")
w.Header().Set("X-Article", id)
w.Write([]byte("article:" + id))
})
})
r.Route("/users", func(r Router) {
r.Head("/{id}", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-User", "-")
w.Write([]byte("user"))
})
r.Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
id := URLParam(r, "id")
w.Header().Set("X-User", id)
w.Write([]byte("user:" + id))
})
})
tctx := NewRouteContext()
tctx.Reset()
if r.Match(tctx, "GET", "/users/1") == false {
t.Fatal("expecting to find match for route:", "GET", "/users/1")
}
tctx.Reset()
if r.Match(tctx, "HEAD", "/articles/10") == true {
t.Fatal("not expecting to find match for route:", "HEAD", "/articles/10")
}
}
func TestServerBaseContext(t *testing.T) {
r := NewRouter()
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
baseYes := r.Context().Value(ctxKey{"base"}).(string)
if _, ok := r.Context().Value(http.ServerContextKey).(*http.Server); !ok {
panic("missing server context")
}
if _, ok := r.Context().Value(http.LocalAddrContextKey).(net.Addr); !ok {
panic("missing local addr context")
}
w.Write([]byte(baseYes))
})
// Setup http Server with a base context
ctx := context.WithValue(context.Background(), ctxKey{"base"}, "yes")
ts := httptest.NewServer(ServerBaseContext(ctx, r))
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/", nil); body != "yes" {
t.Fatalf(body)
}
}
func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) {
req, err := http.NewRequest(method, ts.URL+path, body)
if err != nil {
t.Fatal(err)
return nil, ""
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
return nil, ""
}
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
return nil, ""
}
defer resp.Body.Close()
return resp, string(respBody)
}
func testHandler(t *testing.T, h http.Handler, method, path string, body io.Reader) (*http.Response, string) {
r, _ := http.NewRequest(method, path, body)
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
return w.Result(), string(w.Body.Bytes())
}
type testFileSystem struct {
open func(name string) (http.File, error)
}
func (fs *testFileSystem) Open(name string) (http.File, error) {
return fs.open(name)
}
type testFile struct {
name string
contents []byte
}
func (tf *testFile) Close() error {
return nil
}
func (tf *testFile) Read(p []byte) (n int, err error) {
copy(p, tf.contents)
return len(p), nil
}
func (tf *testFile) Seek(offset int64, whence int) (int64, error) {
return 0, nil
}
func (tf *testFile) Readdir(count int) ([]os.FileInfo, error) {
stat, _ := tf.Stat()
return []os.FileInfo{stat}, nil
}
func (tf *testFile) Stat() (os.FileInfo, error) {
return &testFileInfo{tf.name, int64(len(tf.contents))}, nil
}
type testFileInfo struct {
name string
size int64
}
func (tfi *testFileInfo) Name() string { return tfi.name }
func (tfi *testFileInfo) Size() int64 { return tfi.size }
func (tfi *testFileInfo) Mode() os.FileMode { return 0755 }
func (tfi *testFileInfo) ModTime() time.Time { return time.Now() }
func (tfi *testFileInfo) IsDir() bool { return false }
func (tfi *testFileInfo) Sys() interface{} { return nil }
type ctxKey struct {
name string
}
func (k ctxKey) String() string {
return "context value " + k.name
}
func BenchmarkMux(b *testing.B) {
h1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
h3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
h4 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
h5 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
h6 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
mx := NewRouter()
mx.Get("/", h1)
mx.Get("/hi", h2)
mx.Get("/sup/{id}/and/{this}", h3)
mx.Route("/sharing/{hash}", func(mx Router) {
mx.Get("/", h4) // subrouter-1
mx.Get("/{network}", h5) // subrouter-1
mx.Get("/twitter", h5)
mx.Route("/direct", func(mx Router) {
mx.Get("/", h6) // subrouter-2
})
})
routes := []string{
"/",
"/sup/123/and/this",
"/sharing/aBc", // subrouter-1
"/sharing/aBc/twitter", // subrouter-1
"/sharing/aBc/direct", // subrouter-2
}
for _, path := range routes {
b.Run("route:"+path, func(b *testing.B) {
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", path, nil)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
mx.ServeHTTP(w, r)
}
})
}
}