125 lines
2.8 KiB
Go
125 lines
2.8 KiB
Go
|
package middleware
|
||
|
|
||
|
import (
|
||
|
"net/http"
|
||
|
"net/http/httptest"
|
||
|
"testing"
|
||
|
|
||
|
"github.com/go-chi/chi"
|
||
|
)
|
||
|
|
||
|
func TestContentCharset(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
var tests = []struct {
|
||
|
name string
|
||
|
inputValue string
|
||
|
inputContentCharset []string
|
||
|
want int
|
||
|
}{
|
||
|
{
|
||
|
"should accept requests with a matching charset",
|
||
|
"application/json; charset=UTF-8",
|
||
|
[]string{"UTF-8"},
|
||
|
http.StatusOK,
|
||
|
},
|
||
|
{
|
||
|
"should be case-insensitive",
|
||
|
"application/json; charset=utf-8",
|
||
|
[]string{"UTF-8"},
|
||
|
http.StatusOK,
|
||
|
},
|
||
|
{
|
||
|
"should accept requests with a matching charset with extra values",
|
||
|
"application/json; foo=bar; charset=UTF-8; spam=eggs",
|
||
|
[]string{"UTF-8"},
|
||
|
http.StatusOK,
|
||
|
},
|
||
|
{
|
||
|
"should accept requests with a matching charset when multiple charsets are supported",
|
||
|
"text/xml; charset=UTF-8",
|
||
|
[]string{"UTF-8", "Latin-1"},
|
||
|
http.StatusOK,
|
||
|
},
|
||
|
{
|
||
|
"should accept requests with no charset if empty charset headers are allowed",
|
||
|
"text/xml",
|
||
|
[]string{"UTF-8", ""},
|
||
|
http.StatusOK,
|
||
|
},
|
||
|
{
|
||
|
"should not accept requests with no charset if empty charset headers are not allowed",
|
||
|
"text/xml",
|
||
|
[]string{"UTF-8"},
|
||
|
http.StatusUnsupportedMediaType,
|
||
|
},
|
||
|
{
|
||
|
"should not accept requests with a mismatching charset",
|
||
|
"text/plain; charset=Latin-1",
|
||
|
[]string{"UTF-8"},
|
||
|
http.StatusUnsupportedMediaType,
|
||
|
},
|
||
|
{
|
||
|
"should not accept requests with a mismatching charset even if empty charsets are allowed",
|
||
|
"text/plain; charset=Latin-1",
|
||
|
[]string{"UTF-8", ""},
|
||
|
http.StatusUnsupportedMediaType,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
var tt = tt
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
var recorder = httptest.NewRecorder()
|
||
|
|
||
|
var r = chi.NewRouter()
|
||
|
r.Use(ContentCharset(tt.inputContentCharset...))
|
||
|
r.Get("/", func(w http.ResponseWriter, r *http.Request) {})
|
||
|
|
||
|
var req, _ = http.NewRequest("GET", "/", nil)
|
||
|
req.Header.Set("Content-Type", tt.inputValue)
|
||
|
|
||
|
r.ServeHTTP(recorder, req)
|
||
|
var res = recorder.Result()
|
||
|
|
||
|
if res.StatusCode != tt.want {
|
||
|
t.Errorf("response is incorrect, got %d, want %d", recorder.Code, tt.want)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestSplit(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
var s1, s2 = split(" type1;type2 ", ";")
|
||
|
|
||
|
if s1 != "type1" || s2 != "type2" {
|
||
|
t.Errorf("Want type1, type2 got %s, %s", s1, s2)
|
||
|
}
|
||
|
|
||
|
s1, s2 = split("type1 ", ";")
|
||
|
|
||
|
if s1 != "type1" {
|
||
|
t.Errorf("Want \"type1\" got \"%s\"", s1)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestContentEncoding(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
if !contentEncoding("application/json; foo=bar; charset=utf-8; spam=eggs", []string{"utf-8"}...) {
|
||
|
t.Error("Want true, got false")
|
||
|
}
|
||
|
|
||
|
if contentEncoding("text/plain; charset=latin-1", []string{"utf-8"}...) {
|
||
|
t.Error("Want false, got true")
|
||
|
}
|
||
|
|
||
|
if !contentEncoding("text/xml; charset=UTF-8", []string{"latin-1", "utf-8"}...) {
|
||
|
t.Error("Want true, got false")
|
||
|
}
|
||
|
}
|