Add dependencies
This commit is contained in:
parent
bfae5cb265
commit
df7052dd40
|
@ -21,7 +21,5 @@ _testmain.go
|
|||
|
||||
*.exe
|
||||
|
||||
/vendor/src
|
||||
|
||||
# Data, etc
|
||||
/data
|
||||
|
|
|
@ -1,11 +1,97 @@
|
|||
{
|
||||
"version": 0,
|
||||
"dependencies": [
|
||||
{
|
||||
"importpath": "github.com/Xe/middleware",
|
||||
"repository": "https://github.com/Xe/middleware",
|
||||
"revision": "7d23200fbed9e7f3be4ac76b4f7f6bd19cc4aba0",
|
||||
"branch": "master"
|
||||
},
|
||||
{
|
||||
"importpath": "github.com/codegangsta/negroni",
|
||||
"repository": "https://github.com/codegangsta/negroni",
|
||||
"revision": "c7477ad8e330bef55bf1ebe300cf8aa67c492d1b",
|
||||
"branch": "master"
|
||||
},
|
||||
{
|
||||
"importpath": "github.com/disintegration/imaging",
|
||||
"repository": "https://github.com/disintegration/imaging",
|
||||
"revision": "3ab6ec550f20d497d2755ed3c48a3e45ad6b7eb9",
|
||||
"branch": "master"
|
||||
},
|
||||
{
|
||||
"importpath": "github.com/drone/routes",
|
||||
"repository": "https://github.com/drone/routes",
|
||||
"revision": "853bef2b231162bb7b09355720416d3af1510d88",
|
||||
"branch": "master"
|
||||
},
|
||||
{
|
||||
"importpath": "github.com/garyburd/redigo",
|
||||
"repository": "https://github.com/garyburd/redigo",
|
||||
"revision": "a47585eaae68b1d14b02940d2af1b9194f3caa9c",
|
||||
"branch": "master"
|
||||
},
|
||||
{
|
||||
"importpath": "github.com/gorilla/context",
|
||||
"repository": "https://github.com/gorilla/context",
|
||||
"revision": "215affda49addc4c8ef7e2534915df2c8c35c6cd",
|
||||
"branch": "master"
|
||||
},
|
||||
{
|
||||
"importpath": "github.com/gorilla/mux",
|
||||
"repository": "https://github.com/gorilla/mux",
|
||||
"revision": "f15e0c49460fd49eebe2bcc8486b05d1bef68d3a",
|
||||
"branch": "master"
|
||||
},
|
||||
{
|
||||
"importpath": "github.com/jinzhu/gorm",
|
||||
"repository": "https://github.com/jinzhu/gorm",
|
||||
"revision": "6a7dda9a32e187c044178aadb0a4510f053a73fa",
|
||||
"branch": "master"
|
||||
},
|
||||
{
|
||||
"importpath": "github.com/lib/pq",
|
||||
"repository": "https://github.com/lib/pq",
|
||||
"revision": "0dad96c0b94f8dee039aa40467f767467392a0af",
|
||||
"branch": "master"
|
||||
},
|
||||
{
|
||||
"importpath": "github.com/sebest/xff",
|
||||
"repository": "https://github.com/sebest/xff",
|
||||
"revision": "d90d345f39f4e84675192d6662c42f33a46ec830",
|
||||
"branch": "master"
|
||||
},
|
||||
{
|
||||
"importpath": "github.com/thoj/go-ircevent",
|
||||
"repository": "https://github.com/thoj/go-ircevent",
|
||||
"revision": "c47f9d8e3db1e137c31efbd755bd563d1bf29efc",
|
||||
"branch": "master"
|
||||
},
|
||||
{
|
||||
"importpath": "github.com/unrolled/render",
|
||||
"repository": "https://github.com/unrolled/render",
|
||||
"revision": "aa61028b1d32873eaa3e261a3ef0e892a153107b",
|
||||
"branch": "v1"
|
||||
},
|
||||
{
|
||||
"importpath": "github.com/yosssi/ace-proxy",
|
||||
"repository": "https://github.com/yosssi/ace-proxy",
|
||||
"revision": "ecd9b785e6023e00b1a451cdf584a30d4eff14c0",
|
||||
"branch": "master"
|
||||
},
|
||||
{
|
||||
"importpath": "golang.org/x/image/bmp",
|
||||
"repository": "https://go.googlesource.com/image",
|
||||
"revision": "5ec5e003b21ac1f06e175898413ada23a6797fc0",
|
||||
"branch": "master",
|
||||
"path": "/bmp"
|
||||
},
|
||||
{
|
||||
"importpath": "golang.org/x/image/tiff",
|
||||
"repository": "https://go.googlesource.com/image",
|
||||
"revision": "5ec5e003b21ac1f06e175898413ada23a6797fc0",
|
||||
"branch": "master",
|
||||
"path": "/tiff"
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1,2 @@
|
|||
# middleware
|
||||
All of the useful middlewares I use
|
|
@ -0,0 +1,13 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/Xe/middleware/xff"
|
||||
"github.com/Xe/middleware/xrequestid"
|
||||
"github.com/codegangsta/negroni"
|
||||
)
|
||||
|
||||
// Inject adds x-request-id and x-forwarded-for support to an existing negroni instance.
|
||||
func Inject(n *negroni.Negroni) {
|
||||
n.Use(negroni.HandlerFunc(xff.XFF))
|
||||
n.Use(xrequestid.New(26))
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
# X-Forwarded-For middleware fo Go [](https://godoc.org/github.com/sebest/xff)
|
||||
|
||||
Package `xff` is a `net/http` middleware/handler to parse [Forwarded HTTP Extension](http://tools.ietf.org/html/rfc7239) in Golang.
|
||||
|
||||
## Example usage
|
||||
|
||||
Install `xff`:
|
||||
|
||||
go get github.com/sebest/xff
|
||||
|
||||
Edit `server.go`:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/sebest/xff"
|
||||
)
|
||||
|
||||
func main() {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("hello from " + r.RemoteAddr + "\n"))
|
||||
})
|
||||
|
||||
http.ListenAndServe(":8080", xff.Handler(handler))
|
||||
}
|
||||
```
|
||||
|
||||
Then run your server:
|
||||
|
||||
go run server.go
|
||||
|
||||
The server now runs on `localhost:8080`:
|
||||
|
||||
$ curl -D - -H 'X-Forwarded-For: 42.42.42.42' http://localhost:8080/
|
||||
HTTP/1.1 200 OK
|
||||
Date: Fri, 20 Feb 2015 20:03:02 GMT
|
||||
Content-Length: 29
|
||||
Content-Type: text/plain; charset=utf-8
|
||||
|
||||
hello from 42.42.42.42:52661
|
|
@ -0,0 +1,23 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/codegangsta/negroni"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/sebest/xff"
|
||||
)
|
||||
|
||||
func main() {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("hello from " + r.RemoteAddr + "\n"))
|
||||
})
|
||||
|
||||
mux := mux.NewRouter()
|
||||
mux.Handle("/", handler)
|
||||
|
||||
n := negroni.Classic()
|
||||
n.Use(negroni.HandlerFunc(xff.XFF))
|
||||
n.UseHandler(mux)
|
||||
n.Run(":3000")
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/sebest/xff"
|
||||
)
|
||||
|
||||
func main() {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("hello from " + r.RemoteAddr + "\n"))
|
||||
})
|
||||
|
||||
http.ListenAndServe(":3000", xff.Handler(handler))
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
package xff
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var privateMasks = func() []net.IPNet {
|
||||
masks := []net.IPNet{}
|
||||
for _, cidr := range []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "fc00::/7"} {
|
||||
_, net, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
masks = append(masks, *net)
|
||||
}
|
||||
return masks
|
||||
}()
|
||||
|
||||
// IsPublicIP returns true if the given IP can be routed on the Internet
|
||||
func IsPublicIP(ip net.IP) bool {
|
||||
if !ip.IsGlobalUnicast() {
|
||||
return false
|
||||
}
|
||||
for _, mask := range privateMasks {
|
||||
if mask.Contains(ip) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Parse parses the value of the X-Forwarded-For Header and returns the IP address.
|
||||
func Parse(ipList string) string {
|
||||
for _, ip := range strings.Split(ipList, ",") {
|
||||
ip = strings.TrimSpace(ip)
|
||||
if IP := net.ParseIP(ip); IP != nil && IsPublicIP(IP) {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetRemoteAddr parses the given request, resolves the X-Forwarded-For header
|
||||
// and returns the resolved remote address.
|
||||
func GetRemoteAddr(r *http.Request) string {
|
||||
xff := r.Header.Get("X-Forwarded-For")
|
||||
var ip string
|
||||
if xff != "" {
|
||||
ip = Parse(xff)
|
||||
}
|
||||
_, oport, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err == nil && ip != "" {
|
||||
return net.JoinHostPort(ip, oport)
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
// Handler is a middleware to update RemoteAdd from X-Fowarded-* Headers.
|
||||
func Handler(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r.RemoteAddr = GetRemoteAddr(r)
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// HandlerFunc is a Martini compatible handler
|
||||
func HandlerFunc(w http.ResponseWriter, r *http.Request) {
|
||||
r.RemoteAddr = GetRemoteAddr(r)
|
||||
}
|
||||
|
||||
// XFF is a Negroni compatible interface
|
||||
func XFF(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
r.RemoteAddr = GetRemoteAddr(r)
|
||||
next(w, r)
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
package xff
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParse_none(t *testing.T) {
|
||||
res := Parse("")
|
||||
assert.Equal(t, "", res)
|
||||
}
|
||||
|
||||
func TestParse_localhost(t *testing.T) {
|
||||
res := Parse("127.0.0.1")
|
||||
assert.Equal(t, "", res)
|
||||
}
|
||||
|
||||
func TestParse_invalid(t *testing.T) {
|
||||
res := Parse("invalid")
|
||||
assert.Equal(t, "", res)
|
||||
}
|
||||
|
||||
func TestParse_invalid_sioux(t *testing.T) {
|
||||
res := Parse("123#1#2#3")
|
||||
assert.Equal(t, "", res)
|
||||
}
|
||||
|
||||
func TestParse_invalid_private_lookalike(t *testing.T) {
|
||||
res := Parse("102.3.2.1")
|
||||
assert.Equal(t, "102.3.2.1", res)
|
||||
}
|
||||
|
||||
func TestParse_valid(t *testing.T) {
|
||||
res := Parse("68.45.152.220")
|
||||
assert.Equal(t, "68.45.152.220", res)
|
||||
}
|
||||
|
||||
func TestParse_multi_first(t *testing.T) {
|
||||
res := Parse("12.13.14.15, 68.45.152.220")
|
||||
assert.Equal(t, "12.13.14.15", res)
|
||||
}
|
||||
|
||||
func TestParse_multi_last(t *testing.T) {
|
||||
res := Parse("192.168.110.162, 190.57.149.90")
|
||||
assert.Equal(t, "190.57.149.90", res)
|
||||
}
|
||||
|
||||
func TestParse_multi_with_invalid(t *testing.T) {
|
||||
res := Parse("192.168.110.162, invalid, 190.57.149.90")
|
||||
assert.Equal(t, "190.57.149.90", res)
|
||||
}
|
||||
|
||||
func TestParse_multi_with_invalid2(t *testing.T) {
|
||||
res := Parse("192.168.110.162, 190.57.149.90, invalid")
|
||||
assert.Equal(t, "190.57.149.90", res)
|
||||
}
|
||||
|
||||
func TestParse_multi_with_invalid_sioux(t *testing.T) {
|
||||
res := Parse("192.168.110.162, 190.57.149.90, 123#1#2#3")
|
||||
assert.Equal(t, "190.57.149.90", res)
|
||||
}
|
||||
|
||||
func TestParse_ipv6_with_port(t *testing.T) {
|
||||
res := Parse("2604:2000:71a9:bf00:f178:a500:9a2d:670d")
|
||||
assert.Equal(t, "2604:2000:71a9:bf00:f178:a500:9a2d:670d", res)
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014 Andrea Franz (http://gravityblast.com)
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
GO_CMD=go
|
||||
GOLINT_CMD=golint
|
||||
GO_TEST=$(GO_CMD) test -v ./...
|
||||
GO_VET=$(GO_CMD) vet ./...
|
||||
GO_LINT=$(GOLINT_CMD) ./...
|
||||
|
||||
all:
|
||||
$(GO_VET)
|
||||
$(GO_LINT)
|
||||
$(GO_TEST)
|
|
@ -0,0 +1,5 @@
|
|||
# xrequestid
|
||||
|
||||
> Package xrequestid implements an http middleware for Negroni that assigns a random id to each request. It's written in the Go programming language.
|
||||
|
||||
Docs at http://godoc.org/github.com/pilu/xrequestid
|
|
@ -0,0 +1,76 @@
|
|||
// Package xrequestid implements an http middleware for Negroni that assigns a random id to each request
|
||||
//
|
||||
// Example:
|
||||
// package main
|
||||
//
|
||||
// import (
|
||||
// "fmt"
|
||||
// "net/http"
|
||||
//
|
||||
// "github.com/codegangsta/negroni"
|
||||
// "github.com/pilu/xrequestid"
|
||||
// )
|
||||
//
|
||||
// func main() {
|
||||
// mux := http.NewServeMux()
|
||||
// mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
// fmt.Fprintf(w, "X-Request-Id is `%s`", r.Header.Get("X-Request-Id"))
|
||||
// })
|
||||
//
|
||||
// n := negroni.New()
|
||||
// n.Use(xrequestid.New(16))
|
||||
// n.UseHandler(mux)
|
||||
// n.Run(":3000")
|
||||
// }
|
||||
package xrequestid
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// By default the middleware set the generated random string to this key in the request header
|
||||
const DefaultHeaderKey = "X-Request-Id"
|
||||
|
||||
// GenerateFunc is the func used by the middleware to generates the random string.
|
||||
type GenerateFunc func(int) (string, error)
|
||||
|
||||
// XRequestID is a middleware that adds a random ID to the request X-Request-Id header
|
||||
type XRequestID struct {
|
||||
// Size specifies the length of the random length. The length of the result string is twice of n.
|
||||
Size int
|
||||
// Generate is a GenerateFunc that generates the random string. The default one uses crypto/rand
|
||||
Generate GenerateFunc
|
||||
// HeaderKey is the header name where the middleware set the random string. By default it uses the DefaultHeaderKey constant value
|
||||
HeaderKey string
|
||||
}
|
||||
|
||||
// New returns a new XRequestID middleware instance. n specifies the length of the random length. The length of the result string is twice of n.
|
||||
func New(n int) *XRequestID {
|
||||
return &XRequestID{
|
||||
Size: n,
|
||||
Generate: generateID,
|
||||
HeaderKey: DefaultHeaderKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *XRequestID) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
id, err := m.Generate(m.Size)
|
||||
if err == nil {
|
||||
r.Header.Set(m.HeaderKey, id)
|
||||
rw.Header().Set(m.HeaderKey, id)
|
||||
}
|
||||
|
||||
next(rw, r)
|
||||
}
|
||||
|
||||
func generateID(n int) (string, error) {
|
||||
r := make([]byte, n)
|
||||
_, err := rand.Read(r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return hex.EncodeToString(r), nil
|
||||
}
|
20
vendor/src/github.com/Xe/middleware/xrequestid/xrequestid_middleware_test.go
vendored
Normal file
20
vendor/src/github.com/Xe/middleware/xrequestid/xrequestid_middleware_test.go
vendored
Normal file
|
@ -0,0 +1,20 @@
|
|||
package xrequestid
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestXRequestID(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
|
||||
middleware := New(16)
|
||||
middleware.Generate = func(n int) (string, error) { return "test-id", nil }
|
||||
middleware.ServeHTTP(recorder, req, func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
if id := req.Header.Get("X-Request-ID"); id != "test-id" {
|
||||
t.Fatalf("Expected X-Request-Id to be `test-id`, got `%v`", id)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014 Jeremy Saenz
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,181 @@
|
|||
# Negroni [](http://godoc.org/github.com/codegangsta/negroni) [](https://app.wercker.com/project/bykey/13688a4a94b82d84a0b8d038c4965b61)
|
||||
|
||||
Negroni is an idiomatic approach to web middleware in Go. It is tiny, non-intrusive, and encourages use of `net/http` Handlers.
|
||||
|
||||
If you like the idea of [Martini](http://github.com/go-martini/martini), but you think it contains too much magic, then Negroni is a great fit.
|
||||
|
||||
|
||||
Language Translations:
|
||||
* [Português Brasileiro (pt_BR)](translations/README_pt_br.md)
|
||||
|
||||
## Getting Started
|
||||
|
||||
After installing Go and setting up your [GOPATH](http://golang.org/doc/code.html#GOPATH), create your first `.go` file. We'll call it `server.go`.
|
||||
|
||||
~~~ go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/codegangsta/negroni"
|
||||
"net/http"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func main() {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
||||
fmt.Fprintf(w, "Welcome to the home page!")
|
||||
})
|
||||
|
||||
n := negroni.Classic()
|
||||
n.UseHandler(mux)
|
||||
n.Run(":3000")
|
||||
}
|
||||
~~~
|
||||
|
||||
Then install the Negroni package (**go 1.1** and greater is required):
|
||||
~~~
|
||||
go get github.com/codegangsta/negroni
|
||||
~~~
|
||||
|
||||
Then run your server:
|
||||
~~~
|
||||
go run server.go
|
||||
~~~
|
||||
|
||||
You will now have a Go net/http webserver running on `localhost:3000`.
|
||||
|
||||
## Need Help?
|
||||
If you have a question or feature request, [go ask the mailing list](https://groups.google.com/forum/#!forum/negroni-users). The GitHub issues for Negroni will be used exclusively for bug reports and pull requests.
|
||||
|
||||
## Is Negroni a Framework?
|
||||
Negroni is **not** a framework. It is a library that is designed to work directly with net/http.
|
||||
|
||||
## Routing?
|
||||
Negroni is BYOR (Bring your own Router). The Go community already has a number of great http routers available, Negroni tries to play well with all of them by fully supporting `net/http`. For instance, integrating with [Gorilla Mux](http://github.com/gorilla/mux) looks like so:
|
||||
|
||||
~~~ go
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/", HomeHandler)
|
||||
|
||||
n := negroni.New(Middleware1, Middleware2)
|
||||
// Or use a middleware with the Use() function
|
||||
n.Use(Middleware3)
|
||||
// router goes last
|
||||
n.UseHandler(router)
|
||||
|
||||
n.Run(":3000")
|
||||
~~~
|
||||
|
||||
## `negroni.Classic()`
|
||||
`negroni.Classic()` provides some default middleware that is useful for most applications:
|
||||
|
||||
* `negroni.Recovery` - Panic Recovery Middleware.
|
||||
* `negroni.Logging` - Request/Response Logging Middleware.
|
||||
* `negroni.Static` - Static File serving under the "public" directory.
|
||||
|
||||
This makes it really easy to get started with some useful features from Negroni.
|
||||
|
||||
## Handlers
|
||||
Negroni provides a bidirectional middleware flow. This is done through the `negroni.Handler` interface:
|
||||
|
||||
~~~ go
|
||||
type Handler interface {
|
||||
ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc)
|
||||
}
|
||||
~~~
|
||||
|
||||
If a middleware hasn't already written to the ResponseWriter, it should call the next `http.HandlerFunc` in the chain to yield to the next middleware handler. This can be used for great good:
|
||||
|
||||
~~~ go
|
||||
func MyMiddleware(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
// do some stuff before
|
||||
next(rw, r)
|
||||
// do some stuff after
|
||||
}
|
||||
~~~
|
||||
|
||||
And you can map it to the handler chain with the `Use` function:
|
||||
|
||||
~~~ go
|
||||
n := negroni.New()
|
||||
n.Use(negroni.HandlerFunc(MyMiddleware))
|
||||
~~~
|
||||
|
||||
You can also map plain old `http.Handler`s:
|
||||
|
||||
~~~ go
|
||||
n := negroni.New()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
// map your routes
|
||||
|
||||
n.UseHandler(mux)
|
||||
|
||||
n.Run(":3000")
|
||||
~~~
|
||||
|
||||
## `Run()`
|
||||
Negroni has a convenience function called `Run`. `Run` takes an addr string identical to [http.ListenAndServe](http://golang.org/pkg/net/http#ListenAndServe).
|
||||
|
||||
~~~ go
|
||||
n := negroni.Classic()
|
||||
// ...
|
||||
log.Fatal(http.ListenAndServe(":8080", n))
|
||||
~~~
|
||||
|
||||
## Route Specific Middleware
|
||||
If you have a route group of routes that need specific middleware to be executed, you can simply create a new Negroni instance and use it as your route handler.
|
||||
|
||||
~~~ go
|
||||
router := mux.NewRouter()
|
||||
adminRoutes := mux.NewRouter()
|
||||
// add admin routes here
|
||||
|
||||
// Create a new negroni for the admin middleware
|
||||
router.Handle("/admin", negroni.New(
|
||||
Middleware1,
|
||||
Middleware2,
|
||||
negroni.Wrap(adminRoutes),
|
||||
))
|
||||
~~~
|
||||
|
||||
## Third Party Middleware
|
||||
|
||||
Here is a current list of Negroni compatible middlware. Feel free to put up a PR linking your middleware if you have built one:
|
||||
|
||||
|
||||
| Middleware | Author | Description |
|
||||
| -----------|--------|-------------|
|
||||
| [RestGate](https://github.com/pjebs/restgate) | [Prasanga Siripala](https://github.com/pjebs) | Secure authentication for REST API endpoints |
|
||||
| [Graceful](https://github.com/stretchr/graceful) | [Tyler Bunnell](https://github.com/tylerb) | Graceful HTTP Shutdown |
|
||||
| [secure](https://github.com/unrolled/secure) | [Cory Jacobsen](https://github.com/unrolled) | Middleware that implements a few quick security wins |
|
||||
| [JWT Middleware](https://github.com/auth0/go-jwt-middleware) | [Auth0](https://github.com/auth0) | Middleware checks for a JWT on the `Authorization` header on incoming requests and decodes it|
|
||||
| [binding](https://github.com/mholt/binding) | [Matt Holt](https://github.com/mholt) | Data binding from HTTP requests into structs |
|
||||
| [logrus](https://github.com/meatballhat/negroni-logrus) | [Dan Buch](https://github.com/meatballhat) | Logrus-based logger |
|
||||
| [render](https://github.com/unrolled/render) | [Cory Jacobsen](https://github.com/unrolled) | Render JSON, XML and HTML templates |
|
||||
| [gorelic](https://github.com/jingweno/negroni-gorelic) | [Jingwen Owen Ou](https://github.com/jingweno) | New Relic agent for Go runtime |
|
||||
| [gzip](https://github.com/phyber/negroni-gzip) | [phyber](https://github.com/phyber) | GZIP response compression |
|
||||
| [oauth2](https://github.com/goincremental/negroni-oauth2) | [David Bochenski](https://github.com/bochenski) | oAuth2 middleware |
|
||||
| [sessions](https://github.com/goincremental/negroni-sessions) | [David Bochenski](https://github.com/bochenski) | Session Management |
|
||||
| [permissions2](https://github.com/xyproto/permissions2) | [Alexander Rødseth](https://github.com/xyproto) | Cookies, users and permissions |
|
||||
| [onthefly](https://github.com/xyproto/onthefly) | [Alexander Rødseth](https://github.com/xyproto) | Generate TinySVG, HTML and CSS on the fly |
|
||||
| [cors](https://github.com/rs/cors) | [Olivier Poitrey](https://github.com/rs) | [Cross Origin Resource Sharing](http://www.w3.org/TR/cors/) (CORS) support |
|
||||
| [xrequestid](https://github.com/pilu/xrequestid) | [Andrea Franz](https://github.com/pilu) | Middleware that assigns a random X-Request-Id header to each request |
|
||||
| [VanGoH](https://github.com/auroratechnologies/vangoh) | [Taylor Wrobel](https://github.com/twrobel3) | Configurable [AWS-Style](http://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAuthentication.html) HMAC authentication middleware |
|
||||
| [stats](https://github.com/thoas/stats) | [Florent Messa](https://github.com/thoas) | Store information about your web application (response time, etc.) |
|
||||
|
||||
## Examples
|
||||
[Alexander Rødseth](https://github.com/xyproto) created [mooseware](https://github.com/xyproto/mooseware), a skeleton for writing a Negroni middleware handler.
|
||||
|
||||
## Live code reload?
|
||||
[gin](https://github.com/codegangsta/gin) and [fresh](https://github.com/pilu/fresh) both live reload negroni apps.
|
||||
|
||||
## Essential Reading for Beginners of Go & Negroni
|
||||
|
||||
* [Using a Context to pass information from middleware to end handler](http://elithrar.github.io/article/map-string-interface/)
|
||||
* [Understanding middleware](http://mattstauffer.co/blog/laravel-5.0-middleware-replacing-filters)
|
||||
|
||||
## About
|
||||
|
||||
Negroni is obsessively designed by none other than the [Code Gangsta](http://codegangsta.io/)
|
|
@ -0,0 +1,25 @@
|
|||
// Package negroni is an idiomatic approach to web middleware in Go. It is tiny, non-intrusive, and encourages use of net/http Handlers.
|
||||
//
|
||||
// If you like the idea of Martini, but you think it contains too much magic, then Negroni is a great fit.
|
||||
//
|
||||
// For a full guide visit http://github.com/codegangsta/negroni
|
||||
//
|
||||
// package main
|
||||
//
|
||||
// import (
|
||||
// "github.com/codegangsta/negroni"
|
||||
// "net/http"
|
||||
// "fmt"
|
||||
// )
|
||||
//
|
||||
// func main() {
|
||||
// mux := http.NewServeMux()
|
||||
// mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
||||
// fmt.Fprintf(w, "Welcome to the home page!")
|
||||
// })
|
||||
//
|
||||
// n := negroni.Classic()
|
||||
// n.UseHandler(mux)
|
||||
// n.Run(":3000")
|
||||
// }
|
||||
package negroni
|
|
@ -0,0 +1,29 @@
|
|||
package negroni
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Logger is a middleware handler that logs the request as it goes in and the response as it goes out.
|
||||
type Logger struct {
|
||||
// Logger inherits from log.Logger used to log messages with the Logger middleware
|
||||
*log.Logger
|
||||
}
|
||||
|
||||
// NewLogger returns a new Logger instance
|
||||
func NewLogger() *Logger {
|
||||
return &Logger{log.New(os.Stdout, "[negroni] ", 0)}
|
||||
}
|
||||
|
||||
func (l *Logger) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
start := time.Now()
|
||||
l.Printf("Started %s %s", r.Method, r.URL.Path)
|
||||
|
||||
next(rw, r)
|
||||
|
||||
res := rw.(ResponseWriter)
|
||||
l.Printf("Completed %v %s in %v", res.Status(), http.StatusText(res.Status()), time.Since(start))
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
package negroni
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_Logger(t *testing.T) {
|
||||
buff := bytes.NewBufferString("")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
l := NewLogger()
|
||||
l.Logger = log.New(buff, "[negroni] ", 0)
|
||||
|
||||
n := New()
|
||||
// replace log for testing
|
||||
n.Use(l)
|
||||
n.UseHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
|
||||
req, err := http.NewRequest("GET", "http://localhost:3000/foobar", nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
n.ServeHTTP(recorder, req)
|
||||
expect(t, recorder.Code, http.StatusNotFound)
|
||||
refute(t, len(buff.String()), 0)
|
||||
}
|
|
@ -0,0 +1,129 @@
|
|||
package negroni
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Handler handler is an interface that objects can implement to be registered to serve as middleware
|
||||
// in the Negroni middleware stack.
|
||||
// ServeHTTP should yield to the next middleware in the chain by invoking the next http.HandlerFunc
|
||||
// passed in.
|
||||
//
|
||||
// If the Handler writes to the ResponseWriter, the next http.HandlerFunc should not be invoked.
|
||||
type Handler interface {
|
||||
ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc)
|
||||
}
|
||||
|
||||
// HandlerFunc is an adapter to allow the use of ordinary functions as Negroni handlers.
|
||||
// If f is a function with the appropriate signature, HandlerFunc(f) is a Handler object that calls f.
|
||||
type HandlerFunc func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc)
|
||||
|
||||
func (h HandlerFunc) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
h(rw, r, next)
|
||||
}
|
||||
|
||||
type middleware struct {
|
||||
handler Handler
|
||||
next *middleware
|
||||
}
|
||||
|
||||
func (m middleware) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
m.handler.ServeHTTP(rw, r, m.next.ServeHTTP)
|
||||
}
|
||||
|
||||
// Wrap converts a http.Handler into a negroni.Handler so it can be used as a Negroni
|
||||
// middleware. The next http.HandlerFunc is automatically called after the Handler
|
||||
// is executed.
|
||||
func Wrap(handler http.Handler) Handler {
|
||||
return HandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
handler.ServeHTTP(rw, r)
|
||||
next(rw, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Negroni is a stack of Middleware Handlers that can be invoked as an http.Handler.
|
||||
// Negroni middleware is evaluated in the order that they are added to the stack using
|
||||
// the Use and UseHandler methods.
|
||||
type Negroni struct {
|
||||
middleware middleware
|
||||
handlers []Handler
|
||||
}
|
||||
|
||||
// New returns a new Negroni instance with no middleware preconfigured.
|
||||
func New(handlers ...Handler) *Negroni {
|
||||
return &Negroni{
|
||||
handlers: handlers,
|
||||
middleware: build(handlers),
|
||||
}
|
||||
}
|
||||
|
||||
// Classic returns a new Negroni instance with the default middleware already
|
||||
// in the stack.
|
||||
//
|
||||
// Recovery - Panic Recovery Middleware
|
||||
// Logger - Request/Response Logging
|
||||
// Static - Static File Serving
|
||||
func Classic() *Negroni {
|
||||
return New(NewRecovery(), NewLogger(), NewStatic(http.Dir("public")))
|
||||
}
|
||||
|
||||
func (n *Negroni) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
n.middleware.ServeHTTP(NewResponseWriter(rw), r)
|
||||
}
|
||||
|
||||
// Use adds a Handler onto the middleware stack. Handlers are invoked in the order they are added to a Negroni.
|
||||
func (n *Negroni) Use(handler Handler) {
|
||||
n.handlers = append(n.handlers, handler)
|
||||
n.middleware = build(n.handlers)
|
||||
}
|
||||
|
||||
// UseFunc adds a Negroni-style handler function onto the middleware stack.
|
||||
func (n *Negroni) UseFunc(handlerFunc func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc)) {
|
||||
n.Use(HandlerFunc(handlerFunc))
|
||||
}
|
||||
|
||||
// UseHandler adds a http.Handler onto the middleware stack. Handlers are invoked in the order they are added to a Negroni.
|
||||
func (n *Negroni) UseHandler(handler http.Handler) {
|
||||
n.Use(Wrap(handler))
|
||||
}
|
||||
|
||||
// UseHandler adds a http.HandlerFunc-style handler function onto the middleware stack.
|
||||
func (n *Negroni) UseHandlerFunc(handlerFunc func(rw http.ResponseWriter, r *http.Request)) {
|
||||
n.UseHandler(http.HandlerFunc(handlerFunc))
|
||||
}
|
||||
|
||||
// Run is a convenience function that runs the negroni stack as an HTTP
|
||||
// server. The addr string takes the same format as http.ListenAndServe.
|
||||
func (n *Negroni) Run(addr string) {
|
||||
l := log.New(os.Stdout, "[negroni] ", 0)
|
||||
l.Printf("listening on %s", addr)
|
||||
l.Fatal(http.ListenAndServe(addr, n))
|
||||
}
|
||||
|
||||
// Returns a list of all the handlers in the current Negroni middleware chain.
|
||||
func (n *Negroni) Handlers() []Handler {
|
||||
return n.handlers
|
||||
}
|
||||
|
||||
func build(handlers []Handler) middleware {
|
||||
var next middleware
|
||||
|
||||
if len(handlers) == 0 {
|
||||
return voidMiddleware()
|
||||
} else if len(handlers) > 1 {
|
||||
next = build(handlers[1:])
|
||||
} else {
|
||||
next = voidMiddleware()
|
||||
}
|
||||
|
||||
return middleware{handlers[0], &next}
|
||||
}
|
||||
|
||||
func voidMiddleware() middleware {
|
||||
return middleware{
|
||||
HandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {}),
|
||||
&middleware{},
|
||||
}
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
package negroni
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
/* Test Helpers */
|
||||
func expect(t *testing.T, a interface{}, b interface{}) {
|
||||
if a != b {
|
||||
t.Errorf("Expected %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a))
|
||||
}
|
||||
}
|
||||
|
||||
func refute(t *testing.T, a interface{}, b interface{}) {
|
||||
if a == b {
|
||||
t.Errorf("Did not expect %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNegroniRun(t *testing.T) {
|
||||
// just test that Run doesn't bomb
|
||||
go New().Run(":3000")
|
||||
}
|
||||
|
||||
func TestNegroniServeHTTP(t *testing.T) {
|
||||
result := ""
|
||||
response := httptest.NewRecorder()
|
||||
|
||||
n := New()
|
||||
n.Use(HandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
result += "foo"
|
||||
next(rw, r)
|
||||
result += "ban"
|
||||
}))
|
||||
n.Use(HandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
result += "bar"
|
||||
next(rw, r)
|
||||
result += "baz"
|
||||
}))
|
||||
n.Use(HandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
result += "bat"
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
|
||||
n.ServeHTTP(response, (*http.Request)(nil))
|
||||
|
||||
expect(t, result, "foobarbatbazban")
|
||||
expect(t, response.Code, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// Ensures that a Negroni middleware chain
|
||||
// can correctly return all of its handlers.
|
||||
func TestHandlers(t *testing.T) {
|
||||
response := httptest.NewRecorder()
|
||||
n := New()
|
||||
handlers := n.Handlers()
|
||||
expect(t, 0, len(handlers))
|
||||
|
||||
n.Use(HandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Expects the length of handlers to be exactly 1
|
||||
// after adding exactly one handler to the middleware chain
|
||||
handlers = n.Handlers()
|
||||
expect(t, 1, len(handlers))
|
||||
|
||||
// Ensures that the first handler that is in sequence behaves
|
||||
// exactly the same as the one that was registered earlier
|
||||
handlers[0].ServeHTTP(response, (*http.Request)(nil), nil)
|
||||
expect(t, response.Code, http.StatusOK)
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
package negroni
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// Recovery is a Negroni middleware that recovers from any panics and writes a 500 if there was one.
|
||||
type Recovery struct {
|
||||
Logger *log.Logger
|
||||
PrintStack bool
|
||||
StackAll bool
|
||||
StackSize int
|
||||
}
|
||||
|
||||
// NewRecovery returns a new instance of Recovery
|
||||
func NewRecovery() *Recovery {
|
||||
return &Recovery{
|
||||
Logger: log.New(os.Stdout, "[negroni] ", 0),
|
||||
PrintStack: true,
|
||||
StackAll: false,
|
||||
StackSize: 1024 * 8,
|
||||
}
|
||||
}
|
||||
|
||||
func (rec *Recovery) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
stack := make([]byte, rec.StackSize)
|
||||
stack = stack[:runtime.Stack(stack, rec.StackAll)]
|
||||
|
||||
f := "PANIC: %s\n%s"
|
||||
rec.Logger.Printf(f, err, stack)
|
||||
|
||||
if rec.PrintStack {
|
||||
fmt.Fprintf(rw, f, err, stack)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
next(rw, r)
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
package negroni
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRecovery(t *testing.T) {
|
||||
buff := bytes.NewBufferString("")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
rec := NewRecovery()
|
||||
rec.Logger = log.New(buff, "[negroni] ", 0)
|
||||
|
||||
n := New()
|
||||
// replace log for testing
|
||||
n.Use(rec)
|
||||
n.UseHandler(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
|
||||
panic("here is a panic!")
|
||||
}))
|
||||
n.ServeHTTP(recorder, (*http.Request)(nil))
|
||||
expect(t, recorder.Code, http.StatusInternalServerError)
|
||||
refute(t, recorder.Body.Len(), 0)
|
||||
refute(t, len(buff.String()), 0)
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
package negroni
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// ResponseWriter is a wrapper around http.ResponseWriter that provides extra information about
|
||||
// the response. It is recommended that middleware handlers use this construct to wrap a responsewriter
|
||||
// if the functionality calls for it.
|
||||
type ResponseWriter interface {
|
||||
http.ResponseWriter
|
||||
http.Flusher
|
||||
// Status returns the status code of the response or 0 if the response has not been written.
|
||||
Status() int
|
||||
// Written returns whether or not the ResponseWriter has been written.
|
||||
Written() bool
|
||||
// Size returns the size of the response body.
|
||||
Size() int
|
||||
// Before allows for a function to be called before the ResponseWriter has been written to. This is
|
||||
// useful for setting headers or any other operations that must happen before a response has been written.
|
||||
Before(func(ResponseWriter))
|
||||
}
|
||||
|
||||
type beforeFunc func(ResponseWriter)
|
||||
|
||||
// NewResponseWriter creates a ResponseWriter that wraps an http.ResponseWriter
|
||||
func NewResponseWriter(rw http.ResponseWriter) ResponseWriter {
|
||||
return &responseWriter{rw, 0, 0, nil}
|
||||
}
|
||||
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
size int
|
||||
beforeFuncs []beforeFunc
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(s int) {
|
||||
rw.status = s
|
||||
rw.callBefore()
|
||||
rw.ResponseWriter.WriteHeader(s)
|
||||
}
|
||||
|
||||
func (rw *responseWriter) Write(b []byte) (int, error) {
|
||||
if !rw.Written() {
|
||||
// The status will be StatusOK if WriteHeader has not been called yet
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}
|
||||
size, err := rw.ResponseWriter.Write(b)
|
||||
rw.size += size
|
||||
return size, err
|
||||
}
|
||||
|
||||
func (rw *responseWriter) Status() int {
|
||||
return rw.status
|
||||
}
|
||||
|
||||
func (rw *responseWriter) Size() int {
|
||||
return rw.size
|
||||
}
|
||||
|
||||
func (rw *responseWriter) Written() bool {
|
||||
return rw.status != 0
|
||||
}
|
||||
|
||||
func (rw *responseWriter) Before(before func(ResponseWriter)) {
|
||||
rw.beforeFuncs = append(rw.beforeFuncs, before)
|
||||
}
|
||||
|
||||
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hijacker, ok := rw.ResponseWriter.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("the ResponseWriter doesn't support the Hijacker interface")
|
||||
}
|
||||
return hijacker.Hijack()
|
||||
}
|
||||
|
||||
func (rw *responseWriter) CloseNotify() <-chan bool {
|
||||
return rw.ResponseWriter.(http.CloseNotifier).CloseNotify()
|
||||
}
|
||||
|
||||
func (rw *responseWriter) callBefore() {
|
||||
for i := len(rw.beforeFuncs) - 1; i >= 0; i-- {
|
||||
rw.beforeFuncs[i](rw)
|
||||
}
|
||||
}
|
||||
|
||||
func (rw *responseWriter) Flush() {
|
||||
flusher, ok := rw.ResponseWriter.(http.Flusher)
|
||||
if ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
package negroni
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type closeNotifyingRecorder struct {
|
||||
*httptest.ResponseRecorder
|
||||
closed chan bool
|
||||
}
|
||||
|
||||
func newCloseNotifyingRecorder() *closeNotifyingRecorder {
|
||||
return &closeNotifyingRecorder{
|
||||
httptest.NewRecorder(),
|
||||
make(chan bool, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *closeNotifyingRecorder) close() {
|
||||
c.closed <- true
|
||||
}
|
||||
|
||||
func (c *closeNotifyingRecorder) CloseNotify() <-chan bool {
|
||||
return c.closed
|
||||
}
|
||||
|
||||
type hijackableResponse struct {
|
||||
Hijacked bool
|
||||
}
|
||||
|
||||
func newHijackableResponse() *hijackableResponse {
|
||||
return &hijackableResponse{}
|
||||
}
|
||||
|
||||
func (h *hijackableResponse) Header() http.Header { return nil }
|
||||
func (h *hijackableResponse) Write(buf []byte) (int, error) { return 0, nil }
|
||||
func (h *hijackableResponse) WriteHeader(code int) {}
|
||||
func (h *hijackableResponse) Flush() {}
|
||||
func (h *hijackableResponse) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
h.Hijacked = true
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func TestResponseWriterWritingString(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
rw := NewResponseWriter(rec)
|
||||
|
||||
rw.Write([]byte("Hello world"))
|
||||
|
||||
expect(t, rec.Code, rw.Status())
|
||||
expect(t, rec.Body.String(), "Hello world")
|
||||
expect(t, rw.Status(), http.StatusOK)
|
||||
expect(t, rw.Size(), 11)
|
||||
expect(t, rw.Written(), true)
|
||||
}
|
||||
|
||||
func TestResponseWriterWritingStrings(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
rw := NewResponseWriter(rec)
|
||||
|
||||
rw.Write([]byte("Hello world"))
|
||||
rw.Write([]byte("foo bar bat baz"))
|
||||
|
||||
expect(t, rec.Code, rw.Status())
|
||||
expect(t, rec.Body.String(), "Hello worldfoo bar bat baz")
|
||||
expect(t, rw.Status(), http.StatusOK)
|
||||
expect(t, rw.Size(), 26)
|
||||
}
|
||||
|
||||
func TestResponseWriterWritingHeader(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
rw := NewResponseWriter(rec)
|
||||
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
|
||||
expect(t, rec.Code, rw.Status())
|
||||
expect(t, rec.Body.String(), "")
|
||||
expect(t, rw.Status(), http.StatusNotFound)
|
||||
expect(t, rw.Size(), 0)
|
||||
}
|
||||
|
||||
func TestResponseWriterBefore(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
rw := NewResponseWriter(rec)
|
||||
result := ""
|
||||
|
||||
rw.Before(func(ResponseWriter) {
|
||||
result += "foo"
|
||||
})
|
||||
rw.Before(func(ResponseWriter) {
|
||||
result += "bar"
|
||||
})
|
||||
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
|
||||
expect(t, rec.Code, rw.Status())
|
||||
expect(t, rec.Body.String(), "")
|
||||
expect(t, rw.Status(), http.StatusNotFound)
|
||||
expect(t, rw.Size(), 0)
|
||||
expect(t, result, "barfoo")
|
||||
}
|
||||
|
||||
func TestResponseWriterHijack(t *testing.T) {
|
||||
hijackable := newHijackableResponse()
|
||||
rw := NewResponseWriter(hijackable)
|
||||
hijacker, ok := rw.(http.Hijacker)
|
||||
expect(t, ok, true)
|
||||
_, _, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
expect(t, hijackable.Hijacked, true)
|
||||
}
|
||||
|
||||
func TestResponseWriteHijackNotOK(t *testing.T) {
|
||||
hijackable := new(http.ResponseWriter)
|
||||
rw := NewResponseWriter(*hijackable)
|
||||
hijacker, ok := rw.(http.Hijacker)
|
||||
expect(t, ok, true)
|
||||
_, _, err := hijacker.Hijack()
|
||||
|
||||
refute(t, err, nil)
|
||||
}
|
||||
|
||||
func TestResponseWriterCloseNotify(t *testing.T) {
|
||||
rec := newCloseNotifyingRecorder()
|
||||
rw := NewResponseWriter(rec)
|
||||
closed := false
|
||||
notifier := rw.(http.CloseNotifier).CloseNotify()
|
||||
rec.close()
|
||||
select {
|
||||
case <-notifier:
|
||||
closed = true
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
expect(t, closed, true)
|
||||
}
|
||||
|
||||
func TestResponseWriterFlusher(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
rw := NewResponseWriter(rec)
|
||||
|
||||
_, ok := rw.(http.Flusher)
|
||||
expect(t, ok, true)
|
||||
}
|
|
@ -0,0 +1,84 @@
|
|||
package negroni
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Static is a middleware handler that serves static files in the given directory/filesystem.
|
||||
type Static struct {
|
||||
// Dir is the directory to serve static files from
|
||||
Dir http.FileSystem
|
||||
// Prefix is the optional prefix used to serve the static directory content
|
||||
Prefix string
|
||||
// IndexFile defines which file to serve as index if it exists.
|
||||
IndexFile string
|
||||
}
|
||||
|
||||
// NewStatic returns a new instance of Static
|
||||
func NewStatic(directory http.FileSystem) *Static {
|
||||
return &Static{
|
||||
Dir: directory,
|
||||
Prefix: "",
|
||||
IndexFile: "index.html",
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Static) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
if r.Method != "GET" && r.Method != "HEAD" {
|
||||
next(rw, r)
|
||||
return
|
||||
}
|
||||
file := r.URL.Path
|
||||
// if we have a prefix, filter requests by stripping the prefix
|
||||
if s.Prefix != "" {
|
||||
if !strings.HasPrefix(file, s.Prefix) {
|
||||
next(rw, r)
|
||||
return
|
||||
}
|
||||
file = file[len(s.Prefix):]
|
||||
if file != "" && file[0] != '/' {
|
||||
next(rw, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
f, err := s.Dir.Open(file)
|
||||
if err != nil {
|
||||
// discard the error?
|
||||
next(rw, r)
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
next(rw, r)
|
||||
return
|
||||
}
|
||||
|
||||
// try to serve index file
|
||||
if fi.IsDir() {
|
||||
// redirect if missing trailing slash
|
||||
if !strings.HasSuffix(r.URL.Path, "/") {
|
||||
http.Redirect(rw, r, r.URL.Path+"/", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
file = path.Join(file, s.IndexFile)
|
||||
f, err = s.Dir.Open(file)
|
||||
if err != nil {
|
||||
next(rw, r)
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
fi, err = f.Stat()
|
||||
if err != nil || fi.IsDir() {
|
||||
next(rw, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
http.ServeContent(rw, r, file, fi.ModTime(), f)
|
||||
}
|
|
@ -0,0 +1,113 @@
|
|||
package negroni
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStatic(t *testing.T) {
|
||||
response := httptest.NewRecorder()
|
||||
response.Body = new(bytes.Buffer)
|
||||
|
||||
n := New()
|
||||
n.Use(NewStatic(http.Dir(".")))
|
||||
|
||||
req, err := http.NewRequest("GET", "http://localhost:3000/negroni.go", nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
n.ServeHTTP(response, req)
|
||||
expect(t, response.Code, http.StatusOK)
|
||||
expect(t, response.Header().Get("Expires"), "")
|
||||
if response.Body.Len() == 0 {
|
||||
t.Errorf("Got empty body for GET request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticHead(t *testing.T) {
|
||||
response := httptest.NewRecorder()
|
||||
response.Body = new(bytes.Buffer)
|
||||
|
||||
n := New()
|
||||
n.Use(NewStatic(http.Dir(".")))
|
||||
n.UseHandler(http.NotFoundHandler())
|
||||
|
||||
req, err := http.NewRequest("HEAD", "http://localhost:3000/negroni.go", nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
n.ServeHTTP(response, req)
|
||||
expect(t, response.Code, http.StatusOK)
|
||||
if response.Body.Len() != 0 {
|
||||
t.Errorf("Got non-empty body for HEAD request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticAsPost(t *testing.T) {
|
||||
response := httptest.NewRecorder()
|
||||
|
||||
n := New()
|
||||
n.Use(NewStatic(http.Dir(".")))
|
||||
n.UseHandler(http.NotFoundHandler())
|
||||
|
||||
req, err := http.NewRequest("POST", "http://localhost:3000/negroni.go", nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
n.ServeHTTP(response, req)
|
||||
expect(t, response.Code, http.StatusNotFound)
|
||||
}
|
||||
|
||||
func TestStaticBadDir(t *testing.T) {
|
||||
response := httptest.NewRecorder()
|
||||
|
||||
n := Classic()
|
||||
n.UseHandler(http.NotFoundHandler())
|
||||
|
||||
req, err := http.NewRequest("GET", "http://localhost:3000/negroni.go", nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
n.ServeHTTP(response, req)
|
||||
refute(t, response.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
func TestStaticOptionsServeIndex(t *testing.T) {
|
||||
response := httptest.NewRecorder()
|
||||
|
||||
n := New()
|
||||
s := NewStatic(http.Dir("."))
|
||||
s.IndexFile = "negroni.go"
|
||||
n.Use(s)
|
||||
|
||||
req, err := http.NewRequest("GET", "http://localhost:3000/", nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
n.ServeHTTP(response, req)
|
||||
expect(t, response.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
func TestStaticOptionsPrefix(t *testing.T) {
|
||||
response := httptest.NewRecorder()
|
||||
|
||||
n := New()
|
||||
s := NewStatic(http.Dir("."))
|
||||
s.Prefix = "/public"
|
||||
n.Use(s)
|
||||
|
||||
// Check file content behaviour
|
||||
req, err := http.NewRequest("GET", "http://localhost:3000/public/negroni.go", nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
n.ServeHTTP(response, req)
|
||||
expect(t, response.Code, http.StatusOK)
|
||||
}
|
|
@ -0,0 +1,170 @@
|
|||
# Negroni [](http://godoc.org/github.com/codegangsta/negroni) [](https://app.wercker.com/project/bykey/13688a4a94b82d84a0b8d038c4965b61)
|
||||
|
||||
Negroni é uma abordagem idiomática para middleware web em Go. É pequeno, não intrusivo, e incentiva uso da biblioteca `net/http`.
|
||||
|
||||
Se gosta da idéia do [Martini](http://github.com/go-martini/martini), mas acha que contém muita mágica, então Negroni é ideal.
|
||||
|
||||
## Começando
|
||||
|
||||
Depois de instalar Go e definir seu [GOPATH](http://golang.org/doc/code.html#GOPATH), criar seu primeirto arquivo `.go`. Iremos chamá-lo `server.go`.
|
||||
|
||||
~~~ go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/codegangsta/negroni"
|
||||
"net/http"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func main() {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
||||
fmt.Fprintf(w, "Welcome to the home page!")
|
||||
})
|
||||
|
||||
n := negroni.Classic()
|
||||
n.UseHandler(mux)
|
||||
n.Run(":3000")
|
||||
}
|
||||
~~~
|
||||
|
||||
Depois instale o pacote Negroni (**go 1.1** ou superior)
|
||||
~~~
|
||||
go get github.com/codegangsta/negroni
|
||||
~~~
|
||||
|
||||
Depois execute seu servidor:
|
||||
~~~
|
||||
go run server.go
|
||||
~~~
|
||||
|
||||
Agora terá um servidor web Go net/http rodando em `localhost:3000`.
|
||||
|
||||
## Precisa de Ajuda?
|
||||
Se você tem uma pergunta ou pedido de recurso,[go ask the mailing list](https://groups.google.com/forum/#!forum/negroni-users). O Github issues para o Negroni será usado exclusivamente para Reportar bugs e pull requests.
|
||||
|
||||
## Negroni é um Framework?
|
||||
Negroni **não** é a framework. É uma biblioteca que é desenhada para trabalhar diretamente com net/http.
|
||||
|
||||
## Roteamento?
|
||||
Negroni é TSPR(Traga seu próprio Roteamento). A comunidade Go já tem um grande número de roteadores http disponíveis, Negroni tenta rodar bem com todos eles pelo suporte total `net/http`/ Por exemplo, a integração com [Gorilla Mux](http://github.com/gorilla/mux) se parece com isso:
|
||||
|
||||
~~~ go
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/", HomeHandler)
|
||||
|
||||
n := negroni.New(Middleware1, Middleware2)
|
||||
// Or use a middleware with the Use() function
|
||||
n.Use(Middleware3)
|
||||
// router goes last
|
||||
n.UseHandler(router)
|
||||
|
||||
n.Run(":3000")
|
||||
~~~
|
||||
|
||||
## `negroni.Classic()`
|
||||
`negroni.Classic()` fornece alguns middlewares padrão que são úteis para maioria das aplicações:
|
||||
|
||||
* `negroni.Recovery` - Panic Recovery Middleware.
|
||||
* `negroni.Logging` - Request/Response Logging Middleware.
|
||||
* `negroni.Static` - Static File serving under the "public" directory.
|
||||
|
||||
Isso torna muito fácil começar com alguns recursos úteis do Negroni.
|
||||
|
||||
## Handlers
|
||||
Negroni fornece um middleware de fluxo bidirecional. Isso é feito através da interface `negroni.Handler`:
|
||||
|
||||
~~~ go
|
||||
type Handler interface {
|
||||
ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc)
|
||||
}
|
||||
~~~
|
||||
|
||||
Se um middleware não tenha escrito o ResponseWriter, ele deve chamar a próxima `http.HandlerFunc` na cadeia para produzir o próximo handler middleware. Isso pode ser usado muito bem:
|
||||
|
||||
~~~ go
|
||||
func MyMiddleware(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
// do some stuff before
|
||||
next(rw, r)
|
||||
// do some stuff after
|
||||
}
|
||||
~~~
|
||||
|
||||
E pode mapear isso para a cadeia de handler com a função `Use`:
|
||||
|
||||
~~~ go
|
||||
n := negroni.New()
|
||||
n.Use(negroni.HandlerFunc(MyMiddleware))
|
||||
~~~
|
||||
|
||||
Você também pode mapear `http.Handler` antigos:
|
||||
|
||||
~~~ go
|
||||
n := negroni.New()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
// map your routes
|
||||
|
||||
n.UseHandler(mux)
|
||||
|
||||
n.Run(":3000")
|
||||
~~~
|
||||
|
||||
## `Run()`
|
||||
Negroni tem uma função de conveniência chamada `Run`. `Run` pega um endereço de string idêntico para [http.ListenAndServe](http://golang.org/pkg/net/http#ListenAndServe).
|
||||
|
||||
~~~ go
|
||||
n := negroni.Classic()
|
||||
// ...
|
||||
log.Fatal(http.ListenAndServe(":8080", n))
|
||||
~~~
|
||||
|
||||
## Middleware para Rotas Específicas
|
||||
Se você tem um grupo de rota com rotas que precisam ser executadas por um middleware específico, pode simplesmente criar uma nova instância de Negroni e usar no seu Manipulador de rota.
|
||||
|
||||
~~~ go
|
||||
router := mux.NewRouter()
|
||||
adminRoutes := mux.NewRouter()
|
||||
// add admin routes here
|
||||
|
||||
// Criar um middleware negroni para admin
|
||||
router.Handle("/admin", negroni.New(
|
||||
Middleware1,
|
||||
Middleware2,
|
||||
negroni.Wrap(adminRoutes),
|
||||
))
|
||||
~~~
|
||||
|
||||
## Middleware de Terceiros
|
||||
|
||||
Aqui está uma lista atual de Middleware Compatíveis com Negroni. Sinta se livre para mandar um PR vinculando seu middleware se construiu um:
|
||||
|
||||
|
||||
| Middleware | Autor | Descrição |
|
||||
| -----------|--------|-------------|
|
||||
| [Graceful](https://github.com/stretchr/graceful) | [Tyler Bunnell](https://github.com/tylerb) | Graceful HTTP Shutdown |
|
||||
| [secure](https://github.com/unrolled/secure) | [Cory Jacobsen](https://github.com/unrolled) | Implementa rapidamente itens de segurança.|
|
||||
| [binding](https://github.com/mholt/binding) | [Matt Holt](https://github.com/mholt) | Handler para mapeamento/validação de um request a estrutura. |
|
||||
| [logrus](https://github.com/meatballhat/negroni-logrus) | [Dan Buch](https://github.com/meatballhat) | Logrus-based logger |
|
||||
| [render](https://github.com/unrolled/render) | [Cory Jacobsen](https://github.com/unrolled) | Pacote para renderizar JSON, XML, e templates HTML. |
|
||||
| [gorelic](https://github.com/jingweno/negroni-gorelic) | [Jingwen Owen Ou](https://github.com/jingweno) | New Relic agent for Go runtime |
|
||||
| [gzip](https://github.com/phyber/negroni-gzip) | [phyber](https://github.com/phyber) | Handler para adicionar compreção gzip para as requisições |
|
||||
| [oauth2](https://github.com/goincremental/negroni-oauth2) | [David Bochenski](https://github.com/bochenski) | Handler que prove sistema de login OAuth 2.0 para aplicações Martini. Google Sign-in, Facebook Connect e Github login são suportados. |
|
||||
| [sessions](https://github.com/goincremental/negroni-sessions) | [David Bochenski](https://github.com/bochenski) | Handler que provê o serviço de sessão. |
|
||||
| [permissions](https://github.com/xyproto/permissions) | [Alexander Rødseth](https://github.com/xyproto) | Cookies, usuários e permissões. |
|
||||
| [onthefly](https://github.com/xyproto/onthefly) | [Alexander Rødseth](https://github.com/xyproto) | Pacote para gerar TinySVG, HTML e CSS em tempo real. |
|
||||
|
||||
## Exemplos
|
||||
[Alexander Rødseth](https://github.com/xyproto) criou [mooseware](https://github.com/xyproto/mooseware), uma estrutura para escrever um handler middleware Negroni.
|
||||
|
||||
## Servidor com autoreload?
|
||||
[gin](https://github.com/codegangsta/gin) e [fresh](https://github.com/pilu/fresh) são aplicativos para autoreload do Negroni.
|
||||
|
||||
## Leitura Essencial para Iniciantes em Go & Negroni
|
||||
* [Usando um contexto para passar informação de um middleware para o manipulador final](http://elithrar.github.io/article/map-string-interface/)
|
||||
* [Entendendo middleware](http://mattstauffer.co/blog/laravel-5.0-middleware-replacing-filters)
|
||||
|
||||
|
||||
## Sobre
|
||||
Negroni é obsessivamente desenhado por ninguém menos que [Code Gangsta](http://codegangsta.io/)
|
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2012-2014 Grigory Dryapak
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,163 @@
|
|||
# Imaging
|
||||
|
||||
Package imaging provides basic image manipulation functions (resize, rotate, flip, crop, etc.).
|
||||
This package is based on the standard Go image package and works best along with it.
|
||||
|
||||
Image manipulation functions provided by the package take any image type
|
||||
that implements `image.Image` interface as an input, and return a new image of
|
||||
`*image.NRGBA` type (32bit RGBA colors, not premultiplied by alpha).
|
||||
|
||||
## Installation
|
||||
|
||||
Imaging requires Go version 1.2 or greater.
|
||||
|
||||
go get -u github.com/disintegration/imaging
|
||||
|
||||
## Documentation
|
||||
|
||||
http://godoc.org/github.com/disintegration/imaging
|
||||
|
||||
## Usage examples
|
||||
|
||||
A few usage examples can be found below. See the documentation for the full list of supported functions.
|
||||
|
||||
### Image resizing
|
||||
```go
|
||||
// resize srcImage to size = 128x128px using the Lanczos filter
|
||||
dstImage128 := imaging.Resize(srcImage, 128, 128, imaging.Lanczos)
|
||||
|
||||
// resize srcImage to width = 800px preserving the aspect ratio
|
||||
dstImage800 := imaging.Resize(srcImage, 800, 0, imaging.Lanczos)
|
||||
|
||||
// scale down srcImage to fit the 800x600px bounding box
|
||||
dstImageFit := imaging.Fit(srcImage, 800, 600, imaging.Lanczos)
|
||||
|
||||
// resize and crop the srcImage to make a 100x100px thumbnail
|
||||
dstImageThumb := imaging.Thumbnail(srcImage, 100, 100, imaging.Lanczos)
|
||||
```
|
||||
|
||||
Imaging supports image resizing using various resampling filters. The most notable ones:
|
||||
- `NearestNeighbor` - Fastest resampling filter, no antialiasing.
|
||||
- `Box` - Simple and fast averaging filter appropriate for downscaling. When upscaling it's similar to NearestNeighbor.
|
||||
- `Linear` - Bilinear filter, smooth and reasonably fast.
|
||||
- `MitchellNetravali` - А smooth bicubic filter.
|
||||
- `CatmullRom` - A sharp bicubic filter.
|
||||
- `Gaussian` - Blurring filter that uses gaussian function, useful for noise removal.
|
||||
- `Lanczos` - High-quality resampling filter for photographic images yielding sharp results, but it's slower than cubic filters.
|
||||
|
||||
The full list of supported filters: NearestNeighbor, Box, Linear, Hermite, MitchellNetravali, CatmullRom, BSpline, Gaussian, Lanczos, Hann, Hamming, Blackman, Bartlett, Welch, Cosine. Custom filters can be created using ResampleFilter struct.
|
||||
|
||||
**Resampling filters comparison**
|
||||
|
||||
Original image. Will be resized from 512x512px to 128x128px.
|
||||
|
||||

|
||||
|
||||
Filter | Resize result
|
||||
---|---
|
||||
`imaging.NearestNeighbor` | 
|
||||
`imaging.Box` | 
|
||||
`imaging.Linear` | 
|
||||
`imaging.MitchellNetravali` | 
|
||||
`imaging.CatmullRom` | 
|
||||
`imaging.Gaussian` | 
|
||||
`imaging.Lanczos` | 
|
||||
|
||||
### Gaussian Blur
|
||||
```go
|
||||
dstImage := imaging.Blur(srcImage, 0.5)
|
||||
```
|
||||
|
||||
Sigma parameter allows to control the strength of the blurring effect.
|
||||
|
||||
Original image | Sigma = 0.5 | Sigma = 1.5
|
||||
---|---|---
|
||||
 |  | 
|
||||
|
||||
### Sharpening
|
||||
```go
|
||||
dstImage := imaging.Sharpen(srcImage, 0.5)
|
||||
```
|
||||
|
||||
Uses gaussian function internally. Sigma parameter allows to control the strength of the sharpening effect.
|
||||
|
||||
Original image | Sigma = 0.5 | Sigma = 1.5
|
||||
---|---|---
|
||||
 |  | 
|
||||
|
||||
### Gamma correction
|
||||
```go
|
||||
dstImage := imaging.AdjustGamma(srcImage, 0.75)
|
||||
```
|
||||
|
||||
Original image | Gamma = 0.75 | Gamma = 1.25
|
||||
---|---|---
|
||||
 |  | 
|
||||
|
||||
### Contrast adjustment
|
||||
```go
|
||||
dstImage := imaging.AdjustContrast(srcImage, 20)
|
||||
```
|
||||
|
||||
Original image | Contrast = 20 | Contrast = -20
|
||||
---|---|---
|
||||
 |  | 
|
||||
|
||||
### Brightness adjustment
|
||||
```go
|
||||
dstImage := imaging.AdjustBrightness(srcImage, 20)
|
||||
```
|
||||
|
||||
Original image | Brightness = 20 | Brightness = -20
|
||||
---|---|---
|
||||
 |  | 
|
||||
|
||||
|
||||
### Complete code example
|
||||
Here is the code example that loads several images, makes thumbnails of them
|
||||
and combines them together side-by-side.
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"image"
|
||||
"image/color"
|
||||
"runtime"
|
||||
|
||||
"github.com/disintegration/imaging"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// use all CPU cores for maximum performance
|
||||
runtime.GOMAXPROCS(runtime.NumCPU())
|
||||
|
||||
// input files
|
||||
files := []string{"01.jpg", "02.jpg", "03.jpg"}
|
||||
|
||||
// load images and make 100x100 thumbnails of them
|
||||
var thumbnails []image.Image
|
||||
for _, file := range files {
|
||||
img, err := imaging.Open(file)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
thumb := imaging.Thumbnail(img, 100, 100, imaging.CatmullRom)
|
||||
thumbnails = append(thumbnails, thumb)
|
||||
}
|
||||
|
||||
// create a new blank image
|
||||
dst := imaging.New(100*len(thumbnails), 100, color.NRGBA{0, 0, 0, 0})
|
||||
|
||||
// paste thumbnails into the new image side by side
|
||||
for i, thumb := range thumbnails {
|
||||
dst = imaging.Paste(dst, thumb, image.Pt(i*100, 0))
|
||||
}
|
||||
|
||||
// save the combined image to file
|
||||
err := imaging.Save(dst, "dst.jpg")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,200 @@
|
|||
package imaging
|
||||
|
||||
import (
|
||||
"image"
|
||||
"image/color"
|
||||
"math"
|
||||
)
|
||||
|
||||
// AdjustFunc applies the fn function to each pixel of the img image and returns the adjusted image.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// dstImage = imaging.AdjustFunc(
|
||||
// srcImage,
|
||||
// func(c color.NRGBA) color.NRGBA {
|
||||
// // shift the red channel by 16
|
||||
// r := int(c.R) + 16
|
||||
// if r > 255 {
|
||||
// r = 255
|
||||
// }
|
||||
// return color.NRGBA{uint8(r), c.G, c.B, c.A}
|
||||
// }
|
||||
// )
|
||||
//
|
||||
func AdjustFunc(img image.Image, fn func(c color.NRGBA) color.NRGBA) *image.NRGBA {
|
||||
src := toNRGBA(img)
|
||||
width := src.Bounds().Max.X
|
||||
height := src.Bounds().Max.Y
|
||||
dst := image.NewNRGBA(image.Rect(0, 0, width, height))
|
||||
|
||||
parallel(height, func(partStart, partEnd int) {
|
||||
for y := partStart; y < partEnd; y++ {
|
||||
for x := 0; x < width; x++ {
|
||||
i := y*src.Stride + x*4
|
||||
j := y*dst.Stride + x*4
|
||||
|
||||
r := src.Pix[i+0]
|
||||
g := src.Pix[i+1]
|
||||
b := src.Pix[i+2]
|
||||
a := src.Pix[i+3]
|
||||
|
||||
c := fn(color.NRGBA{r, g, b, a})
|
||||
|
||||
dst.Pix[j+0] = c.R
|
||||
dst.Pix[j+1] = c.G
|
||||
dst.Pix[j+2] = c.B
|
||||
dst.Pix[j+3] = c.A
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// AdjustGamma performs a gamma correction on the image and returns the adjusted image.
|
||||
// Gamma parameter must be positive. Gamma = 1.0 gives the original image.
|
||||
// Gamma less than 1.0 darkens the image and gamma greater than 1.0 lightens it.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// dstImage = imaging.AdjustGamma(srcImage, 0.7)
|
||||
//
|
||||
func AdjustGamma(img image.Image, gamma float64) *image.NRGBA {
|
||||
e := 1.0 / math.Max(gamma, 0.0001)
|
||||
lut := make([]uint8, 256)
|
||||
|
||||
for i := 0; i < 256; i++ {
|
||||
lut[i] = clamp(math.Pow(float64(i)/255.0, e) * 255.0)
|
||||
}
|
||||
|
||||
fn := func(c color.NRGBA) color.NRGBA {
|
||||
return color.NRGBA{lut[c.R], lut[c.G], lut[c.B], c.A}
|
||||
}
|
||||
|
||||
return AdjustFunc(img, fn)
|
||||
}
|
||||
|
||||
func sigmoid(a, b, x float64) float64 {
|
||||
return 1 / (1 + math.Exp(b*(a-x)))
|
||||
}
|
||||
|
||||
// AdjustSigmoid changes the contrast of the image using a sigmoidal function and returns the adjusted image.
|
||||
// It's a non-linear contrast change useful for photo adjustments as it preserves highlight and shadow detail.
|
||||
// The midpoint parameter is the midpoint of contrast that must be between 0 and 1, typically 0.5.
|
||||
// The factor parameter indicates how much to increase or decrease the contrast, typically in range (-10, 10).
|
||||
// If the factor parameter is positive the image contrast is increased otherwise the contrast is decreased.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// dstImage = imaging.AdjustSigmoid(srcImage, 0.5, 3.0) // increase the contrast
|
||||
// dstImage = imaging.AdjustSigmoid(srcImage, 0.5, -3.0) // decrease the contrast
|
||||
//
|
||||
func AdjustSigmoid(img image.Image, midpoint, factor float64) *image.NRGBA {
|
||||
if factor == 0 {
|
||||
return Clone(img)
|
||||
}
|
||||
|
||||
lut := make([]uint8, 256)
|
||||
a := math.Min(math.Max(midpoint, 0.0), 1.0)
|
||||
b := math.Abs(factor)
|
||||
sig0 := sigmoid(a, b, 0)
|
||||
sig1 := sigmoid(a, b, 1)
|
||||
e := 1.0e-6
|
||||
|
||||
if factor > 0 {
|
||||
for i := 0; i < 256; i++ {
|
||||
x := float64(i) / 255.0
|
||||
sigX := sigmoid(a, b, x)
|
||||
f := (sigX - sig0) / (sig1 - sig0)
|
||||
lut[i] = clamp(f * 255.0)
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < 256; i++ {
|
||||
x := float64(i) / 255.0
|
||||
arg := math.Min(math.Max((sig1-sig0)*x+sig0, e), 1.0-e)
|
||||
f := a - math.Log(1.0/arg-1.0)/b
|
||||
lut[i] = clamp(f * 255.0)
|
||||
}
|
||||
}
|
||||
|
||||
fn := func(c color.NRGBA) color.NRGBA {
|
||||
return color.NRGBA{lut[c.R], lut[c.G], lut[c.B], c.A}
|
||||
}
|
||||
|
||||
return AdjustFunc(img, fn)
|
||||
}
|
||||
|
||||
// AdjustContrast changes the contrast of the image using the percentage parameter and returns the adjusted image.
|
||||
// The percentage must be in range (-100, 100). The percentage = 0 gives the original image.
|
||||
// The percentage = -100 gives solid grey image.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// dstImage = imaging.AdjustContrast(srcImage, -10) // decrease image contrast by 10%
|
||||
// dstImage = imaging.AdjustContrast(srcImage, 20) // increase image contrast by 20%
|
||||
//
|
||||
func AdjustContrast(img image.Image, percentage float64) *image.NRGBA {
|
||||
percentage = math.Min(math.Max(percentage, -100.0), 100.0)
|
||||
lut := make([]uint8, 256)
|
||||
|
||||
v := (100.0 + percentage) / 100.0
|
||||
for i := 0; i < 256; i++ {
|
||||
if 0 <= v && v <= 1 {
|
||||
lut[i] = clamp((0.5 + (float64(i)/255.0-0.5)*v) * 255.0)
|
||||
} else if 1 < v && v < 2 {
|
||||
lut[i] = clamp((0.5 + (float64(i)/255.0-0.5)*(1/(2.0-v))) * 255.0)
|
||||
} else {
|
||||
lut[i] = uint8(float64(i)/255.0+0.5) * 255
|
||||
}
|
||||
}
|
||||
|
||||
fn := func(c color.NRGBA) color.NRGBA {
|
||||
return color.NRGBA{lut[c.R], lut[c.G], lut[c.B], c.A}
|
||||
}
|
||||
|
||||
return AdjustFunc(img, fn)
|
||||
}
|
||||
|
||||
// AdjustBrightness changes the brightness of the image using the percentage parameter and returns the adjusted image.
|
||||
// The percentage must be in range (-100, 100). The percentage = 0 gives the original image.
|
||||
// The percentage = -100 gives solid black image. The percentage = 100 gives solid white image.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// dstImage = imaging.AdjustBrightness(srcImage, -15) // decrease image brightness by 15%
|
||||
// dstImage = imaging.AdjustBrightness(srcImage, 10) // increase image brightness by 10%
|
||||
//
|
||||
func AdjustBrightness(img image.Image, percentage float64) *image.NRGBA {
|
||||
percentage = math.Min(math.Max(percentage, -100.0), 100.0)
|
||||
lut := make([]uint8, 256)
|
||||
|
||||
shift := 255.0 * percentage / 100.0
|
||||
for i := 0; i < 256; i++ {
|
||||
lut[i] = clamp(float64(i) + shift)
|
||||
}
|
||||
|
||||
fn := func(c color.NRGBA) color.NRGBA {
|
||||
return color.NRGBA{lut[c.R], lut[c.G], lut[c.B], c.A}
|
||||
}
|
||||
|
||||
return AdjustFunc(img, fn)
|
||||
}
|
||||
|
||||
// Grayscale produces grayscale version of the image.
|
||||
func Grayscale(img image.Image) *image.NRGBA {
|
||||
fn := func(c color.NRGBA) color.NRGBA {
|
||||
f := 0.299*float64(c.R) + 0.587*float64(c.G) + 0.114*float64(c.B)
|
||||
y := uint8(f + 0.5)
|
||||
return color.NRGBA{y, y, y, c.A}
|
||||
}
|
||||
return AdjustFunc(img, fn)
|
||||
}
|
||||
|
||||
// Invert produces inverted (negated) version of the image.
|
||||
func Invert(img image.Image) *image.NRGBA {
|
||||
fn := func(c color.NRGBA) color.NRGBA {
|
||||
return color.NRGBA{255 - c.R, 255 - c.G, 255 - c.B, c.A}
|
||||
}
|
||||
return AdjustFunc(img, fn)
|
||||
}
|
|
@ -0,0 +1,504 @@
|
|||
package imaging
|
||||
|
||||
import (
|
||||
"image"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGrayscale(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"Grayscale 3x3",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0x3d, 0x3d, 0x3d, 0x01, 0x78, 0x78, 0x78, 0x02, 0x17, 0x17, 0x17, 0x03,
|
||||
0x1f, 0x1f, 0x1f, 0xff, 0x25, 0x25, 0x25, 0xff, 0x66, 0x66, 0x66, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := Grayscale(d.src)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvert(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"Invert 3x3",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0x33, 0xff, 0xff, 0x01, 0xff, 0x33, 0xff, 0x02, 0xff, 0xff, 0x33, 0x03,
|
||||
0xee, 0xdd, 0xcc, 0xff, 0xcc, 0xdd, 0xee, 0xff, 0x55, 0xcc, 0x44, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff, 0xcc, 0xcc, 0xcc, 0xff, 0x00, 0x00, 0x00, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := Invert(d.src)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdjustContrast(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
p float64
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"AdjustContrast 3x3 10",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
10,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xd5, 0x00, 0x00, 0x01, 0x00, 0xd5, 0x00, 0x02, 0x00, 0x00, 0xd5, 0x03,
|
||||
0x05, 0x18, 0x2b, 0xff, 0x2b, 0x18, 0x05, 0xff, 0xaf, 0x2b, 0xc2, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x2b, 0x2b, 0x2b, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"AdjustContrast 3x3 100",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
100,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xff, 0x00, 0x00, 0x01, 0x00, 0xff, 0x00, 0x02, 0x00, 0x00, 0xff, 0x03,
|
||||
0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0xff, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"AdjustContrast 3x3 -10",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
-10,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xc4, 0x0d, 0x0d, 0x01, 0x0d, 0xc4, 0x0d, 0x02, 0x0d, 0x0d, 0xc4, 0x03,
|
||||
0x1c, 0x2b, 0x3b, 0xff, 0x3b, 0x2b, 0x1c, 0xff, 0xa6, 0x3b, 0xb5, 0xff,
|
||||
0x0d, 0x0d, 0x0d, 0xff, 0x3b, 0x3b, 0x3b, 0xff, 0xf2, 0xf2, 0xf2, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"AdjustContrast 3x3 -100",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
-100,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0x80, 0x80, 0x80, 0x01, 0x80, 0x80, 0x80, 0x02, 0x80, 0x80, 0x80, 0x03,
|
||||
0x80, 0x80, 0x80, 0xff, 0x80, 0x80, 0x80, 0xff, 0x80, 0x80, 0x80, 0xff,
|
||||
0x80, 0x80, 0x80, 0xff, 0x80, 0x80, 0x80, 0xff, 0x80, 0x80, 0x80, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"AdjustContrast 3x3 0",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
0,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := AdjustContrast(d.src, d.p)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdjustBrightness(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
p float64
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"AdjustBrightness 3x3 10",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
10,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xe6, 0x1a, 0x1a, 0x01, 0x1a, 0xe6, 0x1a, 0x02, 0x1a, 0x1a, 0xe6, 0x03,
|
||||
0x2b, 0x3c, 0x4d, 0xff, 0x4d, 0x3c, 0x2b, 0xff, 0xc4, 0x4d, 0xd5, 0xff,
|
||||
0x1a, 0x1a, 0x1a, 0xff, 0x4d, 0x4d, 0x4d, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"AdjustBrightness 3x3 100",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
100,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xff, 0xff, 0xff, 0x01, 0xff, 0xff, 0xff, 0x02, 0xff, 0xff, 0xff, 0x03,
|
||||
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"AdjustBrightness 3x3 -10",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
-10,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xb3, 0x00, 0x00, 0x01, 0x00, 0xb3, 0x00, 0x02, 0x00, 0x00, 0xb3, 0x03,
|
||||
0x00, 0x09, 0x1a, 0xff, 0x1a, 0x09, 0x00, 0xff, 0x91, 0x1a, 0xa2, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x1a, 0x1a, 0x1a, 0xff, 0xe6, 0xe6, 0xe6, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"AdjustBrightness 3x3 -100",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
-100,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03,
|
||||
0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"AdjustBrightness 3x3 0",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
0,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := AdjustBrightness(d.src, d.p)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdjustGamma(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
p float64
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"AdjustGamma 3x3 0.75",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
0.75,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xbd, 0x00, 0x00, 0x01, 0x00, 0xbd, 0x00, 0x02, 0x00, 0x00, 0xbd, 0x03,
|
||||
0x07, 0x11, 0x1e, 0xff, 0x1e, 0x11, 0x07, 0xff, 0x95, 0x1e, 0xa9, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x1e, 0x1e, 0x1e, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"AdjustGamma 3x3 1.5",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
1.5,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xdc, 0x00, 0x00, 0x01, 0x00, 0xdc, 0x00, 0x02, 0x00, 0x00, 0xdc, 0x03,
|
||||
0x2a, 0x43, 0x57, 0xff, 0x57, 0x43, 0x2a, 0xff, 0xc3, 0x57, 0xcf, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x57, 0x57, 0x57, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"AdjustGamma 3x3 1.0",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
1.0,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := AdjustGamma(d.src, d.p)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdjustSigmoid(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
m float64
|
||||
p float64
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"AdjustSigmoid 3x3 0.5 3.0",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
0.5,
|
||||
3.0,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xd4, 0x00, 0x00, 0x01, 0x00, 0xd4, 0x00, 0x02, 0x00, 0x00, 0xd4, 0x03,
|
||||
0x0d, 0x1b, 0x2b, 0xff, 0x2b, 0x1b, 0x0d, 0xff, 0xb1, 0x2b, 0xc3, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x2b, 0x2b, 0x2b, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"AdjustSigmoid 3x3 0.5 -3.0",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
0.5,
|
||||
-3.0,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xc4, 0x00, 0x00, 0x01, 0x00, 0xc4, 0x00, 0x02, 0x00, 0x00, 0xc4, 0x03,
|
||||
0x16, 0x2a, 0x3b, 0xff, 0x3b, 0x2a, 0x16, 0xff, 0xa4, 0x3b, 0xb3, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x3b, 0x3b, 0x3b, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"AdjustSigmoid 3x3 0.5 0.0",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
0.5,
|
||||
0.0,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0x00, 0x00, 0x01, 0x00, 0xcc, 0x00, 0x02, 0x00, 0x00, 0xcc, 0x03,
|
||||
0x11, 0x22, 0x33, 0xff, 0x33, 0x22, 0x11, 0xff, 0xaa, 0x33, 0xbb, 0xff,
|
||||
0x00, 0x00, 0x00, 0xff, 0x33, 0x33, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := AdjustSigmoid(d.src, d.m, d.p)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,187 @@
|
|||
package imaging
|
||||
|
||||
import (
|
||||
"image"
|
||||
"math"
|
||||
)
|
||||
|
||||
func gaussianBlurKernel(x, sigma float64) float64 {
|
||||
return math.Exp(-(x*x)/(2*sigma*sigma)) / (sigma * math.Sqrt(2*math.Pi))
|
||||
}
|
||||
|
||||
// Blur produces a blurred version of the image using a Gaussian function.
|
||||
// Sigma parameter must be positive and indicates how much the image will be blurred.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// dstImage := imaging.Blur(srcImage, 3.5)
|
||||
//
|
||||
func Blur(img image.Image, sigma float64) *image.NRGBA {
|
||||
if sigma <= 0 {
|
||||
// sigma parameter must be positive!
|
||||
return Clone(img)
|
||||
}
|
||||
|
||||
src := toNRGBA(img)
|
||||
radius := int(math.Ceil(sigma * 3.0))
|
||||
kernel := make([]float64, radius+1)
|
||||
|
||||
for i := 0; i <= radius; i++ {
|
||||
kernel[i] = gaussianBlurKernel(float64(i), sigma)
|
||||
}
|
||||
|
||||
var dst *image.NRGBA
|
||||
dst = blurHorizontal(src, kernel)
|
||||
dst = blurVertical(dst, kernel)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func blurHorizontal(src *image.NRGBA, kernel []float64) *image.NRGBA {
|
||||
radius := len(kernel) - 1
|
||||
width := src.Bounds().Max.X
|
||||
height := src.Bounds().Max.Y
|
||||
|
||||
dst := image.NewNRGBA(image.Rect(0, 0, width, height))
|
||||
|
||||
parallel(width, func(partStart, partEnd int) {
|
||||
for x := partStart; x < partEnd; x++ {
|
||||
start := x - radius
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
|
||||
end := x + radius
|
||||
if end > width-1 {
|
||||
end = width - 1
|
||||
}
|
||||
|
||||
weightSum := 0.0
|
||||
for ix := start; ix <= end; ix++ {
|
||||
weightSum += kernel[absint(x-ix)]
|
||||
}
|
||||
|
||||
for y := 0; y < height; y++ {
|
||||
|
||||
r, g, b, a := 0.0, 0.0, 0.0, 0.0
|
||||
for ix := start; ix <= end; ix++ {
|
||||
weight := kernel[absint(x-ix)]
|
||||
i := y*src.Stride + ix*4
|
||||
r += float64(src.Pix[i+0]) * weight
|
||||
g += float64(src.Pix[i+1]) * weight
|
||||
b += float64(src.Pix[i+2]) * weight
|
||||
a += float64(src.Pix[i+3]) * weight
|
||||
}
|
||||
|
||||
r = math.Min(math.Max(r/weightSum, 0.0), 255.0)
|
||||
g = math.Min(math.Max(g/weightSum, 0.0), 255.0)
|
||||
b = math.Min(math.Max(b/weightSum, 0.0), 255.0)
|
||||
a = math.Min(math.Max(a/weightSum, 0.0), 255.0)
|
||||
|
||||
j := y*dst.Stride + x*4
|
||||
dst.Pix[j+0] = uint8(r + 0.5)
|
||||
dst.Pix[j+1] = uint8(g + 0.5)
|
||||
dst.Pix[j+2] = uint8(b + 0.5)
|
||||
dst.Pix[j+3] = uint8(a + 0.5)
|
||||
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func blurVertical(src *image.NRGBA, kernel []float64) *image.NRGBA {
|
||||
radius := len(kernel) - 1
|
||||
width := src.Bounds().Max.X
|
||||
height := src.Bounds().Max.Y
|
||||
|
||||
dst := image.NewNRGBA(image.Rect(0, 0, width, height))
|
||||
|
||||
parallel(height, func(partStart, partEnd int) {
|
||||
for y := partStart; y < partEnd; y++ {
|
||||
start := y - radius
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
|
||||
end := y + radius
|
||||
if end > height-1 {
|
||||
end = height - 1
|
||||
}
|
||||
|
||||
weightSum := 0.0
|
||||
for iy := start; iy <= end; iy++ {
|
||||
weightSum += kernel[absint(y-iy)]
|
||||
}
|
||||
|
||||
for x := 0; x < width; x++ {
|
||||
|
||||
r, g, b, a := 0.0, 0.0, 0.0, 0.0
|
||||
for iy := start; iy <= end; iy++ {
|
||||
weight := kernel[absint(y-iy)]
|
||||
i := iy*src.Stride + x*4
|
||||
r += float64(src.Pix[i+0]) * weight
|
||||
g += float64(src.Pix[i+1]) * weight
|
||||
b += float64(src.Pix[i+2]) * weight
|
||||
a += float64(src.Pix[i+3]) * weight
|
||||
}
|
||||
|
||||
r = math.Min(math.Max(r/weightSum, 0.0), 255.0)
|
||||
g = math.Min(math.Max(g/weightSum, 0.0), 255.0)
|
||||
b = math.Min(math.Max(b/weightSum, 0.0), 255.0)
|
||||
a = math.Min(math.Max(a/weightSum, 0.0), 255.0)
|
||||
|
||||
j := y*dst.Stride + x*4
|
||||
dst.Pix[j+0] = uint8(r + 0.5)
|
||||
dst.Pix[j+1] = uint8(g + 0.5)
|
||||
dst.Pix[j+2] = uint8(b + 0.5)
|
||||
dst.Pix[j+3] = uint8(a + 0.5)
|
||||
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// Sharpen produces a sharpened version of the image.
|
||||
// Sigma parameter must be positive and indicates how much the image will be sharpened.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// dstImage := imaging.Sharpen(srcImage, 3.5)
|
||||
//
|
||||
func Sharpen(img image.Image, sigma float64) *image.NRGBA {
|
||||
if sigma <= 0 {
|
||||
// sigma parameter must be positive!
|
||||
return Clone(img)
|
||||
}
|
||||
|
||||
src := toNRGBA(img)
|
||||
blurred := Blur(img, sigma)
|
||||
|
||||
width := src.Bounds().Max.X
|
||||
height := src.Bounds().Max.Y
|
||||
dst := image.NewNRGBA(image.Rect(0, 0, width, height))
|
||||
|
||||
parallel(height, func(partStart, partEnd int) {
|
||||
for y := partStart; y < partEnd; y++ {
|
||||
for x := 0; x < width; x++ {
|
||||
i := y*src.Stride + x*4
|
||||
for j := 0; j < 4; j++ {
|
||||
k := i + j
|
||||
val := int(src.Pix[k]) + (int(src.Pix[k]) - int(blurred.Pix[k]))
|
||||
if val < 0 {
|
||||
val = 0
|
||||
} else if val > 255 {
|
||||
val = 255
|
||||
}
|
||||
dst.Pix[k] = uint8(val)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return dst
|
||||
}
|
|
@ -0,0 +1,128 @@
|
|||
package imaging
|
||||
|
||||
import (
|
||||
"image"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBlur(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
sigma float64
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"Blur 3x3 0.5",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x66, 0xaa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
},
|
||||
},
|
||||
0.5,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0x01, 0x02, 0x04, 0x04, 0x0a, 0x10, 0x18, 0x18, 0x01, 0x02, 0x04, 0x04,
|
||||
0x09, 0x10, 0x18, 0x18, 0x3f, 0x69, 0x9e, 0x9e, 0x09, 0x10, 0x18, 0x18,
|
||||
0x01, 0x02, 0x04, 0x04, 0x0a, 0x10, 0x18, 0x18, 0x01, 0x02, 0x04, 0x04,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
"Blur 3x3 10",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x66, 0xaa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
},
|
||||
},
|
||||
10,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0x0b, 0x13, 0x1c, 0x1c, 0x0b, 0x13, 0x1c, 0x1c, 0x0b, 0x13, 0x1c, 0x1c,
|
||||
0x0b, 0x13, 0x1c, 0x1c, 0x0b, 0x13, 0x1c, 0x1c, 0x0b, 0x13, 0x1c, 0x1c,
|
||||
0x0b, 0x13, 0x1c, 0x1c, 0x0b, 0x13, 0x1c, 0x1c, 0x0b, 0x13, 0x1c, 0x1c,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := Blur(d.src, d.sigma)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSharpen(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
sigma float64
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"Sharpen 3x3 0.5",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
|
||||
0x66, 0x66, 0x66, 0x66, 0x77, 0x77, 0x77, 0x77, 0x66, 0x66, 0x66, 0x66,
|
||||
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
|
||||
},
|
||||
},
|
||||
0.5,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0x66, 0x66, 0x66, 0x66, 0x64, 0x64, 0x64, 0x64, 0x66, 0x66, 0x66, 0x66,
|
||||
0x64, 0x64, 0x64, 0x64, 0x7e, 0x7e, 0x7e, 0x7e, 0x64, 0x64, 0x64, 0x64,
|
||||
0x66, 0x66, 0x66, 0x66, 0x64, 0x64, 0x64, 0x64, 0x66, 0x66, 0x66, 0x66},
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
"Sharpen 3x3 10",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
|
||||
0x66, 0x66, 0x66, 0x66, 0x77, 0x77, 0x77, 0x77, 0x66, 0x66, 0x66, 0x66,
|
||||
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66},
|
||||
},
|
||||
100,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 3),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64,
|
||||
0x64, 0x64, 0x64, 0x64, 0x86, 0x86, 0x86, 0x86, 0x64, 0x64, 0x64, 0x64,
|
||||
0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64, 0x64,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := Sharpen(d.src, d.sigma)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,436 @@
|
|||
/*
|
||||
Package imaging provides basic image manipulation functions (resize, rotate, flip, crop, etc.).
|
||||
This package is based on the standard Go image package and works best along with it.
|
||||
|
||||
Image manipulation functions provided by the package take any image type
|
||||
that implements `image.Image` interface as an input, and return a new image of
|
||||
`*image.NRGBA` type (32bit RGBA colors, not premultiplied by alpha).
|
||||
|
||||
Imaging package uses parallel goroutines for faster image processing.
|
||||
To achieve maximum performance, make sure to allow Go to utilize all CPU cores:
|
||||
|
||||
runtime.GOMAXPROCS(runtime.NumCPU())
|
||||
*/
|
||||
package imaging
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/gif"
|
||||
"image/jpeg"
|
||||
"image/png"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/image/bmp"
|
||||
"golang.org/x/image/tiff"
|
||||
)
|
||||
|
||||
type Format int
|
||||
|
||||
const (
|
||||
JPEG Format = iota
|
||||
PNG
|
||||
GIF
|
||||
TIFF
|
||||
BMP
|
||||
)
|
||||
|
||||
func (f Format) String() string {
|
||||
switch f {
|
||||
case JPEG:
|
||||
return "JPEG"
|
||||
case PNG:
|
||||
return "PNG"
|
||||
case GIF:
|
||||
return "GIF"
|
||||
case TIFF:
|
||||
return "TIFF"
|
||||
case BMP:
|
||||
return "BMP"
|
||||
default:
|
||||
return "Unsupported"
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
ErrUnsupportedFormat = errors.New("imaging: unsupported image format")
|
||||
)
|
||||
|
||||
// Decode reads an image from r.
|
||||
func Decode(r io.Reader) (image.Image, error) {
|
||||
img, _, err := image.Decode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toNRGBA(img), nil
|
||||
}
|
||||
|
||||
// Open loads an image from file
|
||||
func Open(filename string) (image.Image, error) {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
img, err := Decode(file)
|
||||
return img, err
|
||||
}
|
||||
|
||||
// Encode writes the image img to w in the specified format (JPEG, PNG, GIF, TIFF or BMP).
|
||||
func Encode(w io.Writer, img image.Image, format Format) error {
|
||||
var err error
|
||||
switch format {
|
||||
case JPEG:
|
||||
var rgba *image.RGBA
|
||||
if nrgba, ok := img.(*image.NRGBA); ok {
|
||||
if nrgba.Opaque() {
|
||||
rgba = &image.RGBA{
|
||||
Pix: nrgba.Pix,
|
||||
Stride: nrgba.Stride,
|
||||
Rect: nrgba.Rect,
|
||||
}
|
||||
}
|
||||
}
|
||||
if rgba != nil {
|
||||
err = jpeg.Encode(w, rgba, &jpeg.Options{Quality: 95})
|
||||
} else {
|
||||
err = jpeg.Encode(w, img, &jpeg.Options{Quality: 95})
|
||||
}
|
||||
|
||||
case PNG:
|
||||
err = png.Encode(w, img)
|
||||
case GIF:
|
||||
err = gif.Encode(w, img, &gif.Options{NumColors: 256})
|
||||
case TIFF:
|
||||
err = tiff.Encode(w, img, &tiff.Options{Compression: tiff.Deflate, Predictor: true})
|
||||
case BMP:
|
||||
err = bmp.Encode(w, img)
|
||||
default:
|
||||
err = ErrUnsupportedFormat
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Save saves the image to file with the specified filename.
|
||||
// The format is determined from the filename extension: "jpg" (or "jpeg"), "png", "gif", "tif" (or "tiff") and "bmp" are supported.
|
||||
func Save(img image.Image, filename string) (err error) {
|
||||
formats := map[string]Format{
|
||||
".jpg": JPEG,
|
||||
".jpeg": JPEG,
|
||||
".png": PNG,
|
||||
".tif": TIFF,
|
||||
".tiff": TIFF,
|
||||
".bmp": BMP,
|
||||
".gif": GIF,
|
||||
}
|
||||
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
f, ok := formats[ext]
|
||||
if !ok {
|
||||
return ErrUnsupportedFormat
|
||||
}
|
||||
|
||||
file, err := os.Create(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
return Encode(file, img, f)
|
||||
}
|
||||
|
||||
// New creates a new image with the specified width and height, and fills it with the specified color.
|
||||
func New(width, height int, fillColor color.Color) *image.NRGBA {
|
||||
if width <= 0 || height <= 0 {
|
||||
return &image.NRGBA{}
|
||||
}
|
||||
|
||||
dst := image.NewNRGBA(image.Rect(0, 0, width, height))
|
||||
c := color.NRGBAModel.Convert(fillColor).(color.NRGBA)
|
||||
|
||||
if c.R == 0 && c.G == 0 && c.B == 0 && c.A == 0 {
|
||||
return dst
|
||||
}
|
||||
|
||||
cs := []uint8{c.R, c.G, c.B, c.A}
|
||||
|
||||
// fill the first row
|
||||
for x := 0; x < width; x++ {
|
||||
copy(dst.Pix[x*4:(x+1)*4], cs)
|
||||
}
|
||||
// copy the first row to other rows
|
||||
for y := 1; y < height; y++ {
|
||||
copy(dst.Pix[y*dst.Stride:y*dst.Stride+width*4], dst.Pix[0:width*4])
|
||||
}
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// Clone returns a copy of the given image.
|
||||
func Clone(img image.Image) *image.NRGBA {
|
||||
srcBounds := img.Bounds()
|
||||
srcMinX := srcBounds.Min.X
|
||||
srcMinY := srcBounds.Min.Y
|
||||
|
||||
dstBounds := srcBounds.Sub(srcBounds.Min)
|
||||
dstW := dstBounds.Dx()
|
||||
dstH := dstBounds.Dy()
|
||||
dst := image.NewNRGBA(dstBounds)
|
||||
|
||||
switch src := img.(type) {
|
||||
|
||||
case *image.NRGBA:
|
||||
rowSize := srcBounds.Dx() * 4
|
||||
parallel(dstH, func(partStart, partEnd int) {
|
||||
for dstY := partStart; dstY < partEnd; dstY++ {
|
||||
di := dst.PixOffset(0, dstY)
|
||||
si := src.PixOffset(srcMinX, srcMinY+dstY)
|
||||
copy(dst.Pix[di:di+rowSize], src.Pix[si:si+rowSize])
|
||||
}
|
||||
})
|
||||
|
||||
case *image.NRGBA64:
|
||||
parallel(dstH, func(partStart, partEnd int) {
|
||||
for dstY := partStart; dstY < partEnd; dstY++ {
|
||||
di := dst.PixOffset(0, dstY)
|
||||
si := src.PixOffset(srcMinX, srcMinY+dstY)
|
||||
for dstX := 0; dstX < dstW; dstX++ {
|
||||
|
||||
dst.Pix[di+0] = src.Pix[si+0]
|
||||
dst.Pix[di+1] = src.Pix[si+2]
|
||||
dst.Pix[di+2] = src.Pix[si+4]
|
||||
dst.Pix[di+3] = src.Pix[si+6]
|
||||
|
||||
di += 4
|
||||
si += 8
|
||||
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
case *image.RGBA:
|
||||
parallel(dstH, func(partStart, partEnd int) {
|
||||
for dstY := partStart; dstY < partEnd; dstY++ {
|
||||
di := dst.PixOffset(0, dstY)
|
||||
si := src.PixOffset(srcMinX, srcMinY+dstY)
|
||||
for dstX := 0; dstX < dstW; dstX++ {
|
||||
|
||||
a := src.Pix[si+3]
|
||||
dst.Pix[di+3] = a
|
||||
switch a {
|
||||
case 0:
|
||||
dst.Pix[di+0] = 0
|
||||
dst.Pix[di+1] = 0
|
||||
dst.Pix[di+2] = 0
|
||||
case 0xff:
|
||||
dst.Pix[di+0] = src.Pix[si+0]
|
||||
dst.Pix[di+1] = src.Pix[si+1]
|
||||
dst.Pix[di+2] = src.Pix[si+2]
|
||||
default:
|
||||
dst.Pix[di+0] = uint8(uint16(src.Pix[si+0]) * 0xff / uint16(a))
|
||||
dst.Pix[di+1] = uint8(uint16(src.Pix[si+1]) * 0xff / uint16(a))
|
||||
dst.Pix[di+2] = uint8(uint16(src.Pix[si+2]) * 0xff / uint16(a))
|
||||
}
|
||||
|
||||
di += 4
|
||||
si += 4
|
||||
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
case *image.RGBA64:
|
||||
parallel(dstH, func(partStart, partEnd int) {
|
||||
for dstY := partStart; dstY < partEnd; dstY++ {
|
||||
di := dst.PixOffset(0, dstY)
|
||||
si := src.PixOffset(srcMinX, srcMinY+dstY)
|
||||
for dstX := 0; dstX < dstW; dstX++ {
|
||||
|
||||
a := src.Pix[si+6]
|
||||
dst.Pix[di+3] = a
|
||||
switch a {
|
||||
case 0:
|
||||
dst.Pix[di+0] = 0
|
||||
dst.Pix[di+1] = 0
|
||||
dst.Pix[di+2] = 0
|
||||
case 0xff:
|
||||
dst.Pix[di+0] = src.Pix[si+0]
|
||||
dst.Pix[di+1] = src.Pix[si+2]
|
||||
dst.Pix[di+2] = src.Pix[si+4]
|
||||
default:
|
||||
dst.Pix[di+0] = uint8(uint16(src.Pix[si+0]) * 0xff / uint16(a))
|
||||
dst.Pix[di+1] = uint8(uint16(src.Pix[si+2]) * 0xff / uint16(a))
|
||||
dst.Pix[di+2] = uint8(uint16(src.Pix[si+4]) * 0xff / uint16(a))
|
||||
}
|
||||
|
||||
di += 4
|
||||
si += 8
|
||||
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
case *image.Gray:
|
||||
parallel(dstH, func(partStart, partEnd int) {
|
||||
for dstY := partStart; dstY < partEnd; dstY++ {
|
||||
di := dst.PixOffset(0, dstY)
|
||||
si := src.PixOffset(srcMinX, srcMinY+dstY)
|
||||
for dstX := 0; dstX < dstW; dstX++ {
|
||||
|
||||
c := src.Pix[si]
|
||||
dst.Pix[di+0] = c
|
||||
dst.Pix[di+1] = c
|
||||
dst.Pix[di+2] = c
|
||||
dst.Pix[di+3] = 0xff
|
||||
|
||||
di += 4
|
||||
si += 1
|
||||
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
case *image.Gray16:
|
||||
parallel(dstH, func(partStart, partEnd int) {
|
||||
for dstY := partStart; dstY < partEnd; dstY++ {
|
||||
di := dst.PixOffset(0, dstY)
|
||||
si := src.PixOffset(srcMinX, srcMinY+dstY)
|
||||
for dstX := 0; dstX < dstW; dstX++ {
|
||||
|
||||
c := src.Pix[si]
|
||||
dst.Pix[di+0] = c
|
||||
dst.Pix[di+1] = c
|
||||
dst.Pix[di+2] = c
|
||||
dst.Pix[di+3] = 0xff
|
||||
|
||||
di += 4
|
||||
si += 2
|
||||
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
case *image.YCbCr:
|
||||
parallel(dstH, func(partStart, partEnd int) {
|
||||
for dstY := partStart; dstY < partEnd; dstY++ {
|
||||
di := dst.PixOffset(0, dstY)
|
||||
switch src.SubsampleRatio {
|
||||
case image.YCbCrSubsampleRatio422:
|
||||
siy0 := dstY * src.YStride
|
||||
sic0 := dstY * src.CStride
|
||||
for dstX := 0; dstX < dstW; dstX = dstX + 1 {
|
||||
siy := siy0 + dstX
|
||||
sic := sic0 + ((srcMinX+dstX)/2 - srcMinX/2)
|
||||
r, g, b := color.YCbCrToRGB(src.Y[siy], src.Cb[sic], src.Cr[sic])
|
||||
dst.Pix[di+0] = r
|
||||
dst.Pix[di+1] = g
|
||||
dst.Pix[di+2] = b
|
||||
dst.Pix[di+3] = 0xff
|
||||
di += 4
|
||||
}
|
||||
case image.YCbCrSubsampleRatio420:
|
||||
siy0 := dstY * src.YStride
|
||||
sic0 := ((srcMinY+dstY)/2 - srcMinY/2) * src.CStride
|
||||
for dstX := 0; dstX < dstW; dstX = dstX + 1 {
|
||||
siy := siy0 + dstX
|
||||
sic := sic0 + ((srcMinX+dstX)/2 - srcMinX/2)
|
||||
r, g, b := color.YCbCrToRGB(src.Y[siy], src.Cb[sic], src.Cr[sic])
|
||||
dst.Pix[di+0] = r
|
||||
dst.Pix[di+1] = g
|
||||
dst.Pix[di+2] = b
|
||||
dst.Pix[di+3] = 0xff
|
||||
di += 4
|
||||
}
|
||||
case image.YCbCrSubsampleRatio440:
|
||||
siy0 := dstY * src.YStride
|
||||
sic0 := ((srcMinY+dstY)/2 - srcMinY/2) * src.CStride
|
||||
for dstX := 0; dstX < dstW; dstX = dstX + 1 {
|
||||
siy := siy0 + dstX
|
||||
sic := sic0 + dstX
|
||||
r, g, b := color.YCbCrToRGB(src.Y[siy], src.Cb[sic], src.Cr[sic])
|
||||
dst.Pix[di+0] = r
|
||||
dst.Pix[di+1] = g
|
||||
dst.Pix[di+2] = b
|
||||
dst.Pix[di+3] = 0xff
|
||||
di += 4
|
||||
}
|
||||
default:
|
||||
siy0 := dstY * src.YStride
|
||||
sic0 := dstY * src.CStride
|
||||
for dstX := 0; dstX < dstW; dstX++ {
|
||||
siy := siy0 + dstX
|
||||
sic := sic0 + dstX
|
||||
r, g, b := color.YCbCrToRGB(src.Y[siy], src.Cb[sic], src.Cr[sic])
|
||||
dst.Pix[di+0] = r
|
||||
dst.Pix[di+1] = g
|
||||
dst.Pix[di+2] = b
|
||||
dst.Pix[di+3] = 0xff
|
||||
di += 4
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
case *image.Paletted:
|
||||
plen := len(src.Palette)
|
||||
pnew := make([]color.NRGBA, plen)
|
||||
for i := 0; i < plen; i++ {
|
||||
pnew[i] = color.NRGBAModel.Convert(src.Palette[i]).(color.NRGBA)
|
||||
}
|
||||
|
||||
parallel(dstH, func(partStart, partEnd int) {
|
||||
for dstY := partStart; dstY < partEnd; dstY++ {
|
||||
di := dst.PixOffset(0, dstY)
|
||||
si := src.PixOffset(srcMinX, srcMinY+dstY)
|
||||
for dstX := 0; dstX < dstW; dstX++ {
|
||||
|
||||
c := pnew[src.Pix[si]]
|
||||
dst.Pix[di+0] = c.R
|
||||
dst.Pix[di+1] = c.G
|
||||
dst.Pix[di+2] = c.B
|
||||
dst.Pix[di+3] = c.A
|
||||
|
||||
di += 4
|
||||
si += 1
|
||||
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
default:
|
||||
parallel(dstH, func(partStart, partEnd int) {
|
||||
for dstY := partStart; dstY < partEnd; dstY++ {
|
||||
di := dst.PixOffset(0, dstY)
|
||||
for dstX := 0; dstX < dstW; dstX++ {
|
||||
|
||||
c := color.NRGBAModel.Convert(img.At(srcMinX+dstX, srcMinY+dstY)).(color.NRGBA)
|
||||
dst.Pix[di+0] = c.R
|
||||
dst.Pix[di+1] = c.G
|
||||
dst.Pix[di+2] = c.B
|
||||
dst.Pix[di+3] = c.A
|
||||
|
||||
di += 4
|
||||
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// This function used internally to convert any image type to NRGBA if needed.
|
||||
func toNRGBA(img image.Image) *image.NRGBA {
|
||||
srcBounds := img.Bounds()
|
||||
if srcBounds.Min.X == 0 && srcBounds.Min.Y == 0 {
|
||||
if src0, ok := img.(*image.NRGBA); ok {
|
||||
return src0
|
||||
}
|
||||
}
|
||||
return Clone(img)
|
||||
}
|
|
@ -0,0 +1,361 @@
|
|||
package imaging
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"image/color"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func compareNRGBA(img1, img2 *image.NRGBA, delta int) bool {
|
||||
if !img1.Rect.Eq(img2.Rect) {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(img1.Pix) != len(img2.Pix) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := 0; i < len(img1.Pix); i++ {
|
||||
if absint(int(img1.Pix[i])-int(img2.Pix[i])) > delta {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func TestEncodeDecode(t *testing.T) {
|
||||
imgWithAlpha := image.NewNRGBA(image.Rect(0, 0, 3, 3))
|
||||
imgWithAlpha.Pix = []uint8{
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138,
|
||||
244, 245, 246, 247, 248, 249, 250, 252, 252, 253, 254, 255,
|
||||
}
|
||||
|
||||
imgWithoutAlpha := image.NewNRGBA(image.Rect(0, 0, 3, 3))
|
||||
imgWithoutAlpha.Pix = []uint8{
|
||||
0, 1, 2, 255, 4, 5, 6, 255, 8, 9, 10, 255,
|
||||
127, 128, 129, 255, 131, 132, 133, 255, 135, 136, 137, 255,
|
||||
244, 245, 246, 255, 248, 249, 250, 255, 252, 253, 254, 255,
|
||||
}
|
||||
|
||||
for _, format := range []Format{JPEG, PNG, GIF, BMP, TIFF} {
|
||||
img := imgWithoutAlpha
|
||||
if format == PNG {
|
||||
img = imgWithAlpha
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
err := Encode(buf, img, format)
|
||||
if err != nil {
|
||||
t.Errorf("fail encoding format %s", format)
|
||||
continue
|
||||
}
|
||||
|
||||
img2, err := Decode(buf)
|
||||
if err != nil {
|
||||
t.Errorf("fail decoding format %s", format)
|
||||
continue
|
||||
}
|
||||
img2cloned := Clone(img2)
|
||||
|
||||
delta := 0
|
||||
if format == JPEG {
|
||||
delta = 3
|
||||
} else if format == GIF {
|
||||
delta = 16
|
||||
}
|
||||
|
||||
if !compareNRGBA(img, img2cloned, delta) {
|
||||
t.Errorf("test [DecodeEncode %s] failed: %#v %#v", format, img, img2cloned)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
err := Encode(buf, imgWithAlpha, Format(100))
|
||||
if err != ErrUnsupportedFormat {
|
||||
t.Errorf("expected ErrUnsupportedFormat")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
w, h int
|
||||
c color.Color
|
||||
dstBounds image.Rectangle
|
||||
dstPix []uint8
|
||||
}{
|
||||
{
|
||||
"New 1x1 black",
|
||||
1, 1,
|
||||
color.NRGBA{0, 0, 0, 0},
|
||||
image.Rect(0, 0, 1, 1),
|
||||
[]uint8{0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
{
|
||||
"New 1x2 red",
|
||||
1, 2,
|
||||
color.NRGBA{255, 0, 0, 255},
|
||||
image.Rect(0, 0, 1, 2),
|
||||
[]uint8{0xff, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0xff},
|
||||
},
|
||||
{
|
||||
"New 2x1 white",
|
||||
2, 1,
|
||||
color.NRGBA{255, 255, 255, 255},
|
||||
image.Rect(0, 0, 2, 1),
|
||||
[]uint8{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
},
|
||||
}
|
||||
|
||||
for _, d := range td {
|
||||
got := New(d.w, d.h, d.c)
|
||||
want := image.NewNRGBA(d.dstBounds)
|
||||
want.Pix = d.dstPix
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClone(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"Clone NRGBA",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 0, 1),
|
||||
Stride: 1 * 4,
|
||||
Pix: []uint8{0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 1, 2),
|
||||
Stride: 1 * 4,
|
||||
Pix: []uint8{0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Clone NRGBA64",
|
||||
&image.NRGBA64{
|
||||
Rect: image.Rect(-1, -1, 0, 1),
|
||||
Stride: 1 * 8,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x11, 0x11, 0x22, 0x22, 0x33, 0x33,
|
||||
0xcc, 0xcc, 0xdd, 0xdd, 0xee, 0xee, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 1, 2),
|
||||
Stride: 1 * 4,
|
||||
Pix: []uint8{0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Clone RGBA",
|
||||
&image.RGBA{
|
||||
Rect: image.Rect(-1, -1, 0, 1),
|
||||
Stride: 1 * 4,
|
||||
Pix: []uint8{0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 1, 2),
|
||||
Stride: 1 * 4,
|
||||
Pix: []uint8{0x00, 0x55, 0xaa, 0x33, 0xcc, 0xdd, 0xee, 0xff},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Clone RGBA64",
|
||||
&image.RGBA64{
|
||||
Rect: image.Rect(-1, -1, 0, 1),
|
||||
Stride: 1 * 8,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x11, 0x11, 0x22, 0x22, 0x33, 0x33,
|
||||
0xcc, 0xcc, 0xdd, 0xdd, 0xee, 0xee, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 1, 2),
|
||||
Stride: 1 * 4,
|
||||
Pix: []uint8{0x00, 0x55, 0xaa, 0x33, 0xcc, 0xdd, 0xee, 0xff},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Clone Gray",
|
||||
&image.Gray{
|
||||
Rect: image.Rect(-1, -1, 0, 1),
|
||||
Stride: 1 * 1,
|
||||
Pix: []uint8{0x11, 0xee},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 1, 2),
|
||||
Stride: 1 * 4,
|
||||
Pix: []uint8{0x11, 0x11, 0x11, 0xff, 0xee, 0xee, 0xee, 0xff},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Clone Gray16",
|
||||
&image.Gray16{
|
||||
Rect: image.Rect(-1, -1, 0, 1),
|
||||
Stride: 1 * 2,
|
||||
Pix: []uint8{0x11, 0x11, 0xee, 0xee},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 1, 2),
|
||||
Stride: 1 * 4,
|
||||
Pix: []uint8{0x11, 0x11, 0x11, 0xff, 0xee, 0xee, 0xee, 0xff},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Clone Alpha",
|
||||
&image.Alpha{
|
||||
Rect: image.Rect(-1, -1, 0, 1),
|
||||
Stride: 1 * 1,
|
||||
Pix: []uint8{0x11, 0xee},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 1, 2),
|
||||
Stride: 1 * 4,
|
||||
Pix: []uint8{0xff, 0xff, 0xff, 0x11, 0xff, 0xff, 0xff, 0xee},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Clone YCbCr",
|
||||
&image.YCbCr{
|
||||
Rect: image.Rect(-1, -1, 5, 0),
|
||||
SubsampleRatio: image.YCbCrSubsampleRatio444,
|
||||
YStride: 6,
|
||||
CStride: 6,
|
||||
Y: []uint8{0x00, 0xff, 0x7f, 0x26, 0x4b, 0x0e},
|
||||
Cb: []uint8{0x80, 0x80, 0x80, 0x6b, 0x56, 0xc0},
|
||||
Cr: []uint8{0x80, 0x80, 0x80, 0xc0, 0x4b, 0x76},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 6, 1),
|
||||
Stride: 6 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x00, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff,
|
||||
0x7f, 0x7f, 0x7f, 0xff,
|
||||
0x7f, 0x00, 0x00, 0xff,
|
||||
0x00, 0x7f, 0x00, 0xff,
|
||||
0x00, 0x00, 0x7f, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Clone YCbCr 444",
|
||||
&image.YCbCr{
|
||||
Y: []uint8{0x4c, 0x69, 0x1d, 0xb1, 0x96, 0xe2, 0x26, 0x34, 0xe, 0x59, 0x4b, 0x71, 0x0, 0x4c, 0x99, 0xff},
|
||||
Cb: []uint8{0x55, 0xd4, 0xff, 0x8e, 0x2c, 0x01, 0x6b, 0xaa, 0xc0, 0x95, 0x56, 0x40, 0x80, 0x80, 0x80, 0x80},
|
||||
Cr: []uint8{0xff, 0xeb, 0x6b, 0x36, 0x15, 0x95, 0xc0, 0xb5, 0x76, 0x41, 0x4b, 0x8c, 0x80, 0x80, 0x80, 0x80},
|
||||
YStride: 4,
|
||||
CStride: 4,
|
||||
SubsampleRatio: image.YCbCrSubsampleRatio444,
|
||||
Rect: image.Rectangle{Min: image.Point{X: 0, Y: 0}, Max: image.Point{X: 4, Y: 4}},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Pix: []uint8{0xff, 0x0, 0x0, 0xff, 0xff, 0x0, 0xff, 0xff, 0x0, 0x0, 0xff, 0xff, 0x49, 0xe1, 0xca, 0xff, 0x0, 0xff, 0x0, 0xff, 0xff, 0xff, 0x0, 0xff, 0x7f, 0x0, 0x0, 0xff, 0x7f, 0x0, 0x7f, 0xff, 0x0, 0x0, 0x7f, 0xff, 0x0, 0x7f, 0x7f, 0xff, 0x0, 0x7f, 0x0, 0xff, 0x82, 0x7f, 0x0, 0xff, 0x0, 0x0, 0x0, 0xff, 0x4c, 0x4c, 0x4c, 0xff, 0x99, 0x99, 0x99, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
Stride: 16,
|
||||
Rect: image.Rectangle{Min: image.Point{X: 0, Y: 0}, Max: image.Point{X: 4, Y: 4}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Clone YCbCr 440",
|
||||
&image.YCbCr{
|
||||
Y: []uint8{0x4c, 0x69, 0x1d, 0xb1, 0x96, 0xe2, 0x26, 0x34, 0xe, 0x59, 0x4b, 0x71, 0x0, 0x4c, 0x99, 0xff},
|
||||
Cb: []uint8{0x2c, 0x01, 0x6b, 0xaa, 0x80, 0x80, 0x80, 0x80},
|
||||
Cr: []uint8{0x15, 0x95, 0xc0, 0xb5, 0x80, 0x80, 0x80, 0x80},
|
||||
YStride: 4,
|
||||
CStride: 4,
|
||||
SubsampleRatio: image.YCbCrSubsampleRatio440,
|
||||
Rect: image.Rectangle{Min: image.Point{X: 0, Y: 0}, Max: image.Point{X: 4, Y: 4}},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Pix: []uint8{0x0, 0xb5, 0x0, 0xff, 0x86, 0x86, 0x0, 0xff, 0x77, 0x0, 0x0, 0xff, 0xfb, 0x7d, 0xfb, 0xff, 0x0, 0xff, 0x1, 0xff, 0xff, 0xff, 0x1, 0xff, 0x80, 0x0, 0x1, 0xff, 0x7e, 0x0, 0x7e, 0xff, 0xe, 0xe, 0xe, 0xff, 0x59, 0x59, 0x59, 0xff, 0x4b, 0x4b, 0x4b, 0xff, 0x71, 0x71, 0x71, 0xff, 0x0, 0x0, 0x0, 0xff, 0x4c, 0x4c, 0x4c, 0xff, 0x99, 0x99, 0x99, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
Stride: 16,
|
||||
Rect: image.Rectangle{Min: image.Point{X: 0, Y: 0}, Max: image.Point{X: 4, Y: 4}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Clone YCbCr 422",
|
||||
&image.YCbCr{
|
||||
Y: []uint8{0x4c, 0x69, 0x1d, 0xb1, 0x96, 0xe2, 0x26, 0x34, 0xe, 0x59, 0x4b, 0x71, 0x0, 0x4c, 0x99, 0xff},
|
||||
Cb: []uint8{0xd4, 0x8e, 0x01, 0xaa, 0x95, 0x40, 0x80, 0x80},
|
||||
Cr: []uint8{0xeb, 0x36, 0x95, 0xb5, 0x41, 0x8c, 0x80, 0x80},
|
||||
YStride: 4,
|
||||
CStride: 2,
|
||||
SubsampleRatio: image.YCbCrSubsampleRatio422,
|
||||
Rect: image.Rectangle{Min: image.Point{X: 0, Y: 0}, Max: image.Point{X: 4, Y: 4}},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Pix: []uint8{0xe2, 0x0, 0xe1, 0xff, 0xff, 0x0, 0xfe, 0xff, 0x0, 0x4d, 0x36, 0xff, 0x49, 0xe1, 0xca, 0xff, 0xb3, 0xb3, 0x0, 0xff, 0xff, 0xff, 0x1, 0xff, 0x70, 0x0, 0x70, 0xff, 0x7e, 0x0, 0x7e, 0xff, 0x0, 0x34, 0x33, 0xff, 0x1, 0x7f, 0x7e, 0xff, 0x5c, 0x58, 0x0, 0xff, 0x82, 0x7e, 0x0, 0xff, 0x0, 0x0, 0x0, 0xff, 0x4c, 0x4c, 0x4c, 0xff, 0x99, 0x99, 0x99, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
Stride: 16,
|
||||
Rect: image.Rectangle{Min: image.Point{X: 0, Y: 0}, Max: image.Point{X: 4, Y: 4}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Clone YCbCr 420",
|
||||
&image.YCbCr{
|
||||
Y: []uint8{0x4c, 0x69, 0x1d, 0xb1, 0x96, 0xe2, 0x26, 0x34, 0xe, 0x59, 0x4b, 0x71, 0x0, 0x4c, 0x99, 0xff},
|
||||
Cb: []uint8{0x01, 0xaa, 0x80, 0x80},
|
||||
Cr: []uint8{0x95, 0xb5, 0x80, 0x80},
|
||||
YStride: 4, CStride: 2,
|
||||
SubsampleRatio: image.YCbCrSubsampleRatio420,
|
||||
Rect: image.Rectangle{Min: image.Point{X: 0, Y: 0}, Max: image.Point{X: 4, Y: 4}},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Pix: []uint8{0x69, 0x69, 0x0, 0xff, 0x86, 0x86, 0x0, 0xff, 0x67, 0x0, 0x67, 0xff, 0xfb, 0x7d, 0xfb, 0xff, 0xb3, 0xb3, 0x0, 0xff, 0xff, 0xff, 0x1, 0xff, 0x70, 0x0, 0x70, 0xff, 0x7e, 0x0, 0x7e, 0xff, 0xe, 0xe, 0xe, 0xff, 0x59, 0x59, 0x59, 0xff, 0x4b, 0x4b, 0x4b, 0xff, 0x71, 0x71, 0x71, 0xff, 0x0, 0x0, 0x0, 0xff, 0x4c, 0x4c, 0x4c, 0xff, 0x99, 0x99, 0x99, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
Stride: 16,
|
||||
Rect: image.Rectangle{Min: image.Point{X: 0, Y: 0}, Max: image.Point{X: 4, Y: 4}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Clone Paletted",
|
||||
&image.Paletted{
|
||||
Rect: image.Rect(-1, -1, 5, 0),
|
||||
Stride: 6 * 1,
|
||||
Palette: color.Palette{
|
||||
color.NRGBA{R: 0x00, G: 0x00, B: 0x00, A: 0xff},
|
||||
color.NRGBA{R: 0xff, G: 0xff, B: 0xff, A: 0xff},
|
||||
color.NRGBA{R: 0x7f, G: 0x7f, B: 0x7f, A: 0xff},
|
||||
color.NRGBA{R: 0x7f, G: 0x00, B: 0x00, A: 0xff},
|
||||
color.NRGBA{R: 0x00, G: 0x7f, B: 0x00, A: 0xff},
|
||||
color.NRGBA{R: 0x00, G: 0x00, B: 0x7f, A: 0xff},
|
||||
},
|
||||
Pix: []uint8{0x0, 0x1, 0x2, 0x3, 0x4, 0x5},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 6, 1),
|
||||
Stride: 6 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x00, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff,
|
||||
0x7f, 0x7f, 0x7f, 0xff,
|
||||
0x7f, 0x00, 0x00, 0xff,
|
||||
0x00, 0x7f, 0x00, 0xff,
|
||||
0x00, 0x00, 0x7f, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, d := range td {
|
||||
got := Clone(d.src)
|
||||
want := d.want
|
||||
|
||||
delta := 0
|
||||
if _, ok := d.src.(*image.YCbCr); ok {
|
||||
delta = 1
|
||||
}
|
||||
|
||||
if !compareNRGBA(got, want, delta) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,564 @@
|
|||
package imaging
|
||||
|
||||
import (
|
||||
"image"
|
||||
"math"
|
||||
)
|
||||
|
||||
type iwpair struct {
|
||||
i int
|
||||
w int32
|
||||
}
|
||||
|
||||
type pweights struct {
|
||||
iwpairs []iwpair
|
||||
wsum int32
|
||||
}
|
||||
|
||||
func precomputeWeights(dstSize, srcSize int, filter ResampleFilter) []pweights {
|
||||
du := float64(srcSize) / float64(dstSize)
|
||||
scale := du
|
||||
if scale < 1.0 {
|
||||
scale = 1.0
|
||||
}
|
||||
ru := math.Ceil(scale * filter.Support)
|
||||
|
||||
out := make([]pweights, dstSize)
|
||||
|
||||
for v := 0; v < dstSize; v++ {
|
||||
fu := (float64(v)+0.5)*du - 0.5
|
||||
|
||||
startu := int(math.Ceil(fu - ru))
|
||||
if startu < 0 {
|
||||
startu = 0
|
||||
}
|
||||
endu := int(math.Floor(fu + ru))
|
||||
if endu > srcSize-1 {
|
||||
endu = srcSize - 1
|
||||
}
|
||||
|
||||
wsum := int32(0)
|
||||
for u := startu; u <= endu; u++ {
|
||||
w := int32(0xff * filter.Kernel((float64(u)-fu)/scale))
|
||||
if w != 0 {
|
||||
wsum += w
|
||||
out[v].iwpairs = append(out[v].iwpairs, iwpair{u, w})
|
||||
}
|
||||
}
|
||||
out[v].wsum = wsum
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// Resize resizes the image to the specified width and height using the specified resampling
|
||||
// filter and returns the transformed image. If one of width or height is 0, the image aspect
|
||||
// ratio is preserved.
|
||||
//
|
||||
// Supported resample filters: NearestNeighbor, Box, Linear, Hermite, MitchellNetravali,
|
||||
// CatmullRom, BSpline, Gaussian, Lanczos, Hann, Hamming, Blackman, Bartlett, Welch, Cosine.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// dstImage := imaging.Resize(srcImage, 800, 600, imaging.Lanczos)
|
||||
//
|
||||
func Resize(img image.Image, width, height int, filter ResampleFilter) *image.NRGBA {
|
||||
dstW, dstH := width, height
|
||||
|
||||
if dstW < 0 || dstH < 0 {
|
||||
return &image.NRGBA{}
|
||||
}
|
||||
if dstW == 0 && dstH == 0 {
|
||||
return &image.NRGBA{}
|
||||
}
|
||||
|
||||
src := toNRGBA(img)
|
||||
|
||||
srcW := src.Bounds().Max.X
|
||||
srcH := src.Bounds().Max.Y
|
||||
|
||||
if srcW <= 0 || srcH <= 0 {
|
||||
return &image.NRGBA{}
|
||||
}
|
||||
|
||||
// if new width or height is 0 then preserve aspect ratio, minimum 1px
|
||||
if dstW == 0 {
|
||||
tmpW := float64(dstH) * float64(srcW) / float64(srcH)
|
||||
dstW = int(math.Max(1.0, math.Floor(tmpW+0.5)))
|
||||
}
|
||||
if dstH == 0 {
|
||||
tmpH := float64(dstW) * float64(srcH) / float64(srcW)
|
||||
dstH = int(math.Max(1.0, math.Floor(tmpH+0.5)))
|
||||
}
|
||||
|
||||
var dst *image.NRGBA
|
||||
|
||||
if filter.Support <= 0.0 {
|
||||
// nearest-neighbor special case
|
||||
dst = resizeNearest(src, dstW, dstH)
|
||||
|
||||
} else {
|
||||
// two-pass resize
|
||||
if srcW != dstW {
|
||||
dst = resizeHorizontal(src, dstW, filter)
|
||||
} else {
|
||||
dst = src
|
||||
}
|
||||
|
||||
if srcH != dstH {
|
||||
dst = resizeVertical(dst, dstH, filter)
|
||||
}
|
||||
}
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func resizeHorizontal(src *image.NRGBA, width int, filter ResampleFilter) *image.NRGBA {
|
||||
srcBounds := src.Bounds()
|
||||
srcW := srcBounds.Max.X
|
||||
srcH := srcBounds.Max.Y
|
||||
|
||||
dstW := width
|
||||
dstH := srcH
|
||||
|
||||
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
|
||||
|
||||
weights := precomputeWeights(dstW, srcW, filter)
|
||||
|
||||
parallel(dstH, func(partStart, partEnd int) {
|
||||
for dstY := partStart; dstY < partEnd; dstY++ {
|
||||
for dstX := 0; dstX < dstW; dstX++ {
|
||||
var c [4]int32
|
||||
for _, iw := range weights[dstX].iwpairs {
|
||||
i := dstY*src.Stride + iw.i*4
|
||||
c[0] += int32(src.Pix[i+0]) * iw.w
|
||||
c[1] += int32(src.Pix[i+1]) * iw.w
|
||||
c[2] += int32(src.Pix[i+2]) * iw.w
|
||||
c[3] += int32(src.Pix[i+3]) * iw.w
|
||||
}
|
||||
j := dstY*dst.Stride + dstX*4
|
||||
sum := weights[dstX].wsum
|
||||
dst.Pix[j+0] = clampint32(int32(float32(c[0])/float32(sum) + 0.5))
|
||||
dst.Pix[j+1] = clampint32(int32(float32(c[1])/float32(sum) + 0.5))
|
||||
dst.Pix[j+2] = clampint32(int32(float32(c[2])/float32(sum) + 0.5))
|
||||
dst.Pix[j+3] = clampint32(int32(float32(c[3])/float32(sum) + 0.5))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func resizeVertical(src *image.NRGBA, height int, filter ResampleFilter) *image.NRGBA {
|
||||
srcBounds := src.Bounds()
|
||||
srcW := srcBounds.Max.X
|
||||
srcH := srcBounds.Max.Y
|
||||
|
||||
dstW := srcW
|
||||
dstH := height
|
||||
|
||||
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
|
||||
|
||||
weights := precomputeWeights(dstH, srcH, filter)
|
||||
|
||||
parallel(dstW, func(partStart, partEnd int) {
|
||||
|
||||
for dstX := partStart; dstX < partEnd; dstX++ {
|
||||
for dstY := 0; dstY < dstH; dstY++ {
|
||||
var c [4]int32
|
||||
for _, iw := range weights[dstY].iwpairs {
|
||||
i := iw.i*src.Stride + dstX*4
|
||||
c[0] += int32(src.Pix[i+0]) * iw.w
|
||||
c[1] += int32(src.Pix[i+1]) * iw.w
|
||||
c[2] += int32(src.Pix[i+2]) * iw.w
|
||||
c[3] += int32(src.Pix[i+3]) * iw.w
|
||||
}
|
||||
j := dstY*dst.Stride + dstX*4
|
||||
sum := weights[dstY].wsum
|
||||
dst.Pix[j+0] = clampint32(int32(float32(c[0])/float32(sum) + 0.5))
|
||||
dst.Pix[j+1] = clampint32(int32(float32(c[1])/float32(sum) + 0.5))
|
||||
dst.Pix[j+2] = clampint32(int32(float32(c[2])/float32(sum) + 0.5))
|
||||
dst.Pix[j+3] = clampint32(int32(float32(c[3])/float32(sum) + 0.5))
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// fast nearest-neighbor resize, no filtering
|
||||
func resizeNearest(src *image.NRGBA, width, height int) *image.NRGBA {
|
||||
dstW, dstH := width, height
|
||||
|
||||
srcBounds := src.Bounds()
|
||||
srcW := srcBounds.Max.X
|
||||
srcH := srcBounds.Max.Y
|
||||
|
||||
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
|
||||
|
||||
dx := float64(srcW) / float64(dstW)
|
||||
dy := float64(srcH) / float64(dstH)
|
||||
|
||||
parallel(dstH, func(partStart, partEnd int) {
|
||||
|
||||
for dstY := partStart; dstY < partEnd; dstY++ {
|
||||
fy := (float64(dstY)+0.5)*dy - 0.5
|
||||
|
||||
for dstX := 0; dstX < dstW; dstX++ {
|
||||
fx := (float64(dstX)+0.5)*dx - 0.5
|
||||
|
||||
srcX := int(math.Min(math.Max(math.Floor(fx+0.5), 0.0), float64(srcW)))
|
||||
srcY := int(math.Min(math.Max(math.Floor(fy+0.5), 0.0), float64(srcH)))
|
||||
|
||||
srcOff := srcY*src.Stride + srcX*4
|
||||
dstOff := dstY*dst.Stride + dstX*4
|
||||
|
||||
copy(dst.Pix[dstOff:dstOff+4], src.Pix[srcOff:srcOff+4])
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// Fit scales down the image using the specified resample filter to fit the specified
|
||||
// maximum width and height and returns the transformed image.
|
||||
//
|
||||
// Supported resample filters: NearestNeighbor, Box, Linear, Hermite, MitchellNetravali,
|
||||
// CatmullRom, BSpline, Gaussian, Lanczos, Hann, Hamming, Blackman, Bartlett, Welch, Cosine.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// dstImage := imaging.Fit(srcImage, 800, 600, imaging.Lanczos)
|
||||
//
|
||||
func Fit(img image.Image, width, height int, filter ResampleFilter) *image.NRGBA {
|
||||
maxW, maxH := width, height
|
||||
|
||||
if maxW <= 0 || maxH <= 0 {
|
||||
return &image.NRGBA{}
|
||||
}
|
||||
|
||||
srcBounds := img.Bounds()
|
||||
srcW := srcBounds.Dx()
|
||||
srcH := srcBounds.Dy()
|
||||
|
||||
if srcW <= 0 || srcH <= 0 {
|
||||
return &image.NRGBA{}
|
||||
}
|
||||
|
||||
if srcW <= maxW && srcH <= maxH {
|
||||
return Clone(img)
|
||||
}
|
||||
|
||||
srcAspectRatio := float64(srcW) / float64(srcH)
|
||||
maxAspectRatio := float64(maxW) / float64(maxH)
|
||||
|
||||
var newW, newH int
|
||||
if srcAspectRatio > maxAspectRatio {
|
||||
newW = maxW
|
||||
newH = int(float64(newW) / srcAspectRatio)
|
||||
} else {
|
||||
newH = maxH
|
||||
newW = int(float64(newH) * srcAspectRatio)
|
||||
}
|
||||
|
||||
return Resize(img, newW, newH, filter)
|
||||
}
|
||||
|
||||
// Thumbnail scales the image up or down using the specified resample filter, crops it
|
||||
// to the specified width and hight and returns the transformed image.
|
||||
//
|
||||
// Supported resample filters: NearestNeighbor, Box, Linear, Hermite, MitchellNetravali,
|
||||
// CatmullRom, BSpline, Gaussian, Lanczos, Hann, Hamming, Blackman, Bartlett, Welch, Cosine.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// dstImage := imaging.Thumbnail(srcImage, 100, 100, imaging.Lanczos)
|
||||
//
|
||||
func Thumbnail(img image.Image, width, height int, filter ResampleFilter) *image.NRGBA {
|
||||
thumbW, thumbH := width, height
|
||||
|
||||
if thumbW <= 0 || thumbH <= 0 {
|
||||
return &image.NRGBA{}
|
||||
}
|
||||
|
||||
srcBounds := img.Bounds()
|
||||
srcW := srcBounds.Dx()
|
||||
srcH := srcBounds.Dy()
|
||||
|
||||
if srcW <= 0 || srcH <= 0 {
|
||||
return &image.NRGBA{}
|
||||
}
|
||||
|
||||
srcAspectRatio := float64(srcW) / float64(srcH)
|
||||
thumbAspectRatio := float64(thumbW) / float64(thumbH)
|
||||
|
||||
var tmp image.Image
|
||||
if srcAspectRatio > thumbAspectRatio {
|
||||
tmp = Resize(img, 0, thumbH, filter)
|
||||
} else {
|
||||
tmp = Resize(img, thumbW, 0, filter)
|
||||
}
|
||||
|
||||
return CropCenter(tmp, thumbW, thumbH)
|
||||
}
|
||||
|
||||
// Resample filter struct. It can be used to make custom filters.
|
||||
//
|
||||
// Supported resample filters: NearestNeighbor, Box, Linear, Hermite, MitchellNetravali,
|
||||
// CatmullRom, BSpline, Gaussian, Lanczos, Hann, Hamming, Blackman, Bartlett, Welch, Cosine.
|
||||
//
|
||||
// General filter recommendations:
|
||||
//
|
||||
// - Lanczos
|
||||
// Probably the best resampling filter for photographic images yielding sharp results,
|
||||
// but it's slower than cubic filters (see below).
|
||||
//
|
||||
// - CatmullRom
|
||||
// A sharp cubic filter. It's a good filter for both upscaling and downscaling if sharp results are needed.
|
||||
//
|
||||
// - MitchellNetravali
|
||||
// A high quality cubic filter that produces smoother results with less ringing than CatmullRom.
|
||||
//
|
||||
// - BSpline
|
||||
// A good filter if a very smooth output is needed.
|
||||
//
|
||||
// - Linear
|
||||
// Bilinear interpolation filter, produces reasonably good, smooth output. It's faster than cubic filters.
|
||||
//
|
||||
// - Box
|
||||
// Simple and fast resampling filter appropriate for downscaling.
|
||||
// When upscaling it's similar to NearestNeighbor.
|
||||
//
|
||||
// - NearestNeighbor
|
||||
// Fastest resample filter, no antialiasing at all. Rarely used.
|
||||
//
|
||||
type ResampleFilter struct {
|
||||
Support float64
|
||||
Kernel func(float64) float64
|
||||
}
|
||||
|
||||
// Nearest-neighbor filter, no anti-aliasing.
|
||||
var NearestNeighbor ResampleFilter
|
||||
|
||||
// Box filter (averaging pixels).
|
||||
var Box ResampleFilter
|
||||
|
||||
// Linear filter.
|
||||
var Linear ResampleFilter
|
||||
|
||||
// Hermite cubic spline filter (BC-spline; B=0; C=0).
|
||||
var Hermite ResampleFilter
|
||||
|
||||
// Mitchell-Netravali cubic filter (BC-spline; B=1/3; C=1/3).
|
||||
var MitchellNetravali ResampleFilter
|
||||
|
||||
// Catmull-Rom - sharp cubic filter (BC-spline; B=0; C=0.5).
|
||||
var CatmullRom ResampleFilter
|
||||
|
||||
// Cubic B-spline - smooth cubic filter (BC-spline; B=1; C=0).
|
||||
var BSpline ResampleFilter
|
||||
|
||||
// Gaussian Blurring Filter.
|
||||
var Gaussian ResampleFilter
|
||||
|
||||
// Bartlett-windowed sinc filter (3 lobes).
|
||||
var Bartlett ResampleFilter
|
||||
|
||||
// Lanczos filter (3 lobes).
|
||||
var Lanczos ResampleFilter
|
||||
|
||||
// Hann-windowed sinc filter (3 lobes).
|
||||
var Hann ResampleFilter
|
||||
|
||||
// Hamming-windowed sinc filter (3 lobes).
|
||||
var Hamming ResampleFilter
|
||||
|
||||
// Blackman-windowed sinc filter (3 lobes).
|
||||
var Blackman ResampleFilter
|
||||
|
||||
// Welch-windowed sinc filter (parabolic window, 3 lobes).
|
||||
var Welch ResampleFilter
|
||||
|
||||
// Cosine-windowed sinc filter (3 lobes).
|
||||
var Cosine ResampleFilter
|
||||
|
||||
func bcspline(x, b, c float64) float64 {
|
||||
x = math.Abs(x)
|
||||
if x < 1.0 {
|
||||
return ((12-9*b-6*c)*x*x*x + (-18+12*b+6*c)*x*x + (6 - 2*b)) / 6
|
||||
}
|
||||
if x < 2.0 {
|
||||
return ((-b-6*c)*x*x*x + (6*b+30*c)*x*x + (-12*b-48*c)*x + (8*b + 24*c)) / 6
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func sinc(x float64) float64 {
|
||||
if x == 0 {
|
||||
return 1
|
||||
}
|
||||
return math.Sin(math.Pi*x) / (math.Pi * x)
|
||||
}
|
||||
|
||||
func init() {
|
||||
NearestNeighbor = ResampleFilter{
|
||||
Support: 0.0, // special case - not applying the filter
|
||||
}
|
||||
|
||||
Box = ResampleFilter{
|
||||
Support: 0.5,
|
||||
Kernel: func(x float64) float64 {
|
||||
x = math.Abs(x)
|
||||
if x <= 0.5 {
|
||||
return 1.0
|
||||
}
|
||||
return 0
|
||||
},
|
||||
}
|
||||
|
||||
Linear = ResampleFilter{
|
||||
Support: 1.0,
|
||||
Kernel: func(x float64) float64 {
|
||||
x = math.Abs(x)
|
||||
if x < 1.0 {
|
||||
return 1.0 - x
|
||||
}
|
||||
return 0
|
||||
},
|
||||
}
|
||||
|
||||
Hermite = ResampleFilter{
|
||||
Support: 1.0,
|
||||
Kernel: func(x float64) float64 {
|
||||
x = math.Abs(x)
|
||||
if x < 1.0 {
|
||||
return bcspline(x, 0.0, 0.0)
|
||||
}
|
||||
return 0
|
||||
},
|
||||
}
|
||||
|
||||
MitchellNetravali = ResampleFilter{
|
||||
Support: 2.0,
|
||||
Kernel: func(x float64) float64 {
|
||||
x = math.Abs(x)
|
||||
if x < 2.0 {
|
||||
return bcspline(x, 1.0/3.0, 1.0/3.0)
|
||||
}
|
||||
return 0
|
||||
},
|
||||
}
|
||||
|
||||
CatmullRom = ResampleFilter{
|
||||
Support: 2.0,
|
||||
Kernel: func(x float64) float64 {
|
||||
x = math.Abs(x)
|
||||
if x < 2.0 {
|
||||
return bcspline(x, 0.0, 0.5)
|
||||
}
|
||||
return 0
|
||||
},
|
||||
}
|
||||
|
||||
BSpline = ResampleFilter{
|
||||
Support: 2.0,
|
||||
Kernel: func(x float64) float64 {
|
||||
x = math.Abs(x)
|
||||
if x < 2.0 {
|
||||
return bcspline(x, 1.0, 0.0)
|
||||
}
|
||||
return 0
|
||||
},
|
||||
}
|
||||
|
||||
Gaussian = ResampleFilter{
|
||||
Support: 2.0,
|
||||
Kernel: func(x float64) float64 {
|
||||
x = math.Abs(x)
|
||||
if x < 2.0 {
|
||||
return math.Exp(-2 * x * x)
|
||||
}
|
||||
return 0
|
||||
},
|
||||
}
|
||||
|
||||
Bartlett = ResampleFilter{
|
||||
Support: 3.0,
|
||||
Kernel: func(x float64) float64 {
|
||||
x = math.Abs(x)
|
||||
if x < 3.0 {
|
||||
return sinc(x) * (3.0 - x) / 3.0
|
||||
}
|
||||
return 0
|
||||
},
|
||||
}
|
||||
|
||||
Lanczos = ResampleFilter{
|
||||
Support: 3.0,
|
||||
Kernel: func(x float64) float64 {
|
||||
x = math.Abs(x)
|
||||
if x < 3.0 {
|
||||
return sinc(x) * sinc(x/3.0)
|
||||
}
|
||||
return 0
|
||||
},
|
||||
}
|
||||
|
||||
Hann = ResampleFilter{
|
||||
Support: 3.0,
|
||||
Kernel: func(x float64) float64 {
|
||||
x = math.Abs(x)
|
||||
if x < 3.0 {
|
||||
return sinc(x) * (0.5 + 0.5*math.Cos(math.Pi*x/3.0))
|
||||
}
|
||||
return 0
|
||||
},
|
||||
}
|
||||
|
||||
Hamming = ResampleFilter{
|
||||
Support: 3.0,
|
||||
Kernel: func(x float64) float64 {
|
||||
x = math.Abs(x)
|
||||
if x < 3.0 {
|
||||
return sinc(x) * (0.54 + 0.46*math.Cos(math.Pi*x/3.0))
|
||||
}
|
||||
return 0
|
||||
},
|
||||
}
|
||||
|
||||
Blackman = ResampleFilter{
|
||||
Support: 3.0,
|
||||
Kernel: func(x float64) float64 {
|
||||
x = math.Abs(x)
|
||||
if x < 3.0 {
|
||||
return sinc(x) * (0.42 - 0.5*math.Cos(math.Pi*x/3.0+math.Pi) + 0.08*math.Cos(2.0*math.Pi*x/3.0))
|
||||
}
|
||||
return 0
|
||||
},
|
||||
}
|
||||
|
||||
Welch = ResampleFilter{
|
||||
Support: 3.0,
|
||||
Kernel: func(x float64) float64 {
|
||||
x = math.Abs(x)
|
||||
if x < 3.0 {
|
||||
return sinc(x) * (1.0 - (x * x / 9.0))
|
||||
}
|
||||
return 0
|
||||
},
|
||||
}
|
||||
|
||||
Cosine = ResampleFilter{
|
||||
Support: 3.0,
|
||||
Kernel: func(x float64) float64 {
|
||||
x = math.Abs(x)
|
||||
if x < 3.0 {
|
||||
return sinc(x) * math.Cos((math.Pi/2.0)*(x/3.0))
|
||||
}
|
||||
return 0
|
||||
},
|
||||
}
|
||||
}
|
|
@ -0,0 +1,281 @@
|
|||
package imaging
|
||||
|
||||
import (
|
||||
"image"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResize(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
w, h int
|
||||
f ResampleFilter
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"Resize 2x2 1x1 box",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 1),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff,
|
||||
0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
1, 1,
|
||||
Box,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 1, 1),
|
||||
Stride: 1 * 4,
|
||||
Pix: []uint8{0x40, 0x40, 0x40, 0xc0},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Resize 2x2 2x2 box",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 1),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff,
|
||||
0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
2, 2,
|
||||
Box,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 2, 2),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff,
|
||||
0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Resize 3x1 1x1 nearest",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 2, 0),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xff, 0x00, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
1, 1,
|
||||
NearestNeighbor,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 1, 1),
|
||||
Stride: 1 * 4,
|
||||
Pix: []uint8{0x00, 0xff, 0x00, 0xff},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Resize 2x2 0x4 box",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 1),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff,
|
||||
0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
0, 4,
|
||||
Box,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 4, 4),
|
||||
Stride: 4 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0xff,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0xff,
|
||||
0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0xff, 0xff,
|
||||
0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Resize 2x2 4x0 linear",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 1),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff,
|
||||
0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
4, 0,
|
||||
Linear,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 4, 4),
|
||||
Stride: 4 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0xbf, 0x00, 0x00, 0xbf, 0xff, 0x00, 0x00, 0xff,
|
||||
0x00, 0x40, 0x00, 0x40, 0x30, 0x30, 0x10, 0x70, 0x8f, 0x10, 0x30, 0xcf, 0xbf, 0x00, 0x40, 0xff,
|
||||
0x00, 0xbf, 0x00, 0xbf, 0x10, 0x8f, 0x30, 0xcf, 0x30, 0x30, 0x8f, 0xef, 0x40, 0x00, 0xbf, 0xff,
|
||||
0x00, 0xff, 0x00, 0xff, 0x00, 0xbf, 0x40, 0xff, 0x00, 0x40, 0xbf, 0xff, 0x00, 0x00, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := Resize(d.src, d.w, d.h, d.f)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 1) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFit(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
w, h int
|
||||
f ResampleFilter
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"Fit 2x2 1x10 box",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 1),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff,
|
||||
0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
1, 10,
|
||||
Box,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 1, 1),
|
||||
Stride: 1 * 4,
|
||||
Pix: []uint8{0x40, 0x40, 0x40, 0xc0},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Fit 2x2 10x1 box",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 1),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff,
|
||||
0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
10, 1,
|
||||
Box,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 1, 1),
|
||||
Stride: 1 * 4,
|
||||
Pix: []uint8{0x40, 0x40, 0x40, 0xc0},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Fit 2x2 10x10 box",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 1),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff,
|
||||
0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
10, 10,
|
||||
Box,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 2, 2),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff,
|
||||
0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := Fit(d.src, d.w, d.h, d.f)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestThumbnail(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
w, h int
|
||||
f ResampleFilter
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"Thumbnail 6x2 1x1 box",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 5, 1),
|
||||
Stride: 6 * 4,
|
||||
Pix: []uint8{
|
||||
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
1, 1,
|
||||
Box,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 1, 1),
|
||||
Stride: 1 * 4,
|
||||
Pix: []uint8{0x40, 0x40, 0x40, 0xc0},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Thumbnail 2x6 1x1 box",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 5),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff,
|
||||
0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
1, 1,
|
||||
Box,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 1, 1),
|
||||
Stride: 1 * 4,
|
||||
Pix: []uint8{0x40, 0x40, 0x40, 0xc0},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Thumbnail 1x3 2x2 box",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 0, 2),
|
||||
Stride: 1 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0xff, 0x00, 0x00, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
},
|
||||
2, 2,
|
||||
Box,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 2, 2),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0xff, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0xff,
|
||||
0xff, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := Thumbnail(d.src, d.w, d.h, d.f)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,139 @@
|
|||
package imaging
|
||||
|
||||
import (
|
||||
"image"
|
||||
"math"
|
||||
)
|
||||
|
||||
// Crop cuts out a rectangular region with the specified bounds
|
||||
// from the image and returns the cropped image.
|
||||
func Crop(img image.Image, rect image.Rectangle) *image.NRGBA {
|
||||
src := toNRGBA(img)
|
||||
srcRect := rect.Sub(img.Bounds().Min)
|
||||
sub := src.SubImage(srcRect)
|
||||
return Clone(sub) // New image Bounds().Min point will be (0, 0)
|
||||
}
|
||||
|
||||
// CropCenter cuts out a rectangular region with the specified size
|
||||
// from the center of the image and returns the cropped image.
|
||||
func CropCenter(img image.Image, width, height int) *image.NRGBA {
|
||||
cropW, cropH := width, height
|
||||
|
||||
srcBounds := img.Bounds()
|
||||
srcW := srcBounds.Dx()
|
||||
srcH := srcBounds.Dy()
|
||||
srcMinX := srcBounds.Min.X
|
||||
srcMinY := srcBounds.Min.Y
|
||||
|
||||
centerX := srcMinX + srcW/2
|
||||
centerY := srcMinY + srcH/2
|
||||
|
||||
x0 := centerX - cropW/2
|
||||
y0 := centerY - cropH/2
|
||||
x1 := x0 + cropW
|
||||
y1 := y0 + cropH
|
||||
|
||||
return Crop(img, image.Rect(x0, y0, x1, y1))
|
||||
}
|
||||
|
||||
// Paste pastes the img image to the background image at the specified position and returns the combined image.
|
||||
func Paste(background, img image.Image, pos image.Point) *image.NRGBA {
|
||||
src := toNRGBA(img)
|
||||
dst := Clone(background) // cloned image bounds start at (0, 0)
|
||||
startPt := pos.Sub(background.Bounds().Min) // so we should translate start point
|
||||
endPt := startPt.Add(src.Bounds().Size())
|
||||
pasteBounds := image.Rectangle{startPt, endPt}
|
||||
|
||||
if dst.Bounds().Overlaps(pasteBounds) {
|
||||
intersectBounds := dst.Bounds().Intersect(pasteBounds)
|
||||
|
||||
rowSize := intersectBounds.Dx() * 4
|
||||
numRows := intersectBounds.Dy()
|
||||
|
||||
srcStartX := intersectBounds.Min.X - pasteBounds.Min.X
|
||||
srcStartY := intersectBounds.Min.Y - pasteBounds.Min.Y
|
||||
|
||||
i0 := dst.PixOffset(intersectBounds.Min.X, intersectBounds.Min.Y)
|
||||
j0 := src.PixOffset(srcStartX, srcStartY)
|
||||
|
||||
di := dst.Stride
|
||||
dj := src.Stride
|
||||
|
||||
for row := 0; row < numRows; row++ {
|
||||
copy(dst.Pix[i0:i0+rowSize], src.Pix[j0:j0+rowSize])
|
||||
i0 += di
|
||||
j0 += dj
|
||||
}
|
||||
}
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// PasteCenter pastes the img image to the center of the background image and returns the combined image.
|
||||
func PasteCenter(background, img image.Image) *image.NRGBA {
|
||||
bgBounds := background.Bounds()
|
||||
bgW := bgBounds.Dx()
|
||||
bgH := bgBounds.Dy()
|
||||
bgMinX := bgBounds.Min.X
|
||||
bgMinY := bgBounds.Min.Y
|
||||
|
||||
centerX := bgMinX + bgW/2
|
||||
centerY := bgMinY + bgH/2
|
||||
|
||||
x0 := centerX - img.Bounds().Dx()/2
|
||||
y0 := centerY - img.Bounds().Dy()/2
|
||||
|
||||
return Paste(background, img, image.Pt(x0, y0))
|
||||
}
|
||||
|
||||
// Overlay draws the img image over the background image at given position
|
||||
// and returns the combined image. Opacity parameter is the opacity of the img
|
||||
// image layer, used to compose the images, it must be from 0.0 to 1.0.
|
||||
//
|
||||
// Usage examples:
|
||||
//
|
||||
// // draw the sprite over the background at position (50, 50)
|
||||
// dstImage := imaging.Overlay(backgroundImage, spriteImage, image.Pt(50, 50), 1.0)
|
||||
//
|
||||
// // blend two opaque images of the same size
|
||||
// dstImage := imaging.Overlay(imageOne, imageTwo, image.Pt(0, 0), 0.5)
|
||||
//
|
||||
func Overlay(background, img image.Image, pos image.Point, opacity float64) *image.NRGBA {
|
||||
opacity = math.Min(math.Max(opacity, 0.0), 1.0) // check: 0.0 <= opacity <= 1.0
|
||||
|
||||
src := toNRGBA(img)
|
||||
dst := Clone(background) // cloned image bounds start at (0, 0)
|
||||
startPt := pos.Sub(background.Bounds().Min) // so we should translate start point
|
||||
endPt := startPt.Add(src.Bounds().Size())
|
||||
pasteBounds := image.Rectangle{startPt, endPt}
|
||||
|
||||
if dst.Bounds().Overlaps(pasteBounds) {
|
||||
intersectBounds := dst.Bounds().Intersect(pasteBounds)
|
||||
|
||||
for y := intersectBounds.Min.Y; y < intersectBounds.Max.Y; y++ {
|
||||
for x := intersectBounds.Min.X; x < intersectBounds.Max.X; x++ {
|
||||
i := y*dst.Stride + x*4
|
||||
|
||||
srcX := x - pasteBounds.Min.X
|
||||
srcY := y - pasteBounds.Min.Y
|
||||
j := srcY*src.Stride + srcX*4
|
||||
|
||||
a1 := float64(dst.Pix[i+3])
|
||||
a2 := float64(src.Pix[j+3])
|
||||
|
||||
coef2 := opacity * a2 / 255.0
|
||||
coef1 := (1 - coef2) * a1 / 255.0
|
||||
coefSum := coef1 + coef2
|
||||
coef1 /= coefSum
|
||||
coef2 /= coefSum
|
||||
|
||||
dst.Pix[i+0] = uint8(float64(dst.Pix[i+0])*coef1 + float64(src.Pix[j+0])*coef2)
|
||||
dst.Pix[i+1] = uint8(float64(dst.Pix[i+1])*coef1 + float64(src.Pix[j+1])*coef2)
|
||||
dst.Pix[i+2] = uint8(float64(dst.Pix[i+2])*coef1 + float64(src.Pix[j+2])*coef2)
|
||||
dst.Pix[i+3] = uint8(math.Min(a1+a2*opacity*(255.0-a1)/255.0, 255.0))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return dst
|
||||
}
|
|
@ -0,0 +1,250 @@
|
|||
package imaging
|
||||
|
||||
import (
|
||||
"image"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCrop(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
r image.Rectangle
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"Crop 2x3 2x1",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 2),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
|
||||
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
|
||||
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
|
||||
},
|
||||
},
|
||||
image.Rect(-1, 0, 1, 1),
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 2, 1),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := Crop(d.src, d.r)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCropCenter(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
w, h int
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"CropCenter 2x3 2x1",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 2),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
|
||||
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
|
||||
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
|
||||
},
|
||||
},
|
||||
2, 1,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 2, 1),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := CropCenter(d.src, d.w, d.h)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPaste(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src1 image.Image
|
||||
src2 image.Image
|
||||
p image.Point
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"Paste 2x3 2x1",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 2),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
|
||||
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
|
||||
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
|
||||
},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(1, 1, 3, 2),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
|
||||
},
|
||||
},
|
||||
image.Pt(-1, 0),
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 2, 3),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
|
||||
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := Paste(d.src1, d.src2, d.p)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasteCenter(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src1 image.Image
|
||||
src2 image.Image
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"PasteCenter 2x3 2x1",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 2),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
|
||||
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
|
||||
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
|
||||
},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(1, 1, 3, 2),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
|
||||
},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 2, 3),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
|
||||
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := PasteCenter(d.src1, d.src2)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOverlay(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src1 image.Image
|
||||
src2 image.Image
|
||||
p image.Point
|
||||
a float64
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"Overlay 2x3 2x1 1.0",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 2),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
|
||||
0x60, 0x00, 0x90, 0xff, 0xff, 0x00, 0x99, 0x7f,
|
||||
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
|
||||
},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(1, 1, 3, 2),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x20, 0x40, 0x80, 0x7f, 0xaa, 0xbb, 0xcc, 0xff,
|
||||
},
|
||||
},
|
||||
image.Pt(-1, 0),
|
||||
1.0,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 2, 3),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
|
||||
0x40, 0x1f, 0x88, 0xff, 0xaa, 0xbb, 0xcc, 0xff,
|
||||
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Overlay 2x2 2x2 0.5",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 1),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0xff, 0x00, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff,
|
||||
0x00, 0x00, 0xff, 0xff, 0x20, 0x20, 0x20, 0x00,
|
||||
},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 1),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
|
||||
0xff, 0xff, 0x00, 0xff, 0x20, 0x20, 0x20, 0xff,
|
||||
},
|
||||
},
|
||||
image.Pt(-1, -1),
|
||||
0.5,
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 2, 2),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0xff, 0x7f, 0x7f, 0xff, 0x00, 0xff, 0x00, 0xff,
|
||||
0x7f, 0x7f, 0x7f, 0xff, 0x20, 0x20, 0x20, 0x7f,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := Overlay(d.src1, d.src2, d.p, d.a)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 1) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,201 @@
|
|||
package imaging
|
||||
|
||||
import (
|
||||
"image"
|
||||
)
|
||||
|
||||
// Rotate90 rotates the image 90 degrees counterclockwise and returns the transformed image.
|
||||
func Rotate90(img image.Image) *image.NRGBA {
|
||||
src := toNRGBA(img)
|
||||
srcW := src.Bounds().Max.X
|
||||
srcH := src.Bounds().Max.Y
|
||||
dstW := srcH
|
||||
dstH := srcW
|
||||
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
|
||||
|
||||
parallel(dstH, func(partStart, partEnd int) {
|
||||
|
||||
for dstY := partStart; dstY < partEnd; dstY++ {
|
||||
for dstX := 0; dstX < dstW; dstX++ {
|
||||
srcX := dstH - dstY - 1
|
||||
srcY := dstX
|
||||
|
||||
srcOff := srcY*src.Stride + srcX*4
|
||||
dstOff := dstY*dst.Stride + dstX*4
|
||||
|
||||
copy(dst.Pix[dstOff:dstOff+4], src.Pix[srcOff:srcOff+4])
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// Rotate180 rotates the image 180 degrees counterclockwise and returns the transformed image.
|
||||
func Rotate180(img image.Image) *image.NRGBA {
|
||||
src := toNRGBA(img)
|
||||
srcW := src.Bounds().Max.X
|
||||
srcH := src.Bounds().Max.Y
|
||||
dstW := srcW
|
||||
dstH := srcH
|
||||
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
|
||||
|
||||
parallel(dstH, func(partStart, partEnd int) {
|
||||
|
||||
for dstY := partStart; dstY < partEnd; dstY++ {
|
||||
for dstX := 0; dstX < dstW; dstX++ {
|
||||
srcX := dstW - dstX - 1
|
||||
srcY := dstH - dstY - 1
|
||||
|
||||
srcOff := srcY*src.Stride + srcX*4
|
||||
dstOff := dstY*dst.Stride + dstX*4
|
||||
|
||||
copy(dst.Pix[dstOff:dstOff+4], src.Pix[srcOff:srcOff+4])
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// Rotate270 rotates the image 270 degrees counterclockwise and returns the transformed image.
|
||||
func Rotate270(img image.Image) *image.NRGBA {
|
||||
src := toNRGBA(img)
|
||||
srcW := src.Bounds().Max.X
|
||||
srcH := src.Bounds().Max.Y
|
||||
dstW := srcH
|
||||
dstH := srcW
|
||||
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
|
||||
|
||||
parallel(dstH, func(partStart, partEnd int) {
|
||||
|
||||
for dstY := partStart; dstY < partEnd; dstY++ {
|
||||
for dstX := 0; dstX < dstW; dstX++ {
|
||||
srcX := dstY
|
||||
srcY := dstW - dstX - 1
|
||||
|
||||
srcOff := srcY*src.Stride + srcX*4
|
||||
dstOff := dstY*dst.Stride + dstX*4
|
||||
|
||||
copy(dst.Pix[dstOff:dstOff+4], src.Pix[srcOff:srcOff+4])
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// FlipH flips the image horizontally (from left to right) and returns the transformed image.
|
||||
func FlipH(img image.Image) *image.NRGBA {
|
||||
src := toNRGBA(img)
|
||||
srcW := src.Bounds().Max.X
|
||||
srcH := src.Bounds().Max.Y
|
||||
dstW := srcW
|
||||
dstH := srcH
|
||||
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
|
||||
|
||||
parallel(dstH, func(partStart, partEnd int) {
|
||||
|
||||
for dstY := partStart; dstY < partEnd; dstY++ {
|
||||
for dstX := 0; dstX < dstW; dstX++ {
|
||||
srcX := dstW - dstX - 1
|
||||
srcY := dstY
|
||||
|
||||
srcOff := srcY*src.Stride + srcX*4
|
||||
dstOff := dstY*dst.Stride + dstX*4
|
||||
|
||||
copy(dst.Pix[dstOff:dstOff+4], src.Pix[srcOff:srcOff+4])
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// FlipV flips the image vertically (from top to bottom) and returns the transformed image.
|
||||
func FlipV(img image.Image) *image.NRGBA {
|
||||
src := toNRGBA(img)
|
||||
srcW := src.Bounds().Max.X
|
||||
srcH := src.Bounds().Max.Y
|
||||
dstW := srcW
|
||||
dstH := srcH
|
||||
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
|
||||
|
||||
parallel(dstH, func(partStart, partEnd int) {
|
||||
|
||||
for dstY := partStart; dstY < partEnd; dstY++ {
|
||||
for dstX := 0; dstX < dstW; dstX++ {
|
||||
srcX := dstX
|
||||
srcY := dstH - dstY - 1
|
||||
|
||||
srcOff := srcY*src.Stride + srcX*4
|
||||
dstOff := dstY*dst.Stride + dstX*4
|
||||
|
||||
copy(dst.Pix[dstOff:dstOff+4], src.Pix[srcOff:srcOff+4])
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// Transpose flips the image horizontally and rotates 90 degrees counter-clockwise.
|
||||
func Transpose(img image.Image) *image.NRGBA {
|
||||
src := toNRGBA(img)
|
||||
srcW := src.Bounds().Max.X
|
||||
srcH := src.Bounds().Max.Y
|
||||
dstW := srcH
|
||||
dstH := srcW
|
||||
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
|
||||
|
||||
parallel(dstH, func(partStart, partEnd int) {
|
||||
|
||||
for dstY := partStart; dstY < partEnd; dstY++ {
|
||||
for dstX := 0; dstX < dstW; dstX++ {
|
||||
srcX := dstY
|
||||
srcY := dstX
|
||||
|
||||
srcOff := srcY*src.Stride + srcX*4
|
||||
dstOff := dstY*dst.Stride + dstX*4
|
||||
|
||||
copy(dst.Pix[dstOff:dstOff+4], src.Pix[srcOff:srcOff+4])
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// Transverse flips the image vertically and rotates 90 degrees counter-clockwise.
|
||||
func Transverse(img image.Image) *image.NRGBA {
|
||||
src := toNRGBA(img)
|
||||
srcW := src.Bounds().Max.X
|
||||
srcH := src.Bounds().Max.Y
|
||||
dstW := srcH
|
||||
dstH := srcW
|
||||
dst := image.NewNRGBA(image.Rect(0, 0, dstW, dstH))
|
||||
|
||||
parallel(dstH, func(partStart, partEnd int) {
|
||||
|
||||
for dstY := partStart; dstY < partEnd; dstY++ {
|
||||
for dstX := 0; dstX < dstW; dstX++ {
|
||||
srcX := dstH - dstY - 1
|
||||
srcY := dstW - dstX - 1
|
||||
|
||||
srcOff := srcY*src.Stride + srcX*4
|
||||
dstOff := dstY*dst.Stride + dstX*4
|
||||
|
||||
copy(dst.Pix[dstOff:dstOff+4], src.Pix[srcOff:srcOff+4])
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
return dst
|
||||
}
|
|
@ -0,0 +1,261 @@
|
|||
package imaging
|
||||
|
||||
import (
|
||||
"image"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRotate90(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"Rotate90 2x3",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 2),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
|
||||
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
|
||||
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
|
||||
},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0xdd, 0xee, 0xff, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff,
|
||||
0x00, 0x11, 0x22, 0x33, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := Rotate90(d.src)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotate180(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"Rotate180 2x3",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 2),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
|
||||
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
|
||||
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
|
||||
},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 2, 3),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff, 0x00,
|
||||
0x00, 0xff, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00,
|
||||
0xcc, 0xdd, 0xee, 0xff, 0x00, 0x11, 0x22, 0x33,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := Rotate180(d.src)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotate270(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"Rotate270 2x3",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 2),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
|
||||
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
|
||||
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
|
||||
},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0x11, 0x22, 0x33,
|
||||
0x00, 0x00, 0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xcc, 0xdd, 0xee, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := Rotate270(d.src)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlipV(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"FlipV 2x3",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 2),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
|
||||
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
|
||||
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
|
||||
},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 2, 3),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
|
||||
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
|
||||
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := FlipV(d.src)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlipH(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"FlipH 2x3",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 2),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
|
||||
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
|
||||
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
|
||||
},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 2, 3),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0xcc, 0xdd, 0xee, 0xff, 0x00, 0x11, 0x22, 0x33,
|
||||
0x00, 0xff, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0xff, 0x00,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := FlipH(d.src)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranspose(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"Transpose 2x3",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 2),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
|
||||
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
|
||||
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
|
||||
},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x11, 0x22, 0x33, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00,
|
||||
0xcc, 0xdd, 0xee, 0xff, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := Transpose(d.src)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransverse(t *testing.T) {
|
||||
td := []struct {
|
||||
desc string
|
||||
src image.Image
|
||||
want *image.NRGBA
|
||||
}{
|
||||
{
|
||||
"Transverse 2x3",
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(-1, -1, 1, 2),
|
||||
Stride: 2 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x11, 0x22, 0x33, 0xcc, 0xdd, 0xee, 0xff,
|
||||
0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00,
|
||||
0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff,
|
||||
},
|
||||
},
|
||||
&image.NRGBA{
|
||||
Rect: image.Rect(0, 0, 3, 2),
|
||||
Stride: 3 * 4,
|
||||
Pix: []uint8{
|
||||
0x00, 0x00, 0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0xcc, 0xdd, 0xee, 0xff,
|
||||
0x00, 0x00, 0xff, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0x11, 0x22, 0x33,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, d := range td {
|
||||
got := Transverse(d.src)
|
||||
want := d.want
|
||||
if !compareNRGBA(got, want, 0) {
|
||||
t.Errorf("test [%s] failed: %#v", d.desc, got)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
package imaging
|
||||
|
||||
import (
|
||||
"math"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var parallelizationEnabled = true
|
||||
|
||||
// if GOMAXPROCS = 1: no goroutines used
|
||||
// if GOMAXPROCS > 1: spawn N=GOMAXPROCS workers in separate goroutines
|
||||
func parallel(dataSize int, fn func(partStart, partEnd int)) {
|
||||
numGoroutines := 1
|
||||
partSize := dataSize
|
||||
|
||||
if parallelizationEnabled {
|
||||
numProcs := runtime.GOMAXPROCS(0)
|
||||
if numProcs > 1 {
|
||||
numGoroutines = numProcs
|
||||
partSize = dataSize / (numGoroutines * 10)
|
||||
if partSize < 1 {
|
||||
partSize = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if numGoroutines == 1 {
|
||||
fn(0, dataSize)
|
||||
} else {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
idx := uint64(0)
|
||||
|
||||
for p := 0; p < numGoroutines; p++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
partStart := int(atomic.AddUint64(&idx, uint64(partSize))) - partSize
|
||||
if partStart >= dataSize {
|
||||
break
|
||||
}
|
||||
partEnd := partStart + partSize
|
||||
if partEnd > dataSize {
|
||||
partEnd = dataSize
|
||||
}
|
||||
fn(partStart, partEnd)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
func absint(i int) int {
|
||||
if i < 0 {
|
||||
return -i
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
// clamp & round float64 to uint8 (0..255)
|
||||
func clamp(v float64) uint8 {
|
||||
return uint8(math.Min(math.Max(v, 0.0), 255.0) + 0.5)
|
||||
}
|
||||
|
||||
// clamp int32 to uint8 (0..255)
|
||||
func clampint32(v int32) uint8 {
|
||||
if v < 0 {
|
||||
return 0
|
||||
} else if v > 255 {
|
||||
return 255
|
||||
}
|
||||
return uint8(v)
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package imaging
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func testParallelN(enabled bool, n, procs int) bool {
|
||||
data := make([]bool, n)
|
||||
before := runtime.GOMAXPROCS(0)
|
||||
runtime.GOMAXPROCS(procs)
|
||||
parallel(n, func(start, end int) {
|
||||
for i := start; i < end; i++ {
|
||||
data[i] = true
|
||||
}
|
||||
})
|
||||
for i := 0; i < n; i++ {
|
||||
if data[i] != true {
|
||||
return false
|
||||
}
|
||||
}
|
||||
runtime.GOMAXPROCS(before)
|
||||
return true
|
||||
}
|
||||
|
||||
func TestParallel(t *testing.T) {
|
||||
for _, e := range []bool{true, false} {
|
||||
for _, n := range []int{1, 10, 100, 1000} {
|
||||
for _, p := range []int{1, 2, 4, 8, 16, 100} {
|
||||
if testParallelN(e, n, p) != true {
|
||||
t.Errorf("test [parallel %v %d %d] failed", e, n, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClamp(t *testing.T) {
|
||||
td := []struct {
|
||||
f float64
|
||||
u uint8
|
||||
}{
|
||||
{0, 0},
|
||||
{255, 255},
|
||||
{128, 128},
|
||||
{0.49, 0},
|
||||
{0.50, 1},
|
||||
{254.9, 255},
|
||||
{254.0, 254},
|
||||
{256, 255},
|
||||
{2500, 255},
|
||||
{-10, 0},
|
||||
{127.6, 128},
|
||||
}
|
||||
|
||||
for _, d := range td {
|
||||
if clamp(d.f) != d.u {
|
||||
t.Errorf("test [clamp %v %v] failed: %v", d.f, d.u, clamp(d.f))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
Copyright (c) 2012 Brad Rydzewski
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
|
@ -0,0 +1,100 @@
|
|||
# routes.go
|
||||
a simple http routing API for the Go programming language
|
||||
|
||||
go get github.com/drone/routes
|
||||
|
||||
for more information see:
|
||||
http://gopkgdoc.appspot.com/pkg/github.com/bradrydzewski/routes
|
||||
|
||||
[](https://drone.io/drone/routes/latest)
|
||||
|
||||
## Getting Started
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/drone/routes"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func Whoami(w http.ResponseWriter, r *http.Request) {
|
||||
params := r.URL.Query()
|
||||
lastName := params.Get(":last")
|
||||
firstName := params.Get(":first")
|
||||
fmt.Fprintf(w, "you are %s %s", firstName, lastName)
|
||||
}
|
||||
|
||||
func main() {
|
||||
mux := routes.New()
|
||||
mux.Get("/:last/:first", Whoami)
|
||||
|
||||
http.Handle("/", mux)
|
||||
http.ListenAndServe(":8088", nil)
|
||||
}
|
||||
|
||||
### Route Examples
|
||||
You can create routes for all http methods:
|
||||
|
||||
mux.Get("/:param", handler)
|
||||
mux.Put("/:param", handler)
|
||||
mux.Post("/:param", handler)
|
||||
mux.Patch("/:param", handler)
|
||||
mux.Del("/:param", handler)
|
||||
|
||||
You can specify custom regular expressions for routes:
|
||||
|
||||
mux.Get("/files/:param(.+)", handler)
|
||||
|
||||
You can also create routes for static files:
|
||||
|
||||
pwd, _ := os.Getwd()
|
||||
mux.Static("/static", pwd)
|
||||
|
||||
this will serve any files in `/static`, including files in subdirectories. For example `/static/logo.gif` or `/static/style/main.css`.
|
||||
|
||||
## Filters / Middleware
|
||||
You can apply filters to routes, which is useful for enforcing security,
|
||||
redirects, etc.
|
||||
|
||||
You can, for example, filter all request to enforce some type of security:
|
||||
|
||||
var FilterUser = func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.User == nil || r.URL.User.Username() != "admin" {
|
||||
http.Error(w, "", http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
|
||||
r.Filter(FilterUser)
|
||||
|
||||
You can also apply filters only when certain REST URL Parameters exist:
|
||||
|
||||
r.Get("/:id", handler)
|
||||
r.Filter("id", func(rw http.ResponseWriter, r *http.Request) {
|
||||
...
|
||||
})
|
||||
|
||||
## Helper Functions
|
||||
You can use helper functions for serializing to Json and Xml. I found myself constantly writing code to serialize, set content type, content length, etc. Feel free to use these functions to eliminate redundant code in your app.
|
||||
|
||||
Helper function for serving Json, sets content type to `application/json`:
|
||||
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
mystruct := { ... }
|
||||
routes.ServeJson(w, &mystruct)
|
||||
}
|
||||
|
||||
Helper function for serving Xml, sets content type to `application/xml`:
|
||||
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
mystruct := { ... }
|
||||
routes.ServeXml(w, &mystruct)
|
||||
}
|
||||
|
||||
Helper function to serve Xml OR Json, depending on the value of the `Accept` header:
|
||||
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
mystruct := { ... }
|
||||
routes.ServeFormatted(w, r, &mystruct)
|
||||
}
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
package bench
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/drone/routes"
|
||||
gorilla "code.google.com/p/gorilla/mux"
|
||||
"github.com/bmizerany/pat"
|
||||
)
|
||||
|
||||
func HandlerOk(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "hello world")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// Benchmark_Routes runs a benchmark against our custom Mux using the
|
||||
// default settings.
|
||||
func Benchmark_Routes(b *testing.B) {
|
||||
|
||||
handler := routes.New()
|
||||
handler.Get("/person/:last/:first", HandlerOk)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Benchmark_Web runs a benchmark against the pat.go Mux using the
|
||||
// default settings.
|
||||
func Benchmark_Pat(b *testing.B) {
|
||||
|
||||
|
||||
|
||||
m := pat.New()
|
||||
m.Get("/person/:last/:first", http.HandlerFunc(HandlerOk))
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
|
||||
w := httptest.NewRecorder()
|
||||
m.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark_Gorilla runs a benchmark against the Gorilla Mux using
|
||||
// the default settings.
|
||||
func Benchmark_GorillaHandler(b *testing.B) {
|
||||
|
||||
|
||||
handler := gorilla.NewRouter()
|
||||
handler.HandleFunc("/person/{last}/{first}", HandlerOk)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark_ServeMux runs a benchmark against the ServeMux Go function.
|
||||
// We use this to determine performance impact of our library, when compared
|
||||
// to the out-of-the-box Mux provided by Go.
|
||||
func Benchmark_ServeMux(b *testing.B) {
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", HandlerOk)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
|
||||
w := httptest.NewRecorder()
|
||||
mux.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
/*
|
||||
Package routes a simple http routing API for the Go programming language,
|
||||
compatible with the standard http.ListenAndServe function.
|
||||
|
||||
Create a new route multiplexer:
|
||||
|
||||
mux := routes.New()
|
||||
|
||||
Define a simple route with a given method (ie Get, Put, Post ...), path and
|
||||
http.HandleFunc.
|
||||
|
||||
mux.Get("/foo", fooHandler)
|
||||
|
||||
Define a route with restful parameters in the path:
|
||||
|
||||
mux.Get("/:foo/:bar", func(w http.ResponseWriter, r *http.Request) {
|
||||
params := r.URL.Query()
|
||||
foo := params.Get(":foo")
|
||||
bar := params.Get(":bar")
|
||||
fmt.Fprintf(w, "%s %s", foo, bar)
|
||||
})
|
||||
|
||||
The parameters are parsed from the URL, and appended to the Request URL's
|
||||
query parameters.
|
||||
|
||||
More control over the route's parameter matching is possible by providing
|
||||
a custom regular expression:
|
||||
|
||||
mux.Get("/files/:file(.+)", handler)
|
||||
|
||||
To start the web server, use the standard http.ListenAndServe
|
||||
function, and provide the route multiplexer:
|
||||
|
||||
http.Handle("/", mux)
|
||||
http.ListenAndServe(":8000", nil)
|
||||
|
||||
*/
|
||||
package routes
|
|
@ -0,0 +1,107 @@
|
|||
# routes.go
|
||||
a simple http routing API for the Go programming language
|
||||
|
||||
go get github.com/drone/routes
|
||||
|
||||
for more information see:
|
||||
http://gopkgdoc.appspot.com/pkg/github.com/drone/routes
|
||||
|
||||
[](https://drone.io/drone/routes/latest)
|
||||
|
||||
## Getting Started
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/drone/routes"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func foobar (w http.ResponseWriter, r *http.Request) {
|
||||
c := routes.NewContext(r)
|
||||
foo := c.Params.Get(":foo")
|
||||
bar := c.Params.Get(":bar")
|
||||
fmt.Fprintf(w, "%s %s", foo, bar)
|
||||
}
|
||||
|
||||
func main() {
|
||||
r := routes.NewRouter()
|
||||
r.Get("/:bar/:foo", foobar)
|
||||
|
||||
http.Handle("/", r)
|
||||
http.ListenAndServe(":8088", nil)
|
||||
}
|
||||
|
||||
### Route Examples
|
||||
You can create routes for all http methods:
|
||||
|
||||
r.Get("/:param", handler)
|
||||
r.Put("/:param", handler)
|
||||
r.Post("/:param", handler)
|
||||
r.Patch("/:param", handler)
|
||||
r.Del("/:param", handler)
|
||||
|
||||
You can specify custom regular expressions for routes:
|
||||
|
||||
r.Get("/files/:param(.+)", handler)
|
||||
|
||||
You can also create routes for static files:
|
||||
|
||||
pwd, _ := os.Getwd()
|
||||
r.Static("/static", pwd)
|
||||
|
||||
this will serve any files in `/static`, including files in subdirectories. For
|
||||
example `/static/logo.gif` or `/static/style/main.css`.
|
||||
|
||||
## Filters / Middleware
|
||||
You can implement route filters to do things like enforce security, set session
|
||||
variables, etc
|
||||
|
||||
You can, for example, filter all request to enforce some type of security:
|
||||
|
||||
r.Filter(func(rw http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.User != "admin" {
|
||||
http.Error(w, "", http.StatusForbidden)
|
||||
}
|
||||
})
|
||||
|
||||
You can also apply filters only when certain REST URL Parameters exist:
|
||||
|
||||
r.Get("/:id", handler)
|
||||
r.Filter("id", func(rw http.ResponseWriter, r *http.Request) {
|
||||
c := routes.NewContext(r)
|
||||
id := c.Params.Get("id")
|
||||
|
||||
// verify the user has access to the specified resource id
|
||||
user := r.URL.User.Username()
|
||||
if HasAccess(user, id) == false {
|
||||
http.Error(w, "", http.StatusForbidden)
|
||||
}
|
||||
})
|
||||
|
||||
## Helper Functions
|
||||
You can use helper functions for serializing to Json and Xml. I found myself
|
||||
constantly writing code to serialize, set content type, content length, etc.
|
||||
Feel free to use these functions to eliminate redundant code in your app.
|
||||
|
||||
Helper function for serving Json, sets content type to `application/json`:
|
||||
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
mystruct := { ... }
|
||||
routes.ServeJson(w, &mystruct)
|
||||
}
|
||||
|
||||
Helper function for serving Xml, sets content type to `application/xml`:
|
||||
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
mystruct := { ... }
|
||||
routes.ServeXml(w, &mystruct)
|
||||
}
|
||||
|
||||
Helper function to serve Xml OR Json, depending on the value of the `Accept` header:
|
||||
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
mystruct := { ... }
|
||||
routes.ServeFormatted(w, r, &mystruct)
|
||||
}
|
|
@ -0,0 +1,132 @@
|
|||
package context
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Context stores data for the duration of the http.Request
|
||||
type Context struct {
|
||||
// named parameters that are passed in via RESTful URL Parameters
|
||||
Params Params
|
||||
|
||||
// named attributes that persist for the lifetime of the request
|
||||
Values Values
|
||||
|
||||
// reference to the parent http.Request
|
||||
req *http.Request
|
||||
}
|
||||
|
||||
// Retruns the Context associated with the http.Request.
|
||||
func Get(r *http.Request) *Context {
|
||||
|
||||
// get the context bound to the http.Request
|
||||
if v, ok := r.Body.(*wrapper); ok {
|
||||
return v.context
|
||||
}
|
||||
|
||||
// create a new context
|
||||
c := Context{ }
|
||||
c.Params = make(Params)
|
||||
c.Values = make(Values)
|
||||
c.req = r
|
||||
|
||||
// wrap the request and bind the context
|
||||
wrapper := wrap(r)
|
||||
wrapper.context = &c
|
||||
return &c
|
||||
}
|
||||
|
||||
// Retruns the parent http.Request to which the context is bound.
|
||||
func (c *Context) Request() *http.Request {
|
||||
return c.req
|
||||
}
|
||||
|
||||
// wrapper decorates an http.Request's Body (io.ReadCloser) so that we can
|
||||
// bind a Context to the Request. This is obviously a hack that i'd rather
|
||||
// avoid, however, it is for the greater good ...
|
||||
//
|
||||
// NOTE: If this turns out to be a really stupid approach we can use this
|
||||
// approach from the go mailing list: http://goo.gl/Vw13f which I
|
||||
// avoided because I didn't want a global lock
|
||||
type wrapper struct {
|
||||
body io.ReadCloser // the original message body
|
||||
context *Context
|
||||
}
|
||||
|
||||
func wrap(r *http.Request) *wrapper {
|
||||
w := wrapper{ body: r.Body }
|
||||
r.Body = &w
|
||||
return &w
|
||||
}
|
||||
|
||||
func (w *wrapper) Read(p []byte) (n int, err error) {
|
||||
return w.body.Read(p)
|
||||
}
|
||||
|
||||
func (w *wrapper) Close() error {
|
||||
return w.body.Close()
|
||||
}
|
||||
|
||||
// Parameter Map ---------------------------------------------------------------
|
||||
|
||||
// Params maps a string key to a list of values.
|
||||
type Params map[string]string
|
||||
|
||||
// Get gets the first value associated with the given key. If there are
|
||||
// no values associated with the key, Get returns the empty string.
|
||||
func (p Params) Get(key string) string {
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
return p[key]
|
||||
}
|
||||
|
||||
// Set sets the key to value. It replaces any existing values.
|
||||
func (p Params) Set(key, value string) {
|
||||
p[key] = value
|
||||
}
|
||||
|
||||
// Del deletes the values associated with key.
|
||||
func (p Params) Del(key string) {
|
||||
delete(p, key)
|
||||
}
|
||||
|
||||
// Value Map -------------------------------------------------------------------
|
||||
|
||||
// Values maps a string key to a list of values.
|
||||
type Values map[interface{}]interface{}
|
||||
|
||||
// Get gets the value associated with the given key. If there are
|
||||
// no values associated with the key, Get returns nil.
|
||||
func (v Values) Get(key interface{}) interface{} {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return v[key]
|
||||
}
|
||||
|
||||
// GetStr gets the value associated with the given key in string format.
|
||||
// If there are no values associated with the key, Get returns an
|
||||
// empty string.
|
||||
func (v Values) GetStr(key interface{}) interface{} {
|
||||
if v == nil { return "" }
|
||||
|
||||
val := v.Get(key)
|
||||
if val == nil { return "" }
|
||||
|
||||
str, ok := val.(string)
|
||||
if !ok { return "" }
|
||||
return str
|
||||
}
|
||||
|
||||
// Set sets the key to value. It replaces any existing values.
|
||||
func (v Values) Set(key, value interface{}) {
|
||||
v[key] = value
|
||||
}
|
||||
|
||||
// Del deletes the values associated with key.
|
||||
func (v Values) Del(key interface{}) {
|
||||
delete(v, key)
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
Copyright (c) 2011 Dmitry Chestnykh
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
|
@ -0,0 +1,99 @@
|
|||
Package authcookie
|
||||
=====================
|
||||
|
||||
import "github.com/dchest/authcookie"
|
||||
|
||||
Package authcookie implements creation and verification of signed
|
||||
authentication cookies.
|
||||
|
||||
Cookie is a Base64 encoded (using URLEncoding, from RFC 4648) string, which
|
||||
consists of concatenation of expiration time, login, and signature:
|
||||
|
||||
expiration time || login || signature
|
||||
|
||||
where expiration time is the number of seconds since Unix epoch UTC
|
||||
indicating when this cookie must expire (4 bytes, big-endian, uint32), login
|
||||
is a byte string of arbitrary length (at least 1 byte, not null-terminated),
|
||||
and signature is 32 bytes of HMAC-SHA256(expiration_time || login, k), where
|
||||
k = HMAC-SHA256(expiration_time || login, secret key).
|
||||
|
||||
Example:
|
||||
|
||||
secret := []byte("my secret key")
|
||||
|
||||
// Generate cookie valid for 24 hours for user "bender"
|
||||
cookie := authcookie.NewSinceNow("bender", 24 * time.Hour, secret)
|
||||
|
||||
// cookie is now:
|
||||
// Tajh02JlbmRlcskYMxowgwPj5QZ94jaxhDoh3n0Yp4hgGtUpeO0YbMTY
|
||||
// send it to user's browser..
|
||||
|
||||
// To authenticate a user later, receive cookie and:
|
||||
login := authcookie.Login(cookie, secret)
|
||||
if login != "" {
|
||||
// access for login granted
|
||||
} else {
|
||||
// access denied
|
||||
}
|
||||
|
||||
Note that login and expiration time are not encrypted, they are only signed
|
||||
and Base64 encoded.
|
||||
|
||||
|
||||
Variables
|
||||
---------
|
||||
|
||||
var (
|
||||
ErrMalformedCookie = errors.New("malformed cookie")
|
||||
ErrWrongSignature = errors.New("wrong cookie signature")
|
||||
)
|
||||
|
||||
|
||||
var MinLength = base64.URLEncoding.EncodedLen(decodedMinLength)
|
||||
|
||||
MinLength is the minimum allowed length of cookie string.
|
||||
|
||||
It is useful for avoiding DoS attacks with too long cookies: before passing
|
||||
a cookie to Parse or Login functions, check that it has length less than the
|
||||
[maximum login length allowed in your application] + MinLength.
|
||||
|
||||
|
||||
Functions
|
||||
---------
|
||||
|
||||
### func Login
|
||||
|
||||
func Login(cookie string, secret []byte) string
|
||||
|
||||
Login returns a valid login extracted from the given cookie and verified
|
||||
using the given secret key. If verification fails or the cookie expired,
|
||||
the function returns an empty string.
|
||||
|
||||
### func New
|
||||
|
||||
func New(login string, expires time.Time, secret []byte) string
|
||||
|
||||
New returns a signed authentication cookie for the given login,
|
||||
expiration time, and secret key.
|
||||
If the login is empty, the function returns an empty string.
|
||||
|
||||
### func NewSinceNow
|
||||
|
||||
func NewSinceNow(login string, dur time.Duration, secret []byte) string
|
||||
|
||||
NewSinceNow returns a signed authetication cookie for the given login,
|
||||
duration time since current time, and secret key.
|
||||
|
||||
### func Parse
|
||||
|
||||
func Parse(cookie string, secret []byte) (login string, expires time.Time, err error)
|
||||
|
||||
Parse verifies the given cookie with the secret key and returns login and
|
||||
expiration time extracted from the cookie. If the cookie fails verification
|
||||
or is not well-formed, the function returns an error.
|
||||
|
||||
Callers must:
|
||||
|
||||
1. Check for the returned error and deny access if it's present.
|
||||
|
||||
2. Check the returned expiration time and deny access if it's in the past.
|
|
@ -0,0 +1,154 @@
|
|||
// Package authcookie implements creation and verification of signed
|
||||
// authentication cookies.
|
||||
//
|
||||
// Cookie is a Base64 encoded (using URLEncoding, from RFC 4648) string, which
|
||||
// consists of concatenation of expiration time, login, and signature:
|
||||
//
|
||||
// expiration time || login || signature
|
||||
//
|
||||
// where expiration time is the number of seconds since Unix epoch UTC
|
||||
// indicating when this cookie must expire (4 bytes, big-endian, uint32), login
|
||||
// is a byte string of arbitrary length (at least 1 byte, not null-terminated),
|
||||
// and signature is 32 bytes of HMAC-SHA256(expiration_time || login, k), where
|
||||
// k = HMAC-SHA256(expiration_time || login, secret key).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// secret := []byte("my secret key")
|
||||
//
|
||||
// // Generate cookie valid for 24 hours for user "bender"
|
||||
// cookie := authcookie.NewSinceNow("bender", 24 * time.Hour, secret)
|
||||
//
|
||||
// // cookie is now:
|
||||
// // Tajh02JlbmRlcskYMxowgwPj5QZ94jaxhDoh3n0Yp4hgGtUpeO0YbMTY
|
||||
// // send it to user's browser..
|
||||
//
|
||||
// // To authenticate a user later, receive cookie and:
|
||||
// login := authcookie.Login(cookie, secret)
|
||||
// if login != "" {
|
||||
// // access for login granted
|
||||
// } else {
|
||||
// // access denied
|
||||
// }
|
||||
//
|
||||
// Note that login and expiration time are not encrypted, they are only signed
|
||||
// and Base64 encoded.
|
||||
//
|
||||
// For safety, the maximum length of base64-decoded cookie is limited to 1024
|
||||
// bytes.
|
||||
package authcookie
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
decodedMinLength = 4 /*expiration*/ + 1 /*login*/ + 32 /*signature*/
|
||||
decodedMaxLength = 1024 /* maximum decoded length, for safety */
|
||||
)
|
||||
|
||||
// MinLength is the minimum allowed length of cookie string.
|
||||
//
|
||||
// It is useful for avoiding DoS attacks with too long cookies: before passing
|
||||
// a cookie to Parse or Login functions, check that it has length less than the
|
||||
// [maximum login length allowed in your application] + MinLength.
|
||||
var MinLength = base64.URLEncoding.EncodedLen(decodedMinLength)
|
||||
|
||||
func getSignature(b []byte, secret []byte) []byte {
|
||||
keym := hmac.New(sha256.New, secret)
|
||||
keym.Write(b)
|
||||
m := hmac.New(sha256.New, keym.Sum(nil))
|
||||
m.Write(b)
|
||||
return m.Sum(nil)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrMalformedCookie = errors.New("malformed cookie")
|
||||
ErrWrongSignature = errors.New("wrong cookie signature")
|
||||
)
|
||||
|
||||
// New returns a signed authentication cookie for the given login,
|
||||
// expiration time, and secret key.
|
||||
// If the login is empty, the function returns an empty string.
|
||||
func New(login string, expires time.Time, secret []byte) string {
|
||||
if login == "" {
|
||||
return ""
|
||||
}
|
||||
llen := len(login)
|
||||
b := make([]byte, llen+4+32)
|
||||
// Put expiration time.
|
||||
binary.BigEndian.PutUint32(b, uint32(expires.Unix()))
|
||||
// Put login.
|
||||
copy(b[4:], []byte(login))
|
||||
// Calculate and put signature.
|
||||
sig := getSignature([]byte(b[:4+llen]), secret)
|
||||
copy(b[4+llen:], sig)
|
||||
// Base64-encode.
|
||||
return base64.URLEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
// NewSinceNow returns a signed authetication cookie for the given login,
|
||||
// duration since current time, and secret key.
|
||||
func NewSinceNow(login string, dur time.Duration, secret []byte) string {
|
||||
return New(login, time.Now().Add(dur), secret)
|
||||
}
|
||||
|
||||
// Parse verifies the given cookie with the secret key and returns login and
|
||||
// expiration time extracted from the cookie. If the cookie fails verification
|
||||
// or is not well-formed, the function returns an error.
|
||||
//
|
||||
// Callers must:
|
||||
//
|
||||
// 1. Check for the returned error and deny access if it's present.
|
||||
//
|
||||
// 2. Check the returned expiration time and deny access if it's in the past.
|
||||
//
|
||||
func Parse(cookie string, secret []byte) (login string, expires time.Time, err error) {
|
||||
blen := base64.URLEncoding.DecodedLen(len(cookie))
|
||||
// Avoid allocation if cookie is too short or too long.
|
||||
if blen < decodedMinLength || blen > decodedMaxLength {
|
||||
err = ErrMalformedCookie
|
||||
return
|
||||
}
|
||||
b, err := base64.URLEncoding.DecodeString(cookie)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Decoded length may be different from max length, which
|
||||
// we allocated, so check it, and set new length for b.
|
||||
blen = len(b)
|
||||
if blen < decodedMinLength {
|
||||
err = ErrMalformedCookie
|
||||
return
|
||||
}
|
||||
b = b[:blen]
|
||||
|
||||
sig := b[blen-32:]
|
||||
data := b[:blen-32]
|
||||
|
||||
realSig := getSignature(data, secret)
|
||||
if subtle.ConstantTimeCompare(realSig, sig) != 1 {
|
||||
err = ErrWrongSignature
|
||||
return
|
||||
}
|
||||
expires = time.Unix(int64(binary.BigEndian.Uint32(data[:4])), 0)
|
||||
login = string(data[4:])
|
||||
return
|
||||
}
|
||||
|
||||
// Login returns a valid login extracted from the given cookie and verified
|
||||
// using the given secret key. If verification fails or the cookie expired,
|
||||
// the function returns an empty string.
|
||||
func Login(cookie string, secret []byte) string {
|
||||
l, exp, err := Parse(cookie, secret)
|
||||
if err != nil || exp.Before(time.Now()) {
|
||||
return ""
|
||||
}
|
||||
return l
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
package authcookie
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
secret := []byte("secret key")
|
||||
good := "AAAAKmhlbGxvIHdvcmxk9p6koQvSacAeliAm445i7errSk1NPkYJGYZhF93wG9U="
|
||||
c := New("hello world", time.Unix(42, 0), secret)
|
||||
if c != good {
|
||||
t.Errorf("expected %q, got %q", good, c)
|
||||
}
|
||||
// Test empty login
|
||||
c = New("", time.Unix(42, 0), secret)
|
||||
if c != "" {
|
||||
t.Errorf(`allowed empty login: got %q, expected ""`, c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
// good
|
||||
sec := time.Now()
|
||||
login := "bender"
|
||||
key := []byte("another secret key")
|
||||
c := New(login, sec, key)
|
||||
l, e, err := Parse(c, key)
|
||||
if err != nil {
|
||||
t.Errorf("error parsing valid cookie: %s", err)
|
||||
}
|
||||
if l != login {
|
||||
t.Errorf("login: expected %q, got %q", login, l)
|
||||
}
|
||||
// NOTE: nanos are discarded internally since only 4 bytes of timestamp are used
|
||||
// so we can only compare seconds here
|
||||
if e.Unix() != sec.Unix() {
|
||||
t.Errorf("expiration: expected %v, got %v", sec, e)
|
||||
}
|
||||
// bad
|
||||
key = []byte("secret key")
|
||||
bad := []string{
|
||||
"",
|
||||
"AAAAKvgQ2I_RGePVk9oAu55q-Valnf__Fx_hlTM-dLwYxXOf",
|
||||
"badcookie",
|
||||
"AAAAAKmhlbGxvIHdvcmxk9p6koQvSacAeliAm445i7errSk1NPkYJGYZhF93wG9U=",
|
||||
"zAAAKmhlbGxvIHdvcmxk9p6koQvSacAeliAm445i7errSk1NPkYJGYZhF93wG9U=",
|
||||
"AAAAAKmhlbGxvIHdvcmxk9p6kiQvSacAeliAm445i7errSk1NPkYJGYZhF93wG9U=",
|
||||
}
|
||||
for _, v := range bad {
|
||||
_, _, err := Parse(v, key)
|
||||
if err == nil {
|
||||
t.Errorf("bad cookie didn't return error: %q", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogin(t *testing.T) {
|
||||
login := "~~~!|zoidberg|!~~~"
|
||||
key := []byte("(:€")
|
||||
exp := time.Now().Add(time.Second * 120)
|
||||
c := New(login, exp, key)
|
||||
l := Login(c, key)
|
||||
if l != login {
|
||||
t.Errorf("login: expected %q, got %q", login, l)
|
||||
}
|
||||
c = "no" + c
|
||||
l = Login(c, key)
|
||||
if l != "" {
|
||||
t.Errorf("login expected empty string, got %q", l)
|
||||
}
|
||||
exp = time.Now().Add(-(time.Second * 30))
|
||||
c = New(login, exp, key)
|
||||
l = Login(c, key)
|
||||
if l != "" {
|
||||
t.Errorf("returned login from expired cookie")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
package cookie
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/drone/routes/exp/cookie/authcookie"
|
||||
)
|
||||
|
||||
// Sign signs and timestamps a cookie so it cannot be forged.
|
||||
func Sign(cookie *http.Cookie, secret string, expires time.Time) {
|
||||
val := SignStr(cookie.Value, secret, expires)
|
||||
cookie.Value = val
|
||||
}
|
||||
|
||||
// SignStr signs and timestamps a string so it cannot be forged.
|
||||
//
|
||||
// Normally used via Sign, but provided as a separate method for
|
||||
// non-cookie uses. To decode a value not stored as a cookie use the
|
||||
// DecodeStr function.
|
||||
func SignStr(value, secret string, expires time.Time) string {
|
||||
return authcookie.New(value, expires, []byte(secret))
|
||||
}
|
||||
|
||||
// DecodeStr returns the given signed cookie value if it validates,
|
||||
// else returns an empty string.
|
||||
func Decode(cookie *http.Cookie, secret string) string {
|
||||
return DecodeStr(cookie.Value, secret)
|
||||
}
|
||||
|
||||
// DecodeStr returns the given signed value if it validates,
|
||||
// else returns an empty string.
|
||||
func DecodeStr(value, secret string) string {
|
||||
return authcookie.Login(value, []byte(secret))
|
||||
}
|
||||
|
||||
// Clear deletes the cookie with the given name.
|
||||
func Clear(w http.ResponseWriter, r *http.Request, name string) {
|
||||
cookie := http.Cookie{
|
||||
Name: name,
|
||||
Value: "deleted",
|
||||
Path: "/",
|
||||
Domain: r.URL.Host,
|
||||
MaxAge: -1,
|
||||
}
|
||||
|
||||
http.SetCookie(w, &cookie)
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,241 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/drone/routes/exp/context"
|
||||
)
|
||||
|
||||
const (
|
||||
DELETE = "DELETE"
|
||||
GET = "GET"
|
||||
HEAD = "HEAD"
|
||||
OPTIONS = "OPTIONS"
|
||||
PATCH = "PATCH"
|
||||
POST = "POST"
|
||||
PUT = "PUT"
|
||||
)
|
||||
|
||||
type route struct {
|
||||
method string
|
||||
regex *regexp.Regexp
|
||||
params map[int]string
|
||||
handler http.HandlerFunc
|
||||
}
|
||||
|
||||
type Router struct {
|
||||
sync.RWMutex
|
||||
routes []*route
|
||||
filters []http.HandlerFunc
|
||||
params map[string]interface{}
|
||||
}
|
||||
|
||||
func New() *Router {
|
||||
r := Router{}
|
||||
r.params = make(map[string]interface{})
|
||||
return &r
|
||||
}
|
||||
|
||||
// Get adds a new Route for GET requests.
|
||||
func (r *Router) Get(pattern string, handler http.HandlerFunc) {
|
||||
r.AddRoute(GET, pattern, handler)
|
||||
}
|
||||
|
||||
// Put adds a new Route for PUT requests.
|
||||
func (r *Router) Put(pattern string, handler http.HandlerFunc) {
|
||||
r.AddRoute(PUT, pattern, handler)
|
||||
}
|
||||
|
||||
// Del adds a new Route for DELETE requests.
|
||||
func (r *Router) Del(pattern string, handler http.HandlerFunc) {
|
||||
r.AddRoute(DELETE, pattern, handler)
|
||||
}
|
||||
|
||||
// Patch adds a new Route for PATCH requests.
|
||||
func (r *Router) Patch(pattern string, handler http.HandlerFunc) {
|
||||
r.AddRoute(PATCH, pattern, handler)
|
||||
}
|
||||
|
||||
// Post adds a new Route for POST requests.
|
||||
func (r *Router) Post(pattern string, handler http.HandlerFunc) {
|
||||
r.AddRoute(POST, pattern, handler)
|
||||
}
|
||||
|
||||
// Adds a new Route for Static http requests. Serves
|
||||
// static files from the specified directory
|
||||
func (r *Router) Static(pattern string, dir string) {
|
||||
//append a regex to the param to match everything
|
||||
// that comes after the prefix
|
||||
pattern = pattern + "(.+)"
|
||||
r.Get(pattern, func(w http.ResponseWriter, req *http.Request) {
|
||||
path := filepath.Clean(req.URL.Path)
|
||||
path = filepath.Join(dir, path)
|
||||
http.ServeFile(w, req, path)
|
||||
})
|
||||
}
|
||||
|
||||
// Adds a new Route to the Handler
|
||||
func (r *Router) AddRoute(method string, pattern string, handler http.HandlerFunc) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
//split the url into sections
|
||||
parts := strings.Split(pattern, "/")
|
||||
|
||||
//find params that start with ":"
|
||||
//replace with regular expressions
|
||||
j := 0
|
||||
params := make(map[int]string)
|
||||
for i, part := range parts {
|
||||
if strings.HasPrefix(part, ":") {
|
||||
expr := "([^/]+)"
|
||||
//a user may choose to override the defult expression
|
||||
// similar to expressjs: ‘/user/:id([0-9]+)’
|
||||
if index := strings.Index(part, "("); index != -1 {
|
||||
expr = part[index:]
|
||||
part = part[:index]
|
||||
}
|
||||
params[j] = part[1:]
|
||||
parts[i] = expr
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
//recreate the url pattern, with parameters replaced
|
||||
//by regular expressions. then compile the regex
|
||||
pattern = strings.Join(parts, "/")
|
||||
regex := regexp.MustCompile(pattern)
|
||||
|
||||
route := &route{
|
||||
method : method,
|
||||
regex : regex,
|
||||
handler : handler,
|
||||
params : params,
|
||||
}
|
||||
|
||||
//append to the list of Routes
|
||||
r.routes = append(r.routes, route)
|
||||
}
|
||||
|
||||
// Filter adds the middleware filter.
|
||||
func (r *Router) Filter(filter http.HandlerFunc) {
|
||||
r.Lock()
|
||||
r.filters = append(r.filters, filter)
|
||||
r.Unlock()
|
||||
}
|
||||
|
||||
// FilterParam adds the middleware filter iff the URL parameter exists.
|
||||
func (r *Router) FilterParam(param string, filter http.HandlerFunc) {
|
||||
r.Filter(func(w http.ResponseWriter, req *http.Request) {
|
||||
c := context.Get(req)
|
||||
if len(c.Params.Get(param)) > 0 { filter(w, req) }
|
||||
})
|
||||
}
|
||||
|
||||
// FilterPath adds the middleware filter iff the path matches the request.
|
||||
func (r *Router) FilterPath(path string, filter http.HandlerFunc) {
|
||||
pattern := path
|
||||
pattern = strings.Replace(pattern, "*", "(.+)", -1)
|
||||
pattern = strings.Replace(pattern, "**", "([^/]+)", -1)
|
||||
regex := regexp.MustCompile(pattern)
|
||||
r.Filter(func(w http.ResponseWriter, req *http.Request) {
|
||||
if regex.MatchString(req.URL.Path) { filter(w, req) }
|
||||
})
|
||||
}
|
||||
|
||||
// Required by http.Handler interface. This method is invoked by the
|
||||
// http server and will handle all page routing
|
||||
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
|
||||
//wrap the response writer in our custom interface
|
||||
w := &responseWriter{writer: rw, Router: r}
|
||||
|
||||
//find a matching Route
|
||||
for _, route := range r.routes {
|
||||
|
||||
//if the methods don't match, skip this handler
|
||||
//i.e if request.Method is 'PUT' Route.Method must be 'PUT'
|
||||
if req.Method != route.method {
|
||||
continue
|
||||
}
|
||||
|
||||
//check if Route pattern matches url
|
||||
if !route.regex.MatchString(req.URL.Path) {
|
||||
continue
|
||||
}
|
||||
|
||||
//get submatches (params)
|
||||
matches := route.regex.FindStringSubmatch(req.URL.Path)
|
||||
|
||||
//double check that the Route matches the URL pattern.
|
||||
if len(matches[0]) != len(req.URL.Path) {
|
||||
continue
|
||||
}
|
||||
|
||||
//create the http.Requests context
|
||||
c := context.Get(req)
|
||||
|
||||
//add url parameters to the context
|
||||
for i, match := range matches[1:] {
|
||||
c.Params.Set(route.params[i], match)
|
||||
}
|
||||
|
||||
//execute middleware filters
|
||||
for _, filter := range r.filters {
|
||||
filter(w, req)
|
||||
if w.started { return }
|
||||
}
|
||||
|
||||
//invoke the request handler
|
||||
route.handler(w, req)
|
||||
return
|
||||
}
|
||||
|
||||
//if no matches to url, throw a not found exception
|
||||
if w.started == false {
|
||||
http.NotFound(w, req)
|
||||
}
|
||||
}
|
||||
|
||||
// responseWriter is a wrapper for the http.ResponseWriter to track if
|
||||
// response was written to, and to store a reference to the router.
|
||||
type responseWriter struct {
|
||||
Router *Router
|
||||
writer http.ResponseWriter
|
||||
started bool
|
||||
status int
|
||||
}
|
||||
|
||||
// Header returns the header map that will be sent by WriteHeader.
|
||||
func (w *responseWriter) Header() http.Header {
|
||||
return w.writer.Header()
|
||||
}
|
||||
|
||||
// Write writes the data to the connection as part of an HTTP reply,
|
||||
// and sets `started` to true
|
||||
func (w *responseWriter) Write(p []byte) (int, error) {
|
||||
w.started = true
|
||||
return w.writer.Write(p)
|
||||
}
|
||||
|
||||
// WriteHeader sends an HTTP response header with status code,
|
||||
// and sets `started` to true
|
||||
func (w *responseWriter) WriteHeader(code int) {
|
||||
w.status = code
|
||||
w.started = true
|
||||
w.writer.WriteHeader(code)
|
||||
}
|
||||
|
||||
// The Hijacker interface is implemented by ResponseWriters that allow an
|
||||
// HTTP handler to take over the connection.
|
||||
func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return w.writer.(http.Hijacker).Hijack()
|
||||
}
|
|
@ -0,0 +1,227 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"github.com/drone/routes/exp/context"
|
||||
)
|
||||
|
||||
func HandlerOk(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "hello world")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func HandlerSetVar(w http.ResponseWriter, r *http.Request) {
|
||||
c := context.Get(r)
|
||||
c.Values.Set("password", "z1on")
|
||||
}
|
||||
|
||||
func HandlerErr(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// TestRouteOk tests that the route is correctly handled, and the URL parameters
|
||||
// are added to the Context.
|
||||
func TestRouteOk(t *testing.T) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
mux := New()
|
||||
mux.Get("/person/:last/:first", HandlerOk)
|
||||
mux.ServeHTTP(w, r)
|
||||
|
||||
c := context.Get(r)
|
||||
lastNameParam := c.Params.Get("last")
|
||||
firstNameParam := c.Params.Get("first")
|
||||
|
||||
if lastNameParam != "anderson" {
|
||||
t.Errorf("url param set to [%s]; want [%s]", lastNameParam, "anderson")
|
||||
}
|
||||
if firstNameParam != "thomas" {
|
||||
t.Errorf("url param set to [%s]; want [%s]", firstNameParam, "thomas")
|
||||
}
|
||||
if w.Body.String() != "hello world" {
|
||||
t.Errorf("Body set to [%s]; want [%s]", w.Body.String(), "hello world")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilter tests that a route is filtered prior to handling
|
||||
func TestRouteFilter(t *testing.T) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
mux := New()
|
||||
mux.Filter(HandlerSetVar)
|
||||
mux.Get("/person/:last/:first", HandlerOk)
|
||||
mux.ServeHTTP(w, r)
|
||||
|
||||
c := context.Get(r)
|
||||
password := c.Values.Get("password")
|
||||
|
||||
if password != "z1on" {
|
||||
t.Errorf("session variable set to [%s]; want [%s]", password, "z1on")
|
||||
}
|
||||
if w.Body.String() != "hello world" {
|
||||
t.Errorf("Body set to [%s]; want [%s]", w.Body.String(), "hello world")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterHalt tests that a route is filtered prior to handling, and then
|
||||
// halts execution (by writing to the response).
|
||||
func TestRouteFilterHalt(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
mux := New()
|
||||
mux.Filter(HandlerErr)
|
||||
mux.Get("/person/:last/:first", HandlerOk)
|
||||
mux.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != 400 {
|
||||
t.Errorf("Code set to [%s]; want [%s]", w.Code, http.StatusBadRequest)
|
||||
}
|
||||
if w.Body.String() == "hello world" {
|
||||
t.Errorf("Body set to [%s]; want empty", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestRouterFilterParam tests the Parameter filter, and ensures the
|
||||
// filter is only executed when the specified Parameter exists.
|
||||
func TestRouterFilterParam(t *testing.T) {
|
||||
// in the first test scenario, the Parameter filter should not
|
||||
// be triggered because the "codename" variab does not exist
|
||||
r, _ := http.NewRequest("GET", "/neo", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
mux := New()
|
||||
mux.Filter(HandlerSetVar)
|
||||
mux.FilterParam("codename", HandlerErr)
|
||||
mux.Get("/:nickname", HandlerOk)
|
||||
mux.ServeHTTP(w, r)
|
||||
|
||||
if w.Body.String() != "hello world" {
|
||||
t.Errorf("Body set to [%s]; want [%s]", w.Body.String(), "hello world")
|
||||
}
|
||||
|
||||
// in this second scenario, the Parameter filter SHOULD fire, and should
|
||||
// halt the request
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
mux = New()
|
||||
mux.Filter(HandlerSetVar)
|
||||
mux.FilterParam("codename", HandlerErr)
|
||||
mux.Get("/:codename", HandlerOk)
|
||||
mux.ServeHTTP(w, r)
|
||||
|
||||
if w.Body.String() == "hello world" {
|
||||
t.Errorf("Body set to [%s]; want empty", w.Body.String())
|
||||
}
|
||||
if w.Code != 400 {
|
||||
t.Errorf("Code set to [%s]; want [%s]", w.Code, http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRouterFilterPath tests the Path filter, and ensures the filter
|
||||
// is only executed when the Request Path matches the filter Path.
|
||||
func TestRouterFilterPath(t *testing.T) {
|
||||
// in the first test scenario, the Path filter should not fire
|
||||
// because it does not take the "first name" section of the URL
|
||||
// into account, and should therefore not match
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
mux := New()
|
||||
mux.FilterPath("/person/*/anderson", HandlerErr)
|
||||
mux.Get("/person/:last/:first", HandlerOk)
|
||||
mux.ServeHTTP(w, r)
|
||||
|
||||
if w.Body.String() != "hello world" {
|
||||
t.Errorf("Body set to [%s]; want [%s]", w.Body.String(), "hello world")
|
||||
}
|
||||
|
||||
// in this second scenario, the Parameter filter SHOULD fire because
|
||||
// we are filtering on all "last names", and the pattern should match
|
||||
// the first section of the URL (person) and the last section of the
|
||||
// url (:first)
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
mux = New()
|
||||
mux.FilterPath("/person/*/thomas", HandlerErr)
|
||||
mux.Get("/person/:last/:first", HandlerOk)
|
||||
mux.ServeHTTP(w, r)
|
||||
|
||||
if w.Body.String() == "hello world" {
|
||||
t.Errorf("Body set to [%s]; want empty", w.Body.String())
|
||||
}
|
||||
if w.Code != 400 {
|
||||
t.Errorf("Code set to [%s]; want [%s]", w.Code, http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotFound tests that a 404 code is returned in the
|
||||
// response if no route matches the request url.
|
||||
func TestNotFound(t *testing.T) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
mux := New()
|
||||
mux.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("Code set to [%s]; want [%s]", w.Code, http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark_Routes runs a benchmark against our custom Mux using the
|
||||
// default settings.
|
||||
func Benchmark_Routes(b *testing.B) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
|
||||
w := httptest.NewRecorder()
|
||||
mux := New()
|
||||
mux.Get("/person/:last/:first", HandlerOk)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
mux.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark_Routes_x30 runs a benchmark against our custom Mux using the
|
||||
// default settings, but with 30 routes
|
||||
func Benchmark_Routes_x30(b *testing.B) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
|
||||
w := httptest.NewRecorder()
|
||||
mux := New()
|
||||
for i:=0;i<30;i++ {
|
||||
mux.Get(fmt.Sprintf("/%v/:last/:first",i), HandlerOk)
|
||||
}
|
||||
|
||||
// and we'll make the matching URL the LAST in the list
|
||||
mux.Get("/person/:last/:first", HandlerOk)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
mux.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark_ServeMux runs a benchmark against the ServeMux Go function.
|
||||
// We use this to determine performance impact of our library, when compared
|
||||
// to the out-of-the-box Mux provided by Go.
|
||||
func Benchmark_ServeMux(b *testing.B) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
|
||||
w := httptest.NewRecorder()
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", HandlerOk)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r.URL.Query().Get("learn")
|
||||
mux.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,107 @@
|
|||
# routes.go
|
||||
a simple http routing API for the Go programming language
|
||||
|
||||
go get github.com/drone/routes
|
||||
|
||||
for more information see:
|
||||
http://gopkgdoc.appspot.com/pkg/github.com/drone/routes
|
||||
|
||||
[](https://drone.io/drone/routes/latest)
|
||||
|
||||
## Getting Started
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/drone/routes"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func foobar (w http.ResponseWriter, r *http.Request) {
|
||||
c := routes.NewContext(r)
|
||||
foo := c.Params.Get(":foo")
|
||||
bar := c.Params.Get(":bar")
|
||||
fmt.Fprintf(w, "%s %s", foo, bar)
|
||||
}
|
||||
|
||||
func main() {
|
||||
r := routes.NewRouter()
|
||||
r.Get("/:bar/:foo", foobar)
|
||||
|
||||
http.Handle("/", r)
|
||||
http.ListenAndServe(":8088", nil)
|
||||
}
|
||||
|
||||
### Route Examples
|
||||
You can create routes for all http methods:
|
||||
|
||||
r.Get("/:param", handler)
|
||||
r.Put("/:param", handler)
|
||||
r.Post("/:param", handler)
|
||||
r.Patch("/:param", handler)
|
||||
r.Del("/:param", handler)
|
||||
|
||||
You can specify custom regular expressions for routes:
|
||||
|
||||
r.Get("/files/:param(.+)", handler)
|
||||
|
||||
You can also create routes for static files:
|
||||
|
||||
pwd, _ := os.Getwd()
|
||||
r.Static("/static", pwd)
|
||||
|
||||
this will serve any files in `/static`, including files in subdirectories. For
|
||||
example `/static/logo.gif` or `/static/style/main.css`.
|
||||
|
||||
## Filters / Middleware
|
||||
You can implement route filters to do things like enforce security, set session
|
||||
variables, etc
|
||||
|
||||
You can, for example, filter all request to enforce some type of security:
|
||||
|
||||
r.Filter(func(rw http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.User != "admin" {
|
||||
http.Error(w, "", http.StatusForbidden)
|
||||
}
|
||||
})
|
||||
|
||||
You can also apply filters only when certain REST URL Parameters exist:
|
||||
|
||||
r.Get("/:id", handler)
|
||||
r.Filter("id", func(rw http.ResponseWriter, r *http.Request) {
|
||||
c := routes.NewContext(r)
|
||||
id := c.Params.Get("id")
|
||||
|
||||
// verify the user has access to the specified resource id
|
||||
user := r.URL.User.Username()
|
||||
if HasAccess(user, id) == false {
|
||||
http.Error(w, "", http.StatusForbidden)
|
||||
}
|
||||
})
|
||||
|
||||
## Helper Functions
|
||||
You can use helper functions for serializing to Json and Xml. I found myself
|
||||
constantly writing code to serialize, set content type, content length, etc.
|
||||
Feel free to use these functions to eliminate redundant code in your app.
|
||||
|
||||
Helper function for serving Json, sets content type to `application/json`:
|
||||
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
mystruct := { ... }
|
||||
routes.ServeJson(w, &mystruct)
|
||||
}
|
||||
|
||||
Helper function for serving Xml, sets content type to `application/xml`:
|
||||
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
mystruct := { ... }
|
||||
routes.ServeXml(w, &mystruct)
|
||||
}
|
||||
|
||||
Helper function to serve Xml OR Json, depending on the value of the `Accept` header:
|
||||
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
mystruct := { ... }
|
||||
routes.ServeFormatted(w, r, &mystruct)
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
package bench
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/drone/routes/exp/routes"
|
||||
gorilla "code.google.com/p/gorilla/mux"
|
||||
"github.com/bmizerany/pat"
|
||||
)
|
||||
|
||||
func HandlerOk(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "hello world")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// Benchmark_Routes runs a benchmark against our custom Mux using the
|
||||
// default settings.
|
||||
func Benchmark_Routes(b *testing.B) {
|
||||
|
||||
handler := routes.NewRouter()
|
||||
handler.Get("/person/:last/:first", HandlerOk)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark_Web runs a benchmark against the pat.go Mux using the
|
||||
// default settings.
|
||||
func Benchmark_Pat(b *testing.B) {
|
||||
|
||||
m := pat.New()
|
||||
m.Get("/person/:last/:first", http.HandlerFunc(HandlerOk))
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
|
||||
w := httptest.NewRecorder()
|
||||
m.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark_Gorilla runs a benchmark against the Gorilla Mux using
|
||||
// the default settings.
|
||||
func Benchmark_GorillaHandler(b *testing.B) {
|
||||
|
||||
handler := gorilla.NewRouter()
|
||||
handler.HandleFunc("/person/{last}/{first}", HandlerOk)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark_ServeMux runs a benchmark against the ServeMux Go function.
|
||||
// We use this to determine performance impact of our library, when compared
|
||||
// to the out-of-the-box Mux provided by Go.
|
||||
func Benchmark_ServeMux(b *testing.B) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
|
||||
w := httptest.NewRecorder()
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", HandlerOk)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
mux.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,132 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Context stores data for the duration of the http.Request
|
||||
type Context struct {
|
||||
// named parameters that are passed in via RESTful URL Parameters
|
||||
Params Params
|
||||
|
||||
// named attributes that persist for the lifetime of the request
|
||||
Values Values
|
||||
|
||||
// reference to the parent http.Request
|
||||
req *http.Request
|
||||
}
|
||||
|
||||
// Retruns the Context associated with the http.Request.
|
||||
func NewContext(r *http.Request) *Context {
|
||||
|
||||
// get the context bound to the http.Request
|
||||
if v, ok := r.Body.(*wrapper); ok {
|
||||
return v.context
|
||||
}
|
||||
|
||||
// create a new context
|
||||
c := Context{ }
|
||||
c.Params = make(Params)
|
||||
c.Values = make(Values)
|
||||
c.req = r
|
||||
|
||||
// wrap the request and bind the context
|
||||
wrapper := wrap(r)
|
||||
wrapper.context = &c
|
||||
return &c
|
||||
}
|
||||
|
||||
// Retruns the parent http.Request to which the context is bound.
|
||||
func (c *Context) Request() *http.Request {
|
||||
return c.req
|
||||
}
|
||||
|
||||
// wrapper decorates an http.Request's Body (io.ReadCloser) so that we can
|
||||
// bind a Context to the Request. This is obviously a hack that i'd rather
|
||||
// avoid, however, it is for the greater good ...
|
||||
//
|
||||
// NOTE: If this turns out to be a really stupid approach we can use this
|
||||
// approach from the go mailing list: http://goo.gl/Vw13f which I
|
||||
// avoided because I didn't want a global lock
|
||||
type wrapper struct {
|
||||
body io.ReadCloser // the original message body
|
||||
context *Context
|
||||
}
|
||||
|
||||
func wrap(r *http.Request) *wrapper {
|
||||
w := wrapper{ body: r.Body }
|
||||
r.Body = &w
|
||||
return &w
|
||||
}
|
||||
|
||||
func (w *wrapper) Read(p []byte) (n int, err error) {
|
||||
return w.body.Read(p)
|
||||
}
|
||||
|
||||
func (w *wrapper) Close() error {
|
||||
return w.body.Close()
|
||||
}
|
||||
|
||||
// Parameter Map ---------------------------------------------------------------
|
||||
|
||||
// Params maps a string key to a list of values.
|
||||
type Params map[string]string
|
||||
|
||||
// Get gets the first value associated with the given key. If there are
|
||||
// no values associated with the key, Get returns the empty string.
|
||||
func (p Params) Get(key string) string {
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
return p[key]
|
||||
}
|
||||
|
||||
// Set sets the key to value. It replaces any existing values.
|
||||
func (p Params) Set(key, value string) {
|
||||
p[key] = value
|
||||
}
|
||||
|
||||
// Del deletes the values associated with key.
|
||||
func (p Params) Del(key string) {
|
||||
delete(p, key)
|
||||
}
|
||||
|
||||
// Value Map -------------------------------------------------------------------
|
||||
|
||||
// Values maps a string key to a list of values.
|
||||
type Values map[interface{}]interface{}
|
||||
|
||||
// Get gets the value associated with the given key. If there are
|
||||
// no values associated with the key, Get returns nil.
|
||||
func (v Values) Get(key interface{}) interface{} {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return v[key]
|
||||
}
|
||||
|
||||
// GetStr gets the value associated with the given key in string format.
|
||||
// If there are no values associated with the key, Get returns an
|
||||
// empty string.
|
||||
func (v Values) GetStr(key interface{}) interface{} {
|
||||
if v == nil { return "" }
|
||||
|
||||
val := v.Get(key)
|
||||
if val == nil { return "" }
|
||||
|
||||
str, ok := val.(string)
|
||||
if !ok { return "" }
|
||||
return str
|
||||
}
|
||||
|
||||
// Set sets the key to value. It replaces any existing values.
|
||||
func (v Values) Set(key, value interface{}) {
|
||||
v[key] = value
|
||||
}
|
||||
|
||||
// Del deletes the values associated with key.
|
||||
func (v Values) Del(key interface{}) {
|
||||
delete(v, key)
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
/*
|
||||
Package routes a simple http routing API for the Go programming language,
|
||||
compatible with the standard http.ListenAndServe function.
|
||||
|
||||
Create a new route multiplexer:
|
||||
|
||||
r := routes.NewRouter()
|
||||
|
||||
Define a simple route with a given method (ie Get, Put, Post ...), path and
|
||||
http.HandleFunc.
|
||||
|
||||
r.Get("/foo", fooHandler)
|
||||
|
||||
Define a route with restful parameters in the path:
|
||||
|
||||
r.Get("/:foo/:bar", func(rw http.ResponseWriter, req *http.Request) {
|
||||
c := routes.NewContext(req)
|
||||
foo := c.Params.Get("foo")
|
||||
bar := c.Params.Get("bar")
|
||||
fmt.Fprintf(rw, "%s %s", foo, bar)
|
||||
})
|
||||
|
||||
The parameters are parsed from the URL, and stored in the Request Context.
|
||||
|
||||
More control over the route's parameter matching is possible by providing
|
||||
a custom regular expression:
|
||||
|
||||
r.Get("/files/:file(.+)", handler)
|
||||
|
||||
To start the web server, use the standard http.ListenAndServe
|
||||
function, and provide the route multiplexer:
|
||||
|
||||
http.Handle("/", r)
|
||||
http.ListenAndServe(":8000", nil)
|
||||
|
||||
*/
|
||||
package routes
|
|
@ -0,0 +1,96 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Helper Functions to Read from the http.Request Body -------------------------
|
||||
|
||||
// ReadJson parses the JSON-encoded data in the http.Request object and
|
||||
// stores the result in the value pointed to by v.
|
||||
func ReadJson(r *http.Request, v interface{}) error {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
r.Body.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(body, v)
|
||||
}
|
||||
|
||||
// ReadXml parses the XML-encoded data in the http.Request object and
|
||||
// stores the result in the value pointed to by v.
|
||||
func ReadXml(r *http.Request, v interface{}) error {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
r.Body.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return xml.Unmarshal(body, v)
|
||||
}
|
||||
|
||||
// Helper Functions to Write to the http.ReponseWriter -------------------------
|
||||
|
||||
// ServeJson writes the JSON representation of resource v to the
|
||||
// http.ResponseWriter.
|
||||
func ServeJson(w http.ResponseWriter, v interface{}) {
|
||||
content, err := json.MarshalIndent(v, "", " ")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(content)))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(content)
|
||||
}
|
||||
|
||||
// ServeXml writes the XML representation of resource v to the
|
||||
// http.ResponseWriter.
|
||||
func ServeXml(w http.ResponseWriter, v interface{}) {
|
||||
content, err := xml.Marshal(v)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(content)))
|
||||
w.Header().Set("Content-Type", "text/xml; charset=utf-8")
|
||||
w.Write(content)
|
||||
}
|
||||
|
||||
// ServeTemplate applies the named template to the specified data map and
|
||||
// writes the output to the http.ResponseWriter.
|
||||
func ServeTemplate(w http.ResponseWriter, name string, data map[string]interface{}) {
|
||||
// cast the writer to the resposneWriter, get the router
|
||||
r := w.(*responseWriter).Router
|
||||
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
|
||||
if data == nil {
|
||||
data = map[string]interface{}{}
|
||||
}
|
||||
|
||||
// append global params to the template
|
||||
for k, v := range r.params {
|
||||
data[k] = v
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := r.views.ExecuteTemplate(&buf, name, data); err != nil {
|
||||
panic(err)
|
||||
return
|
||||
}
|
||||
|
||||
// set the content length, type, etc
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.Write(buf.Bytes())
|
||||
}
|
||||
|
||||
// Error will terminate the http Request with the specified error code.
|
||||
func Error(w http.ResponseWriter, code int) {
|
||||
http.Error(w, http.StatusText(code), code)
|
||||
}
|
|
@ -0,0 +1,237 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
const (
|
||||
DELETE = "DELETE"
|
||||
GET = "GET"
|
||||
HEAD = "HEAD"
|
||||
OPTIONS = "OPTIONS"
|
||||
PATCH = "PATCH"
|
||||
POST = "POST"
|
||||
PUT = "PUT"
|
||||
)
|
||||
|
||||
type route struct {
|
||||
method string
|
||||
regex *regexp.Regexp
|
||||
params map[int]string
|
||||
handler http.HandlerFunc
|
||||
}
|
||||
|
||||
type Router struct {
|
||||
sync.RWMutex
|
||||
routes []*route
|
||||
filters []http.HandlerFunc
|
||||
views *template.Template
|
||||
params map[string]interface{}
|
||||
}
|
||||
|
||||
func NewRouter() *Router {
|
||||
r := Router{}
|
||||
r.params = make(map[string]interface{})
|
||||
return &r
|
||||
}
|
||||
|
||||
// Get adds a new Route for GET requests.
|
||||
func (r *Router) Get(pattern string, handler http.HandlerFunc) {
|
||||
r.AddRoute(GET, pattern, handler)
|
||||
}
|
||||
|
||||
// Put adds a new Route for PUT requests.
|
||||
func (r *Router) Put(pattern string, handler http.HandlerFunc) {
|
||||
r.AddRoute(PUT, pattern, handler)
|
||||
}
|
||||
|
||||
// Del adds a new Route for DELETE requests.
|
||||
func (r *Router) Del(pattern string, handler http.HandlerFunc) {
|
||||
r.AddRoute(DELETE, pattern, handler)
|
||||
}
|
||||
|
||||
// Patch adds a new Route for PATCH requests.
|
||||
func (r *Router) Patch(pattern string, handler http.HandlerFunc) {
|
||||
r.AddRoute(PATCH, pattern, handler)
|
||||
}
|
||||
|
||||
// Post adds a new Route for POST requests.
|
||||
func (r *Router) Post(pattern string, handler http.HandlerFunc) {
|
||||
r.AddRoute(POST, pattern, handler)
|
||||
}
|
||||
|
||||
// Adds a new Route for Static http requests. Serves
|
||||
// static files from the specified directory
|
||||
func (r *Router) Static(pattern string, dir string) {
|
||||
//append a regex to the param to match everything
|
||||
// that comes after the prefix
|
||||
pattern = pattern + "(.+)"
|
||||
r.Get(pattern, func(w http.ResponseWriter, req *http.Request) {
|
||||
path := filepath.Clean(req.URL.Path)
|
||||
path = filepath.Join(dir, path)
|
||||
http.ServeFile(w, req, path)
|
||||
})
|
||||
}
|
||||
|
||||
// Adds a new Route to the Handler
|
||||
func (r *Router) AddRoute(method string, pattern string, handler http.HandlerFunc) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
//split the url into sections
|
||||
parts := strings.Split(pattern, "/")
|
||||
|
||||
//find params that start with ":"
|
||||
//replace with regular expressions
|
||||
j := 0
|
||||
params := make(map[int]string)
|
||||
for i, part := range parts {
|
||||
if strings.HasPrefix(part, ":") {
|
||||
expr := "([^/]+)"
|
||||
//a user may choose to override the defult expression
|
||||
// similar to expressjs: ‘/user/:id([0-9]+)’
|
||||
if index := strings.Index(part, "("); index != -1 {
|
||||
expr = part[index:]
|
||||
part = part[:index]
|
||||
}
|
||||
params[j] = part[1:]
|
||||
parts[i] = expr
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
//recreate the url pattern, with parameters replaced
|
||||
//by regular expressions. then compile the regex
|
||||
pattern = strings.Join(parts, "/")
|
||||
regex, regexErr := regexp.Compile(pattern)
|
||||
if regexErr != nil {
|
||||
panic(regexErr)
|
||||
}
|
||||
|
||||
route := &route{
|
||||
method : method,
|
||||
regex : regex,
|
||||
handler : handler,
|
||||
params : params,
|
||||
}
|
||||
|
||||
//append to the list of Routes
|
||||
r.routes = append(r.routes, route)
|
||||
}
|
||||
|
||||
// Filter adds the middleware filter.
|
||||
func (r *Router) Filter(filter http.HandlerFunc) {
|
||||
r.Lock()
|
||||
r.filters = append(r.filters, filter)
|
||||
r.Unlock()
|
||||
}
|
||||
|
||||
// FilterParam adds the middleware filter iff the URL parameter exists.
|
||||
func (r *Router) FilterParam(param string, filter http.HandlerFunc) {
|
||||
r.Filter(func(w http.ResponseWriter, req *http.Request) {
|
||||
c := NewContext(req)
|
||||
if len(c.Params.Get(param)) > 0 { filter(w, req) }
|
||||
})
|
||||
}
|
||||
|
||||
// Set stores the specified key / value pair.
|
||||
func (r *Router) Set(name string, value interface{}) {
|
||||
r.Lock()
|
||||
r.params[name] = value
|
||||
r.Unlock()
|
||||
}
|
||||
|
||||
// SetEnv stores the specified environment variable as a key / value pair. If
|
||||
// the environment variable is not set the default value will be used
|
||||
func (r *Router) SetEnv(name, value string) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
env := os.Getenv(name)
|
||||
if len(env) == 0 { env = value }
|
||||
r.Set(name, env)
|
||||
}
|
||||
|
||||
// Required by http.Handler interface. This method is invoked by the
|
||||
// http server and will handle all page routing
|
||||
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
|
||||
//wrap the response writer in our custom interface
|
||||
w := &responseWriter{writer: rw, Router: r}
|
||||
|
||||
//find a matching Route
|
||||
for _, route := range r.routes {
|
||||
|
||||
//if the methods don't match, skip this handler
|
||||
//i.e if request.Method is 'PUT' Route.Method must be 'PUT'
|
||||
if req.Method != route.method {
|
||||
continue
|
||||
}
|
||||
|
||||
//check if Route pattern matches url
|
||||
if !route.regex.MatchString(req.URL.Path) {
|
||||
continue
|
||||
}
|
||||
|
||||
//get submatches (params)
|
||||
matches := route.regex.FindStringSubmatch(req.URL.Path)
|
||||
|
||||
//double check that the Route matches the URL pattern.
|
||||
if len(matches[0]) != len(req.URL.Path) {
|
||||
continue
|
||||
}
|
||||
|
||||
//create the http.Requests context
|
||||
c := NewContext(req)
|
||||
|
||||
//add url parameters to the context
|
||||
for i, match := range matches[1:] {
|
||||
c.Params.Set(route.params[i], match)
|
||||
}
|
||||
|
||||
//execute middleware filters
|
||||
for _, filter := range r.filters {
|
||||
filter(w, req)
|
||||
if w.started { return }
|
||||
}
|
||||
|
||||
//invoke the request handler
|
||||
route.handler(w, req)
|
||||
return
|
||||
}
|
||||
|
||||
//if no matches to url, throw a not found exception
|
||||
if w.started == false {
|
||||
http.NotFound(w, req)
|
||||
}
|
||||
}
|
||||
|
||||
// Template uses the provided template definitions.
|
||||
func (r *Router) Template(t *template.Template) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
r.views = template.Must(t.Clone())
|
||||
}
|
||||
|
||||
// TemplateFiles parses the template definitions from the named files.
|
||||
func (r *Router) TemplateFiles(filenames ...string) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
r.views = template.Must(template.ParseFiles(filenames...))
|
||||
}
|
||||
|
||||
// TemplateGlob parses the template definitions from the files identified
|
||||
// by the pattern, which must match at least one file.
|
||||
func (r *Router) TemplateGlob(pattern string) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
r.views = template.Must(template.ParseGlob(pattern))
|
||||
}
|
|
@ -0,0 +1,189 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func HandlerOk(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "hello world")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func HandlerSetVar(w http.ResponseWriter, r *http.Request) {
|
||||
c := NewContext(r)
|
||||
c.Values.Set("password", "z1on")
|
||||
}
|
||||
|
||||
func HandlerErr(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// TestRouteOk tests that the route is correctly handled, and the URL parameters
|
||||
// are added to the Context.
|
||||
func TestRouteOk(t *testing.T) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
mux := NewRouter()
|
||||
mux.Get("/person/:last/:first", HandlerOk)
|
||||
mux.ServeHTTP(w, r)
|
||||
|
||||
c := NewContext(r)
|
||||
lastNameParam := c.Params.Get("last")
|
||||
firstNameParam := c.Params.Get("first")
|
||||
|
||||
if lastNameParam != "anderson" {
|
||||
t.Errorf("url param set to [%s]; want [%s]", lastNameParam, "anderson")
|
||||
}
|
||||
if firstNameParam != "thomas" {
|
||||
t.Errorf("url param set to [%s]; want [%s]", firstNameParam, "thomas")
|
||||
}
|
||||
if w.Body.String() != "hello world" {
|
||||
t.Errorf("Body set to [%s]; want [%s]", w.Body.String(), "hello world")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilter tests that a route is filtered prior to handling
|
||||
func TestRouteFilter(t *testing.T) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
mux := NewRouter()
|
||||
mux.Filter(HandlerSetVar)
|
||||
mux.Get("/person/:last/:first", HandlerOk)
|
||||
mux.ServeHTTP(w, r)
|
||||
|
||||
c := NewContext(r)
|
||||
password := c.Values.Get("password")
|
||||
|
||||
if password != "z1on" {
|
||||
t.Errorf("session variable set to [%s]; want [%s]", password, "z1on")
|
||||
}
|
||||
if w.Body.String() != "hello world" {
|
||||
t.Errorf("Body set to [%s]; want [%s]", w.Body.String(), "hello world")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterHalt tests that a route is filtered prior to handling, and then
|
||||
// halts execution (by writing to the response).
|
||||
func TestRouteFilterHalt(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
mux := NewRouter()
|
||||
mux.Filter(HandlerErr)
|
||||
mux.Get("/person/:last/:first", HandlerOk)
|
||||
mux.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != 400 {
|
||||
t.Errorf("Code set to [%s]; want [%s]", w.Code, http.StatusBadRequest)
|
||||
}
|
||||
if w.Body.String() == "hello world" {
|
||||
t.Errorf("Body set to [%s]; want empty", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestParam tests the Parameter filter, and ensures the filter is only
|
||||
// executed when the specified Parameter exists.
|
||||
func TestParam(t *testing.T) {
|
||||
// in the first test scenario, the Parameter filter should not
|
||||
// be triggered because the "codename" variab does not exist
|
||||
r, _ := http.NewRequest("GET", "/neo", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
mux := NewRouter()
|
||||
mux.Filter(HandlerSetVar)
|
||||
mux.FilterParam("codename", HandlerErr)
|
||||
mux.Get("/:nickname", HandlerOk)
|
||||
mux.ServeHTTP(w, r)
|
||||
|
||||
if w.Body.String() != "hello world" {
|
||||
t.Errorf("Body set to [%s]; want [%s]", w.Body.String(), "hello world")
|
||||
}
|
||||
|
||||
// in this second scenario, the Parameter filter SHOULD fire, and should
|
||||
// halt the request
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
mux = NewRouter()
|
||||
mux.Filter(HandlerSetVar)
|
||||
mux.FilterParam("codename", HandlerErr)
|
||||
mux.Get("/:codename", HandlerOk)
|
||||
mux.ServeHTTP(w, r)
|
||||
|
||||
if w.Body.String() == "hello world" {
|
||||
t.Errorf("Body set to [%s]; want empty", w.Body.String())
|
||||
}
|
||||
if w.Code != 400 {
|
||||
t.Errorf("Code set to [%s]; want [%s]", w.Code, http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
// TestTemplate tests template rendering
|
||||
func TestTemplate(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
tmpl, _ := template.New("template.html").Parse("<html><head><title>{{ .Title }}</title><body>{{ .Name }}</body></html>")
|
||||
|
||||
mux := NewRouter()
|
||||
mux.Template(tmpl)
|
||||
mux.Set("Title", "Matrix")
|
||||
mux.ExecuteTemplate(w, "template.html", map[string]interface{}{ "Name" : "Morpheus" })
|
||||
|
||||
if w.Body.String() != "<html><head><title>Matrix</title><body>Morpheus</body></html>" {
|
||||
t.Errorf("template not rendered correctly [%s]", w.Body.String())
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// TestNotFound tests that a 404 code is returned in the
|
||||
// response if no route matches the request url.
|
||||
func TestNotFound(t *testing.T) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
mux := NewRouter()
|
||||
mux.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("Code set to [%s]; want [%s]", w.Code, http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark_Routes runs a benchmark against our custom Mux using the
|
||||
// default settings.
|
||||
func Benchmark_Routes(b *testing.B) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
|
||||
w := httptest.NewRecorder()
|
||||
mux := NewRouter()
|
||||
mux.Get("/person/:last/:first", HandlerOk)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
mux.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark_ServeMux runs a benchmark against the ServeMux Go function.
|
||||
// We use this to determine performance impact of our library, when compared
|
||||
// to the out-of-the-box Mux provided by Go.
|
||||
func Benchmark_ServeMux(b *testing.B) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
|
||||
w := httptest.NewRecorder()
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", HandlerOk)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r.URL.Query().Get("learn")
|
||||
mux.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// ResponseWriter is a wrapper for the http.ResponseWriter to track if
|
||||
// response was written to.
|
||||
type responseWriter struct {
|
||||
Router *Router
|
||||
writer http.ResponseWriter
|
||||
started bool
|
||||
status int
|
||||
}
|
||||
|
||||
// Header returns the header map that will be sent by WriteHeader.
|
||||
func (w *responseWriter) Header() http.Header {
|
||||
return w.writer.Header()
|
||||
}
|
||||
|
||||
// Write writes the data to the connection as part of an HTTP reply,
|
||||
// and sets `started` to true
|
||||
func (w *responseWriter) Write(p []byte) (int, error) {
|
||||
w.started = true
|
||||
return w.writer.Write(p)
|
||||
}
|
||||
|
||||
// WriteHeader sends an HTTP response header with status code,
|
||||
// and sets `started` to true
|
||||
func (w *responseWriter) WriteHeader(code int) {
|
||||
w.status = code
|
||||
w.started = true
|
||||
w.writer.WriteHeader(code)
|
||||
}
|
||||
|
||||
// The Hijacker interface is implemented by ResponseWriters that allow an
|
||||
// HTTP handler to take over the connection.
|
||||
func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return w.writer.(http.Hijacker).Hijack()
|
||||
}
|
|
@ -0,0 +1,84 @@
|
|||
package user
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"github.com/drone/routes/exp/context"
|
||||
)
|
||||
|
||||
// Key used to store the user in the session
|
||||
const userKey = "_user"
|
||||
|
||||
// User represents a user of the application.
|
||||
type User struct {
|
||||
Id string // the unique permanent ID of the user.
|
||||
Name string // the human-readable ID of the user.
|
||||
Email string
|
||||
Photo string
|
||||
|
||||
FederatedIdentity string
|
||||
FederatedProvider string
|
||||
|
||||
// additional, custom Attributes
|
||||
Attrs map[string]string
|
||||
}
|
||||
|
||||
// Decode will create a user from a URL Query string.
|
||||
func Decode(v string) *User {
|
||||
values, err := url.ParseQuery(v)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
attrs := map[string]string{}
|
||||
for key, _ := range values {
|
||||
attrs[key]=values.Get(key)
|
||||
}
|
||||
|
||||
return &User {
|
||||
Id : values.Get("id"),
|
||||
Name : values.Get("name"),
|
||||
Email : values.Get("email"),
|
||||
Photo : values.Get("photo"),
|
||||
Attrs : attrs,
|
||||
}
|
||||
}
|
||||
|
||||
// Encode will encode a user as a URL query string.
|
||||
func (u *User) Encode() string {
|
||||
values := url.Values{}
|
||||
|
||||
// add custom attributes
|
||||
if u.Attrs != nil {
|
||||
for key, val := range u.Attrs {
|
||||
values.Set(key, val)
|
||||
}
|
||||
}
|
||||
|
||||
values.Set("id", u.Id)
|
||||
values.Set("name", u.Name)
|
||||
values.Set("email", u.Email)
|
||||
values.Set("photo", u.Photo)
|
||||
return values.Encode()
|
||||
}
|
||||
|
||||
// Current returns the currently logged-in user, or nil if the user is not
|
||||
// signed in.
|
||||
func Current(c *context.Context) *User {
|
||||
v := c.Values.Get(userKey)
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
u, ok := v.(*User)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
// Set sets the currently logged-in user. This is typically used by middleware
|
||||
// that handles user authentication.
|
||||
func Set(c *context.Context, u *User) {
|
||||
c.Values.Set(userKey, u)
|
||||
}
|
|
@ -0,0 +1,317 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
CONNECT = "CONNECT"
|
||||
DELETE = "DELETE"
|
||||
GET = "GET"
|
||||
HEAD = "HEAD"
|
||||
OPTIONS = "OPTIONS"
|
||||
PATCH = "PATCH"
|
||||
POST = "POST"
|
||||
PUT = "PUT"
|
||||
TRACE = "TRACE"
|
||||
)
|
||||
|
||||
//commonly used mime-types
|
||||
const (
|
||||
applicationJson = "application/json"
|
||||
applicationXml = "application/xml"
|
||||
textXml = "text/xml"
|
||||
)
|
||||
|
||||
type route struct {
|
||||
method string
|
||||
regex *regexp.Regexp
|
||||
params map[int]string
|
||||
handler http.HandlerFunc
|
||||
}
|
||||
|
||||
type RouteMux struct {
|
||||
routes []*route
|
||||
filters []http.HandlerFunc
|
||||
}
|
||||
|
||||
func New() *RouteMux {
|
||||
return &RouteMux{}
|
||||
}
|
||||
|
||||
// Get adds a new Route for GET requests.
|
||||
func (m *RouteMux) Get(pattern string, handler http.HandlerFunc) {
|
||||
m.AddRoute(GET, pattern, handler)
|
||||
}
|
||||
|
||||
// Put adds a new Route for PUT requests.
|
||||
func (m *RouteMux) Put(pattern string, handler http.HandlerFunc) {
|
||||
m.AddRoute(PUT, pattern, handler)
|
||||
}
|
||||
|
||||
// Del adds a new Route for DELETE requests.
|
||||
func (m *RouteMux) Del(pattern string, handler http.HandlerFunc) {
|
||||
m.AddRoute(DELETE, pattern, handler)
|
||||
}
|
||||
|
||||
// Patch adds a new Route for PATCH requests.
|
||||
func (m *RouteMux) Patch(pattern string, handler http.HandlerFunc) {
|
||||
m.AddRoute(PATCH, pattern, handler)
|
||||
}
|
||||
|
||||
// Post adds a new Route for POST requests.
|
||||
func (m *RouteMux) Post(pattern string, handler http.HandlerFunc) {
|
||||
m.AddRoute(POST, pattern, handler)
|
||||
}
|
||||
|
||||
// Adds a new Route for Static http requests. Serves
|
||||
// static files from the specified directory
|
||||
func (m *RouteMux) Static(pattern string, dir string) {
|
||||
//append a regex to the param to match everything
|
||||
// that comes after the prefix
|
||||
pattern = pattern + "(.+)"
|
||||
m.AddRoute(GET, pattern, func(w http.ResponseWriter, r *http.Request) {
|
||||
path := filepath.Clean(r.URL.Path)
|
||||
path = filepath.Join(dir, path)
|
||||
http.ServeFile(w, r, path)
|
||||
})
|
||||
}
|
||||
|
||||
// Adds a new Route to the Handler
|
||||
func (m *RouteMux) AddRoute(method string, pattern string, handler http.HandlerFunc) {
|
||||
|
||||
//split the url into sections
|
||||
parts := strings.Split(pattern, "/")
|
||||
|
||||
//find params that start with ":"
|
||||
//replace with regular expressions
|
||||
j := 0
|
||||
params := make(map[int]string)
|
||||
for i, part := range parts {
|
||||
if strings.HasPrefix(part, ":") {
|
||||
expr := "([^/]+)"
|
||||
//a user may choose to override the defult expression
|
||||
// similar to expressjs: ‘/user/:id([0-9]+)’
|
||||
if index := strings.Index(part, "("); index != -1 {
|
||||
expr = part[index:]
|
||||
part = part[:index]
|
||||
}
|
||||
params[j] = part
|
||||
parts[i] = expr
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
//recreate the url pattern, with parameters replaced
|
||||
//by regular expressions. then compile the regex
|
||||
pattern = strings.Join(parts, "/")
|
||||
regex, regexErr := regexp.Compile(pattern)
|
||||
if regexErr != nil {
|
||||
//TODO add error handling here to avoid panic
|
||||
panic(regexErr)
|
||||
return
|
||||
}
|
||||
|
||||
//now create the Route
|
||||
route := &route{}
|
||||
route.method = method
|
||||
route.regex = regex
|
||||
route.handler = handler
|
||||
route.params = params
|
||||
|
||||
//and finally append to the list of Routes
|
||||
m.routes = append(m.routes, route)
|
||||
}
|
||||
|
||||
// Filter adds the middleware filter.
|
||||
func (m *RouteMux) Filter(filter http.HandlerFunc) {
|
||||
m.filters = append(m.filters, filter)
|
||||
}
|
||||
|
||||
// FilterParam adds the middleware filter iff the REST URL parameter exists.
|
||||
func (m *RouteMux) FilterParam(param string, filter http.HandlerFunc) {
|
||||
if !strings.HasPrefix(param,":") {
|
||||
param = ":"+param
|
||||
}
|
||||
|
||||
m.Filter(func(w http.ResponseWriter, r *http.Request) {
|
||||
p := r.URL.Query().Get(param)
|
||||
if len(p) > 0 { filter(w, r) }
|
||||
})
|
||||
}
|
||||
|
||||
// Required by http.Handler interface. This method is invoked by the
|
||||
// http server and will handle all page routing
|
||||
func (m *RouteMux) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
requestPath := r.URL.Path
|
||||
|
||||
//wrap the response writer, in our custom interface
|
||||
w := &responseWriter{writer: rw}
|
||||
|
||||
//find a matching Route
|
||||
for _, route := range m.routes {
|
||||
|
||||
//if the methods don't match, skip this handler
|
||||
//i.e if request.Method is 'PUT' Route.Method must be 'PUT'
|
||||
if r.Method != route.method {
|
||||
continue
|
||||
}
|
||||
|
||||
//check if Route pattern matches url
|
||||
if !route.regex.MatchString(requestPath) {
|
||||
continue
|
||||
}
|
||||
|
||||
//get submatches (params)
|
||||
matches := route.regex.FindStringSubmatch(requestPath)
|
||||
|
||||
//double check that the Route matches the URL pattern.
|
||||
if len(matches[0]) != len(requestPath) {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(route.params) > 0 {
|
||||
//add url parameters to the query param map
|
||||
values := r.URL.Query()
|
||||
for i, match := range matches[1:] {
|
||||
values.Add(route.params[i], match)
|
||||
}
|
||||
|
||||
//reassemble query params and add to RawQuery
|
||||
r.URL.RawQuery = url.Values(values).Encode() + "&" + r.URL.RawQuery
|
||||
//r.URL.RawQuery = url.Values(values).Encode()
|
||||
}
|
||||
|
||||
//execute middleware filters
|
||||
for _, filter := range m.filters {
|
||||
filter(w, r)
|
||||
if w.started {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
//Invoke the request handler
|
||||
route.handler(w, r)
|
||||
break
|
||||
}
|
||||
|
||||
//if no matches to url, throw a not found exception
|
||||
if w.started == false {
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Simple wrapper around a ResponseWriter
|
||||
|
||||
// responseWriter is a wrapper for the http.ResponseWriter
|
||||
// to track if response was written to. It also allows us
|
||||
// to automatically set certain headers, such as Content-Type,
|
||||
// Access-Control-Allow-Origin, etc.
|
||||
type responseWriter struct {
|
||||
writer http.ResponseWriter
|
||||
started bool
|
||||
status int
|
||||
}
|
||||
|
||||
// Header returns the header map that will be sent by WriteHeader.
|
||||
func (w *responseWriter) Header() http.Header {
|
||||
return w.writer.Header()
|
||||
}
|
||||
|
||||
// Write writes the data to the connection as part of an HTTP reply,
|
||||
// and sets `started` to true
|
||||
func (w *responseWriter) Write(p []byte) (int, error) {
|
||||
w.started = true
|
||||
return w.writer.Write(p)
|
||||
}
|
||||
|
||||
// WriteHeader sends an HTTP response header with status code,
|
||||
// and sets `started` to true
|
||||
func (w *responseWriter) WriteHeader(code int) {
|
||||
w.status = code
|
||||
w.started = true
|
||||
w.writer.WriteHeader(code)
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Below are helper functions to replace boilerplate
|
||||
// code that serializes resources and writes to the
|
||||
// http response.
|
||||
|
||||
// ServeJson replies to the request with a JSON
|
||||
// representation of resource v.
|
||||
func ServeJson(w http.ResponseWriter, v interface{}) {
|
||||
content, err := json.MarshalIndent(v, "", " ")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(content)))
|
||||
w.Header().Set("Content-Type", applicationJson)
|
||||
w.Write(content)
|
||||
}
|
||||
|
||||
// ReadJson will parses the JSON-encoded data in the http
|
||||
// Request object and stores the result in the value
|
||||
// pointed to by v.
|
||||
func ReadJson(r *http.Request, v interface{}) error {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
r.Body.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(body, v)
|
||||
}
|
||||
|
||||
// ServeXml replies to the request with an XML
|
||||
// representation of resource v.
|
||||
func ServeXml(w http.ResponseWriter, v interface{}) {
|
||||
content, err := xml.Marshal(v)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(content)))
|
||||
w.Header().Set("Content-Type", "text/xml; charset=utf-8")
|
||||
w.Write(content)
|
||||
}
|
||||
|
||||
// ReadXml will parses the XML-encoded data in the http
|
||||
// Request object and stores the result in the value
|
||||
// pointed to by v.
|
||||
func ReadXml(r *http.Request, v interface{}) error {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
r.Body.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return xml.Unmarshal(body, v)
|
||||
}
|
||||
|
||||
// ServeFormatted replies to the request with
|
||||
// a formatted representation of resource v, in the
|
||||
// format requested by the client specified in the
|
||||
// Accept header.
|
||||
func ServeFormatted(w http.ResponseWriter, r *http.Request, v interface{}) {
|
||||
accept := r.Header.Get("Accept")
|
||||
switch accept {
|
||||
case applicationJson:
|
||||
ServeJson(w, v)
|
||||
case applicationXml, textXml:
|
||||
ServeXml(w, v)
|
||||
default:
|
||||
ServeJson(w, v)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
|
@ -0,0 +1,193 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var HandlerOk = func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "hello world")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
var HandlerErr = func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
var FilterUser = func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.User == nil || r.URL.User.Username() != "admin" {
|
||||
http.Error(w, "", http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
|
||||
var FilterId = func(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.URL.Query().Get(":id")
|
||||
if id == "admin" {
|
||||
http.Error(w, "", http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthOk tests that an Auth handler will append the
|
||||
// username and password to to the request URL, and will
|
||||
// continue processing the request by invoking the handler.
|
||||
func TestRouteOk(t *testing.T) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler := new(RouteMux)
|
||||
handler.Get("/person/:last/:first", HandlerOk)
|
||||
handler.ServeHTTP(w, r)
|
||||
|
||||
lastNameParam := r.URL.Query().Get(":last")
|
||||
firstNameParam := r.URL.Query().Get(":first")
|
||||
learnParam := r.URL.Query().Get("learn")
|
||||
|
||||
if lastNameParam != "anderson" {
|
||||
t.Errorf("url param set to [%s]; want [%s]", lastNameParam, "anderson")
|
||||
}
|
||||
if firstNameParam != "thomas" {
|
||||
t.Errorf("url param set to [%s]; want [%s]", firstNameParam, "thomas")
|
||||
}
|
||||
if learnParam != "kungfu" {
|
||||
t.Errorf("url param set to [%s]; want [%s]", learnParam, "kungfu")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotFound tests that a 404 code is returned in the
|
||||
// response if no route matches the request url.
|
||||
func TestNotFound(t *testing.T) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler := new(RouteMux)
|
||||
handler.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("Code set to [%v]; want [%v]", w.Code, http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStatic tests the ability to serve static
|
||||
// content from the filesystem
|
||||
func TestStatic(t *testing.T) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "/routes_test.go", nil)
|
||||
w := httptest.NewRecorder()
|
||||
pwd, _ := os.Getwd()
|
||||
|
||||
handler := new(RouteMux)
|
||||
handler.Static("/", pwd)
|
||||
handler.ServeHTTP(w, r)
|
||||
|
||||
testFile, _ := ioutil.ReadFile(pwd + "/routes_test.go")
|
||||
if w.Body.String() != string(testFile) {
|
||||
t.Errorf("handler.Static failed to serve file")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilter tests the ability to apply middleware function
|
||||
// to filter all routes
|
||||
func TestFilter(t *testing.T) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler := new(RouteMux)
|
||||
handler.Get("/", HandlerOk)
|
||||
handler.Filter(FilterUser)
|
||||
handler.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Did not apply Filter. Code set to [%v]; want [%v]", w.Code, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
r, _ = http.NewRequest("GET", "/", nil)
|
||||
r.URL.User = url.User("admin")
|
||||
w = httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Code set to [%v]; want [%v]", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterParam tests the ability to apply middleware
|
||||
// function to filter all routes with specified parameter
|
||||
// in the REST url
|
||||
func TestFilterParam(t *testing.T) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "/:id", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// first test that the param filter does not trigger
|
||||
handler := new(RouteMux)
|
||||
handler.Get("/", HandlerOk)
|
||||
handler.Get("/:id", HandlerOk)
|
||||
handler.FilterParam("id", FilterId)
|
||||
handler.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Code set to [%v]; want [%v]", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// now test the param filter does trigger
|
||||
r, _ = http.NewRequest("GET", "/admin", nil)
|
||||
w = httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Did not apply Param Filter. Code set to [%v]; want [%v]", w.Code, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Benchmark_RoutedHandler runs a benchmark against
|
||||
// the RouteMux using the default settings.
|
||||
func Benchmark_RoutedHandler(b *testing.B) {
|
||||
handler := new(RouteMux)
|
||||
handler.Get("/", HandlerOk)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark_RoutedHandler runs a benchmark against
|
||||
// the RouteMux using the default settings with REST
|
||||
// URL params.
|
||||
func Benchmark_RoutedHandlerParams(b *testing.B) {
|
||||
|
||||
handler := new(RouteMux)
|
||||
handler.Get("/:user", HandlerOk)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r, _ := http.NewRequest("GET", "/admin", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark_ServeMux runs a benchmark against
|
||||
// the ServeMux Go function. We use this to determine
|
||||
// performance impact of our library, when compared
|
||||
// to the out-of-the-box Mux provided by Go.
|
||||
func Benchmark_ServeMux(b *testing.B) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", HandlerOk)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
mux.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,175 @@
|
|||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
|
@ -0,0 +1,44 @@
|
|||
Redigo
|
||||
======
|
||||
|
||||
Redigo is a [Go](http://golang.org/) client for the [Redis](http://redis.io/) database.
|
||||
|
||||
Features
|
||||
-------
|
||||
|
||||
* A [Print-like](http://godoc.org/github.com/garyburd/redigo/redis#hdr-Executing_Commands) API with support for all Redis commands.
|
||||
* [Pipelining](http://godoc.org/github.com/garyburd/redigo/redis#hdr-Pipelining), including pipelined transactions.
|
||||
* [Publish/Subscribe](http://godoc.org/github.com/garyburd/redigo/redis#hdr-Publish_and_Subscribe).
|
||||
* [Connection pooling](http://godoc.org/github.com/garyburd/redigo/redis#Pool).
|
||||
* [Script helper type](http://godoc.org/github.com/garyburd/redigo/redis#Script) with optimistic use of EVALSHA.
|
||||
* [Helper functions](http://godoc.org/github.com/garyburd/redigo/redis#hdr-Reply_Helpers) for working with command replies.
|
||||
|
||||
Documentation
|
||||
-------------
|
||||
|
||||
- [API Reference](http://godoc.org/github.com/garyburd/redigo/redis)
|
||||
- [FAQ](https://github.com/garyburd/redigo/wiki/FAQ)
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
Install Redigo using the "go get" command:
|
||||
|
||||
go get github.com/garyburd/redigo/redis
|
||||
|
||||
The Go distribution is Redigo's only dependency.
|
||||
|
||||
Contributing
|
||||
------------
|
||||
|
||||
Contributions are welcome.
|
||||
|
||||
Before writing code, send mail to gary@beagledreams.com to discuss what you
|
||||
plan to do. This gives me a chance to validate the design, avoid duplication of
|
||||
effort and ensure that the changes fit the goals of the project. Do not start
|
||||
the discussion with a pull request.
|
||||
|
||||
License
|
||||
-------
|
||||
|
||||
Redigo is available under the [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0.html).
|
|
@ -0,0 +1,54 @@
|
|||
// Copyright 2014 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package internal // import "github.com/garyburd/redigo/internal"
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
WatchState = 1 << iota
|
||||
MultiState
|
||||
SubscribeState
|
||||
MonitorState
|
||||
)
|
||||
|
||||
type CommandInfo struct {
|
||||
Set, Clear int
|
||||
}
|
||||
|
||||
var commandInfos = map[string]CommandInfo{
|
||||
"WATCH": {Set: WatchState},
|
||||
"UNWATCH": {Clear: WatchState},
|
||||
"MULTI": {Set: MultiState},
|
||||
"EXEC": {Clear: WatchState | MultiState},
|
||||
"DISCARD": {Clear: WatchState | MultiState},
|
||||
"PSUBSCRIBE": {Set: SubscribeState},
|
||||
"SUBSCRIBE": {Set: SubscribeState},
|
||||
"MONITOR": {Set: MonitorState},
|
||||
}
|
||||
|
||||
func init() {
|
||||
for n, ci := range commandInfos {
|
||||
commandInfos[strings.ToLower(n)] = ci
|
||||
}
|
||||
}
|
||||
|
||||
func LookupCommandInfo(commandName string) CommandInfo {
|
||||
if ci, ok := commandInfos[commandName]; ok {
|
||||
return ci
|
||||
}
|
||||
return commandInfos[strings.ToUpper(commandName)]
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
package internal
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestLookupCommandInfo(t *testing.T) {
|
||||
for _, n := range []string{"watch", "WATCH", "wAtch"} {
|
||||
if LookupCommandInfo(n) == (CommandInfo{}) {
|
||||
t.Errorf("LookupCommandInfo(%q) = CommandInfo{}, expected non-zero value", n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkLookupCommandInfo(b *testing.B, names ...string) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, c := range names {
|
||||
LookupCommandInfo(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLookupCommandInfoCorrectCase(b *testing.B) {
|
||||
benchmarkLookupCommandInfo(b, "watch", "WATCH", "monitor", "MONITOR")
|
||||
}
|
||||
|
||||
func BenchmarkLookupCommandInfoMixedCase(b *testing.B) {
|
||||
benchmarkLookupCommandInfo(b, "wAtch", "WeTCH", "monItor", "MONiTOR")
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
// Copyright 2014 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
// Package redistest contains utilities for writing Redigo tests.
|
||||
package redistest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/garyburd/redigo/redis"
|
||||
)
|
||||
|
||||
type testConn struct {
|
||||
redis.Conn
|
||||
}
|
||||
|
||||
func (t testConn) Close() error {
|
||||
_, err := t.Conn.Do("SELECT", "9")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
_, err = t.Conn.Do("FLUSHDB")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return t.Conn.Close()
|
||||
}
|
||||
|
||||
// Dial dials the local Redis server and selects database 9. To prevent
|
||||
// stomping on real data, DialTestDB fails if database 9 contains data. The
|
||||
// returned connection flushes database 9 on close.
|
||||
func Dial() (redis.Conn, error) {
|
||||
c, err := redis.DialTimeout("tcp", ":6379", 0, 1*time.Second, 1*time.Second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = c.Do("SELECT", "9")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n, err := redis.Int(c.Do("DBSIZE"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if n != 0 {
|
||||
return nil, errors.New("database #9 is not empty, test can not continue")
|
||||
}
|
||||
|
||||
return testConn{c}, nil
|
||||
}
|
|
@ -0,0 +1,455 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// conn is the low-level implementation of Conn
|
||||
type conn struct {
|
||||
|
||||
// Shared
|
||||
mu sync.Mutex
|
||||
pending int
|
||||
err error
|
||||
conn net.Conn
|
||||
|
||||
// Read
|
||||
readTimeout time.Duration
|
||||
br *bufio.Reader
|
||||
|
||||
// Write
|
||||
writeTimeout time.Duration
|
||||
bw *bufio.Writer
|
||||
|
||||
// Scratch space for formatting argument length.
|
||||
// '*' or '$', length, "\r\n"
|
||||
lenScratch [32]byte
|
||||
|
||||
// Scratch space for formatting integers and floats.
|
||||
numScratch [40]byte
|
||||
}
|
||||
|
||||
// Dial connects to the Redis server at the given network and address.
|
||||
func Dial(network, address string) (Conn, error) {
|
||||
dialer := xDialer{}
|
||||
return dialer.Dial(network, address)
|
||||
}
|
||||
|
||||
// DialTimeout acts like Dial but takes timeouts for establishing the
|
||||
// connection to the server, writing a command and reading a reply.
|
||||
func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) {
|
||||
netDialer := net.Dialer{Timeout: connectTimeout}
|
||||
dialer := xDialer{
|
||||
NetDial: netDialer.Dial,
|
||||
ReadTimeout: readTimeout,
|
||||
WriteTimeout: writeTimeout,
|
||||
}
|
||||
return dialer.Dial(network, address)
|
||||
}
|
||||
|
||||
// A Dialer specifies options for connecting to a Redis server.
|
||||
type xDialer struct {
|
||||
// NetDial specifies the dial function for creating TCP connections. If
|
||||
// NetDial is nil, then net.Dial is used.
|
||||
NetDial func(network, addr string) (net.Conn, error)
|
||||
|
||||
// ReadTimeout specifies the timeout for reading a single command
|
||||
// reply. If ReadTimeout is zero, then no timeout is used.
|
||||
ReadTimeout time.Duration
|
||||
|
||||
// WriteTimeout specifies the timeout for writing a single command. If
|
||||
// WriteTimeout is zero, then no timeout is used.
|
||||
WriteTimeout time.Duration
|
||||
}
|
||||
|
||||
// Dial connects to the Redis server at address on the named network.
|
||||
func (d *xDialer) Dial(network, address string) (Conn, error) {
|
||||
dial := d.NetDial
|
||||
if dial == nil {
|
||||
dial = net.Dial
|
||||
}
|
||||
netConn, err := dial(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &conn{
|
||||
conn: netConn,
|
||||
bw: bufio.NewWriter(netConn),
|
||||
br: bufio.NewReader(netConn),
|
||||
readTimeout: d.ReadTimeout,
|
||||
writeTimeout: d.WriteTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewConn returns a new Redigo connection for the given net connection.
|
||||
func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn {
|
||||
return &conn{
|
||||
conn: netConn,
|
||||
bw: bufio.NewWriter(netConn),
|
||||
br: bufio.NewReader(netConn),
|
||||
readTimeout: readTimeout,
|
||||
writeTimeout: writeTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
c.mu.Lock()
|
||||
err := c.err
|
||||
if c.err == nil {
|
||||
c.err = errors.New("redigo: closed")
|
||||
err = c.conn.Close()
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) fatal(err error) error {
|
||||
c.mu.Lock()
|
||||
if c.err == nil {
|
||||
c.err = err
|
||||
// Close connection to force errors on subsequent calls and to unblock
|
||||
// other reader or writer.
|
||||
c.conn.Close()
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) Err() error {
|
||||
c.mu.Lock()
|
||||
err := c.err
|
||||
c.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) writeLen(prefix byte, n int) error {
|
||||
c.lenScratch[len(c.lenScratch)-1] = '\n'
|
||||
c.lenScratch[len(c.lenScratch)-2] = '\r'
|
||||
i := len(c.lenScratch) - 3
|
||||
for {
|
||||
c.lenScratch[i] = byte('0' + n%10)
|
||||
i -= 1
|
||||
n = n / 10
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
c.lenScratch[i] = prefix
|
||||
_, err := c.bw.Write(c.lenScratch[i:])
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) writeString(s string) error {
|
||||
c.writeLen('$', len(s))
|
||||
c.bw.WriteString(s)
|
||||
_, err := c.bw.WriteString("\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) writeBytes(p []byte) error {
|
||||
c.writeLen('$', len(p))
|
||||
c.bw.Write(p)
|
||||
_, err := c.bw.WriteString("\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) writeInt64(n int64) error {
|
||||
return c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10))
|
||||
}
|
||||
|
||||
func (c *conn) writeFloat64(n float64) error {
|
||||
return c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64))
|
||||
}
|
||||
|
||||
func (c *conn) writeCommand(cmd string, args []interface{}) (err error) {
|
||||
c.writeLen('*', 1+len(args))
|
||||
err = c.writeString(cmd)
|
||||
for _, arg := range args {
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
switch arg := arg.(type) {
|
||||
case string:
|
||||
err = c.writeString(arg)
|
||||
case []byte:
|
||||
err = c.writeBytes(arg)
|
||||
case int:
|
||||
err = c.writeInt64(int64(arg))
|
||||
case int64:
|
||||
err = c.writeInt64(arg)
|
||||
case float64:
|
||||
err = c.writeFloat64(arg)
|
||||
case bool:
|
||||
if arg {
|
||||
err = c.writeString("1")
|
||||
} else {
|
||||
err = c.writeString("0")
|
||||
}
|
||||
case nil:
|
||||
err = c.writeString("")
|
||||
default:
|
||||
var buf bytes.Buffer
|
||||
fmt.Fprint(&buf, arg)
|
||||
err = c.writeBytes(buf.Bytes())
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type protocolError string
|
||||
|
||||
func (pe protocolError) Error() string {
|
||||
return fmt.Sprintf("redigo: %s (possible server error or unsupported concurrent read by application)", string(pe))
|
||||
}
|
||||
|
||||
func (c *conn) readLine() ([]byte, error) {
|
||||
p, err := c.br.ReadSlice('\n')
|
||||
if err == bufio.ErrBufferFull {
|
||||
return nil, protocolError("long response line")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
i := len(p) - 2
|
||||
if i < 0 || p[i] != '\r' {
|
||||
return nil, protocolError("bad response line terminator")
|
||||
}
|
||||
return p[:i], nil
|
||||
}
|
||||
|
||||
// parseLen parses bulk string and array lengths.
|
||||
func parseLen(p []byte) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return -1, protocolError("malformed length")
|
||||
}
|
||||
|
||||
if p[0] == '-' && len(p) == 2 && p[1] == '1' {
|
||||
// handle $-1 and $-1 null replies.
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
var n int
|
||||
for _, b := range p {
|
||||
n *= 10
|
||||
if b < '0' || b > '9' {
|
||||
return -1, protocolError("illegal bytes in length")
|
||||
}
|
||||
n += int(b - '0')
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// parseInt parses an integer reply.
|
||||
func parseInt(p []byte) (interface{}, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, protocolError("malformed integer")
|
||||
}
|
||||
|
||||
var negate bool
|
||||
if p[0] == '-' {
|
||||
negate = true
|
||||
p = p[1:]
|
||||
if len(p) == 0 {
|
||||
return 0, protocolError("malformed integer")
|
||||
}
|
||||
}
|
||||
|
||||
var n int64
|
||||
for _, b := range p {
|
||||
n *= 10
|
||||
if b < '0' || b > '9' {
|
||||
return 0, protocolError("illegal bytes in length")
|
||||
}
|
||||
n += int64(b - '0')
|
||||
}
|
||||
|
||||
if negate {
|
||||
n = -n
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
var (
|
||||
okReply interface{} = "OK"
|
||||
pongReply interface{} = "PONG"
|
||||
)
|
||||
|
||||
func (c *conn) readReply() (interface{}, error) {
|
||||
line, err := c.readLine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(line) == 0 {
|
||||
return nil, protocolError("short response line")
|
||||
}
|
||||
switch line[0] {
|
||||
case '+':
|
||||
switch {
|
||||
case len(line) == 3 && line[1] == 'O' && line[2] == 'K':
|
||||
// Avoid allocation for frequent "+OK" response.
|
||||
return okReply, nil
|
||||
case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G':
|
||||
// Avoid allocation in PING command benchmarks :)
|
||||
return pongReply, nil
|
||||
default:
|
||||
return string(line[1:]), nil
|
||||
}
|
||||
case '-':
|
||||
return Error(string(line[1:])), nil
|
||||
case ':':
|
||||
return parseInt(line[1:])
|
||||
case '$':
|
||||
n, err := parseLen(line[1:])
|
||||
if n < 0 || err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p := make([]byte, n)
|
||||
_, err = io.ReadFull(c.br, p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if line, err := c.readLine(); err != nil {
|
||||
return nil, err
|
||||
} else if len(line) != 0 {
|
||||
return nil, protocolError("bad bulk string format")
|
||||
}
|
||||
return p, nil
|
||||
case '*':
|
||||
n, err := parseLen(line[1:])
|
||||
if n < 0 || err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r := make([]interface{}, n)
|
||||
for i := range r {
|
||||
r[i], err = c.readReply()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
return nil, protocolError("unexpected response line")
|
||||
}
|
||||
|
||||
func (c *conn) Send(cmd string, args ...interface{}) error {
|
||||
c.mu.Lock()
|
||||
c.pending += 1
|
||||
c.mu.Unlock()
|
||||
if c.writeTimeout != 0 {
|
||||
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
||||
}
|
||||
if err := c.writeCommand(cmd, args); err != nil {
|
||||
return c.fatal(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) Flush() error {
|
||||
if c.writeTimeout != 0 {
|
||||
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
||||
}
|
||||
if err := c.bw.Flush(); err != nil {
|
||||
return c.fatal(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) Receive() (reply interface{}, err error) {
|
||||
if c.readTimeout != 0 {
|
||||
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
||||
}
|
||||
if reply, err = c.readReply(); err != nil {
|
||||
return nil, c.fatal(err)
|
||||
}
|
||||
// When using pub/sub, the number of receives can be greater than the
|
||||
// number of sends. To enable normal use of the connection after
|
||||
// unsubscribing from all channels, we do not decrement pending to a
|
||||
// negative value.
|
||||
//
|
||||
// The pending field is decremented after the reply is read to handle the
|
||||
// case where Receive is called before Send.
|
||||
c.mu.Lock()
|
||||
if c.pending > 0 {
|
||||
c.pending -= 1
|
||||
}
|
||||
c.mu.Unlock()
|
||||
if err, ok := reply.(Error); ok {
|
||||
return nil, err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
|
||||
c.mu.Lock()
|
||||
pending := c.pending
|
||||
c.pending = 0
|
||||
c.mu.Unlock()
|
||||
|
||||
if cmd == "" && pending == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if c.writeTimeout != 0 {
|
||||
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
||||
}
|
||||
|
||||
if cmd != "" {
|
||||
c.writeCommand(cmd, args)
|
||||
}
|
||||
|
||||
if err := c.bw.Flush(); err != nil {
|
||||
return nil, c.fatal(err)
|
||||
}
|
||||
|
||||
if c.readTimeout != 0 {
|
||||
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
||||
}
|
||||
|
||||
if cmd == "" {
|
||||
reply := make([]interface{}, pending)
|
||||
for i := range reply {
|
||||
r, e := c.readReply()
|
||||
if e != nil {
|
||||
return nil, c.fatal(e)
|
||||
}
|
||||
reply[i] = r
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
var reply interface{}
|
||||
for i := 0; i <= pending; i++ {
|
||||
var e error
|
||||
if reply, e = c.readReply(); e != nil {
|
||||
return nil, c.fatal(e)
|
||||
}
|
||||
if e, ok := reply.(Error); ok && err == nil {
|
||||
err = e
|
||||
}
|
||||
}
|
||||
return reply, err
|
||||
}
|
|
@ -0,0 +1,542 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"math"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/garyburd/redigo/internal/redistest"
|
||||
"github.com/garyburd/redigo/redis"
|
||||
)
|
||||
|
||||
var writeTests = []struct {
|
||||
args []interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
[]interface{}{"SET", "key", "value"},
|
||||
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
|
||||
},
|
||||
{
|
||||
[]interface{}{"SET", "key", "value"},
|
||||
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
|
||||
},
|
||||
{
|
||||
[]interface{}{"SET", "key", byte(100)},
|
||||
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$3\r\n100\r\n",
|
||||
},
|
||||
{
|
||||
[]interface{}{"SET", "key", 100},
|
||||
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$3\r\n100\r\n",
|
||||
},
|
||||
{
|
||||
[]interface{}{"SET", "key", int64(math.MinInt64)},
|
||||
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$20\r\n-9223372036854775808\r\n",
|
||||
},
|
||||
{
|
||||
[]interface{}{"SET", "key", float64(1349673917.939762)},
|
||||
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$21\r\n1.349673917939762e+09\r\n",
|
||||
},
|
||||
{
|
||||
[]interface{}{"SET", "key", ""},
|
||||
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$0\r\n\r\n",
|
||||
},
|
||||
{
|
||||
[]interface{}{"SET", "key", nil},
|
||||
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$0\r\n\r\n",
|
||||
},
|
||||
{
|
||||
[]interface{}{"ECHO", true, false},
|
||||
"*3\r\n$4\r\nECHO\r\n$1\r\n1\r\n$1\r\n0\r\n",
|
||||
},
|
||||
}
|
||||
|
||||
func TestWrite(t *testing.T) {
|
||||
for _, tt := range writeTests {
|
||||
var buf bytes.Buffer
|
||||
rw := bufio.ReadWriter{Writer: bufio.NewWriter(&buf)}
|
||||
c := redis.NewConnBufio(rw)
|
||||
err := c.Send(tt.args[0].(string), tt.args[1:]...)
|
||||
if err != nil {
|
||||
t.Errorf("Send(%v) returned error %v", tt.args, err)
|
||||
continue
|
||||
}
|
||||
rw.Flush()
|
||||
actual := buf.String()
|
||||
if actual != tt.expected {
|
||||
t.Errorf("Send(%v) = %q, want %q", tt.args, actual, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var errorSentinel = &struct{}{}
|
||||
|
||||
var readTests = []struct {
|
||||
reply string
|
||||
expected interface{}
|
||||
}{
|
||||
{
|
||||
"+OK\r\n",
|
||||
"OK",
|
||||
},
|
||||
{
|
||||
"+PONG\r\n",
|
||||
"PONG",
|
||||
},
|
||||
{
|
||||
"@OK\r\n",
|
||||
errorSentinel,
|
||||
},
|
||||
{
|
||||
"$6\r\nfoobar\r\n",
|
||||
[]byte("foobar"),
|
||||
},
|
||||
{
|
||||
"$-1\r\n",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
":1\r\n",
|
||||
int64(1),
|
||||
},
|
||||
{
|
||||
":-2\r\n",
|
||||
int64(-2),
|
||||
},
|
||||
{
|
||||
"*0\r\n",
|
||||
[]interface{}{},
|
||||
},
|
||||
{
|
||||
"*-1\r\n",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"*4\r\n$3\r\nfoo\r\n$3\r\nbar\r\n$5\r\nHello\r\n$5\r\nWorld\r\n",
|
||||
[]interface{}{[]byte("foo"), []byte("bar"), []byte("Hello"), []byte("World")},
|
||||
},
|
||||
{
|
||||
"*3\r\n$3\r\nfoo\r\n$-1\r\n$3\r\nbar\r\n",
|
||||
[]interface{}{[]byte("foo"), nil, []byte("bar")},
|
||||
},
|
||||
|
||||
{
|
||||
// "x" is not a valid length
|
||||
"$x\r\nfoobar\r\n",
|
||||
errorSentinel,
|
||||
},
|
||||
{
|
||||
// -2 is not a valid length
|
||||
"$-2\r\n",
|
||||
errorSentinel,
|
||||
},
|
||||
{
|
||||
// "x" is not a valid integer
|
||||
":x\r\n",
|
||||
errorSentinel,
|
||||
},
|
||||
{
|
||||
// missing \r\n following value
|
||||
"$6\r\nfoobar",
|
||||
errorSentinel,
|
||||
},
|
||||
{
|
||||
// short value
|
||||
"$6\r\nxx",
|
||||
errorSentinel,
|
||||
},
|
||||
{
|
||||
// long value
|
||||
"$6\r\nfoobarx\r\n",
|
||||
errorSentinel,
|
||||
},
|
||||
}
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
for _, tt := range readTests {
|
||||
rw := bufio.ReadWriter{
|
||||
Reader: bufio.NewReader(strings.NewReader(tt.reply)),
|
||||
Writer: bufio.NewWriter(nil), // writer need to support Flush
|
||||
}
|
||||
c := redis.NewConnBufio(rw)
|
||||
actual, err := c.Receive()
|
||||
if tt.expected == errorSentinel {
|
||||
if err == nil {
|
||||
t.Errorf("Receive(%q) did not return expected error", tt.reply)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Receive(%q) returned error %v", tt.reply, err)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(actual, tt.expected) {
|
||||
t.Errorf("Receive(%q) = %v, want %v", tt.reply, actual, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var testCommands = []struct {
|
||||
args []interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{
|
||||
[]interface{}{"PING"},
|
||||
"PONG",
|
||||
},
|
||||
{
|
||||
[]interface{}{"SET", "foo", "bar"},
|
||||
"OK",
|
||||
},
|
||||
{
|
||||
[]interface{}{"GET", "foo"},
|
||||
[]byte("bar"),
|
||||
},
|
||||
{
|
||||
[]interface{}{"GET", "nokey"},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
[]interface{}{"MGET", "nokey", "foo"},
|
||||
[]interface{}{nil, []byte("bar")},
|
||||
},
|
||||
{
|
||||
[]interface{}{"INCR", "mycounter"},
|
||||
int64(1),
|
||||
},
|
||||
{
|
||||
[]interface{}{"LPUSH", "mylist", "foo"},
|
||||
int64(1),
|
||||
},
|
||||
{
|
||||
[]interface{}{"LPUSH", "mylist", "bar"},
|
||||
int64(2),
|
||||
},
|
||||
{
|
||||
[]interface{}{"LRANGE", "mylist", 0, -1},
|
||||
[]interface{}{[]byte("bar"), []byte("foo")},
|
||||
},
|
||||
{
|
||||
[]interface{}{"MULTI"},
|
||||
"OK",
|
||||
},
|
||||
{
|
||||
[]interface{}{"LRANGE", "mylist", 0, -1},
|
||||
"QUEUED",
|
||||
},
|
||||
{
|
||||
[]interface{}{"PING"},
|
||||
"QUEUED",
|
||||
},
|
||||
{
|
||||
[]interface{}{"EXEC"},
|
||||
[]interface{}{
|
||||
[]interface{}{[]byte("bar"), []byte("foo")},
|
||||
"PONG",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func TestDoCommands(t *testing.T) {
|
||||
c, err := redistest.Dial()
|
||||
if err != nil {
|
||||
t.Fatalf("error connection to database, %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
for _, cmd := range testCommands {
|
||||
actual, err := c.Do(cmd.args[0].(string), cmd.args[1:]...)
|
||||
if err != nil {
|
||||
t.Errorf("Do(%v) returned error %v", cmd.args, err)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(actual, cmd.expected) {
|
||||
t.Errorf("Do(%v) = %v, want %v", cmd.args, actual, cmd.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipelineCommands(t *testing.T) {
|
||||
c, err := redistest.Dial()
|
||||
if err != nil {
|
||||
t.Fatalf("error connection to database, %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
for _, cmd := range testCommands {
|
||||
if err := c.Send(cmd.args[0].(string), cmd.args[1:]...); err != nil {
|
||||
t.Fatalf("Send(%v) returned error %v", cmd.args, err)
|
||||
}
|
||||
}
|
||||
if err := c.Flush(); err != nil {
|
||||
t.Errorf("Flush() returned error %v", err)
|
||||
}
|
||||
for _, cmd := range testCommands {
|
||||
actual, err := c.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Receive(%v) returned error %v", cmd.args, err)
|
||||
}
|
||||
if !reflect.DeepEqual(actual, cmd.expected) {
|
||||
t.Errorf("Receive(%v) = %v, want %v", cmd.args, actual, cmd.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlankCommmand(t *testing.T) {
|
||||
c, err := redistest.Dial()
|
||||
if err != nil {
|
||||
t.Fatalf("error connection to database, %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
for _, cmd := range testCommands {
|
||||
if err := c.Send(cmd.args[0].(string), cmd.args[1:]...); err != nil {
|
||||
t.Fatalf("Send(%v) returned error %v", cmd.args, err)
|
||||
}
|
||||
}
|
||||
reply, err := redis.Values(c.Do(""))
|
||||
if err != nil {
|
||||
t.Fatalf("Do() returned error %v", err)
|
||||
}
|
||||
if len(reply) != len(testCommands) {
|
||||
t.Fatalf("len(reply)=%d, want %d", len(reply), len(testCommands))
|
||||
}
|
||||
for i, cmd := range testCommands {
|
||||
actual := reply[i]
|
||||
if !reflect.DeepEqual(actual, cmd.expected) {
|
||||
t.Errorf("Receive(%v) = %v, want %v", cmd.args, actual, cmd.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecvBeforeSend(t *testing.T) {
|
||||
c, err := redistest.Dial()
|
||||
if err != nil {
|
||||
t.Fatalf("error connection to database, %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
c.Receive()
|
||||
close(done)
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
c.Send("PING")
|
||||
c.Flush()
|
||||
<-done
|
||||
_, err = c.Do("")
|
||||
if err != nil {
|
||||
t.Fatalf("error=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
c, err := redistest.Dial()
|
||||
if err != nil {
|
||||
t.Fatalf("error connection to database, %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
c.Do("SET", "key", "val")
|
||||
_, err = c.Do("HSET", "key", "fld", "val")
|
||||
if err == nil {
|
||||
t.Errorf("Expected err for HSET on string key.")
|
||||
}
|
||||
if c.Err() != nil {
|
||||
t.Errorf("Conn has Err()=%v, expect nil", c.Err())
|
||||
}
|
||||
_, err = c.Do("SET", "key", "val")
|
||||
if err != nil {
|
||||
t.Errorf("Do(SET, key, val) returned error %v, expected nil.", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadDeadline(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("net.Listen returned %v", err)
|
||||
}
|
||||
defer l.Close()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
time.Sleep(time.Second)
|
||||
c.Write([]byte("+OK\r\n"))
|
||||
c.Close()
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
c1, err := redis.DialTimeout(l.Addr().Network(), l.Addr().String(), 0, time.Millisecond, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("redis.Dial returned %v", err)
|
||||
}
|
||||
defer c1.Close()
|
||||
|
||||
_, err = c1.Do("PING")
|
||||
if err == nil {
|
||||
t.Fatalf("c1.Do() returned nil, expect error")
|
||||
}
|
||||
if c1.Err() == nil {
|
||||
t.Fatalf("c1.Err() = nil, expect error")
|
||||
}
|
||||
|
||||
c2, err := redis.DialTimeout(l.Addr().Network(), l.Addr().String(), 0, time.Millisecond, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("redis.Dial returned %v", err)
|
||||
}
|
||||
defer c2.Close()
|
||||
|
||||
c2.Send("PING")
|
||||
c2.Flush()
|
||||
_, err = c2.Receive()
|
||||
if err == nil {
|
||||
t.Fatalf("c2.Receive() returned nil, expect error")
|
||||
}
|
||||
if c2.Err() == nil {
|
||||
t.Fatalf("c2.Err() = nil, expect error")
|
||||
}
|
||||
}
|
||||
|
||||
// Connect to local instance of Redis running on the default port.
|
||||
func ExampleDial(x int) {
|
||||
c, err := redis.Dial("tcp", ":6379")
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
defer c.Close()
|
||||
}
|
||||
|
||||
// TextExecError tests handling of errors in a transaction. See
|
||||
// http://redis.io/topics/transactions for information on how Redis handles
|
||||
// errors in a transaction.
|
||||
func TestExecError(t *testing.T) {
|
||||
c, err := redistest.Dial()
|
||||
if err != nil {
|
||||
t.Fatalf("error connection to database, %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
// Execute commands that fail before EXEC is called.
|
||||
|
||||
c.Do("ZADD", "k0", 0, 0)
|
||||
c.Send("MULTI")
|
||||
c.Send("NOTACOMMAND", "k0", 0, 0)
|
||||
c.Send("ZINCRBY", "k0", 0, 0)
|
||||
v, err := c.Do("EXEC")
|
||||
if err == nil {
|
||||
t.Fatalf("EXEC returned values %v, expected error", v)
|
||||
}
|
||||
|
||||
// Execute commands that fail after EXEC is called. The first command
|
||||
// returns an error.
|
||||
|
||||
c.Do("ZADD", "k1", 0, 0)
|
||||
c.Send("MULTI")
|
||||
c.Send("HSET", "k1", 0, 0)
|
||||
c.Send("ZINCRBY", "k1", 0, 0)
|
||||
v, err = c.Do("EXEC")
|
||||
if err != nil {
|
||||
t.Fatalf("EXEC returned error %v", err)
|
||||
}
|
||||
|
||||
vs, err := redis.Values(v, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Values(v) returned error %v", err)
|
||||
}
|
||||
|
||||
if len(vs) != 2 {
|
||||
t.Fatalf("len(vs) == %d, want 2", len(vs))
|
||||
}
|
||||
|
||||
if _, ok := vs[0].(error); !ok {
|
||||
t.Fatalf("first result is type %T, expected error", vs[0])
|
||||
}
|
||||
|
||||
if _, ok := vs[1].([]byte); !ok {
|
||||
t.Fatalf("second result is type %T, expected []byte", vs[2])
|
||||
}
|
||||
|
||||
// Execute commands that fail after EXEC is called. The second command
|
||||
// returns an error.
|
||||
|
||||
c.Do("ZADD", "k2", 0, 0)
|
||||
c.Send("MULTI")
|
||||
c.Send("ZINCRBY", "k2", 0, 0)
|
||||
c.Send("HSET", "k2", 0, 0)
|
||||
v, err = c.Do("EXEC")
|
||||
if err != nil {
|
||||
t.Fatalf("EXEC returned error %v", err)
|
||||
}
|
||||
|
||||
vs, err = redis.Values(v, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Values(v) returned error %v", err)
|
||||
}
|
||||
|
||||
if len(vs) != 2 {
|
||||
t.Fatalf("len(vs) == %d, want 2", len(vs))
|
||||
}
|
||||
|
||||
if _, ok := vs[0].([]byte); !ok {
|
||||
t.Fatalf("first result is type %T, expected []byte", vs[0])
|
||||
}
|
||||
|
||||
if _, ok := vs[1].(error); !ok {
|
||||
t.Fatalf("second result is type %T, expected error", vs[2])
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDoEmpty(b *testing.B) {
|
||||
b.StopTimer()
|
||||
c, err := redistest.Dial()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer c.Close()
|
||||
b.StartTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := c.Do(""); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDoPing(b *testing.B) {
|
||||
b.StopTimer()
|
||||
c, err := redistest.Dial()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer c.Close()
|
||||
b.StartTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := c.Do("PING"); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,169 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
// Package redis is a client for the Redis database.
|
||||
//
|
||||
// The Redigo FAQ (https://github.com/garyburd/redigo/wiki/FAQ) contains more
|
||||
// documentation about this package.
|
||||
//
|
||||
// Connections
|
||||
//
|
||||
// The Conn interface is the primary interface for working with Redis.
|
||||
// Applications create connections by calling the Dial, DialWithTimeout or
|
||||
// NewConn functions. In the future, functions will be added for creating
|
||||
// sharded and other types of connections.
|
||||
//
|
||||
// The application must call the connection Close method when the application
|
||||
// is done with the connection.
|
||||
//
|
||||
// Executing Commands
|
||||
//
|
||||
// The Conn interface has a generic method for executing Redis commands:
|
||||
//
|
||||
// Do(commandName string, args ...interface{}) (reply interface{}, err error)
|
||||
//
|
||||
// The Redis command reference (http://redis.io/commands) lists the available
|
||||
// commands. An example of using the Redis APPEND command is:
|
||||
//
|
||||
// n, err := conn.Do("APPEND", "key", "value")
|
||||
//
|
||||
// The Do method converts command arguments to binary strings for transmission
|
||||
// to the server as follows:
|
||||
//
|
||||
// Go Type Conversion
|
||||
// []byte Sent as is
|
||||
// string Sent as is
|
||||
// int, int64 strconv.FormatInt(v)
|
||||
// float64 strconv.FormatFloat(v, 'g', -1, 64)
|
||||
// bool true -> "1", false -> "0"
|
||||
// nil ""
|
||||
// all other types fmt.Print(v)
|
||||
//
|
||||
// Redis command reply types are represented using the following Go types:
|
||||
//
|
||||
// Redis type Go type
|
||||
// error redis.Error
|
||||
// integer int64
|
||||
// simple string string
|
||||
// bulk string []byte or nil if value not present.
|
||||
// array []interface{} or nil if value not present.
|
||||
//
|
||||
// Use type assertions or the reply helper functions to convert from
|
||||
// interface{} to the specific Go type for the command result.
|
||||
//
|
||||
// Pipelining
|
||||
//
|
||||
// Connections support pipelining using the Send, Flush and Receive methods.
|
||||
//
|
||||
// Send(commandName string, args ...interface{}) error
|
||||
// Flush() error
|
||||
// Receive() (reply interface{}, err error)
|
||||
//
|
||||
// Send writes the command to the connection's output buffer. Flush flushes the
|
||||
// connection's output buffer to the server. Receive reads a single reply from
|
||||
// the server. The following example shows a simple pipeline.
|
||||
//
|
||||
// c.Send("SET", "foo", "bar")
|
||||
// c.Send("GET", "foo")
|
||||
// c.Flush()
|
||||
// c.Receive() // reply from SET
|
||||
// v, err = c.Receive() // reply from GET
|
||||
//
|
||||
// The Do method combines the functionality of the Send, Flush and Receive
|
||||
// methods. The Do method starts by writing the command and flushing the output
|
||||
// buffer. Next, the Do method receives all pending replies including the reply
|
||||
// for the command just sent by Do. If any of the received replies is an error,
|
||||
// then Do returns the error. If there are no errors, then Do returns the last
|
||||
// reply. If the command argument to the Do method is "", then the Do method
|
||||
// will flush the output buffer and receive pending replies without sending a
|
||||
// command.
|
||||
//
|
||||
// Use the Send and Do methods to implement pipelined transactions.
|
||||
//
|
||||
// c.Send("MULTI")
|
||||
// c.Send("INCR", "foo")
|
||||
// c.Send("INCR", "bar")
|
||||
// r, err := c.Do("EXEC")
|
||||
// fmt.Println(r) // prints [1, 1]
|
||||
//
|
||||
// Concurrency
|
||||
//
|
||||
// Connections do not support concurrent calls to the write methods (Send,
|
||||
// Flush) or concurrent calls to the read method (Receive). Connections do
|
||||
// allow a concurrent reader and writer.
|
||||
//
|
||||
// Because the Do method combines the functionality of Send, Flush and Receive,
|
||||
// the Do method cannot be called concurrently with the other methods.
|
||||
//
|
||||
// For full concurrent access to Redis, use the thread-safe Pool to get and
|
||||
// release connections from within a goroutine.
|
||||
//
|
||||
// Publish and Subscribe
|
||||
//
|
||||
// Use the Send, Flush and Receive methods to implement Pub/Sub subscribers.
|
||||
//
|
||||
// c.Send("SUBSCRIBE", "example")
|
||||
// c.Flush()
|
||||
// for {
|
||||
// reply, err := c.Receive()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// // process pushed message
|
||||
// }
|
||||
//
|
||||
// The PubSubConn type wraps a Conn with convenience methods for implementing
|
||||
// subscribers. The Subscribe, PSubscribe, Unsubscribe and PUnsubscribe methods
|
||||
// send and flush a subscription management command. The receive method
|
||||
// converts a pushed message to convenient types for use in a type switch.
|
||||
//
|
||||
// psc := redis.PubSubConn{c}
|
||||
// psc.Subscribe("example")
|
||||
// for {
|
||||
// switch v := psc.Receive().(type) {
|
||||
// case redis.Message:
|
||||
// fmt.Printf("%s: message: %s\n", v.Channel, v.Data)
|
||||
// case redis.Subscription:
|
||||
// fmt.Printf("%s: %s %d\n", v.Channel, v.Kind, v.Count)
|
||||
// case error:
|
||||
// return v
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// Reply Helpers
|
||||
//
|
||||
// The Bool, Int, Bytes, String, Strings and Values functions convert a reply
|
||||
// to a value of a specific type. To allow convenient wrapping of calls to the
|
||||
// connection Do and Receive methods, the functions take a second argument of
|
||||
// type error. If the error is non-nil, then the helper function returns the
|
||||
// error. If the error is nil, the function converts the reply to the specified
|
||||
// type:
|
||||
//
|
||||
// exists, err := redis.Bool(c.Do("EXISTS", "foo"))
|
||||
// if err != nil {
|
||||
// // handle error return from c.Do or type conversion error.
|
||||
// }
|
||||
//
|
||||
// The Scan function converts elements of a array reply to Go types:
|
||||
//
|
||||
// var value1 int
|
||||
// var value2 string
|
||||
// reply, err := redis.Values(c.Do("MGET", "key1", "key2"))
|
||||
// if err != nil {
|
||||
// // handle error
|
||||
// }
|
||||
// if _, err := redis.Scan(reply, &value1, &value2); err != nil {
|
||||
// // handle error
|
||||
// }
|
||||
package redis // import "github.com/garyburd/redigo/redis"
|
|
@ -0,0 +1,117 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
)
|
||||
|
||||
// NewLoggingConn returns a logging wrapper around a connection.
|
||||
func NewLoggingConn(conn Conn, logger *log.Logger, prefix string) Conn {
|
||||
if prefix != "" {
|
||||
prefix = prefix + "."
|
||||
}
|
||||
return &loggingConn{conn, logger, prefix}
|
||||
}
|
||||
|
||||
type loggingConn struct {
|
||||
Conn
|
||||
logger *log.Logger
|
||||
prefix string
|
||||
}
|
||||
|
||||
func (c *loggingConn) Close() error {
|
||||
err := c.Conn.Close()
|
||||
var buf bytes.Buffer
|
||||
fmt.Fprintf(&buf, "%sClose() -> (%v)", c.prefix, err)
|
||||
c.logger.Output(2, buf.String())
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *loggingConn) printValue(buf *bytes.Buffer, v interface{}) {
|
||||
const chop = 32
|
||||
switch v := v.(type) {
|
||||
case []byte:
|
||||
if len(v) > chop {
|
||||
fmt.Fprintf(buf, "%q...", v[:chop])
|
||||
} else {
|
||||
fmt.Fprintf(buf, "%q", v)
|
||||
}
|
||||
case string:
|
||||
if len(v) > chop {
|
||||
fmt.Fprintf(buf, "%q...", v[:chop])
|
||||
} else {
|
||||
fmt.Fprintf(buf, "%q", v)
|
||||
}
|
||||
case []interface{}:
|
||||
if len(v) == 0 {
|
||||
buf.WriteString("[]")
|
||||
} else {
|
||||
sep := "["
|
||||
fin := "]"
|
||||
if len(v) > chop {
|
||||
v = v[:chop]
|
||||
fin = "...]"
|
||||
}
|
||||
for _, vv := range v {
|
||||
buf.WriteString(sep)
|
||||
c.printValue(buf, vv)
|
||||
sep = ", "
|
||||
}
|
||||
buf.WriteString(fin)
|
||||
}
|
||||
default:
|
||||
fmt.Fprint(buf, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *loggingConn) print(method, commandName string, args []interface{}, reply interface{}, err error) {
|
||||
var buf bytes.Buffer
|
||||
fmt.Fprintf(&buf, "%s%s(", c.prefix, method)
|
||||
if method != "Receive" {
|
||||
buf.WriteString(commandName)
|
||||
for _, arg := range args {
|
||||
buf.WriteString(", ")
|
||||
c.printValue(&buf, arg)
|
||||
}
|
||||
}
|
||||
buf.WriteString(") -> (")
|
||||
if method != "Send" {
|
||||
c.printValue(&buf, reply)
|
||||
buf.WriteString(", ")
|
||||
}
|
||||
fmt.Fprintf(&buf, "%v)", err)
|
||||
c.logger.Output(3, buf.String())
|
||||
}
|
||||
|
||||
func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{}, error) {
|
||||
reply, err := c.Conn.Do(commandName, args...)
|
||||
c.print("Do", commandName, args, reply, err)
|
||||
return reply, err
|
||||
}
|
||||
|
||||
func (c *loggingConn) Send(commandName string, args ...interface{}) error {
|
||||
err := c.Conn.Send(commandName, args...)
|
||||
c.print("Send", commandName, args, nil, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *loggingConn) Receive() (interface{}, error) {
|
||||
reply, err := c.Conn.Receive()
|
||||
c.print("Receive", "", nil, reply, err)
|
||||
return reply, err
|
||||
}
|
|
@ -0,0 +1,389 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"container/list"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"errors"
|
||||
"io"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/garyburd/redigo/internal"
|
||||
)
|
||||
|
||||
var nowFunc = time.Now // for testing
|
||||
|
||||
// ErrPoolExhausted is returned from a pool connection method (Do, Send,
|
||||
// Receive, Flush, Err) when the maximum number of database connections in the
|
||||
// pool has been reached.
|
||||
var ErrPoolExhausted = errors.New("redigo: connection pool exhausted")
|
||||
|
||||
var (
|
||||
errPoolClosed = errors.New("redigo: connection pool closed")
|
||||
errConnClosed = errors.New("redigo: connection closed")
|
||||
)
|
||||
|
||||
// Pool maintains a pool of connections. The application calls the Get method
|
||||
// to get a connection from the pool and the connection's Close method to
|
||||
// return the connection's resources to the pool.
|
||||
//
|
||||
// The following example shows how to use a pool in a web application. The
|
||||
// application creates a pool at application startup and makes it available to
|
||||
// request handlers using a global variable.
|
||||
//
|
||||
// func newPool(server, password string) *redis.Pool {
|
||||
// return &redis.Pool{
|
||||
// MaxIdle: 3,
|
||||
// IdleTimeout: 240 * time.Second,
|
||||
// Dial: func () (redis.Conn, error) {
|
||||
// c, err := redis.Dial("tcp", server)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// if _, err := c.Do("AUTH", password); err != nil {
|
||||
// c.Close()
|
||||
// return nil, err
|
||||
// }
|
||||
// return c, err
|
||||
// },
|
||||
// TestOnBorrow: func(c redis.Conn, t time.Time) error {
|
||||
// _, err := c.Do("PING")
|
||||
// return err
|
||||
// },
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// var (
|
||||
// pool *redis.Pool
|
||||
// redisServer = flag.String("redisServer", ":6379", "")
|
||||
// redisPassword = flag.String("redisPassword", "", "")
|
||||
// )
|
||||
//
|
||||
// func main() {
|
||||
// flag.Parse()
|
||||
// pool = newPool(*redisServer, *redisPassword)
|
||||
// ...
|
||||
// }
|
||||
//
|
||||
// A request handler gets a connection from the pool and closes the connection
|
||||
// when the handler is done:
|
||||
//
|
||||
// func serveHome(w http.ResponseWriter, r *http.Request) {
|
||||
// conn := pool.Get()
|
||||
// defer conn.Close()
|
||||
// ....
|
||||
// }
|
||||
//
|
||||
type Pool struct {
|
||||
|
||||
// Dial is an application supplied function for creating and configuring a
|
||||
// connection
|
||||
Dial func() (Conn, error)
|
||||
|
||||
// TestOnBorrow is an optional application supplied function for checking
|
||||
// the health of an idle connection before the connection is used again by
|
||||
// the application. Argument t is the time that the connection was returned
|
||||
// to the pool. If the function returns an error, then the connection is
|
||||
// closed.
|
||||
TestOnBorrow func(c Conn, t time.Time) error
|
||||
|
||||
// Maximum number of idle connections in the pool.
|
||||
MaxIdle int
|
||||
|
||||
// Maximum number of connections allocated by the pool at a given time.
|
||||
// When zero, there is no limit on the number of connections in the pool.
|
||||
MaxActive int
|
||||
|
||||
// Close connections after remaining idle for this duration. If the value
|
||||
// is zero, then idle connections are not closed. Applications should set
|
||||
// the timeout to a value less than the server's timeout.
|
||||
IdleTimeout time.Duration
|
||||
|
||||
// If Wait is true and the pool is at the MaxIdle limit, then Get() waits
|
||||
// for a connection to be returned to the pool before returning.
|
||||
Wait bool
|
||||
|
||||
// mu protects fields defined below.
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
closed bool
|
||||
active int
|
||||
|
||||
// Stack of idleConn with most recently used at the front.
|
||||
idle list.List
|
||||
}
|
||||
|
||||
type idleConn struct {
|
||||
c Conn
|
||||
t time.Time
|
||||
}
|
||||
|
||||
// NewPool creates a new pool. This function is deprecated. Applications should
|
||||
// initialize the Pool fields directly as shown in example.
|
||||
func NewPool(newFn func() (Conn, error), maxIdle int) *Pool {
|
||||
return &Pool{Dial: newFn, MaxIdle: maxIdle}
|
||||
}
|
||||
|
||||
// Get gets a connection. The application must close the returned connection.
|
||||
// This method always returns a valid connection so that applications can defer
|
||||
// error handling to the first use of the connection. If there is an error
|
||||
// getting an underlying connection, then the connection Err, Do, Send, Flush
|
||||
// and Receive methods return that error.
|
||||
func (p *Pool) Get() Conn {
|
||||
c, err := p.get()
|
||||
if err != nil {
|
||||
return errorConnection{err}
|
||||
}
|
||||
return &pooledConnection{p: p, c: c}
|
||||
}
|
||||
|
||||
// ActiveCount returns the number of active connections in the pool.
|
||||
func (p *Pool) ActiveCount() int {
|
||||
p.mu.Lock()
|
||||
active := p.active
|
||||
p.mu.Unlock()
|
||||
return active
|
||||
}
|
||||
|
||||
// Close releases the resources used by the pool.
|
||||
func (p *Pool) Close() error {
|
||||
p.mu.Lock()
|
||||
idle := p.idle
|
||||
p.idle.Init()
|
||||
p.closed = true
|
||||
p.active -= idle.Len()
|
||||
if p.cond != nil {
|
||||
p.cond.Broadcast()
|
||||
}
|
||||
p.mu.Unlock()
|
||||
for e := idle.Front(); e != nil; e = e.Next() {
|
||||
e.Value.(idleConn).c.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// release decrements the active count and signals waiters. The caller must
|
||||
// hold p.mu during the call.
|
||||
func (p *Pool) release() {
|
||||
p.active -= 1
|
||||
if p.cond != nil {
|
||||
p.cond.Signal()
|
||||
}
|
||||
}
|
||||
|
||||
// get prunes stale connections and returns a connection from the idle list or
|
||||
// creates a new connection.
|
||||
func (p *Pool) get() (Conn, error) {
|
||||
p.mu.Lock()
|
||||
|
||||
// Prune stale connections.
|
||||
|
||||
if timeout := p.IdleTimeout; timeout > 0 {
|
||||
for i, n := 0, p.idle.Len(); i < n; i++ {
|
||||
e := p.idle.Back()
|
||||
if e == nil {
|
||||
break
|
||||
}
|
||||
ic := e.Value.(idleConn)
|
||||
if ic.t.Add(timeout).After(nowFunc()) {
|
||||
break
|
||||
}
|
||||
p.idle.Remove(e)
|
||||
p.release()
|
||||
p.mu.Unlock()
|
||||
ic.c.Close()
|
||||
p.mu.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
|
||||
// Get idle connection.
|
||||
|
||||
for i, n := 0, p.idle.Len(); i < n; i++ {
|
||||
e := p.idle.Front()
|
||||
if e == nil {
|
||||
break
|
||||
}
|
||||
ic := e.Value.(idleConn)
|
||||
p.idle.Remove(e)
|
||||
test := p.TestOnBorrow
|
||||
p.mu.Unlock()
|
||||
if test == nil || test(ic.c, ic.t) == nil {
|
||||
return ic.c, nil
|
||||
}
|
||||
ic.c.Close()
|
||||
p.mu.Lock()
|
||||
p.release()
|
||||
}
|
||||
|
||||
// Check for pool closed before dialing a new connection.
|
||||
|
||||
if p.closed {
|
||||
p.mu.Unlock()
|
||||
return nil, errors.New("redigo: get on closed pool")
|
||||
}
|
||||
|
||||
// Dial new connection if under limit.
|
||||
|
||||
if p.MaxActive == 0 || p.active < p.MaxActive {
|
||||
dial := p.Dial
|
||||
p.active += 1
|
||||
p.mu.Unlock()
|
||||
c, err := dial()
|
||||
if err != nil {
|
||||
p.mu.Lock()
|
||||
p.release()
|
||||
p.mu.Unlock()
|
||||
c = nil
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
|
||||
if !p.Wait {
|
||||
p.mu.Unlock()
|
||||
return nil, ErrPoolExhausted
|
||||
}
|
||||
|
||||
if p.cond == nil {
|
||||
p.cond = sync.NewCond(&p.mu)
|
||||
}
|
||||
p.cond.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pool) put(c Conn, forceClose bool) error {
|
||||
err := c.Err()
|
||||
p.mu.Lock()
|
||||
if !p.closed && err == nil && !forceClose {
|
||||
p.idle.PushFront(idleConn{t: nowFunc(), c: c})
|
||||
if p.idle.Len() > p.MaxIdle {
|
||||
c = p.idle.Remove(p.idle.Back()).(idleConn).c
|
||||
} else {
|
||||
c = nil
|
||||
}
|
||||
}
|
||||
|
||||
if c == nil {
|
||||
if p.cond != nil {
|
||||
p.cond.Signal()
|
||||
}
|
||||
p.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
p.release()
|
||||
p.mu.Unlock()
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
type pooledConnection struct {
|
||||
p *Pool
|
||||
c Conn
|
||||
state int
|
||||
}
|
||||
|
||||
var (
|
||||
sentinel []byte
|
||||
sentinelOnce sync.Once
|
||||
)
|
||||
|
||||
func initSentinel() {
|
||||
p := make([]byte, 64)
|
||||
if _, err := rand.Read(p); err == nil {
|
||||
sentinel = p
|
||||
} else {
|
||||
h := sha1.New()
|
||||
io.WriteString(h, "Oops, rand failed. Use time instead.")
|
||||
io.WriteString(h, strconv.FormatInt(time.Now().UnixNano(), 10))
|
||||
sentinel = h.Sum(nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (pc *pooledConnection) Close() error {
|
||||
c := pc.c
|
||||
if _, ok := c.(errorConnection); ok {
|
||||
return nil
|
||||
}
|
||||
pc.c = errorConnection{errConnClosed}
|
||||
|
||||
if pc.state&internal.MultiState != 0 {
|
||||
c.Send("DISCARD")
|
||||
pc.state &^= (internal.MultiState | internal.WatchState)
|
||||
} else if pc.state&internal.WatchState != 0 {
|
||||
c.Send("UNWATCH")
|
||||
pc.state &^= internal.WatchState
|
||||
}
|
||||
if pc.state&internal.SubscribeState != 0 {
|
||||
c.Send("UNSUBSCRIBE")
|
||||
c.Send("PUNSUBSCRIBE")
|
||||
// To detect the end of the message stream, ask the server to echo
|
||||
// a sentinel value and read until we see that value.
|
||||
sentinelOnce.Do(initSentinel)
|
||||
c.Send("ECHO", sentinel)
|
||||
c.Flush()
|
||||
for {
|
||||
p, err := c.Receive()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if p, ok := p.([]byte); ok && bytes.Equal(p, sentinel) {
|
||||
pc.state &^= internal.SubscribeState
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Do("")
|
||||
pc.p.put(c, pc.state != 0)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pc *pooledConnection) Err() error {
|
||||
return pc.c.Err()
|
||||
}
|
||||
|
||||
func (pc *pooledConnection) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
|
||||
ci := internal.LookupCommandInfo(commandName)
|
||||
pc.state = (pc.state | ci.Set) &^ ci.Clear
|
||||
return pc.c.Do(commandName, args...)
|
||||
}
|
||||
|
||||
func (pc *pooledConnection) Send(commandName string, args ...interface{}) error {
|
||||
ci := internal.LookupCommandInfo(commandName)
|
||||
pc.state = (pc.state | ci.Set) &^ ci.Clear
|
||||
return pc.c.Send(commandName, args...)
|
||||
}
|
||||
|
||||
func (pc *pooledConnection) Flush() error {
|
||||
return pc.c.Flush()
|
||||
}
|
||||
|
||||
func (pc *pooledConnection) Receive() (reply interface{}, err error) {
|
||||
return pc.c.Receive()
|
||||
}
|
||||
|
||||
type errorConnection struct{ err error }
|
||||
|
||||
func (ec errorConnection) Do(string, ...interface{}) (interface{}, error) { return nil, ec.err }
|
||||
func (ec errorConnection) Send(string, ...interface{}) error { return ec.err }
|
||||
func (ec errorConnection) Err() error { return ec.err }
|
||||
func (ec errorConnection) Close() error { return ec.err }
|
||||
func (ec errorConnection) Flush() error { return ec.err }
|
||||
func (ec errorConnection) Receive() (interface{}, error) { return nil, ec.err }
|
|
@ -0,0 +1,674 @@
|
|||
// Copyright 2011 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/garyburd/redigo/internal/redistest"
|
||||
"github.com/garyburd/redigo/redis"
|
||||
)
|
||||
|
||||
type poolTestConn struct {
|
||||
d *poolDialer
|
||||
err error
|
||||
redis.Conn
|
||||
}
|
||||
|
||||
func (c *poolTestConn) Close() error { c.d.open -= 1; return nil }
|
||||
func (c *poolTestConn) Err() error { return c.err }
|
||||
|
||||
func (c *poolTestConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
|
||||
if commandName == "ERR" {
|
||||
c.err = args[0].(error)
|
||||
commandName = "PING"
|
||||
}
|
||||
if commandName != "" {
|
||||
c.d.commands = append(c.d.commands, commandName)
|
||||
}
|
||||
return c.Conn.Do(commandName, args...)
|
||||
}
|
||||
|
||||
func (c *poolTestConn) Send(commandName string, args ...interface{}) error {
|
||||
c.d.commands = append(c.d.commands, commandName)
|
||||
return c.Conn.Send(commandName, args...)
|
||||
}
|
||||
|
||||
type poolDialer struct {
|
||||
t *testing.T
|
||||
dialed int
|
||||
open int
|
||||
commands []string
|
||||
dialErr error
|
||||
}
|
||||
|
||||
func (d *poolDialer) dial() (redis.Conn, error) {
|
||||
d.dialed += 1
|
||||
if d.dialErr != nil {
|
||||
return nil, d.dialErr
|
||||
}
|
||||
c, err := redistest.Dial()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.open += 1
|
||||
return &poolTestConn{d: d, Conn: c}, nil
|
||||
}
|
||||
|
||||
func (d *poolDialer) check(message string, p *redis.Pool, dialed, open int) {
|
||||
if d.dialed != dialed {
|
||||
d.t.Errorf("%s: dialed=%d, want %d", message, d.dialed, dialed)
|
||||
}
|
||||
if d.open != open {
|
||||
d.t.Errorf("%s: open=%d, want %d", message, d.open, open)
|
||||
}
|
||||
if active := p.ActiveCount(); active != open {
|
||||
d.t.Errorf("%s: active=%d, want %d", message, active, open)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolReuse(t *testing.T) {
|
||||
d := poolDialer{t: t}
|
||||
p := &redis.Pool{
|
||||
MaxIdle: 2,
|
||||
Dial: d.dial,
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
c1 := p.Get()
|
||||
c1.Do("PING")
|
||||
c2 := p.Get()
|
||||
c2.Do("PING")
|
||||
c1.Close()
|
||||
c2.Close()
|
||||
}
|
||||
|
||||
d.check("before close", p, 2, 2)
|
||||
p.Close()
|
||||
d.check("after close", p, 2, 0)
|
||||
}
|
||||
|
||||
func TestPoolMaxIdle(t *testing.T) {
|
||||
d := poolDialer{t: t}
|
||||
p := &redis.Pool{
|
||||
MaxIdle: 2,
|
||||
Dial: d.dial,
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
c1 := p.Get()
|
||||
c1.Do("PING")
|
||||
c2 := p.Get()
|
||||
c2.Do("PING")
|
||||
c3 := p.Get()
|
||||
c3.Do("PING")
|
||||
c1.Close()
|
||||
c2.Close()
|
||||
c3.Close()
|
||||
}
|
||||
d.check("before close", p, 12, 2)
|
||||
p.Close()
|
||||
d.check("after close", p, 12, 0)
|
||||
}
|
||||
|
||||
func TestPoolError(t *testing.T) {
|
||||
d := poolDialer{t: t}
|
||||
p := &redis.Pool{
|
||||
MaxIdle: 2,
|
||||
Dial: d.dial,
|
||||
}
|
||||
|
||||
c := p.Get()
|
||||
c.Do("ERR", io.EOF)
|
||||
if c.Err() == nil {
|
||||
t.Errorf("expected c.Err() != nil")
|
||||
}
|
||||
c.Close()
|
||||
|
||||
c = p.Get()
|
||||
c.Do("ERR", io.EOF)
|
||||
c.Close()
|
||||
|
||||
d.check(".", p, 2, 0)
|
||||
}
|
||||
|
||||
func TestPoolClose(t *testing.T) {
|
||||
d := poolDialer{t: t}
|
||||
p := &redis.Pool{
|
||||
MaxIdle: 2,
|
||||
Dial: d.dial,
|
||||
}
|
||||
|
||||
c1 := p.Get()
|
||||
c1.Do("PING")
|
||||
c2 := p.Get()
|
||||
c2.Do("PING")
|
||||
c3 := p.Get()
|
||||
c3.Do("PING")
|
||||
|
||||
c1.Close()
|
||||
if _, err := c1.Do("PING"); err == nil {
|
||||
t.Errorf("expected error after connection closed")
|
||||
}
|
||||
|
||||
c2.Close()
|
||||
c2.Close()
|
||||
|
||||
p.Close()
|
||||
|
||||
d.check("after pool close", p, 3, 1)
|
||||
|
||||
if _, err := c1.Do("PING"); err == nil {
|
||||
t.Errorf("expected error after connection and pool closed")
|
||||
}
|
||||
|
||||
c3.Close()
|
||||
|
||||
d.check("after conn close", p, 3, 0)
|
||||
|
||||
c1 = p.Get()
|
||||
if _, err := c1.Do("PING"); err == nil {
|
||||
t.Errorf("expected error after pool closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolTimeout(t *testing.T) {
|
||||
d := poolDialer{t: t}
|
||||
p := &redis.Pool{
|
||||
MaxIdle: 2,
|
||||
IdleTimeout: 300 * time.Second,
|
||||
Dial: d.dial,
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
redis.SetNowFunc(func() time.Time { return now })
|
||||
defer redis.SetNowFunc(time.Now)
|
||||
|
||||
c := p.Get()
|
||||
c.Do("PING")
|
||||
c.Close()
|
||||
|
||||
d.check("1", p, 1, 1)
|
||||
|
||||
now = now.Add(p.IdleTimeout)
|
||||
|
||||
c = p.Get()
|
||||
c.Do("PING")
|
||||
c.Close()
|
||||
|
||||
d.check("2", p, 2, 1)
|
||||
|
||||
p.Close()
|
||||
}
|
||||
|
||||
func TestPoolConcurrenSendReceive(t *testing.T) {
|
||||
p := &redis.Pool{
|
||||
Dial: redistest.Dial,
|
||||
}
|
||||
c := p.Get()
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := c.Receive()
|
||||
done <- err
|
||||
}()
|
||||
c.Send("PING")
|
||||
c.Flush()
|
||||
err := <-done
|
||||
if err != nil {
|
||||
t.Fatalf("Receive() returned error %v", err)
|
||||
}
|
||||
_, err = c.Do("")
|
||||
if err != nil {
|
||||
t.Fatalf("Do() returned error %v", err)
|
||||
}
|
||||
c.Close()
|
||||
p.Close()
|
||||
}
|
||||
|
||||
func TestPoolBorrowCheck(t *testing.T) {
|
||||
d := poolDialer{t: t}
|
||||
p := &redis.Pool{
|
||||
MaxIdle: 2,
|
||||
Dial: d.dial,
|
||||
TestOnBorrow: func(redis.Conn, time.Time) error { return redis.Error("BLAH") },
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
c := p.Get()
|
||||
c.Do("PING")
|
||||
c.Close()
|
||||
}
|
||||
d.check("1", p, 10, 1)
|
||||
p.Close()
|
||||
}
|
||||
|
||||
func TestPoolMaxActive(t *testing.T) {
|
||||
d := poolDialer{t: t}
|
||||
p := &redis.Pool{
|
||||
MaxIdle: 2,
|
||||
MaxActive: 2,
|
||||
Dial: d.dial,
|
||||
}
|
||||
c1 := p.Get()
|
||||
c1.Do("PING")
|
||||
c2 := p.Get()
|
||||
c2.Do("PING")
|
||||
|
||||
d.check("1", p, 2, 2)
|
||||
|
||||
c3 := p.Get()
|
||||
if _, err := c3.Do("PING"); err != redis.ErrPoolExhausted {
|
||||
t.Errorf("expected pool exhausted")
|
||||
}
|
||||
|
||||
c3.Close()
|
||||
d.check("2", p, 2, 2)
|
||||
c2.Close()
|
||||
d.check("3", p, 2, 2)
|
||||
|
||||
c3 = p.Get()
|
||||
if _, err := c3.Do("PING"); err != nil {
|
||||
t.Errorf("expected good channel, err=%v", err)
|
||||
}
|
||||
c3.Close()
|
||||
|
||||
d.check("4", p, 2, 2)
|
||||
p.Close()
|
||||
}
|
||||
|
||||
func TestPoolMonitorCleanup(t *testing.T) {
|
||||
d := poolDialer{t: t}
|
||||
p := &redis.Pool{
|
||||
MaxIdle: 2,
|
||||
MaxActive: 2,
|
||||
Dial: d.dial,
|
||||
}
|
||||
c := p.Get()
|
||||
c.Send("MONITOR")
|
||||
c.Close()
|
||||
|
||||
d.check("", p, 1, 0)
|
||||
p.Close()
|
||||
}
|
||||
|
||||
func TestPoolPubSubCleanup(t *testing.T) {
|
||||
d := poolDialer{t: t}
|
||||
p := &redis.Pool{
|
||||
MaxIdle: 2,
|
||||
MaxActive: 2,
|
||||
Dial: d.dial,
|
||||
}
|
||||
|
||||
c := p.Get()
|
||||
c.Send("SUBSCRIBE", "x")
|
||||
c.Close()
|
||||
|
||||
want := []string{"SUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "ECHO"}
|
||||
if !reflect.DeepEqual(d.commands, want) {
|
||||
t.Errorf("got commands %v, want %v", d.commands, want)
|
||||
}
|
||||
d.commands = nil
|
||||
|
||||
c = p.Get()
|
||||
c.Send("PSUBSCRIBE", "x*")
|
||||
c.Close()
|
||||
|
||||
want = []string{"PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "ECHO"}
|
||||
if !reflect.DeepEqual(d.commands, want) {
|
||||
t.Errorf("got commands %v, want %v", d.commands, want)
|
||||
}
|
||||
d.commands = nil
|
||||
|
||||
p.Close()
|
||||
}
|
||||
|
||||
func TestPoolTransactionCleanup(t *testing.T) {
|
||||
d := poolDialer{t: t}
|
||||
p := &redis.Pool{
|
||||
MaxIdle: 2,
|
||||
MaxActive: 2,
|
||||
Dial: d.dial,
|
||||
}
|
||||
|
||||
c := p.Get()
|
||||
c.Do("WATCH", "key")
|
||||
c.Do("PING")
|
||||
c.Close()
|
||||
|
||||
want := []string{"WATCH", "PING", "UNWATCH"}
|
||||
if !reflect.DeepEqual(d.commands, want) {
|
||||
t.Errorf("got commands %v, want %v", d.commands, want)
|
||||
}
|
||||
d.commands = nil
|
||||
|
||||
c = p.Get()
|
||||
c.Do("WATCH", "key")
|
||||
c.Do("UNWATCH")
|
||||
c.Do("PING")
|
||||
c.Close()
|
||||
|
||||
want = []string{"WATCH", "UNWATCH", "PING"}
|
||||
if !reflect.DeepEqual(d.commands, want) {
|
||||
t.Errorf("got commands %v, want %v", d.commands, want)
|
||||
}
|
||||
d.commands = nil
|
||||
|
||||
c = p.Get()
|
||||
c.Do("WATCH", "key")
|
||||
c.Do("MULTI")
|
||||
c.Do("PING")
|
||||
c.Close()
|
||||
|
||||
want = []string{"WATCH", "MULTI", "PING", "DISCARD"}
|
||||
if !reflect.DeepEqual(d.commands, want) {
|
||||
t.Errorf("got commands %v, want %v", d.commands, want)
|
||||
}
|
||||
d.commands = nil
|
||||
|
||||
c = p.Get()
|
||||
c.Do("WATCH", "key")
|
||||
c.Do("MULTI")
|
||||
c.Do("DISCARD")
|
||||
c.Do("PING")
|
||||
c.Close()
|
||||
|
||||
want = []string{"WATCH", "MULTI", "DISCARD", "PING"}
|
||||
if !reflect.DeepEqual(d.commands, want) {
|
||||
t.Errorf("got commands %v, want %v", d.commands, want)
|
||||
}
|
||||
d.commands = nil
|
||||
|
||||
c = p.Get()
|
||||
c.Do("WATCH", "key")
|
||||
c.Do("MULTI")
|
||||
c.Do("EXEC")
|
||||
c.Do("PING")
|
||||
c.Close()
|
||||
|
||||
want = []string{"WATCH", "MULTI", "EXEC", "PING"}
|
||||
if !reflect.DeepEqual(d.commands, want) {
|
||||
t.Errorf("got commands %v, want %v", d.commands, want)
|
||||
}
|
||||
d.commands = nil
|
||||
|
||||
p.Close()
|
||||
}
|
||||
|
||||
func startGoroutines(p *redis.Pool, cmd string, args ...interface{}) chan error {
|
||||
errs := make(chan error, 10)
|
||||
for i := 0; i < cap(errs); i++ {
|
||||
go func() {
|
||||
c := p.Get()
|
||||
_, err := c.Do(cmd, args...)
|
||||
errs <- err
|
||||
c.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for goroutines to block.
|
||||
time.Sleep(time.Second / 4)
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
func TestWaitPool(t *testing.T) {
|
||||
d := poolDialer{t: t}
|
||||
p := &redis.Pool{
|
||||
MaxIdle: 1,
|
||||
MaxActive: 1,
|
||||
Dial: d.dial,
|
||||
Wait: true,
|
||||
}
|
||||
defer p.Close()
|
||||
c := p.Get()
|
||||
errs := startGoroutines(p, "PING")
|
||||
d.check("before close", p, 1, 1)
|
||||
c.Close()
|
||||
timeout := time.After(2 * time.Second)
|
||||
for i := 0; i < cap(errs); i++ {
|
||||
select {
|
||||
case err := <-errs:
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
case <-timeout:
|
||||
t.Fatalf("timeout waiting for blocked goroutine %d", i)
|
||||
}
|
||||
}
|
||||
d.check("done", p, 1, 1)
|
||||
}
|
||||
|
||||
func TestWaitPoolClose(t *testing.T) {
|
||||
d := poolDialer{t: t}
|
||||
p := &redis.Pool{
|
||||
MaxIdle: 1,
|
||||
MaxActive: 1,
|
||||
Dial: d.dial,
|
||||
Wait: true,
|
||||
}
|
||||
c := p.Get()
|
||||
if _, err := c.Do("PING"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
errs := startGoroutines(p, "PING")
|
||||
d.check("before close", p, 1, 1)
|
||||
p.Close()
|
||||
timeout := time.After(2 * time.Second)
|
||||
for i := 0; i < cap(errs); i++ {
|
||||
select {
|
||||
case err := <-errs:
|
||||
switch err {
|
||||
case nil:
|
||||
t.Fatal("blocked goroutine did not get error")
|
||||
case redis.ErrPoolExhausted:
|
||||
t.Fatal("blocked goroutine got pool exhausted error")
|
||||
}
|
||||
case <-timeout:
|
||||
t.Fatal("timeout waiting for blocked goroutine")
|
||||
}
|
||||
}
|
||||
c.Close()
|
||||
d.check("done", p, 1, 0)
|
||||
}
|
||||
|
||||
func TestWaitPoolCommandError(t *testing.T) {
|
||||
testErr := errors.New("test")
|
||||
d := poolDialer{t: t}
|
||||
p := &redis.Pool{
|
||||
MaxIdle: 1,
|
||||
MaxActive: 1,
|
||||
Dial: d.dial,
|
||||
Wait: true,
|
||||
}
|
||||
defer p.Close()
|
||||
c := p.Get()
|
||||
errs := startGoroutines(p, "ERR", testErr)
|
||||
d.check("before close", p, 1, 1)
|
||||
c.Close()
|
||||
timeout := time.After(2 * time.Second)
|
||||
for i := 0; i < cap(errs); i++ {
|
||||
select {
|
||||
case err := <-errs:
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
case <-timeout:
|
||||
t.Fatalf("timeout waiting for blocked goroutine %d", i)
|
||||
}
|
||||
}
|
||||
d.check("done", p, cap(errs), 0)
|
||||
}
|
||||
|
||||
func TestWaitPoolDialError(t *testing.T) {
|
||||
testErr := errors.New("test")
|
||||
d := poolDialer{t: t}
|
||||
p := &redis.Pool{
|
||||
MaxIdle: 1,
|
||||
MaxActive: 1,
|
||||
Dial: d.dial,
|
||||
Wait: true,
|
||||
}
|
||||
defer p.Close()
|
||||
c := p.Get()
|
||||
errs := startGoroutines(p, "ERR", testErr)
|
||||
d.check("before close", p, 1, 1)
|
||||
|
||||
d.dialErr = errors.New("dial")
|
||||
c.Close()
|
||||
|
||||
nilCount := 0
|
||||
errCount := 0
|
||||
timeout := time.After(2 * time.Second)
|
||||
for i := 0; i < cap(errs); i++ {
|
||||
select {
|
||||
case err := <-errs:
|
||||
switch err {
|
||||
case nil:
|
||||
nilCount++
|
||||
case d.dialErr:
|
||||
errCount++
|
||||
default:
|
||||
t.Fatalf("expected dial error or nil, got %v", err)
|
||||
}
|
||||
case <-timeout:
|
||||
t.Fatalf("timeout waiting for blocked goroutine %d", i)
|
||||
}
|
||||
}
|
||||
if nilCount != 1 {
|
||||
t.Errorf("expected one nil error, got %d", nilCount)
|
||||
}
|
||||
if errCount != cap(errs)-1 {
|
||||
t.Errorf("expected %d dial erors, got %d", cap(errs)-1, errCount)
|
||||
}
|
||||
d.check("done", p, cap(errs), 0)
|
||||
}
|
||||
|
||||
// Borrowing requires us to iterate over the idle connections, unlock the pool,
|
||||
// and perform a blocking operation to check the connection still works. If
|
||||
// TestOnBorrow fails, we must reacquire the lock and continue iteration. This
|
||||
// test ensures that iteration will work correctly if multiple threads are
|
||||
// iterating simultaneously.
|
||||
func TestLocking_TestOnBorrowFails_PoolDoesntCrash(t *testing.T) {
|
||||
count := 100
|
||||
|
||||
// First we'll Create a pool where the pilfering of idle connections fails.
|
||||
d := poolDialer{t: t}
|
||||
p := &redis.Pool{
|
||||
MaxIdle: count,
|
||||
MaxActive: count,
|
||||
Dial: d.dial,
|
||||
TestOnBorrow: func(c redis.Conn, t time.Time) error {
|
||||
return errors.New("No way back into the real world.")
|
||||
},
|
||||
}
|
||||
defer p.Close()
|
||||
|
||||
// Fill the pool with idle connections.
|
||||
b1 := sync.WaitGroup{}
|
||||
b1.Add(count)
|
||||
b2 := sync.WaitGroup{}
|
||||
b2.Add(count)
|
||||
for i := 0; i < count; i++ {
|
||||
go func() {
|
||||
c := p.Get()
|
||||
if c.Err() != nil {
|
||||
t.Errorf("pool get failed: %v", c.Err())
|
||||
}
|
||||
b1.Done()
|
||||
b1.Wait()
|
||||
c.Close()
|
||||
b2.Done()
|
||||
}()
|
||||
}
|
||||
b2.Wait()
|
||||
if d.dialed != count {
|
||||
t.Errorf("Expected %d dials, got %d", count, d.dialed)
|
||||
}
|
||||
|
||||
// Spawn a bunch of goroutines to thrash the pool.
|
||||
b2.Add(count)
|
||||
for i := 0; i < count; i++ {
|
||||
go func() {
|
||||
c := p.Get()
|
||||
if c.Err() != nil {
|
||||
t.Errorf("pool get failed: %v", c.Err())
|
||||
}
|
||||
c.Close()
|
||||
b2.Done()
|
||||
}()
|
||||
}
|
||||
b2.Wait()
|
||||
if d.dialed != count*2 {
|
||||
t.Errorf("Expected %d dials, got %d", count*2, d.dialed)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPoolGet(b *testing.B) {
|
||||
b.StopTimer()
|
||||
p := redis.Pool{Dial: redistest.Dial, MaxIdle: 2}
|
||||
c := p.Get()
|
||||
if err := c.Err(); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
c.Close()
|
||||
defer p.Close()
|
||||
b.StartTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
c = p.Get()
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPoolGetErr(b *testing.B) {
|
||||
b.StopTimer()
|
||||
p := redis.Pool{Dial: redistest.Dial, MaxIdle: 2}
|
||||
c := p.Get()
|
||||
if err := c.Err(); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
c.Close()
|
||||
defer p.Close()
|
||||
b.StartTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
c = p.Get()
|
||||
if err := c.Err(); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPoolGetPing(b *testing.B) {
|
||||
b.StopTimer()
|
||||
p := redis.Pool{Dial: redistest.Dial, MaxIdle: 2}
|
||||
c := p.Get()
|
||||
if err := c.Err(); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
c.Close()
|
||||
defer p.Close()
|
||||
b.StartTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
c = p.Get()
|
||||
if _, err := c.Do("PING"); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
c.Close()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,144 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis
|
||||
|
||||
import "errors"
|
||||
|
||||
// Subscription represents a subscribe or unsubscribe notification.
|
||||
type Subscription struct {
|
||||
|
||||
// Kind is "subscribe", "unsubscribe", "psubscribe" or "punsubscribe"
|
||||
Kind string
|
||||
|
||||
// The channel that was changed.
|
||||
Channel string
|
||||
|
||||
// The current number of subscriptions for connection.
|
||||
Count int
|
||||
}
|
||||
|
||||
// Message represents a message notification.
|
||||
type Message struct {
|
||||
|
||||
// The originating channel.
|
||||
Channel string
|
||||
|
||||
// The message data.
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// PMessage represents a pmessage notification.
|
||||
type PMessage struct {
|
||||
|
||||
// The matched pattern.
|
||||
Pattern string
|
||||
|
||||
// The originating channel.
|
||||
Channel string
|
||||
|
||||
// The message data.
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// Pong represents a pubsub pong notification.
|
||||
type Pong struct {
|
||||
Data string
|
||||
}
|
||||
|
||||
// PubSubConn wraps a Conn with convenience methods for subscribers.
|
||||
type PubSubConn struct {
|
||||
Conn Conn
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (c PubSubConn) Close() error {
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// Subscribe subscribes the connection to the specified channels.
|
||||
func (c PubSubConn) Subscribe(channel ...interface{}) error {
|
||||
c.Conn.Send("SUBSCRIBE", channel...)
|
||||
return c.Conn.Flush()
|
||||
}
|
||||
|
||||
// PSubscribe subscribes the connection to the given patterns.
|
||||
func (c PubSubConn) PSubscribe(channel ...interface{}) error {
|
||||
c.Conn.Send("PSUBSCRIBE", channel...)
|
||||
return c.Conn.Flush()
|
||||
}
|
||||
|
||||
// Unsubscribe unsubscribes the connection from the given channels, or from all
|
||||
// of them if none is given.
|
||||
func (c PubSubConn) Unsubscribe(channel ...interface{}) error {
|
||||
c.Conn.Send("UNSUBSCRIBE", channel...)
|
||||
return c.Conn.Flush()
|
||||
}
|
||||
|
||||
// PUnsubscribe unsubscribes the connection from the given patterns, or from all
|
||||
// of them if none is given.
|
||||
func (c PubSubConn) PUnsubscribe(channel ...interface{}) error {
|
||||
c.Conn.Send("PUNSUBSCRIBE", channel...)
|
||||
return c.Conn.Flush()
|
||||
}
|
||||
|
||||
// Ping sends a PING to the server with the specified data.
|
||||
func (c PubSubConn) Ping(data string) error {
|
||||
c.Conn.Send("PING", data)
|
||||
return c.Conn.Flush()
|
||||
}
|
||||
|
||||
// Receive returns a pushed message as a Subscription, Message, PMessage, Pong
|
||||
// or error. The return value is intended to be used directly in a type switch
|
||||
// as illustrated in the PubSubConn example.
|
||||
func (c PubSubConn) Receive() interface{} {
|
||||
reply, err := Values(c.Conn.Receive())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var kind string
|
||||
reply, err = Scan(reply, &kind)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case "message":
|
||||
var m Message
|
||||
if _, err := Scan(reply, &m.Channel, &m.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
return m
|
||||
case "pmessage":
|
||||
var pm PMessage
|
||||
if _, err := Scan(reply, &pm.Pattern, &pm.Channel, &pm.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
return pm
|
||||
case "subscribe", "psubscribe", "unsubscribe", "punsubscribe":
|
||||
s := Subscription{Kind: kind}
|
||||
if _, err := Scan(reply, &s.Channel, &s.Count); err != nil {
|
||||
return err
|
||||
}
|
||||
return s
|
||||
case "pong":
|
||||
var p Pong
|
||||
if _, err := Scan(reply, &p.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
return p
|
||||
}
|
||||
return errors.New("redigo: unknown pubsub notification")
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/garyburd/redigo/internal/redistest"
|
||||
"github.com/garyburd/redigo/redis"
|
||||
)
|
||||
|
||||
func publish(channel, value interface{}) {
|
||||
c, err := dial()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer c.Close()
|
||||
c.Do("PUBLISH", channel, value)
|
||||
}
|
||||
|
||||
// Applications can receive pushed messages from one goroutine and manage subscriptions from another goroutine.
|
||||
func ExamplePubSubConn() {
|
||||
c, err := dial()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer c.Close()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
psc := redis.PubSubConn{Conn: c}
|
||||
|
||||
// This goroutine receives and prints pushed notifications from the server.
|
||||
// The goroutine exits when the connection is unsubscribed from all
|
||||
// channels or there is an error.
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
switch n := psc.Receive().(type) {
|
||||
case redis.Message:
|
||||
fmt.Printf("Message: %s %s\n", n.Channel, n.Data)
|
||||
case redis.PMessage:
|
||||
fmt.Printf("PMessage: %s %s %s\n", n.Pattern, n.Channel, n.Data)
|
||||
case redis.Subscription:
|
||||
fmt.Printf("Subscription: %s %s %d\n", n.Kind, n.Channel, n.Count)
|
||||
if n.Count == 0 {
|
||||
return
|
||||
}
|
||||
case error:
|
||||
fmt.Printf("error: %v\n", n)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// This goroutine manages subscriptions for the connection.
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
psc.Subscribe("example")
|
||||
psc.PSubscribe("p*")
|
||||
|
||||
// The following function calls publish a message using another
|
||||
// connection to the Redis server.
|
||||
publish("example", "hello")
|
||||
publish("example", "world")
|
||||
publish("pexample", "foo")
|
||||
publish("pexample", "bar")
|
||||
|
||||
// Unsubscribe from all connections. This will cause the receiving
|
||||
// goroutine to exit.
|
||||
psc.Unsubscribe()
|
||||
psc.PUnsubscribe()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Output:
|
||||
// Subscription: subscribe example 1
|
||||
// Subscription: psubscribe p* 2
|
||||
// Message: example hello
|
||||
// Message: example world
|
||||
// PMessage: p* pexample foo
|
||||
// PMessage: p* pexample bar
|
||||
// Subscription: unsubscribe example 1
|
||||
// Subscription: punsubscribe p* 0
|
||||
}
|
||||
|
||||
func expectPushed(t *testing.T, c redis.PubSubConn, message string, expected interface{}) {
|
||||
actual := c.Receive()
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Errorf("%s = %v, want %v", message, actual, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushed(t *testing.T) {
|
||||
pc, err := redistest.Dial()
|
||||
if err != nil {
|
||||
t.Fatalf("error connection to database, %v", err)
|
||||
}
|
||||
defer pc.Close()
|
||||
|
||||
nc, err := net.Dial("tcp", ":6379")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer nc.Close()
|
||||
nc.SetReadDeadline(time.Now().Add(4 * time.Second))
|
||||
|
||||
c := redis.PubSubConn{Conn: redis.NewConn(nc, 0, 0)}
|
||||
|
||||
c.Subscribe("c1")
|
||||
expectPushed(t, c, "Subscribe(c1)", redis.Subscription{Kind: "subscribe", Channel: "c1", Count: 1})
|
||||
c.Subscribe("c2")
|
||||
expectPushed(t, c, "Subscribe(c2)", redis.Subscription{Kind: "subscribe", Channel: "c2", Count: 2})
|
||||
c.PSubscribe("p1")
|
||||
expectPushed(t, c, "PSubscribe(p1)", redis.Subscription{Kind: "psubscribe", Channel: "p1", Count: 3})
|
||||
c.PSubscribe("p2")
|
||||
expectPushed(t, c, "PSubscribe(p2)", redis.Subscription{Kind: "psubscribe", Channel: "p2", Count: 4})
|
||||
c.PUnsubscribe()
|
||||
expectPushed(t, c, "Punsubscribe(p1)", redis.Subscription{Kind: "punsubscribe", Channel: "p1", Count: 3})
|
||||
expectPushed(t, c, "Punsubscribe()", redis.Subscription{Kind: "punsubscribe", Channel: "p2", Count: 2})
|
||||
|
||||
pc.Do("PUBLISH", "c1", "hello")
|
||||
expectPushed(t, c, "PUBLISH c1 hello", redis.Message{Channel: "c1", Data: []byte("hello")})
|
||||
|
||||
c.Ping("hello")
|
||||
expectPushed(t, c, `Ping("hello")`, redis.Pong{"hello"})
|
||||
|
||||
c.Conn.Send("PING")
|
||||
c.Conn.Flush()
|
||||
expectPushed(t, c, `Send("PING")`, redis.Pong{})
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis
|
||||
|
||||
// Error represents an error returned in a command reply.
|
||||
type Error string
|
||||
|
||||
func (err Error) Error() string { return string(err) }
|
||||
|
||||
// Conn represents a connection to a Redis server.
|
||||
type Conn interface {
|
||||
// Close closes the connection.
|
||||
Close() error
|
||||
|
||||
// Err returns a non-nil value if the connection is broken. The returned
|
||||
// value is either the first non-nil value returned from the underlying
|
||||
// network connection or a protocol parsing error. Applications should
|
||||
// close broken connections.
|
||||
Err() error
|
||||
|
||||
// Do sends a command to the server and returns the received reply.
|
||||
Do(commandName string, args ...interface{}) (reply interface{}, err error)
|
||||
|
||||
// Send writes the command to the client's output buffer.
|
||||
Send(commandName string, args ...interface{}) error
|
||||
|
||||
// Flush flushes the output buffer to the Redis server.
|
||||
Flush() error
|
||||
|
||||
// Receive receives a single reply from the Redis server
|
||||
Receive() (reply interface{}, err error)
|
||||
}
|
|
@ -0,0 +1,364 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// ErrNil indicates that a reply value is nil.
|
||||
var ErrNil = errors.New("redigo: nil returned")
|
||||
|
||||
// Int is a helper that converts a command reply to an integer. If err is not
|
||||
// equal to nil, then Int returns 0, err. Otherwise, Int converts the
|
||||
// reply to an int as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// integer int(reply), nil
|
||||
// bulk string parsed reply, nil
|
||||
// nil 0, ErrNil
|
||||
// other 0, error
|
||||
func Int(reply interface{}, err error) (int, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch reply := reply.(type) {
|
||||
case int64:
|
||||
x := int(reply)
|
||||
if int64(x) != reply {
|
||||
return 0, strconv.ErrRange
|
||||
}
|
||||
return x, nil
|
||||
case []byte:
|
||||
n, err := strconv.ParseInt(string(reply), 10, 0)
|
||||
return int(n), err
|
||||
case nil:
|
||||
return 0, ErrNil
|
||||
case Error:
|
||||
return 0, reply
|
||||
}
|
||||
return 0, fmt.Errorf("redigo: unexpected type for Int, got type %T", reply)
|
||||
}
|
||||
|
||||
// Int64 is a helper that converts a command reply to 64 bit integer. If err is
|
||||
// not equal to nil, then Int returns 0, err. Otherwise, Int64 converts the
|
||||
// reply to an int64 as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// integer reply, nil
|
||||
// bulk string parsed reply, nil
|
||||
// nil 0, ErrNil
|
||||
// other 0, error
|
||||
func Int64(reply interface{}, err error) (int64, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch reply := reply.(type) {
|
||||
case int64:
|
||||
return reply, nil
|
||||
case []byte:
|
||||
n, err := strconv.ParseInt(string(reply), 10, 64)
|
||||
return n, err
|
||||
case nil:
|
||||
return 0, ErrNil
|
||||
case Error:
|
||||
return 0, reply
|
||||
}
|
||||
return 0, fmt.Errorf("redigo: unexpected type for Int64, got type %T", reply)
|
||||
}
|
||||
|
||||
var errNegativeInt = errors.New("redigo: unexpected value for Uint64")
|
||||
|
||||
// Uint64 is a helper that converts a command reply to 64 bit integer. If err is
|
||||
// not equal to nil, then Int returns 0, err. Otherwise, Int64 converts the
|
||||
// reply to an int64 as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// integer reply, nil
|
||||
// bulk string parsed reply, nil
|
||||
// nil 0, ErrNil
|
||||
// other 0, error
|
||||
func Uint64(reply interface{}, err error) (uint64, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch reply := reply.(type) {
|
||||
case int64:
|
||||
if reply < 0 {
|
||||
return 0, errNegativeInt
|
||||
}
|
||||
return uint64(reply), nil
|
||||
case []byte:
|
||||
n, err := strconv.ParseUint(string(reply), 10, 64)
|
||||
return n, err
|
||||
case nil:
|
||||
return 0, ErrNil
|
||||
case Error:
|
||||
return 0, reply
|
||||
}
|
||||
return 0, fmt.Errorf("redigo: unexpected type for Uint64, got type %T", reply)
|
||||
}
|
||||
|
||||
// Float64 is a helper that converts a command reply to 64 bit float. If err is
|
||||
// not equal to nil, then Float64 returns 0, err. Otherwise, Float64 converts
|
||||
// the reply to an int as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// bulk string parsed reply, nil
|
||||
// nil 0, ErrNil
|
||||
// other 0, error
|
||||
func Float64(reply interface{}, err error) (float64, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch reply := reply.(type) {
|
||||
case []byte:
|
||||
n, err := strconv.ParseFloat(string(reply), 64)
|
||||
return n, err
|
||||
case nil:
|
||||
return 0, ErrNil
|
||||
case Error:
|
||||
return 0, reply
|
||||
}
|
||||
return 0, fmt.Errorf("redigo: unexpected type for Float64, got type %T", reply)
|
||||
}
|
||||
|
||||
// String is a helper that converts a command reply to a string. If err is not
|
||||
// equal to nil, then String returns "", err. Otherwise String converts the
|
||||
// reply to a string as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// bulk string string(reply), nil
|
||||
// simple string reply, nil
|
||||
// nil "", ErrNil
|
||||
// other "", error
|
||||
func String(reply interface{}, err error) (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
switch reply := reply.(type) {
|
||||
case []byte:
|
||||
return string(reply), nil
|
||||
case string:
|
||||
return reply, nil
|
||||
case nil:
|
||||
return "", ErrNil
|
||||
case Error:
|
||||
return "", reply
|
||||
}
|
||||
return "", fmt.Errorf("redigo: unexpected type for String, got type %T", reply)
|
||||
}
|
||||
|
||||
// Bytes is a helper that converts a command reply to a slice of bytes. If err
|
||||
// is not equal to nil, then Bytes returns nil, err. Otherwise Bytes converts
|
||||
// the reply to a slice of bytes as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// bulk string reply, nil
|
||||
// simple string []byte(reply), nil
|
||||
// nil nil, ErrNil
|
||||
// other nil, error
|
||||
func Bytes(reply interface{}, err error) ([]byte, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch reply := reply.(type) {
|
||||
case []byte:
|
||||
return reply, nil
|
||||
case string:
|
||||
return []byte(reply), nil
|
||||
case nil:
|
||||
return nil, ErrNil
|
||||
case Error:
|
||||
return nil, reply
|
||||
}
|
||||
return nil, fmt.Errorf("redigo: unexpected type for Bytes, got type %T", reply)
|
||||
}
|
||||
|
||||
// Bool is a helper that converts a command reply to a boolean. If err is not
|
||||
// equal to nil, then Bool returns false, err. Otherwise Bool converts the
|
||||
// reply to boolean as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// integer value != 0, nil
|
||||
// bulk string strconv.ParseBool(reply)
|
||||
// nil false, ErrNil
|
||||
// other false, error
|
||||
func Bool(reply interface{}, err error) (bool, error) {
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
switch reply := reply.(type) {
|
||||
case int64:
|
||||
return reply != 0, nil
|
||||
case []byte:
|
||||
return strconv.ParseBool(string(reply))
|
||||
case nil:
|
||||
return false, ErrNil
|
||||
case Error:
|
||||
return false, reply
|
||||
}
|
||||
return false, fmt.Errorf("redigo: unexpected type for Bool, got type %T", reply)
|
||||
}
|
||||
|
||||
// MultiBulk is deprecated. Use Values.
|
||||
func MultiBulk(reply interface{}, err error) ([]interface{}, error) { return Values(reply, err) }
|
||||
|
||||
// Values is a helper that converts an array command reply to a []interface{}.
|
||||
// If err is not equal to nil, then Values returns nil, err. Otherwise, Values
|
||||
// converts the reply as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// array reply, nil
|
||||
// nil nil, ErrNil
|
||||
// other nil, error
|
||||
func Values(reply interface{}, err error) ([]interface{}, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch reply := reply.(type) {
|
||||
case []interface{}:
|
||||
return reply, nil
|
||||
case nil:
|
||||
return nil, ErrNil
|
||||
case Error:
|
||||
return nil, reply
|
||||
}
|
||||
return nil, fmt.Errorf("redigo: unexpected type for Values, got type %T", reply)
|
||||
}
|
||||
|
||||
// Strings is a helper that converts an array command reply to a []string. If
|
||||
// err is not equal to nil, then Strings returns nil, err. Nil array items are
|
||||
// converted to "" in the output slice. Strings returns an error if an array
|
||||
// item is not a bulk string or nil.
|
||||
func Strings(reply interface{}, err error) ([]string, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch reply := reply.(type) {
|
||||
case []interface{}:
|
||||
result := make([]string, len(reply))
|
||||
for i := range reply {
|
||||
if reply[i] == nil {
|
||||
continue
|
||||
}
|
||||
p, ok := reply[i].([]byte)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("redigo: unexpected element type for Strings, got type %T", reply[i])
|
||||
}
|
||||
result[i] = string(p)
|
||||
}
|
||||
return result, nil
|
||||
case nil:
|
||||
return nil, ErrNil
|
||||
case Error:
|
||||
return nil, reply
|
||||
}
|
||||
return nil, fmt.Errorf("redigo: unexpected type for Strings, got type %T", reply)
|
||||
}
|
||||
|
||||
// Ints is a helper that converts an array command reply to a []int. If
|
||||
// err is not equal to nil, then Ints returns nil, err.
|
||||
func Ints(reply interface{}, err error) ([]int, error) {
|
||||
var ints []int
|
||||
if reply == nil {
|
||||
return ints, ErrNil
|
||||
}
|
||||
values, err := Values(reply, err)
|
||||
if err != nil {
|
||||
return ints, err
|
||||
}
|
||||
if err := ScanSlice(values, &ints); err != nil {
|
||||
return ints, err
|
||||
}
|
||||
return ints, nil
|
||||
}
|
||||
|
||||
// StringMap is a helper that converts an array of strings (alternating key, value)
|
||||
// into a map[string]string. The HGETALL and CONFIG GET commands return replies in this format.
|
||||
// Requires an even number of values in result.
|
||||
func StringMap(result interface{}, err error) (map[string]string, error) {
|
||||
values, err := Values(result, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(values)%2 != 0 {
|
||||
return nil, errors.New("redigo: StringMap expects even number of values result")
|
||||
}
|
||||
m := make(map[string]string, len(values)/2)
|
||||
for i := 0; i < len(values); i += 2 {
|
||||
key, okKey := values[i].([]byte)
|
||||
value, okValue := values[i+1].([]byte)
|
||||
if !okKey || !okValue {
|
||||
return nil, errors.New("redigo: ScanMap key not a bulk string value")
|
||||
}
|
||||
m[string(key)] = string(value)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// IntMap is a helper that converts an array of strings (alternating key, value)
|
||||
// into a map[string]int. The HGETALL commands return replies in this format.
|
||||
// Requires an even number of values in result.
|
||||
func IntMap(result interface{}, err error) (map[string]int, error) {
|
||||
values, err := Values(result, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(values)%2 != 0 {
|
||||
return nil, errors.New("redigo: IntMap expects even number of values result")
|
||||
}
|
||||
m := make(map[string]int, len(values)/2)
|
||||
for i := 0; i < len(values); i += 2 {
|
||||
key, ok := values[i].([]byte)
|
||||
if !ok {
|
||||
return nil, errors.New("redigo: ScanMap key not a bulk string value")
|
||||
}
|
||||
value, err := Int(values[i+1], nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m[string(key)] = value
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Int64Map is a helper that converts an array of strings (alternating key, value)
|
||||
// into a map[string]int64. The HGETALL commands return replies in this format.
|
||||
// Requires an even number of values in result.
|
||||
func Int64Map(result interface{}, err error) (map[string]int64, error) {
|
||||
values, err := Values(result, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(values)%2 != 0 {
|
||||
return nil, errors.New("redigo: Int64Map expects even number of values result")
|
||||
}
|
||||
m := make(map[string]int64, len(values)/2)
|
||||
for i := 0; i < len(values); i += 2 {
|
||||
key, ok := values[i].([]byte)
|
||||
if !ok {
|
||||
return nil, errors.New("redigo: ScanMap key not a bulk string value")
|
||||
}
|
||||
value, err := Int64(values[i+1], nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m[string(key)] = value
|
||||
}
|
||||
return m, nil
|
||||
}
|
|
@ -0,0 +1,166 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/garyburd/redigo/internal/redistest"
|
||||
"github.com/garyburd/redigo/redis"
|
||||
)
|
||||
|
||||
type valueError struct {
|
||||
v interface{}
|
||||
err error
|
||||
}
|
||||
|
||||
func ve(v interface{}, err error) valueError {
|
||||
return valueError{v, err}
|
||||
}
|
||||
|
||||
var replyTests = []struct {
|
||||
name interface{}
|
||||
actual valueError
|
||||
expected valueError
|
||||
}{
|
||||
{
|
||||
"ints([v1, v2])",
|
||||
ve(redis.Ints([]interface{}{[]byte("4"), []byte("5")}, nil)),
|
||||
ve([]int{4, 5}, nil),
|
||||
},
|
||||
{
|
||||
"ints(nil)",
|
||||
ve(redis.Ints(nil, nil)),
|
||||
ve([]int(nil), redis.ErrNil),
|
||||
},
|
||||
{
|
||||
"strings([v1, v2])",
|
||||
ve(redis.Strings([]interface{}{[]byte("v1"), []byte("v2")}, nil)),
|
||||
ve([]string{"v1", "v2"}, nil),
|
||||
},
|
||||
{
|
||||
"strings(nil)",
|
||||
ve(redis.Strings(nil, nil)),
|
||||
ve([]string(nil), redis.ErrNil),
|
||||
},
|
||||
{
|
||||
"values([v1, v2])",
|
||||
ve(redis.Values([]interface{}{[]byte("v1"), []byte("v2")}, nil)),
|
||||
ve([]interface{}{[]byte("v1"), []byte("v2")}, nil),
|
||||
},
|
||||
{
|
||||
"values(nil)",
|
||||
ve(redis.Values(nil, nil)),
|
||||
ve([]interface{}(nil), redis.ErrNil),
|
||||
},
|
||||
{
|
||||
"float64(1.0)",
|
||||
ve(redis.Float64([]byte("1.0"), nil)),
|
||||
ve(float64(1.0), nil),
|
||||
},
|
||||
{
|
||||
"float64(nil)",
|
||||
ve(redis.Float64(nil, nil)),
|
||||
ve(float64(0.0), redis.ErrNil),
|
||||
},
|
||||
{
|
||||
"uint64(1)",
|
||||
ve(redis.Uint64(int64(1), nil)),
|
||||
ve(uint64(1), nil),
|
||||
},
|
||||
{
|
||||
"uint64(-1)",
|
||||
ve(redis.Uint64(int64(-1), nil)),
|
||||
ve(uint64(0), redis.ErrNegativeInt),
|
||||
},
|
||||
}
|
||||
|
||||
func TestReply(t *testing.T) {
|
||||
for _, rt := range replyTests {
|
||||
if rt.actual.err != rt.expected.err {
|
||||
t.Errorf("%s returned err %v, want %v", rt.name, rt.actual.err, rt.expected.err)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(rt.actual.v, rt.expected.v) {
|
||||
t.Errorf("%s=%+v, want %+v", rt.name, rt.actual.v, rt.expected.v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dial wraps DialTestDB() with a more suitable function name for examples.
|
||||
func dial() (redis.Conn, error) {
|
||||
return redistest.Dial()
|
||||
}
|
||||
|
||||
func ExampleBool() {
|
||||
c, err := dial()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
c.Do("SET", "foo", 1)
|
||||
exists, _ := redis.Bool(c.Do("EXISTS", "foo"))
|
||||
fmt.Printf("%#v\n", exists)
|
||||
// Output:
|
||||
// true
|
||||
}
|
||||
|
||||
func ExampleInt() {
|
||||
c, err := dial()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
c.Do("SET", "k1", 1)
|
||||
n, _ := redis.Int(c.Do("GET", "k1"))
|
||||
fmt.Printf("%#v\n", n)
|
||||
n, _ = redis.Int(c.Do("INCR", "k1"))
|
||||
fmt.Printf("%#v\n", n)
|
||||
// Output:
|
||||
// 1
|
||||
// 2
|
||||
}
|
||||
|
||||
func ExampleInts() {
|
||||
c, err := dial()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
c.Do("SADD", "set_with_integers", 4, 5, 6)
|
||||
ints, _ := redis.Ints(c.Do("SMEMBERS", "set_with_integers"))
|
||||
fmt.Printf("%#v\n", ints)
|
||||
// Output:
|
||||
// []int{4, 5, 6}
|
||||
}
|
||||
|
||||
func ExampleString() {
|
||||
c, err := dial()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
c.Do("SET", "hello", "world")
|
||||
s, err := redis.String(c.Do("GET", "hello"))
|
||||
fmt.Printf("%#v\n", s)
|
||||
// Output:
|
||||
// "world"
|
||||
}
|
|
@ -0,0 +1,513 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func ensureLen(d reflect.Value, n int) {
|
||||
if n > d.Cap() {
|
||||
d.Set(reflect.MakeSlice(d.Type(), n, n))
|
||||
} else {
|
||||
d.SetLen(n)
|
||||
}
|
||||
}
|
||||
|
||||
func cannotConvert(d reflect.Value, s interface{}) error {
|
||||
return fmt.Errorf("redigo: Scan cannot convert from %s to %s",
|
||||
reflect.TypeOf(s), d.Type())
|
||||
}
|
||||
|
||||
func convertAssignBytes(d reflect.Value, s []byte) (err error) {
|
||||
switch d.Type().Kind() {
|
||||
case reflect.Float32, reflect.Float64:
|
||||
var x float64
|
||||
x, err = strconv.ParseFloat(string(s), d.Type().Bits())
|
||||
d.SetFloat(x)
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
var x int64
|
||||
x, err = strconv.ParseInt(string(s), 10, d.Type().Bits())
|
||||
d.SetInt(x)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
var x uint64
|
||||
x, err = strconv.ParseUint(string(s), 10, d.Type().Bits())
|
||||
d.SetUint(x)
|
||||
case reflect.Bool:
|
||||
var x bool
|
||||
x, err = strconv.ParseBool(string(s))
|
||||
d.SetBool(x)
|
||||
case reflect.String:
|
||||
d.SetString(string(s))
|
||||
case reflect.Slice:
|
||||
if d.Type().Elem().Kind() != reflect.Uint8 {
|
||||
err = cannotConvert(d, s)
|
||||
} else {
|
||||
d.SetBytes(s)
|
||||
}
|
||||
default:
|
||||
err = cannotConvert(d, s)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func convertAssignInt(d reflect.Value, s int64) (err error) {
|
||||
switch d.Type().Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
d.SetInt(s)
|
||||
if d.Int() != s {
|
||||
err = strconv.ErrRange
|
||||
d.SetInt(0)
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
if s < 0 {
|
||||
err = strconv.ErrRange
|
||||
} else {
|
||||
x := uint64(s)
|
||||
d.SetUint(x)
|
||||
if d.Uint() != x {
|
||||
err = strconv.ErrRange
|
||||
d.SetUint(0)
|
||||
}
|
||||
}
|
||||
case reflect.Bool:
|
||||
d.SetBool(s != 0)
|
||||
default:
|
||||
err = cannotConvert(d, s)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func convertAssignValue(d reflect.Value, s interface{}) (err error) {
|
||||
switch s := s.(type) {
|
||||
case []byte:
|
||||
err = convertAssignBytes(d, s)
|
||||
case int64:
|
||||
err = convertAssignInt(d, s)
|
||||
default:
|
||||
err = cannotConvert(d, s)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func convertAssignValues(d reflect.Value, s []interface{}) error {
|
||||
if d.Type().Kind() != reflect.Slice {
|
||||
return cannotConvert(d, s)
|
||||
}
|
||||
ensureLen(d, len(s))
|
||||
for i := 0; i < len(s); i++ {
|
||||
if err := convertAssignValue(d.Index(i), s[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertAssign(d interface{}, s interface{}) (err error) {
|
||||
// Handle the most common destination types using type switches and
|
||||
// fall back to reflection for all other types.
|
||||
switch s := s.(type) {
|
||||
case nil:
|
||||
// ingore
|
||||
case []byte:
|
||||
switch d := d.(type) {
|
||||
case *string:
|
||||
*d = string(s)
|
||||
case *int:
|
||||
*d, err = strconv.Atoi(string(s))
|
||||
case *bool:
|
||||
*d, err = strconv.ParseBool(string(s))
|
||||
case *[]byte:
|
||||
*d = s
|
||||
case *interface{}:
|
||||
*d = s
|
||||
case nil:
|
||||
// skip value
|
||||
default:
|
||||
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
|
||||
err = cannotConvert(d, s)
|
||||
} else {
|
||||
err = convertAssignBytes(d.Elem(), s)
|
||||
}
|
||||
}
|
||||
case int64:
|
||||
switch d := d.(type) {
|
||||
case *int:
|
||||
x := int(s)
|
||||
if int64(x) != s {
|
||||
err = strconv.ErrRange
|
||||
x = 0
|
||||
}
|
||||
*d = x
|
||||
case *bool:
|
||||
*d = s != 0
|
||||
case *interface{}:
|
||||
*d = s
|
||||
case nil:
|
||||
// skip value
|
||||
default:
|
||||
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
|
||||
err = cannotConvert(d, s)
|
||||
} else {
|
||||
err = convertAssignInt(d.Elem(), s)
|
||||
}
|
||||
}
|
||||
case []interface{}:
|
||||
switch d := d.(type) {
|
||||
case *[]interface{}:
|
||||
*d = s
|
||||
case *interface{}:
|
||||
*d = s
|
||||
case nil:
|
||||
// skip value
|
||||
default:
|
||||
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
|
||||
err = cannotConvert(d, s)
|
||||
} else {
|
||||
err = convertAssignValues(d.Elem(), s)
|
||||
}
|
||||
}
|
||||
case Error:
|
||||
err = s
|
||||
default:
|
||||
err = cannotConvert(reflect.ValueOf(d), s)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Scan copies from src to the values pointed at by dest.
|
||||
//
|
||||
// The values pointed at by dest must be an integer, float, boolean, string,
|
||||
// []byte, interface{} or slices of these types. Scan uses the standard strconv
|
||||
// package to convert bulk strings to numeric and boolean types.
|
||||
//
|
||||
// If a dest value is nil, then the corresponding src value is skipped.
|
||||
//
|
||||
// If a src element is nil, then the corresponding dest value is not modified.
|
||||
//
|
||||
// To enable easy use of Scan in a loop, Scan returns the slice of src
|
||||
// following the copied values.
|
||||
func Scan(src []interface{}, dest ...interface{}) ([]interface{}, error) {
|
||||
if len(src) < len(dest) {
|
||||
return nil, errors.New("redigo: Scan array short")
|
||||
}
|
||||
var err error
|
||||
for i, d := range dest {
|
||||
err = convertAssign(d, src[i])
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
return src[len(dest):], err
|
||||
}
|
||||
|
||||
type fieldSpec struct {
|
||||
name string
|
||||
index []int
|
||||
//omitEmpty bool
|
||||
}
|
||||
|
||||
type structSpec struct {
|
||||
m map[string]*fieldSpec
|
||||
l []*fieldSpec
|
||||
}
|
||||
|
||||
func (ss *structSpec) fieldSpec(name []byte) *fieldSpec {
|
||||
return ss.m[string(name)]
|
||||
}
|
||||
|
||||
func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec) {
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
switch {
|
||||
case f.PkgPath != "":
|
||||
// Ignore unexported fields.
|
||||
case f.Anonymous:
|
||||
// TODO: Handle pointers. Requires change to decoder and
|
||||
// protection against infinite recursion.
|
||||
if f.Type.Kind() == reflect.Struct {
|
||||
compileStructSpec(f.Type, depth, append(index, i), ss)
|
||||
}
|
||||
default:
|
||||
fs := &fieldSpec{name: f.Name}
|
||||
tag := f.Tag.Get("redis")
|
||||
p := strings.Split(tag, ",")
|
||||
if len(p) > 0 {
|
||||
if p[0] == "-" {
|
||||
continue
|
||||
}
|
||||
if len(p[0]) > 0 {
|
||||
fs.name = p[0]
|
||||
}
|
||||
for _, s := range p[1:] {
|
||||
switch s {
|
||||
//case "omitempty":
|
||||
// fs.omitempty = true
|
||||
default:
|
||||
panic(errors.New("redigo: unknown field flag " + s + " for type " + t.Name()))
|
||||
}
|
||||
}
|
||||
}
|
||||
d, found := depth[fs.name]
|
||||
if !found {
|
||||
d = 1 << 30
|
||||
}
|
||||
switch {
|
||||
case len(index) == d:
|
||||
// At same depth, remove from result.
|
||||
delete(ss.m, fs.name)
|
||||
j := 0
|
||||
for i := 0; i < len(ss.l); i++ {
|
||||
if fs.name != ss.l[i].name {
|
||||
ss.l[j] = ss.l[i]
|
||||
j += 1
|
||||
}
|
||||
}
|
||||
ss.l = ss.l[:j]
|
||||
case len(index) < d:
|
||||
fs.index = make([]int, len(index)+1)
|
||||
copy(fs.index, index)
|
||||
fs.index[len(index)] = i
|
||||
depth[fs.name] = len(index)
|
||||
ss.m[fs.name] = fs
|
||||
ss.l = append(ss.l, fs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
structSpecMutex sync.RWMutex
|
||||
structSpecCache = make(map[reflect.Type]*structSpec)
|
||||
defaultFieldSpec = &fieldSpec{}
|
||||
)
|
||||
|
||||
func structSpecForType(t reflect.Type) *structSpec {
|
||||
|
||||
structSpecMutex.RLock()
|
||||
ss, found := structSpecCache[t]
|
||||
structSpecMutex.RUnlock()
|
||||
if found {
|
||||
return ss
|
||||
}
|
||||
|
||||
structSpecMutex.Lock()
|
||||
defer structSpecMutex.Unlock()
|
||||
ss, found = structSpecCache[t]
|
||||
if found {
|
||||
return ss
|
||||
}
|
||||
|
||||
ss = &structSpec{m: make(map[string]*fieldSpec)}
|
||||
compileStructSpec(t, make(map[string]int), nil, ss)
|
||||
structSpecCache[t] = ss
|
||||
return ss
|
||||
}
|
||||
|
||||
var errScanStructValue = errors.New("redigo: ScanStruct value must be non-nil pointer to a struct")
|
||||
|
||||
// ScanStruct scans alternating names and values from src to a struct. The
|
||||
// HGETALL and CONFIG GET commands return replies in this format.
|
||||
//
|
||||
// ScanStruct uses exported field names to match values in the response. Use
|
||||
// 'redis' field tag to override the name:
|
||||
//
|
||||
// Field int `redis:"myName"`
|
||||
//
|
||||
// Fields with the tag redis:"-" are ignored.
|
||||
//
|
||||
// Integer, float, boolean, string and []byte fields are supported. Scan uses the
|
||||
// standard strconv package to convert bulk string values to numeric and
|
||||
// boolean types.
|
||||
//
|
||||
// If a src element is nil, then the corresponding field is not modified.
|
||||
func ScanStruct(src []interface{}, dest interface{}) error {
|
||||
d := reflect.ValueOf(dest)
|
||||
if d.Kind() != reflect.Ptr || d.IsNil() {
|
||||
return errScanStructValue
|
||||
}
|
||||
d = d.Elem()
|
||||
if d.Kind() != reflect.Struct {
|
||||
return errScanStructValue
|
||||
}
|
||||
ss := structSpecForType(d.Type())
|
||||
|
||||
if len(src)%2 != 0 {
|
||||
return errors.New("redigo: ScanStruct expects even number of values in values")
|
||||
}
|
||||
|
||||
for i := 0; i < len(src); i += 2 {
|
||||
s := src[i+1]
|
||||
if s == nil {
|
||||
continue
|
||||
}
|
||||
name, ok := src[i].([]byte)
|
||||
if !ok {
|
||||
return errors.New("redigo: ScanStruct key not a bulk string value")
|
||||
}
|
||||
fs := ss.fieldSpec(name)
|
||||
if fs == nil {
|
||||
continue
|
||||
}
|
||||
if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
errScanSliceValue = errors.New("redigo: ScanSlice dest must be non-nil pointer to a struct")
|
||||
)
|
||||
|
||||
// ScanSlice scans src to the slice pointed to by dest. The elements the dest
|
||||
// slice must be integer, float, boolean, string, struct or pointer to struct
|
||||
// values.
|
||||
//
|
||||
// Struct fields must be integer, float, boolean or string values. All struct
|
||||
// fields are used unless a subset is specified using fieldNames.
|
||||
func ScanSlice(src []interface{}, dest interface{}, fieldNames ...string) error {
|
||||
d := reflect.ValueOf(dest)
|
||||
if d.Kind() != reflect.Ptr || d.IsNil() {
|
||||
return errScanSliceValue
|
||||
}
|
||||
d = d.Elem()
|
||||
if d.Kind() != reflect.Slice {
|
||||
return errScanSliceValue
|
||||
}
|
||||
|
||||
isPtr := false
|
||||
t := d.Type().Elem()
|
||||
if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct {
|
||||
isPtr = true
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
if t.Kind() != reflect.Struct {
|
||||
ensureLen(d, len(src))
|
||||
for i, s := range src {
|
||||
if s == nil {
|
||||
continue
|
||||
}
|
||||
if err := convertAssignValue(d.Index(i), s); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
ss := structSpecForType(t)
|
||||
fss := ss.l
|
||||
if len(fieldNames) > 0 {
|
||||
fss = make([]*fieldSpec, len(fieldNames))
|
||||
for i, name := range fieldNames {
|
||||
fss[i] = ss.m[name]
|
||||
if fss[i] == nil {
|
||||
return errors.New("redigo: ScanSlice bad field name " + name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(fss) == 0 {
|
||||
return errors.New("redigo: ScanSlice no struct fields")
|
||||
}
|
||||
|
||||
n := len(src) / len(fss)
|
||||
if n*len(fss) != len(src) {
|
||||
return errors.New("redigo: ScanSlice length not a multiple of struct field count")
|
||||
}
|
||||
|
||||
ensureLen(d, n)
|
||||
for i := 0; i < n; i++ {
|
||||
d := d.Index(i)
|
||||
if isPtr {
|
||||
if d.IsNil() {
|
||||
d.Set(reflect.New(t))
|
||||
}
|
||||
d = d.Elem()
|
||||
}
|
||||
for j, fs := range fss {
|
||||
s := src[i*len(fss)+j]
|
||||
if s == nil {
|
||||
continue
|
||||
}
|
||||
if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Args is a helper for constructing command arguments from structured values.
|
||||
type Args []interface{}
|
||||
|
||||
// Add returns the result of appending value to args.
|
||||
func (args Args) Add(value ...interface{}) Args {
|
||||
return append(args, value...)
|
||||
}
|
||||
|
||||
// AddFlat returns the result of appending the flattened value of v to args.
|
||||
//
|
||||
// Maps are flattened by appending the alternating keys and map values to args.
|
||||
//
|
||||
// Slices are flattened by appending the slice elements to args.
|
||||
//
|
||||
// Structs are flattened by appending the alternating names and values of
|
||||
// exported fields to args. If v is a nil struct pointer, then nothing is
|
||||
// appended. The 'redis' field tag overrides struct field names. See ScanStruct
|
||||
// for more information on the use of the 'redis' field tag.
|
||||
//
|
||||
// Other types are appended to args as is.
|
||||
func (args Args) AddFlat(v interface{}) Args {
|
||||
rv := reflect.ValueOf(v)
|
||||
switch rv.Kind() {
|
||||
case reflect.Struct:
|
||||
args = flattenStruct(args, rv)
|
||||
case reflect.Slice:
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
args = append(args, rv.Index(i).Interface())
|
||||
}
|
||||
case reflect.Map:
|
||||
for _, k := range rv.MapKeys() {
|
||||
args = append(args, k.Interface(), rv.MapIndex(k).Interface())
|
||||
}
|
||||
case reflect.Ptr:
|
||||
if rv.Type().Elem().Kind() == reflect.Struct {
|
||||
if !rv.IsNil() {
|
||||
args = flattenStruct(args, rv.Elem())
|
||||
}
|
||||
} else {
|
||||
args = append(args, v)
|
||||
}
|
||||
default:
|
||||
args = append(args, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func flattenStruct(args Args, v reflect.Value) Args {
|
||||
ss := structSpecForType(v.Type())
|
||||
for _, fs := range ss.l {
|
||||
fv := v.FieldByIndex(fs.index)
|
||||
args = append(args, fs.name, fv.Interface())
|
||||
}
|
||||
return args
|
||||
}
|
|
@ -0,0 +1,412 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/garyburd/redigo/redis"
|
||||
"math"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var scanConversionTests = []struct {
|
||||
src interface{}
|
||||
dest interface{}
|
||||
}{
|
||||
{[]byte("-inf"), math.Inf(-1)},
|
||||
{[]byte("+inf"), math.Inf(1)},
|
||||
{[]byte("0"), float64(0)},
|
||||
{[]byte("3.14159"), float64(3.14159)},
|
||||
{[]byte("3.14"), float32(3.14)},
|
||||
{[]byte("-100"), int(-100)},
|
||||
{[]byte("101"), int(101)},
|
||||
{int64(102), int(102)},
|
||||
{[]byte("103"), uint(103)},
|
||||
{int64(104), uint(104)},
|
||||
{[]byte("105"), int8(105)},
|
||||
{int64(106), int8(106)},
|
||||
{[]byte("107"), uint8(107)},
|
||||
{int64(108), uint8(108)},
|
||||
{[]byte("0"), false},
|
||||
{int64(0), false},
|
||||
{[]byte("f"), false},
|
||||
{[]byte("1"), true},
|
||||
{int64(1), true},
|
||||
{[]byte("t"), true},
|
||||
{[]byte("hello"), "hello"},
|
||||
{[]byte("world"), []byte("world")},
|
||||
{[]interface{}{[]byte("foo")}, []interface{}{[]byte("foo")}},
|
||||
{[]interface{}{[]byte("foo")}, []string{"foo"}},
|
||||
{[]interface{}{[]byte("hello"), []byte("world")}, []string{"hello", "world"}},
|
||||
{[]interface{}{[]byte("bar")}, [][]byte{[]byte("bar")}},
|
||||
{[]interface{}{[]byte("1")}, []int{1}},
|
||||
{[]interface{}{[]byte("1"), []byte("2")}, []int{1, 2}},
|
||||
{[]interface{}{[]byte("1"), []byte("2")}, []float64{1, 2}},
|
||||
{[]interface{}{[]byte("1")}, []byte{1}},
|
||||
{[]interface{}{[]byte("1")}, []bool{true}},
|
||||
}
|
||||
|
||||
func TestScanConversion(t *testing.T) {
|
||||
for _, tt := range scanConversionTests {
|
||||
values := []interface{}{tt.src}
|
||||
dest := reflect.New(reflect.TypeOf(tt.dest))
|
||||
values, err := redis.Scan(values, dest.Interface())
|
||||
if err != nil {
|
||||
t.Errorf("Scan(%v) returned error %v", tt, err)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(tt.dest, dest.Elem().Interface()) {
|
||||
t.Errorf("Scan(%v) returned %v, want %v", tt, dest.Elem().Interface(), tt.dest)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var scanConversionErrorTests = []struct {
|
||||
src interface{}
|
||||
dest interface{}
|
||||
}{
|
||||
{[]byte("1234"), byte(0)},
|
||||
{int64(1234), byte(0)},
|
||||
{[]byte("-1"), byte(0)},
|
||||
{int64(-1), byte(0)},
|
||||
{[]byte("junk"), false},
|
||||
{redis.Error("blah"), false},
|
||||
}
|
||||
|
||||
func TestScanConversionError(t *testing.T) {
|
||||
for _, tt := range scanConversionErrorTests {
|
||||
values := []interface{}{tt.src}
|
||||
dest := reflect.New(reflect.TypeOf(tt.dest))
|
||||
values, err := redis.Scan(values, dest.Interface())
|
||||
if err == nil {
|
||||
t.Errorf("Scan(%v) did not return error", tt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleScan() {
|
||||
c, err := dial()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
c.Send("HMSET", "album:1", "title", "Red", "rating", 5)
|
||||
c.Send("HMSET", "album:2", "title", "Earthbound", "rating", 1)
|
||||
c.Send("HMSET", "album:3", "title", "Beat")
|
||||
c.Send("LPUSH", "albums", "1")
|
||||
c.Send("LPUSH", "albums", "2")
|
||||
c.Send("LPUSH", "albums", "3")
|
||||
values, err := redis.Values(c.Do("SORT", "albums",
|
||||
"BY", "album:*->rating",
|
||||
"GET", "album:*->title",
|
||||
"GET", "album:*->rating"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for len(values) > 0 {
|
||||
var title string
|
||||
rating := -1 // initialize to illegal value to detect nil.
|
||||
values, err = redis.Scan(values, &title, &rating)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if rating == -1 {
|
||||
fmt.Println(title, "not-rated")
|
||||
} else {
|
||||
fmt.Println(title, rating)
|
||||
}
|
||||
}
|
||||
// Output:
|
||||
// Beat not-rated
|
||||
// Earthbound 1
|
||||
// Red 5
|
||||
}
|
||||
|
||||
type s0 struct {
|
||||
X int
|
||||
Y int `redis:"y"`
|
||||
Bt bool
|
||||
}
|
||||
|
||||
type s1 struct {
|
||||
X int `redis:"-"`
|
||||
I int `redis:"i"`
|
||||
U uint `redis:"u"`
|
||||
S string `redis:"s"`
|
||||
P []byte `redis:"p"`
|
||||
B bool `redis:"b"`
|
||||
Bt bool
|
||||
Bf bool
|
||||
s0
|
||||
}
|
||||
|
||||
var scanStructTests = []struct {
|
||||
title string
|
||||
reply []string
|
||||
value interface{}
|
||||
}{
|
||||
{"basic",
|
||||
[]string{"i", "-1234", "u", "5678", "s", "hello", "p", "world", "b", "t", "Bt", "1", "Bf", "0", "X", "123", "y", "456"},
|
||||
&s1{I: -1234, U: 5678, S: "hello", P: []byte("world"), B: true, Bt: true, Bf: false, s0: s0{X: 123, Y: 456}},
|
||||
},
|
||||
}
|
||||
|
||||
func TestScanStruct(t *testing.T) {
|
||||
for _, tt := range scanStructTests {
|
||||
|
||||
var reply []interface{}
|
||||
for _, v := range tt.reply {
|
||||
reply = append(reply, []byte(v))
|
||||
}
|
||||
|
||||
value := reflect.New(reflect.ValueOf(tt.value).Type().Elem())
|
||||
|
||||
if err := redis.ScanStruct(reply, value.Interface()); err != nil {
|
||||
t.Fatalf("ScanStruct(%s) returned error %v", tt.title, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(value.Interface(), tt.value) {
|
||||
t.Fatalf("ScanStruct(%s) returned %v, want %v", tt.title, value.Interface(), tt.value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadScanStructArgs(t *testing.T) {
|
||||
x := []interface{}{"A", "b"}
|
||||
test := func(v interface{}) {
|
||||
if err := redis.ScanStruct(x, v); err == nil {
|
||||
t.Errorf("Expect error for ScanStruct(%T, %T)", x, v)
|
||||
}
|
||||
}
|
||||
|
||||
test(nil)
|
||||
|
||||
var v0 *struct{}
|
||||
test(v0)
|
||||
|
||||
var v1 int
|
||||
test(&v1)
|
||||
|
||||
x = x[:1]
|
||||
v2 := struct{ A string }{}
|
||||
test(&v2)
|
||||
}
|
||||
|
||||
var scanSliceTests = []struct {
|
||||
src []interface{}
|
||||
fieldNames []string
|
||||
ok bool
|
||||
dest interface{}
|
||||
}{
|
||||
{
|
||||
[]interface{}{[]byte("1"), nil, []byte("-1")},
|
||||
nil,
|
||||
true,
|
||||
[]int{1, 0, -1},
|
||||
},
|
||||
{
|
||||
[]interface{}{[]byte("1"), nil, []byte("2")},
|
||||
nil,
|
||||
true,
|
||||
[]uint{1, 0, 2},
|
||||
},
|
||||
{
|
||||
[]interface{}{[]byte("-1")},
|
||||
nil,
|
||||
false,
|
||||
[]uint{1},
|
||||
},
|
||||
{
|
||||
[]interface{}{[]byte("hello"), nil, []byte("world")},
|
||||
nil,
|
||||
true,
|
||||
[][]byte{[]byte("hello"), nil, []byte("world")},
|
||||
},
|
||||
{
|
||||
[]interface{}{[]byte("hello"), nil, []byte("world")},
|
||||
nil,
|
||||
true,
|
||||
[]string{"hello", "", "world"},
|
||||
},
|
||||
{
|
||||
[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
|
||||
nil,
|
||||
true,
|
||||
[]struct{ A, B string }{{"a1", "b1"}, {"a2", "b2"}},
|
||||
},
|
||||
{
|
||||
[]interface{}{[]byte("a1"), []byte("b1")},
|
||||
nil,
|
||||
false,
|
||||
[]struct{ A, B, C string }{{"a1", "b1", ""}},
|
||||
},
|
||||
{
|
||||
[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
|
||||
nil,
|
||||
true,
|
||||
[]*struct{ A, B string }{{"a1", "b1"}, {"a2", "b2"}},
|
||||
},
|
||||
{
|
||||
[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
|
||||
[]string{"A", "B"},
|
||||
true,
|
||||
[]struct{ A, C, B string }{{"a1", "", "b1"}, {"a2", "", "b2"}},
|
||||
},
|
||||
{
|
||||
[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
|
||||
nil,
|
||||
false,
|
||||
[]struct{}{},
|
||||
},
|
||||
}
|
||||
|
||||
func TestScanSlice(t *testing.T) {
|
||||
for _, tt := range scanSliceTests {
|
||||
|
||||
typ := reflect.ValueOf(tt.dest).Type()
|
||||
dest := reflect.New(typ)
|
||||
|
||||
err := redis.ScanSlice(tt.src, dest.Interface(), tt.fieldNames...)
|
||||
if tt.ok != (err == nil) {
|
||||
t.Errorf("ScanSlice(%v, []%s, %v) returned error %v", tt.src, typ, tt.fieldNames, err)
|
||||
continue
|
||||
}
|
||||
if tt.ok && !reflect.DeepEqual(dest.Elem().Interface(), tt.dest) {
|
||||
t.Errorf("ScanSlice(src, []%s) returned %#v, want %#v", typ, dest.Elem().Interface(), tt.dest)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleScanSlice() {
|
||||
c, err := dial()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
c.Send("HMSET", "album:1", "title", "Red", "rating", 5)
|
||||
c.Send("HMSET", "album:2", "title", "Earthbound", "rating", 1)
|
||||
c.Send("HMSET", "album:3", "title", "Beat", "rating", 4)
|
||||
c.Send("LPUSH", "albums", "1")
|
||||
c.Send("LPUSH", "albums", "2")
|
||||
c.Send("LPUSH", "albums", "3")
|
||||
values, err := redis.Values(c.Do("SORT", "albums",
|
||||
"BY", "album:*->rating",
|
||||
"GET", "album:*->title",
|
||||
"GET", "album:*->rating"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var albums []struct {
|
||||
Title string
|
||||
Rating int
|
||||
}
|
||||
if err := redis.ScanSlice(values, &albums); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("%v\n", albums)
|
||||
// Output:
|
||||
// [{Earthbound 1} {Beat 4} {Red 5}]
|
||||
}
|
||||
|
||||
var argsTests = []struct {
|
||||
title string
|
||||
actual redis.Args
|
||||
expected redis.Args
|
||||
}{
|
||||
{"struct ptr",
|
||||
redis.Args{}.AddFlat(&struct {
|
||||
I int `redis:"i"`
|
||||
U uint `redis:"u"`
|
||||
S string `redis:"s"`
|
||||
P []byte `redis:"p"`
|
||||
Bt bool
|
||||
Bf bool
|
||||
}{
|
||||
-1234, 5678, "hello", []byte("world"), true, false,
|
||||
}),
|
||||
redis.Args{"i", int(-1234), "u", uint(5678), "s", "hello", "p", []byte("world"), "Bt", true, "Bf", false},
|
||||
},
|
||||
{"struct",
|
||||
redis.Args{}.AddFlat(struct{ I int }{123}),
|
||||
redis.Args{"I", 123},
|
||||
},
|
||||
{"slice",
|
||||
redis.Args{}.Add(1).AddFlat([]string{"a", "b", "c"}).Add(2),
|
||||
redis.Args{1, "a", "b", "c", 2},
|
||||
},
|
||||
}
|
||||
|
||||
func TestArgs(t *testing.T) {
|
||||
for _, tt := range argsTests {
|
||||
if !reflect.DeepEqual(tt.actual, tt.expected) {
|
||||
t.Fatalf("%s is %v, want %v", tt.title, tt.actual, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleArgs() {
|
||||
c, err := dial()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
var p1, p2 struct {
|
||||
Title string `redis:"title"`
|
||||
Author string `redis:"author"`
|
||||
Body string `redis:"body"`
|
||||
}
|
||||
|
||||
p1.Title = "Example"
|
||||
p1.Author = "Gary"
|
||||
p1.Body = "Hello"
|
||||
|
||||
if _, err := c.Do("HMSET", redis.Args{}.Add("id1").AddFlat(&p1)...); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
m := map[string]string{
|
||||
"title": "Example2",
|
||||
"author": "Steve",
|
||||
"body": "Map",
|
||||
}
|
||||
|
||||
if _, err := c.Do("HMSET", redis.Args{}.Add("id2").AddFlat(m)...); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for _, id := range []string{"id1", "id2"} {
|
||||
|
||||
v, err := redis.Values(c.Do("HGETALL", id))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := redis.ScanStruct(v, &p2); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("%+v\n", p2)
|
||||
}
|
||||
|
||||
// Output:
|
||||
// {Title:Example Author:Gary Body:Hello}
|
||||
// {Title:Example2 Author:Steve Body:Map}
|
||||
}
|
|
@ -0,0 +1,86 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Script encapsulates the source, hash and key count for a Lua script. See
|
||||
// http://redis.io/commands/eval for information on scripts in Redis.
|
||||
type Script struct {
|
||||
keyCount int
|
||||
src string
|
||||
hash string
|
||||
}
|
||||
|
||||
// NewScript returns a new script object. If keyCount is greater than or equal
|
||||
// to zero, then the count is automatically inserted in the EVAL command
|
||||
// argument list. If keyCount is less than zero, then the application supplies
|
||||
// the count as the first value in the keysAndArgs argument to the Do, Send and
|
||||
// SendHash methods.
|
||||
func NewScript(keyCount int, src string) *Script {
|
||||
h := sha1.New()
|
||||
io.WriteString(h, src)
|
||||
return &Script{keyCount, src, hex.EncodeToString(h.Sum(nil))}
|
||||
}
|
||||
|
||||
func (s *Script) args(spec string, keysAndArgs []interface{}) []interface{} {
|
||||
var args []interface{}
|
||||
if s.keyCount < 0 {
|
||||
args = make([]interface{}, 1+len(keysAndArgs))
|
||||
args[0] = spec
|
||||
copy(args[1:], keysAndArgs)
|
||||
} else {
|
||||
args = make([]interface{}, 2+len(keysAndArgs))
|
||||
args[0] = spec
|
||||
args[1] = s.keyCount
|
||||
copy(args[2:], keysAndArgs)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// Do evaluates the script. Under the covers, Do optimistically evaluates the
|
||||
// script using the EVALSHA command. If the command fails because the script is
|
||||
// not loaded, then Do evaluates the script using the EVAL command (thus
|
||||
// causing the script to load).
|
||||
func (s *Script) Do(c Conn, keysAndArgs ...interface{}) (interface{}, error) {
|
||||
v, err := c.Do("EVALSHA", s.args(s.hash, keysAndArgs)...)
|
||||
if e, ok := err.(Error); ok && strings.HasPrefix(string(e), "NOSCRIPT ") {
|
||||
v, err = c.Do("EVAL", s.args(s.src, keysAndArgs)...)
|
||||
}
|
||||
return v, err
|
||||
}
|
||||
|
||||
// SendHash evaluates the script without waiting for the reply. The script is
|
||||
// evaluated with the EVALSHA command. The application must ensure that the
|
||||
// script is loaded by a previous call to Send, Do or Load methods.
|
||||
func (s *Script) SendHash(c Conn, keysAndArgs ...interface{}) error {
|
||||
return c.Send("EVALSHA", s.args(s.hash, keysAndArgs)...)
|
||||
}
|
||||
|
||||
// Send evaluates the script without waiting for the reply.
|
||||
func (s *Script) Send(c Conn, keysAndArgs ...interface{}) error {
|
||||
return c.Send("EVAL", s.args(s.src, keysAndArgs)...)
|
||||
}
|
||||
|
||||
// Load loads the script without evaluating it.
|
||||
func (s *Script) Load(c Conn) error {
|
||||
_, err := c.Do("SCRIPT", "LOAD", s.src)
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/garyburd/redigo/internal/redistest"
|
||||
"github.com/garyburd/redigo/redis"
|
||||
)
|
||||
|
||||
func ExampleScript(c redis.Conn, reply interface{}, err error) {
|
||||
// Initialize a package-level variable with a script.
|
||||
var getScript = redis.NewScript(1, `return redis.call('get', KEYS[1])`)
|
||||
|
||||
// In a function, use the script Do method to evaluate the script. The Do
|
||||
// method optimistically uses the EVALSHA command. If the script is not
|
||||
// loaded, then the Do method falls back to the EVAL command.
|
||||
reply, err = getScript.Do(c, "foo")
|
||||
}
|
||||
|
||||
func TestScript(t *testing.T) {
|
||||
c, err := redistest.Dial()
|
||||
if err != nil {
|
||||
t.Fatalf("error connection to database, %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
// To test fall back in Do, we make script unique by adding comment with current time.
|
||||
script := fmt.Sprintf("--%d\nreturn {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", time.Now().UnixNano())
|
||||
s := redis.NewScript(2, script)
|
||||
reply := []interface{}{[]byte("key1"), []byte("key2"), []byte("arg1"), []byte("arg2")}
|
||||
|
||||
v, err := s.Do(c, "key1", "key2", "arg1", "arg2")
|
||||
if err != nil {
|
||||
t.Errorf("s.Do(c, ...) returned %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(v, reply) {
|
||||
t.Errorf("s.Do(c, ..); = %v, want %v", v, reply)
|
||||
}
|
||||
|
||||
err = s.Load(c)
|
||||
if err != nil {
|
||||
t.Errorf("s.Load(c) returned %v", err)
|
||||
}
|
||||
|
||||
err = s.SendHash(c, "key1", "key2", "arg1", "arg2")
|
||||
if err != nil {
|
||||
t.Errorf("s.SendHash(c, ...) returned %v", err)
|
||||
}
|
||||
|
||||
err = c.Flush()
|
||||
if err != nil {
|
||||
t.Errorf("c.Flush() returned %v", err)
|
||||
}
|
||||
|
||||
v, err = c.Receive()
|
||||
if !reflect.DeepEqual(v, reply) {
|
||||
t.Errorf("s.SendHash(c, ..); c.Receive() = %v, want %v", v, reply)
|
||||
}
|
||||
|
||||
err = s.Send(c, "key1", "key2", "arg1", "arg2")
|
||||
if err != nil {
|
||||
t.Errorf("s.Send(c, ...) returned %v", err)
|
||||
}
|
||||
|
||||
err = c.Flush()
|
||||
if err != nil {
|
||||
t.Errorf("c.Flush() returned %v", err)
|
||||
}
|
||||
|
||||
v, err = c.Receive()
|
||||
if !reflect.DeepEqual(v, reply) {
|
||||
t.Errorf("s.Send(c, ..); c.Receive() = %v, want %v", v, reply)
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
func SetNowFunc(f func() time.Time) {
|
||||
nowFunc = f
|
||||
}
|
||||
|
||||
type nopCloser struct{ net.Conn }
|
||||
|
||||
func (nopCloser) Close() error { return nil }
|
||||
|
||||
// NewConnBufio is a hook for tests.
|
||||
func NewConnBufio(rw bufio.ReadWriter) Conn {
|
||||
return &conn{br: rw.Reader, bw: rw.Writer, conn: nopCloser{}}
|
||||
}
|
||||
|
||||
var (
|
||||
ErrNegativeInt = errNegativeInt
|
||||
)
|
|
@ -0,0 +1,113 @@
|
|||
// Copyright 2013 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/garyburd/redigo/redis"
|
||||
)
|
||||
|
||||
// zpop pops a value from the ZSET key using WATCH/MULTI/EXEC commands.
|
||||
func zpop(c redis.Conn, key string) (result string, err error) {
|
||||
|
||||
defer func() {
|
||||
// Return connection to normal state on error.
|
||||
if err != nil {
|
||||
c.Do("DISCARD")
|
||||
}
|
||||
}()
|
||||
|
||||
// Loop until transaction is successful.
|
||||
for {
|
||||
if _, err := c.Do("WATCH", key); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
members, err := redis.Strings(c.Do("ZRANGE", key, 0, 0))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(members) != 1 {
|
||||
return "", redis.ErrNil
|
||||
}
|
||||
|
||||
c.Send("MULTI")
|
||||
c.Send("ZREM", key, members[0])
|
||||
queued, err := c.Do("EXEC")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if queued != nil {
|
||||
result = members[0]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// zpopScript pops a value from a ZSET.
|
||||
var zpopScript = redis.NewScript(1, `
|
||||
local r = redis.call('ZRANGE', KEYS[1], 0, 0)
|
||||
if r ~= nil then
|
||||
r = r[1]
|
||||
redis.call('ZREM', KEYS[1], r)
|
||||
end
|
||||
return r
|
||||
`)
|
||||
|
||||
// This example implements ZPOP as described at
|
||||
// http://redis.io/topics/transactions using WATCH/MULTI/EXEC and scripting.
|
||||
func Example_zpop() {
|
||||
c, err := dial()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
// Add test data using a pipeline.
|
||||
|
||||
for i, member := range []string{"red", "blue", "green"} {
|
||||
c.Send("ZADD", "zset", i, member)
|
||||
}
|
||||
if _, err := c.Do(""); err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
// Pop using WATCH/MULTI/EXEC
|
||||
|
||||
v, err := zpop(c, "zset")
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
fmt.Println(v)
|
||||
|
||||
// Pop using a script.
|
||||
|
||||
v, err = redis.String(zpopScript.Do(c, "zset"))
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
fmt.Println(v)
|
||||
|
||||
// Output:
|
||||
// red
|
||||
// blue
|
||||
}
|
|
@ -0,0 +1,152 @@
|
|||
// Copyright 2014 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redisx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/garyburd/redigo/internal"
|
||||
"github.com/garyburd/redigo/redis"
|
||||
)
|
||||
|
||||
// ConnMux multiplexes one or more connections to a single underlying
|
||||
// connection. The ConnMux connections do not support concurrency, commands
|
||||
// that associate server side state with the connection or commands that put
|
||||
// the connection in a special mode.
|
||||
type ConnMux struct {
|
||||
c redis.Conn
|
||||
|
||||
sendMu sync.Mutex
|
||||
sendID uint
|
||||
|
||||
recvMu sync.Mutex
|
||||
recvID uint
|
||||
recvWait map[uint]chan struct{}
|
||||
}
|
||||
|
||||
func NewConnMux(c redis.Conn) *ConnMux {
|
||||
return &ConnMux{c: c, recvWait: make(map[uint]chan struct{})}
|
||||
}
|
||||
|
||||
// Get gets a connection. The application must close the returned connection.
|
||||
func (p *ConnMux) Get() redis.Conn {
|
||||
c := &muxConn{p: p}
|
||||
c.ids = c.buf[:0]
|
||||
return c
|
||||
}
|
||||
|
||||
// Close closes the underlying connection.
|
||||
func (p *ConnMux) Close() error {
|
||||
return p.c.Close()
|
||||
}
|
||||
|
||||
type muxConn struct {
|
||||
p *ConnMux
|
||||
ids []uint
|
||||
buf [8]uint
|
||||
}
|
||||
|
||||
func (c *muxConn) send(flush bool, cmd string, args ...interface{}) error {
|
||||
if internal.LookupCommandInfo(cmd).Set != 0 {
|
||||
return errors.New("command not supported by mux pool")
|
||||
}
|
||||
p := c.p
|
||||
p.sendMu.Lock()
|
||||
id := p.sendID
|
||||
c.ids = append(c.ids, id)
|
||||
p.sendID++
|
||||
err := p.c.Send(cmd, args...)
|
||||
if flush {
|
||||
err = p.c.Flush()
|
||||
}
|
||||
p.sendMu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *muxConn) Send(cmd string, args ...interface{}) error {
|
||||
return c.send(false, cmd, args...)
|
||||
}
|
||||
|
||||
func (c *muxConn) Flush() error {
|
||||
p := c.p
|
||||
p.sendMu.Lock()
|
||||
err := p.c.Flush()
|
||||
p.sendMu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *muxConn) Receive() (interface{}, error) {
|
||||
if len(c.ids) == 0 {
|
||||
return nil, errors.New("mux pool underflow")
|
||||
}
|
||||
|
||||
id := c.ids[0]
|
||||
c.ids = c.ids[1:]
|
||||
if len(c.ids) == 0 {
|
||||
c.ids = c.buf[:0]
|
||||
}
|
||||
|
||||
p := c.p
|
||||
p.recvMu.Lock()
|
||||
if p.recvID != id {
|
||||
ch := make(chan struct{})
|
||||
p.recvWait[id] = ch
|
||||
p.recvMu.Unlock()
|
||||
<-ch
|
||||
p.recvMu.Lock()
|
||||
if p.recvID != id {
|
||||
panic("out of sync")
|
||||
}
|
||||
}
|
||||
|
||||
v, err := p.c.Receive()
|
||||
|
||||
id++
|
||||
p.recvID = id
|
||||
ch, ok := p.recvWait[id]
|
||||
if ok {
|
||||
delete(p.recvWait, id)
|
||||
}
|
||||
p.recvMu.Unlock()
|
||||
if ok {
|
||||
ch <- struct{}{}
|
||||
}
|
||||
|
||||
return v, err
|
||||
}
|
||||
|
||||
func (c *muxConn) Close() error {
|
||||
var err error
|
||||
if len(c.ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
c.Flush()
|
||||
for _ = range c.ids {
|
||||
_, err = c.Receive()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *muxConn) Do(cmd string, args ...interface{}) (interface{}, error) {
|
||||
if err := c.send(true, cmd, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.Receive()
|
||||
}
|
||||
|
||||
func (c *muxConn) Err() error {
|
||||
return c.p.c.Err()
|
||||
}
|
|
@ -0,0 +1,259 @@
|
|||
// Copyright 2014 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redisx_test
|
||||
|
||||
import (
|
||||
"net/textproto"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/garyburd/redigo/internal/redistest"
|
||||
"github.com/garyburd/redigo/redis"
|
||||
"github.com/garyburd/redigo/redisx"
|
||||
)
|
||||
|
||||
func TestConnMux(t *testing.T) {
|
||||
c, err := redistest.Dial()
|
||||
if err != nil {
|
||||
t.Fatalf("error connection to database, %v", err)
|
||||
}
|
||||
m := redisx.NewConnMux(c)
|
||||
defer m.Close()
|
||||
|
||||
c1 := m.Get()
|
||||
c2 := m.Get()
|
||||
c1.Send("ECHO", "hello")
|
||||
c2.Send("ECHO", "world")
|
||||
c1.Flush()
|
||||
c2.Flush()
|
||||
s, err := redis.String(c1.Receive())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if s != "hello" {
|
||||
t.Fatalf("echo returned %q, want %q", s, "hello")
|
||||
}
|
||||
s, err = redis.String(c2.Receive())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if s != "world" {
|
||||
t.Fatalf("echo returned %q, want %q", s, "world")
|
||||
}
|
||||
c1.Close()
|
||||
c2.Close()
|
||||
}
|
||||
|
||||
func TestConnMuxClose(t *testing.T) {
|
||||
c, err := redistest.Dial()
|
||||
if err != nil {
|
||||
t.Fatalf("error connection to database, %v", err)
|
||||
}
|
||||
m := redisx.NewConnMux(c)
|
||||
defer m.Close()
|
||||
|
||||
c1 := m.Get()
|
||||
c2 := m.Get()
|
||||
|
||||
if err := c1.Send("ECHO", "hello"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := c1.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := c2.Send("ECHO", "world"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := c2.Flush(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
s, err := redis.String(c2.Receive())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if s != "world" {
|
||||
t.Fatalf("echo returned %q, want %q", s, "world")
|
||||
}
|
||||
c2.Close()
|
||||
}
|
||||
|
||||
func BenchmarkConn(b *testing.B) {
|
||||
b.StopTimer()
|
||||
c, err := redistest.Dial()
|
||||
if err != nil {
|
||||
b.Fatalf("error connection to database, %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := c.Do("PING"); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkConnMux(b *testing.B) {
|
||||
b.StopTimer()
|
||||
c, err := redistest.Dial()
|
||||
if err != nil {
|
||||
b.Fatalf("error connection to database, %v", err)
|
||||
}
|
||||
m := redisx.NewConnMux(c)
|
||||
defer m.Close()
|
||||
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
c := m.Get()
|
||||
if _, err := c.Do("PING"); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPool(b *testing.B) {
|
||||
b.StopTimer()
|
||||
|
||||
p := redis.Pool{Dial: redistest.Dial, MaxIdle: 1}
|
||||
defer p.Close()
|
||||
|
||||
// Fill the pool.
|
||||
c := p.Get()
|
||||
if err := c.Err(); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
c.Close()
|
||||
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
c := p.Get()
|
||||
if _, err := c.Do("PING"); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
const numConcurrent = 10
|
||||
|
||||
func BenchmarkConnMuxConcurrent(b *testing.B) {
|
||||
b.StopTimer()
|
||||
c, err := redistest.Dial()
|
||||
if err != nil {
|
||||
b.Fatalf("error connection to database, %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
m := redisx.NewConnMux(c)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numConcurrent)
|
||||
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < numConcurrent; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < b.N; i++ {
|
||||
c := m.Get()
|
||||
if _, err := c.Do("PING"); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
c.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func BenchmarkPoolConcurrent(b *testing.B) {
|
||||
b.StopTimer()
|
||||
|
||||
p := redis.Pool{Dial: redistest.Dial, MaxIdle: numConcurrent}
|
||||
defer p.Close()
|
||||
|
||||
// Fill the pool.
|
||||
conns := make([]redis.Conn, numConcurrent)
|
||||
for i := range conns {
|
||||
c := p.Get()
|
||||
if err := c.Err(); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
conns[i] = c
|
||||
}
|
||||
for _, c := range conns {
|
||||
c.Close()
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numConcurrent)
|
||||
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < numConcurrent; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < b.N; i++ {
|
||||
c := p.Get()
|
||||
if _, err := c.Do("PING"); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
c.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func BenchmarkPipelineConcurrency(b *testing.B) {
|
||||
b.StopTimer()
|
||||
c, err := redistest.Dial()
|
||||
if err != nil {
|
||||
b.Fatalf("error connection to database, %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numConcurrent)
|
||||
|
||||
var pipeline textproto.Pipeline
|
||||
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < numConcurrent; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < b.N; i++ {
|
||||
id := pipeline.Next()
|
||||
pipeline.StartRequest(id)
|
||||
c.Send("PING")
|
||||
c.Flush()
|
||||
pipeline.EndRequest(id)
|
||||
pipeline.StartResponse(id)
|
||||
_, err := c.Receive()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
pipeline.EndResponse(id)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
// Package redisx contains experimental features for Redigo. Features in this
|
||||
// package may be modified or deleted at any time.
|
||||
package redisx // import "github.com/garyburd/redigo/redisx"
|
|
@ -0,0 +1,27 @@
|
|||
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,7 @@
|
|||
context
|
||||
=======
|
||||
[](https://travis-ci.org/gorilla/context)
|
||||
|
||||
gorilla/context is a general purpose registry for global request variables.
|
||||
|
||||
Read the full documentation here: http://www.gorillatoolkit.org/pkg/context
|
|
@ -0,0 +1,143 @@
|
|||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package context
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
mutex sync.RWMutex
|
||||
data = make(map[*http.Request]map[interface{}]interface{})
|
||||
datat = make(map[*http.Request]int64)
|
||||
)
|
||||
|
||||
// Set stores a value for a given key in a given request.
|
||||
func Set(r *http.Request, key, val interface{}) {
|
||||
mutex.Lock()
|
||||
if data[r] == nil {
|
||||
data[r] = make(map[interface{}]interface{})
|
||||
datat[r] = time.Now().Unix()
|
||||
}
|
||||
data[r][key] = val
|
||||
mutex.Unlock()
|
||||
}
|
||||
|
||||
// Get returns a value stored for a given key in a given request.
|
||||
func Get(r *http.Request, key interface{}) interface{} {
|
||||
mutex.RLock()
|
||||
if ctx := data[r]; ctx != nil {
|
||||
value := ctx[key]
|
||||
mutex.RUnlock()
|
||||
return value
|
||||
}
|
||||
mutex.RUnlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetOk returns stored value and presence state like multi-value return of map access.
|
||||
func GetOk(r *http.Request, key interface{}) (interface{}, bool) {
|
||||
mutex.RLock()
|
||||
if _, ok := data[r]; ok {
|
||||
value, ok := data[r][key]
|
||||
mutex.RUnlock()
|
||||
return value, ok
|
||||
}
|
||||
mutex.RUnlock()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// GetAll returns all stored values for the request as a map. Nil is returned for invalid requests.
|
||||
func GetAll(r *http.Request) map[interface{}]interface{} {
|
||||
mutex.RLock()
|
||||
if context, ok := data[r]; ok {
|
||||
result := make(map[interface{}]interface{}, len(context))
|
||||
for k, v := range context {
|
||||
result[k] = v
|
||||
}
|
||||
mutex.RUnlock()
|
||||
return result
|
||||
}
|
||||
mutex.RUnlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllOk returns all stored values for the request as a map and a boolean value that indicates if
|
||||
// the request was registered.
|
||||
func GetAllOk(r *http.Request) (map[interface{}]interface{}, bool) {
|
||||
mutex.RLock()
|
||||
context, ok := data[r]
|
||||
result := make(map[interface{}]interface{}, len(context))
|
||||
for k, v := range context {
|
||||
result[k] = v
|
||||
}
|
||||
mutex.RUnlock()
|
||||
return result, ok
|
||||
}
|
||||
|
||||
// Delete removes a value stored for a given key in a given request.
|
||||
func Delete(r *http.Request, key interface{}) {
|
||||
mutex.Lock()
|
||||
if data[r] != nil {
|
||||
delete(data[r], key)
|
||||
}
|
||||
mutex.Unlock()
|
||||
}
|
||||
|
||||
// Clear removes all values stored for a given request.
|
||||
//
|
||||
// This is usually called by a handler wrapper to clean up request
|
||||
// variables at the end of a request lifetime. See ClearHandler().
|
||||
func Clear(r *http.Request) {
|
||||
mutex.Lock()
|
||||
clear(r)
|
||||
mutex.Unlock()
|
||||
}
|
||||
|
||||
// clear is Clear without the lock.
|
||||
func clear(r *http.Request) {
|
||||
delete(data, r)
|
||||
delete(datat, r)
|
||||
}
|
||||
|
||||
// Purge removes request data stored for longer than maxAge, in seconds.
|
||||
// It returns the amount of requests removed.
|
||||
//
|
||||
// If maxAge <= 0, all request data is removed.
|
||||
//
|
||||
// This is only used for sanity check: in case context cleaning was not
|
||||
// properly set some request data can be kept forever, consuming an increasing
|
||||
// amount of memory. In case this is detected, Purge() must be called
|
||||
// periodically until the problem is fixed.
|
||||
func Purge(maxAge int) int {
|
||||
mutex.Lock()
|
||||
count := 0
|
||||
if maxAge <= 0 {
|
||||
count = len(data)
|
||||
data = make(map[*http.Request]map[interface{}]interface{})
|
||||
datat = make(map[*http.Request]int64)
|
||||
} else {
|
||||
min := time.Now().Unix() - int64(maxAge)
|
||||
for r := range data {
|
||||
if datat[r] < min {
|
||||
clear(r)
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
mutex.Unlock()
|
||||
return count
|
||||
}
|
||||
|
||||
// ClearHandler wraps an http.Handler and clears request values at the end
|
||||
// of a request lifetime.
|
||||
func ClearHandler(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer Clear(r)
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
|
@ -0,0 +1,161 @@
|
|||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package context
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type keyType int
|
||||
|
||||
const (
|
||||
key1 keyType = iota
|
||||
key2
|
||||
)
|
||||
|
||||
func TestContext(t *testing.T) {
|
||||
assertEqual := func(val interface{}, exp interface{}) {
|
||||
if val != exp {
|
||||
t.Errorf("Expected %v, got %v.", exp, val)
|
||||
}
|
||||
}
|
||||
|
||||
r, _ := http.NewRequest("GET", "http://localhost:8080/", nil)
|
||||
emptyR, _ := http.NewRequest("GET", "http://localhost:8080/", nil)
|
||||
|
||||
// Get()
|
||||
assertEqual(Get(r, key1), nil)
|
||||
|
||||
// Set()
|
||||
Set(r, key1, "1")
|
||||
assertEqual(Get(r, key1), "1")
|
||||
assertEqual(len(data[r]), 1)
|
||||
|
||||
Set(r, key2, "2")
|
||||
assertEqual(Get(r, key2), "2")
|
||||
assertEqual(len(data[r]), 2)
|
||||
|
||||
//GetOk
|
||||
value, ok := GetOk(r, key1)
|
||||
assertEqual(value, "1")
|
||||
assertEqual(ok, true)
|
||||
|
||||
value, ok = GetOk(r, "not exists")
|
||||
assertEqual(value, nil)
|
||||
assertEqual(ok, false)
|
||||
|
||||
Set(r, "nil value", nil)
|
||||
value, ok = GetOk(r, "nil value")
|
||||
assertEqual(value, nil)
|
||||
assertEqual(ok, true)
|
||||
|
||||
// GetAll()
|
||||
values := GetAll(r)
|
||||
assertEqual(len(values), 3)
|
||||
|
||||
// GetAll() for empty request
|
||||
values = GetAll(emptyR)
|
||||
if values != nil {
|
||||
t.Error("GetAll didn't return nil value for invalid request")
|
||||
}
|
||||
|
||||
// GetAllOk()
|
||||
values, ok = GetAllOk(r)
|
||||
assertEqual(len(values), 3)
|
||||
assertEqual(ok, true)
|
||||
|
||||
// GetAllOk() for empty request
|
||||
values, ok = GetAllOk(emptyR)
|
||||
assertEqual(value, nil)
|
||||
assertEqual(ok, false)
|
||||
|
||||
// Delete()
|
||||
Delete(r, key1)
|
||||
assertEqual(Get(r, key1), nil)
|
||||
assertEqual(len(data[r]), 2)
|
||||
|
||||
Delete(r, key2)
|
||||
assertEqual(Get(r, key2), nil)
|
||||
assertEqual(len(data[r]), 1)
|
||||
|
||||
// Clear()
|
||||
Clear(r)
|
||||
assertEqual(len(data), 0)
|
||||
}
|
||||
|
||||
func parallelReader(r *http.Request, key string, iterations int, wait, done chan struct{}) {
|
||||
<-wait
|
||||
for i := 0; i < iterations; i++ {
|
||||
Get(r, key)
|
||||
}
|
||||
done <- struct{}{}
|
||||
|
||||
}
|
||||
|
||||
func parallelWriter(r *http.Request, key, value string, iterations int, wait, done chan struct{}) {
|
||||
<-wait
|
||||
for i := 0; i < iterations; i++ {
|
||||
Set(r, key, value)
|
||||
}
|
||||
done <- struct{}{}
|
||||
|
||||
}
|
||||
|
||||
func benchmarkMutex(b *testing.B, numReaders, numWriters, iterations int) {
|
||||
|
||||
b.StopTimer()
|
||||
r, _ := http.NewRequest("GET", "http://localhost:8080/", nil)
|
||||
done := make(chan struct{})
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
wait := make(chan struct{})
|
||||
|
||||
for i := 0; i < numReaders; i++ {
|
||||
go parallelReader(r, "test", iterations, wait, done)
|
||||
}
|
||||
|
||||
for i := 0; i < numWriters; i++ {
|
||||
go parallelWriter(r, "test", "123", iterations, wait, done)
|
||||
}
|
||||
|
||||
close(wait)
|
||||
|
||||
for i := 0; i < numReaders+numWriters; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func BenchmarkMutexSameReadWrite1(b *testing.B) {
|
||||
benchmarkMutex(b, 1, 1, 32)
|
||||
}
|
||||
func BenchmarkMutexSameReadWrite2(b *testing.B) {
|
||||
benchmarkMutex(b, 2, 2, 32)
|
||||
}
|
||||
func BenchmarkMutexSameReadWrite4(b *testing.B) {
|
||||
benchmarkMutex(b, 4, 4, 32)
|
||||
}
|
||||
func BenchmarkMutex1(b *testing.B) {
|
||||
benchmarkMutex(b, 2, 8, 32)
|
||||
}
|
||||
func BenchmarkMutex2(b *testing.B) {
|
||||
benchmarkMutex(b, 16, 4, 64)
|
||||
}
|
||||
func BenchmarkMutex3(b *testing.B) {
|
||||
benchmarkMutex(b, 1, 2, 128)
|
||||
}
|
||||
func BenchmarkMutex4(b *testing.B) {
|
||||
benchmarkMutex(b, 128, 32, 256)
|
||||
}
|
||||
func BenchmarkMutex5(b *testing.B) {
|
||||
benchmarkMutex(b, 1024, 2048, 64)
|
||||
}
|
||||
func BenchmarkMutex6(b *testing.B) {
|
||||
benchmarkMutex(b, 2048, 1024, 512)
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
Package context stores values shared during a request lifetime.
|
||||
|
||||
For example, a router can set variables extracted from the URL and later
|
||||
application handlers can access those values, or it can be used to store
|
||||
sessions values to be saved at the end of a request. There are several
|
||||
others common uses.
|
||||
|
||||
The idea was posted by Brad Fitzpatrick to the go-nuts mailing list:
|
||||
|
||||
http://groups.google.com/group/golang-nuts/msg/e2d679d303aa5d53
|
||||
|
||||
Here's the basic usage: first define the keys that you will need. The key
|
||||
type is interface{} so a key can be of any type that supports equality.
|
||||
Here we define a key using a custom int type to avoid name collisions:
|
||||
|
||||
package foo
|
||||
|
||||
import (
|
||||
"github.com/gorilla/context"
|
||||
)
|
||||
|
||||
type key int
|
||||
|
||||
const MyKey key = 0
|
||||
|
||||
Then set a variable. Variables are bound to an http.Request object, so you
|
||||
need a request instance to set a value:
|
||||
|
||||
context.Set(r, MyKey, "bar")
|
||||
|
||||
The application can later access the variable using the same key you provided:
|
||||
|
||||
func MyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// val is "bar".
|
||||
val := context.Get(r, foo.MyKey)
|
||||
|
||||
// returns ("bar", true)
|
||||
val, ok := context.GetOk(r, foo.MyKey)
|
||||
// ...
|
||||
}
|
||||
|
||||
And that's all about the basic usage. We discuss some other ideas below.
|
||||
|
||||
Any type can be stored in the context. To enforce a given type, make the key
|
||||
private and wrap Get() and Set() to accept and return values of a specific
|
||||
type:
|
||||
|
||||
type key int
|
||||
|
||||
const mykey key = 0
|
||||
|
||||
// GetMyKey returns a value for this package from the request values.
|
||||
func GetMyKey(r *http.Request) SomeType {
|
||||
if rv := context.Get(r, mykey); rv != nil {
|
||||
return rv.(SomeType)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetMyKey sets a value for this package in the request values.
|
||||
func SetMyKey(r *http.Request, val SomeType) {
|
||||
context.Set(r, mykey, val)
|
||||
}
|
||||
|
||||
Variables must be cleared at the end of a request, to remove all values
|
||||
that were stored. This can be done in an http.Handler, after a request was
|
||||
served. Just call Clear() passing the request:
|
||||
|
||||
context.Clear(r)
|
||||
|
||||
...or use ClearHandler(), which conveniently wraps an http.Handler to clear
|
||||
variables at the end of a request lifetime.
|
||||
|
||||
The Routers from the packages gorilla/mux and gorilla/pat call Clear()
|
||||
so if you are using either of them you don't need to clear the context manually.
|
||||
*/
|
||||
package context
|
|
@ -0,0 +1,27 @@
|
|||
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,7 @@
|
|||
mux
|
||||
===
|
||||
[](https://travis-ci.org/gorilla/mux)
|
||||
|
||||
gorilla/mux is a powerful URL router and dispatcher.
|
||||
|
||||
Read the full documentation here: http://www.gorillatoolkit.org/pkg/mux
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue